Repository: replicate/cog Branch: main Commit: cb3ac3c727f6 Files: 575 Total size: 3.2 MB Directory structure: gitextract_1_nczbz8/ ├── .all-contributorsrc ├── .git_archival.txt ├── .gitattributes ├── .github/ │ ├── CODEOWNERS │ ├── dependabot.yml │ └── workflows/ │ ├── README.md │ ├── ci.yaml │ ├── codeql.yml │ ├── docs.yaml │ ├── release-build.yaml │ └── release-publish.yaml ├── .gitignore ├── .golangci.yaml ├── .goreleaser.yaml ├── .mockery.yml ├── .vscode/ │ ├── extensions.json │ └── settings.json ├── AGENTS.md ├── CONTRIBUTING.md ├── DESIGN.md ├── LICENSE ├── Makefile ├── README.md ├── architecture/ │ ├── 00-overview.md │ ├── 01-model-source.md │ ├── 02-schema.md │ ├── 05-build-system.md │ ├── 06-cli.md │ ├── ffi/ │ │ ├── 03-prediction-api.md │ │ ├── 04-container-runtime.md │ │ └── README.md │ └── legacy/ │ ├── 03-prediction-api.md │ ├── 04-container-runtime.md │ └── README.md ├── cmd/ │ └── cog/ │ └── cog.go ├── crates/ │ ├── .gitignore │ ├── Cargo.toml │ ├── README.md │ ├── coglet/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ └── src/ │ │ ├── bridge/ │ │ │ ├── codec.rs │ │ │ ├── mod.rs │ │ │ ├── protocol.rs │ │ │ ├── snapshots/ │ │ │ │ ├── coglet__bridge__protocol__tests__control_cancel_serializes.snap │ │ │ │ ├── coglet__bridge__protocol__tests__control_cancelled_serializes.snap │ │ │ │ ├── coglet__bridge__protocol__tests__control_failed_serializes.snap │ │ │ │ ├── coglet__bridge__protocol__tests__control_healthcheck_result_healthy_serializes.snap │ │ │ │ ├── coglet__bridge__protocol__tests__control_healthcheck_result_unhealthy_serializes.snap │ │ │ │ ├── coglet__bridge__protocol__tests__control_healthcheck_serializes.snap │ │ │ │ ├── coglet__bridge__protocol__tests__control_idle_serializes.snap │ │ │ │ ├── coglet__bridge__protocol__tests__control_init_serializes.snap │ │ │ │ ├── coglet__bridge__protocol__tests__control_ready_serializes.snap │ │ │ │ ├── coglet__bridge__protocol__tests__control_ready_with_schema_serializes.snap │ │ │ │ ├── coglet__bridge__protocol__tests__control_shutdown_serializes.snap │ │ │ │ ├── coglet__bridge__protocol__tests__slot_cancelled_serializes.snap │ │ │ │ ├── coglet__bridge__protocol__tests__slot_done_serializes.snap │ │ │ │ ├── coglet__bridge__protocol__tests__slot_failed_serializes.snap │ │ │ │ ├── coglet__bridge__protocol__tests__slot_log_serializes.snap │ │ │ │ ├── coglet__bridge__protocol__tests__slot_metric_append_serializes.snap │ │ │ │ ├── coglet__bridge__protocol__tests__slot_metric_complex_value_serializes.snap │ │ │ │ ├── coglet__bridge__protocol__tests__slot_metric_delete_serializes.snap │ │ │ │ ├── coglet__bridge__protocol__tests__slot_metric_increment_serializes.snap │ │ │ │ ├── coglet__bridge__protocol__tests__slot_metric_replace_serializes.snap │ │ │ │ ├── coglet__bridge__protocol__tests__slot_output_serializes.snap │ │ │ │ ├── coglet__bridge__protocol__tests__slot_predict_file_input_serializes.snap │ │ │ │ └── coglet__bridge__protocol__tests__slot_predict_serializes.snap │ │ │ └── transport.rs │ │ ├── fd_redirect.rs │ │ ├── health.rs │ │ ├── input_validation.rs │ │ ├── lib.rs │ │ ├── orchestrator.rs │ │ ├── permit/ │ │ │ ├── mod.rs │ │ │ ├── pool.rs │ │ │ └── slot.rs │ │ ├── prediction.rs │ │ ├── predictor.rs │ │ ├── service.rs │ │ ├── setup_log_accumulator.rs │ │ ├── snapshots/ │ │ │ ├── coglet__health__tests__health_all_variants.snap │ │ │ ├── coglet__health__tests__health_response_all_variants.snap │ │ │ ├── coglet__health__tests__setup_status_all_variants.snap │ │ │ ├── coglet__predictor__tests__output_single.snap │ │ │ ├── coglet__predictor__tests__output_stream.snap │ │ │ ├── coglet__version__tests__version_full.snap │ │ │ └── coglet__version__tests__version_minimal.snap │ │ ├── transport/ │ │ │ ├── http/ │ │ │ │ ├── mod.rs │ │ │ │ ├── routes.rs │ │ │ │ └── server.rs │ │ │ └── mod.rs │ │ ├── version.rs │ │ ├── webhook.rs │ │ ├── worker.rs │ │ └── worker_tracing_layer.rs │ ├── coglet-python/ │ │ ├── Cargo.toml │ │ ├── README.md │ │ ├── build.rs │ │ ├── coglet/ │ │ │ ├── __init__.py │ │ │ ├── __init__.pyi │ │ │ ├── _impl.pyi │ │ │ ├── _sdk/ │ │ │ │ └── __init__.pyi │ │ │ └── py.typed │ │ ├── pyproject.toml │ │ ├── src/ │ │ │ ├── audit.rs │ │ │ ├── bin/ │ │ │ │ └── stub_gen.rs │ │ │ ├── cancel.rs │ │ │ ├── input.rs │ │ │ ├── lib.rs │ │ │ ├── log_writer.rs │ │ │ ├── metric_scope.rs │ │ │ ├── output.rs │ │ │ ├── predictor.rs │ │ │ └── worker_bridge.rs │ │ └── tests/ │ │ └── test_coglet.py │ └── deny.toml ├── docs/ │ ├── CNAME │ ├── cli.md │ ├── deploy.md │ ├── environment.md │ ├── getting-started-own-model.md │ ├── getting-started.md │ ├── http.md │ ├── llms.txt │ ├── notebooks.md │ ├── private-package-registry.md │ ├── python.md │ ├── stylesheets/ │ │ └── extra.css │ ├── training.md │ ├── wsl2/ │ │ └── wsl2.md │ └── yaml.md ├── go.mod ├── go.sum ├── integration-tests/ │ ├── .gitignore │ ├── README.md │ ├── concurrent/ │ │ └── concurrent_test.go │ ├── harness/ │ │ ├── cmd_pty.go │ │ ├── command.go │ │ └── harness.go │ ├── login/ │ │ └── login_test.go │ ├── suite_test.go │ └── tests/ │ ├── apt_packages.txtar │ ├── async_generator_precollect.txtar │ ├── async_predictor.txtar │ ├── async_sleep.txtar │ ├── bad_dockerignore.txtar │ ├── bool_input_output.txtar │ ├── build_base_image_sha.txtar │ ├── build_cog_init.txtar │ ├── build_cog_version_match.txtar │ ├── build_gpu_labels.txtar │ ├── build_image_option.txtar │ ├── build_openapi_schema.txtar │ ├── build_openapi_schema_complex.txtar │ ├── build_pip_freeze.txtar │ ├── build_python313_base_image.txtar │ ├── build_torch_version_required.txtar │ ├── ca_cert.txtar │ ├── cancel_async_prediction.txtar │ ├── cancel_repeated.txtar │ ├── cancel_sync_prediction.txtar │ ├── coglet_iterator_path_output.txtar │ ├── coglet_iterator_upload_url.txtar │ ├── coglet_large_file_upload_serial.txtar │ ├── coglet_large_input.txtar │ ├── coglet_large_output.txtar │ ├── coglet_list_path_single_element.txtar │ ├── coglet_list_path_upload_url.txtar │ ├── coglet_metrics.txtar │ ├── coglet_metrics_webhook.txtar │ ├── coglet_single_path_output.txtar │ ├── complex_output.txtar │ ├── concatenate_iterator_output.txtar │ ├── config_subdirectory.txtar │ ├── debug_secrets.txtar │ ├── dict_output.txtar │ ├── emit_metric_deprecated.txtar │ ├── env_vars.txtar │ ├── experimental_feature_warning.txtar │ ├── ffmpeg_package.txtar │ ├── file_input.txtar │ ├── file_list_input.txtar │ ├── float_input_output.txtar │ ├── function_predictor.txtar │ ├── future_annotations.txtar │ ├── glb_project.txtar │ ├── granite_project.txtar │ ├── healthcheck.txtar │ ├── healthcheck_async.txtar │ ├── healthcheck_async_exception.txtar │ ├── healthcheck_async_timeout.txtar │ ├── healthcheck_async_unhealthy.txtar │ ├── healthcheck_during_prediction.txtar │ ├── healthcheck_exception.txtar │ ├── healthcheck_immediately_after_prediction.txtar │ ├── healthcheck_repeated_calls.txtar │ ├── healthcheck_timeout.txtar │ ├── healthcheck_unhealthy.txtar │ ├── int_input_output.txtar │ ├── int_none_output.txtar │ ├── int_predictor.txtar │ ├── invalid_int_validation.txtar │ ├── iterator_error_midstream.txtar │ ├── iterator_string_output.txtar │ ├── legacy_sdk_schema.txtar │ ├── list_int_input_output.txtar │ ├── list_string_output.txtar │ ├── many_inputs.txtar │ ├── multi_file_schema.txtar │ ├── nested_output_types.txtar │ ├── no_predictor.txtar │ ├── non_base_predictor_class.txtar │ ├── non_base_predictor_function.txtar │ ├── oci_bundle_build.txtar │ ├── oci_bundle_inspect.txtar │ ├── oci_bundle_push.txtar │ ├── optional_path_input.txtar │ ├── path_input.txtar │ ├── path_input_output.txtar │ ├── path_list_input.txtar │ ├── path_list_output.txtar │ ├── path_output.txtar │ ├── predict_existing_image.txtar │ ├── predict_json_file.txtar │ ├── predict_json_input.txtar │ ├── predict_json_output_file.txtar │ ├── predict_json_stdin.txtar │ ├── predict_json_stdin_dash.txtar │ ├── predict_many_inputs_image.txtar │ ├── predict_output_file.txtar │ ├── predict_output_string.txtar │ ├── predict_sys_exit.txtar │ ├── prediction_error_response.txtar │ ├── pty_echo.txtar │ ├── pty_interactive.txtar │ ├── pydantic2.txtar │ ├── pydantic2_output.txtar │ ├── python313.txtar │ ├── python37_deprecated.txtar │ ├── python38_deprecated.txtar │ ├── python39_deprecated.txtar │ ├── run_basic.txtar │ ├── run_stdin_cat.txtar │ ├── run_stdin_unconsumed.txtar │ ├── scope_context.txtar │ ├── secrets.txtar │ ├── sequential_state_leak.txtar │ ├── setup_slow_serial.txtar │ ├── setup_subprocess_double_fork.txtar │ ├── setup_subprocess_double_fork_http.txtar │ ├── setup_subprocess_multiprocessing.txtar │ ├── setup_subprocess_simple.txtar │ ├── setup_timeout_serial.txtar │ ├── setup_worker_tracing_logs.txtar │ ├── static_schema_fallback.txtar │ ├── static_schema_gen.txtar │ ├── string_list_input.txtar │ ├── string_none_output.txtar │ ├── string_predictor.txtar │ ├── subdirectory_predictor.txtar │ ├── tensorflow.txtar │ ├── torch_270_cuda_126.txtar │ ├── torch_271_cuda_128.txtar │ ├── torch_baseimage_fallback.txtar │ ├── torch_baseimage_no_cog_base.txtar │ ├── torch_baseimage_precompile.txtar │ ├── torch_cuda_baseimage.txtar │ ├── train_basic.txtar │ ├── train_deprecated.txtar │ ├── training_setup.txtar │ ├── union_type.txtar │ ├── webhook_delivery_failure.txtar │ ├── webhook_prediction_error.txtar │ ├── weights_build.txtar │ ├── weights_push_inspect.txtar │ ├── wheel_coglet_missing.txtar │ ├── wheel_resolution.txtar │ └── zsh_package.txtar ├── mise.toml ├── mkdocs.yml ├── noxfile.py ├── pkg/ │ ├── cli/ │ │ ├── baseimage.go │ │ ├── build.go │ │ ├── debug.go │ │ ├── init-templates/ │ │ │ └── base/ │ │ │ ├── .dockerignore │ │ │ ├── .github/ │ │ │ │ └── workflows/ │ │ │ │ └── push.yaml │ │ │ ├── cog.yaml │ │ │ ├── predict.py │ │ │ └── requirements.txt │ │ ├── init.go │ │ ├── init_test.go │ │ ├── inspect.go │ │ ├── login.go │ │ ├── predict.go │ │ ├── predict_test.go │ │ ├── push.go │ │ ├── root.go │ │ ├── run.go │ │ ├── serve.go │ │ ├── train.go │ │ ├── train_test.go │ │ ├── weights.go │ │ └── weights_inspect.go │ ├── config/ │ │ ├── build_options.go │ │ ├── compatibility.go │ │ ├── compatibility_test.go │ │ ├── config.go │ │ ├── config_file.go │ │ ├── config_test.go │ │ ├── cuda_compatibility.json │ │ ├── data/ │ │ │ └── config_schema_v1.0.json │ │ ├── env.go │ │ ├── env_variables_test.go │ │ ├── errors.go │ │ ├── image_name.go │ │ ├── image_name_test.go │ │ ├── load.go │ │ ├── load_test.go │ │ ├── parse.go │ │ ├── tf_compatibility.json │ │ ├── torch_compatibility.json │ │ ├── validate.go │ │ ├── validate_test.go │ │ └── version.go │ ├── docker/ │ │ ├── build_secrets.go │ │ ├── buildkit.go │ │ ├── command/ │ │ │ ├── command.go │ │ │ ├── errors.go │ │ │ ├── manifest.go │ │ │ └── user_info.go │ │ ├── credential_helper_input.go │ │ ├── credentials.go │ │ ├── credentials_test.go │ │ ├── docker.go │ │ ├── docker_client_test.go │ │ ├── dockertest/ │ │ │ ├── command_mocks.go │ │ │ ├── helper_client.go │ │ │ ├── image.go │ │ │ ├── mock_command.go │ │ │ ├── ref.go │ │ │ ├── ref_test.go │ │ │ └── testdata/ │ │ │ └── create-image-fixtures.sh │ │ ├── env.go │ │ ├── errors.go │ │ ├── host.go │ │ ├── host_unix.go │ │ ├── host_windows.go │ │ ├── login.go │ │ ├── options.go │ │ ├── progress.go │ │ ├── push.go │ │ ├── run.go │ │ ├── run_test.go │ │ ├── standard_push.go │ │ └── standard_push_test.go │ ├── dockercontext/ │ │ ├── build_tempdir.go │ │ ├── build_tempdir_test.go │ │ └── directories.go │ ├── dockerfile/ │ │ ├── base.go │ │ ├── base_test.go │ │ ├── cacert.go │ │ ├── cacert_test.go │ │ ├── env.go │ │ ├── generator.go │ │ ├── generator_factory.go │ │ ├── generator_factory_test.go │ │ ├── standard_generator.go │ │ ├── standard_generator_test.go │ │ └── version_check.go │ ├── dockerignore/ │ │ ├── dockerignore.go │ │ └── dockerignore_test.go │ ├── env/ │ │ ├── env.go │ │ └── env_test.go │ ├── errors/ │ │ ├── common.go │ │ └── errors.go │ ├── global/ │ │ └── global.go │ ├── http/ │ │ ├── client.go │ │ ├── client_test.go │ │ ├── transport.go │ │ ├── transport_test.go │ │ ├── user_agent.go │ │ └── user_agent_test.go │ ├── image/ │ │ ├── build.go │ │ ├── build_test.go │ │ ├── config.go │ │ ├── openapi_schema.go │ │ └── pip_freeze.go │ ├── model/ │ │ ├── artifact.go │ │ ├── artifact_image.go │ │ ├── artifact_image_test.go │ │ ├── artifact_test.go │ │ ├── artifact_weight.go │ │ ├── artifact_weight_test.go │ │ ├── builder.go │ │ ├── builder_test.go │ │ ├── errors.go │ │ ├── errors_test.go │ │ ├── factory.go │ │ ├── factory_test.go │ │ ├── format.go │ │ ├── format_test.go │ │ ├── hash.go │ │ ├── image_builder.go │ │ ├── image_builder_test.go │ │ ├── image_pusher.go │ │ ├── image_pusher_test.go │ │ ├── image_test.go │ │ ├── index.go │ │ ├── index_factory.go │ │ ├── index_factory_test.go │ │ ├── index_test.go │ │ ├── model.go │ │ ├── model_test.go │ │ ├── options.go │ │ ├── options_test.go │ │ ├── push_helpers.go │ │ ├── pusher.go │ │ ├── pusher_test.go │ │ ├── ref.go │ │ ├── ref_test.go │ │ ├── ref_types.go │ │ ├── ref_types_test.go │ │ ├── resolver.go │ │ ├── resolver_test.go │ │ ├── source.go │ │ ├── source_test.go │ │ ├── weight_builder.go │ │ ├── weight_builder_test.go │ │ ├── weight_pusher.go │ │ ├── weight_pusher_test.go │ │ ├── weights.go │ │ ├── weights_lock.go │ │ ├── weights_lock_test.go │ │ └── weights_test.go │ ├── path/ │ │ ├── path.go │ │ └── path_test.go │ ├── predict/ │ │ ├── api.go │ │ ├── input.go │ │ └── predictor.go │ ├── provider/ │ │ ├── generic/ │ │ │ ├── generic.go │ │ │ └── generic_test.go │ │ ├── provider.go │ │ ├── registry.go │ │ ├── registry_test.go │ │ ├── replicate/ │ │ │ ├── replicate.go │ │ │ └── replicate_test.go │ │ └── setup/ │ │ ├── setup.go │ │ └── setup_test.go │ ├── registry/ │ │ ├── client.go │ │ ├── client_test.go │ │ ├── config.go │ │ ├── config_test.go │ │ ├── manifest_result.go │ │ ├── push_test.go │ │ ├── registry_client.go │ │ └── registrytest/ │ │ └── mock_client.go │ ├── registry_testhelpers/ │ │ ├── registry_container.go │ │ └── testdata/ │ │ └── docker/ │ │ └── registry/ │ │ └── v2/ │ │ ├── blobs/ │ │ │ └── sha256/ │ │ │ ├── 1c/ │ │ │ │ └── 1c4eef651f65e2f7daee7ee785882ac164b02b78fb74503052a26dc061c90474/ │ │ │ │ └── data │ │ │ ├── 6e/ │ │ │ │ └── 6e771e15690e2fabf2332d3a3b744495411d6e0b00b2aea64419b58b0066cf81/ │ │ │ │ └── data │ │ │ ├── 75/ │ │ │ │ └── 757d680068d77be46fd1ea20fb21db16f150468c5e7079a08a2e4705aec096ac/ │ │ │ │ └── data │ │ │ ├── 8d/ │ │ │ │ └── 8d591b0b7dea080ea3be9e12ae563eebf9869168ffced1cb25b2470a3d9fe15e/ │ │ │ │ └── data │ │ │ ├── 9a/ │ │ │ │ └── 9a0ff41dccad7a96f324a4655a715c623ed3511c7336361ffa9dadcecbdb99e5/ │ │ │ │ └── data │ │ │ ├── ad/ │ │ │ │ └── aded1e1a5b3705116fa0a92ba074a5e0b0031647d9c315983ccba2ee5428ec8b/ │ │ │ │ └── data │ │ │ └── f1/ │ │ │ └── f18232174bc91741fdf3da96d85011092101a032a93a388b79e99e69c2d5c870/ │ │ │ └── data │ │ └── repositories/ │ │ └── alpine/ │ │ ├── _layers/ │ │ │ └── sha256/ │ │ │ ├── 6e771e15690e2fabf2332d3a3b744495411d6e0b00b2aea64419b58b0066cf81/ │ │ │ │ └── link │ │ │ ├── 8d591b0b7dea080ea3be9e12ae563eebf9869168ffced1cb25b2470a3d9fe15e/ │ │ │ │ └── link │ │ │ ├── aded1e1a5b3705116fa0a92ba074a5e0b0031647d9c315983ccba2ee5428ec8b/ │ │ │ │ └── link │ │ │ └── f18232174bc91741fdf3da96d85011092101a032a93a388b79e99e69c2d5c870/ │ │ │ └── link │ │ └── _manifests/ │ │ ├── revisions/ │ │ │ └── sha256/ │ │ │ ├── 1c4eef651f65e2f7daee7ee785882ac164b02b78fb74503052a26dc061c90474/ │ │ │ │ └── link │ │ │ ├── 757d680068d77be46fd1ea20fb21db16f150468c5e7079a08a2e4705aec096ac/ │ │ │ │ └── link │ │ │ └── 9a0ff41dccad7a96f324a4655a715c623ed3511c7336361ffa9dadcecbdb99e5/ │ │ │ └── link │ │ └── tags/ │ │ └── latest/ │ │ ├── current/ │ │ │ └── link │ │ └── index/ │ │ └── sha256/ │ │ └── 9a0ff41dccad7a96f324a4655a715c623ed3511c7336361ffa9dadcecbdb99e5/ │ │ └── link │ ├── requirements/ │ │ ├── requirements.go │ │ └── requirements_test.go │ ├── schema/ │ │ ├── errors.go │ │ ├── generator.go │ │ ├── generator_test.go │ │ ├── openapi.go │ │ ├── openapi_test.go │ │ ├── python/ │ │ │ ├── parser.go │ │ │ ├── parser_fuzz_test.go │ │ │ └── parser_test.go │ │ ├── schema_type.go │ │ ├── schema_type_fuzz_test.go │ │ └── types.go │ ├── update/ │ │ ├── state.go │ │ └── update.go │ ├── util/ │ │ ├── console/ │ │ │ ├── console.go │ │ │ ├── formatting.go │ │ │ ├── global.go │ │ │ ├── interactive.go │ │ │ ├── levels.go │ │ │ └── term.go │ │ ├── env.go │ │ ├── errors.go │ │ ├── files/ │ │ │ ├── files.go │ │ │ └── files_test.go │ │ ├── hash.go │ │ ├── hash_test.go │ │ ├── mime/ │ │ │ ├── mime.go │ │ │ └── mime_test.go │ │ ├── net.go │ │ ├── overwrite_yaml.go │ │ ├── overwrite_yaml_test.go │ │ ├── platform.go │ │ ├── ringbuffer.go │ │ ├── shell/ │ │ │ ├── net.go │ │ │ └── pipes.go │ │ └── version/ │ │ ├── version.go │ │ └── version_test.go │ ├── web/ │ │ ├── client.go │ │ └── client_test.go │ ├── weights/ │ │ ├── manifest.go │ │ ├── weights.go │ │ └── weights_test.go │ └── wheels/ │ ├── wheels.go │ └── wheels_test.go ├── pyproject.toml ├── python/ │ ├── cog/ │ │ ├── .gitignore │ │ ├── __init__.py │ │ ├── _adt.py │ │ ├── _inspector.py │ │ ├── _schemas.py │ │ ├── coder.py │ │ ├── command/ │ │ │ ├── __init__.py │ │ │ └── openapi_schema.py │ │ ├── config.py │ │ ├── errors.py │ │ ├── input.py │ │ ├── mode.py │ │ ├── model.py │ │ ├── predictor.py │ │ ├── server/ │ │ │ ├── __init__.py │ │ │ └── http.py │ │ ├── suppress_output.py │ │ └── types.py │ └── tests/ │ ├── __init__.py │ ├── test_emit_metric.py │ ├── test_experimental_feature_warning.py │ ├── test_input.py │ ├── test_model.py │ ├── test_predictor.py │ └── test_types.py ├── script/ │ └── generate-compat ├── test-helpers/ │ └── https-server/ │ ├── go.mod │ └── main.go ├── test-integration/ │ └── test_integration/ │ └── fixtures/ │ └── hello-image/ │ ├── cog.yaml │ └── predict.py └── tools/ ├── compatgen/ │ ├── internal/ │ │ ├── cuda.go │ │ ├── tensorflow.go │ │ ├── torch.go │ │ ├── torch_package.go │ │ ├── torch_test.go │ │ ├── torch_test.html │ │ └── util.go │ └── main.go ├── gendocs/ │ └── main.go ├── install.sh ├── test-harness/ │ ├── .gitignore │ ├── README.md │ ├── harness/ │ │ ├── __init__.py │ │ ├── __main__.py │ │ ├── cli.py │ │ ├── cog_resolver.py │ │ ├── patcher.py │ │ ├── report.py │ │ ├── runner.py │ │ └── validators.py │ ├── manifest.yaml │ ├── pyproject.toml │ └── results/ │ └── .gitkeep ├── test-registry-util/ │ ├── README.md │ └── main.go └── weights-gen/ ├── README.md └── main.go ================================================ FILE CONTENTS ================================================ ================================================ FILE: .all-contributorsrc ================================================ { "projectName": "cog", "projectOwner": "replicate", "repoType": "github", "repoHost": "https://github.com", "files": [ "README.md" ], "imageSize": 100, "commit": false, "commitConvention": "none", "contributors": [ { "login": "bfirsh", "name": "Ben Firshman", "avatar_url": "https://avatars.githubusercontent.com/u/40906?v=4", "profile": "https://fir.sh/", "contributions": [ "code", "doc" ] }, { "login": "andreasjansson", "name": "Andreas Jansson", "avatar_url": "https://avatars.githubusercontent.com/u/713993?v=4", "profile": "https://replicate.ai/", "contributions": [ "code", "doc", "maintenance" ] }, { "login": "zeke", "name": "Zeke Sikelianos", "avatar_url": "https://avatars.githubusercontent.com/u/2289?v=4", "profile": "http://zeke.sikelianos.com/", "contributions": [ "code", "doc", "tool" ] }, { "login": "synek", "name": "Rory Byrne", "avatar_url": "https://avatars.githubusercontent.com/u/9436784?v=4", "profile": "https://rory.bio/", "contributions": [ "code", "doc", "test" ] }, { "login": "hangtwenty", "name": "Michael Floering", "avatar_url": "https://avatars.githubusercontent.com/u/2420688?v=4", "profile": "https://github.com/hangtwenty", "contributions": [ "code", "doc", "ideas" ] }, { "login": "bencevans", "name": "Ben Evans", "avatar_url": "https://avatars.githubusercontent.com/u/638535?v=4", "profile": "https://bencevans.io/", "contributions": [ "doc" ] }, { "login": "imshashank", "name": "shashank agarwal", "avatar_url": "https://avatars.githubusercontent.com/u/778870?v=4", "profile": "https://shashank.pw/", "contributions": [ "code", "doc" ] }, { "login": "VictorXLR", "name": "VictorXLR", "avatar_url": "https://avatars.githubusercontent.com/u/22397950?v=4", "profile": "https://victorxlr.me/", "contributions": [ "code", "doc", "test" ] }, { "login": "annahung31", "name": "hung anna", "avatar_url": "https://avatars.githubusercontent.com/u/39179888?v=4", "profile": "https://annahung31.github.io/", "contributions": [ "bug" ] }, { "login": "bwhitman", "name": "Brian Whitman", "avatar_url": "https://avatars.githubusercontent.com/u/76612?v=4", "profile": "http://notes.variogr.am/", "contributions": [ "bug" ] }, { "login": "JimothyJohn", "name": "JimothyJohn", "avatar_url": "https://avatars.githubusercontent.com/u/24216724?v=4", "profile": "https://github.com/JimothyJohn", "contributions": [ "bug" ] }, { "login": "ericguizzo", "name": "ericguizzo", "avatar_url": "https://avatars.githubusercontent.com/u/26746670?v=4", "profile": "https://github.com/ericguizzo", "contributions": [ "bug" ] }, { "login": "evilstreak", "name": "Dominic Baggott", "avatar_url": "https://avatars.githubusercontent.com/u/74812?v=4", "profile": "http://www.dominicbaggott.com", "contributions": [ "code", "test" ] }, { "login": "dashstander", "name": "Dashiell Stander", "avatar_url": "https://avatars.githubusercontent.com/u/7449128?v=4", "profile": "https://github.com/dashstander", "contributions": [ "bug", "code", "test" ] }, { "login": "Hurricane-eye", "name": "Shuwei Liang", "avatar_url": "https://avatars.githubusercontent.com/u/31437546?v=4", "profile": "https://github.com/Hurricane-eye", "contributions": [ "bug", "question" ] }, { "login": "ericallam", "name": "Eric Allam", "avatar_url": "https://avatars.githubusercontent.com/u/534?v=4", "profile": "https://github.com/ericallam", "contributions": [ "ideas" ] }, { "login": "iperdomo", "name": "Iván Perdomo", "avatar_url": "https://avatars.githubusercontent.com/u/178474?v=4", "profile": "https://perdomo.me", "contributions": [ "bug" ] }, { "login": "charlesfrye", "name": "Charles Frye", "avatar_url": "https://avatars.githubusercontent.com/u/10442975?v=4", "profile": "http://charlesfrye.github.io", "contributions": [ "doc" ] }, { "login": "phamquiluan", "name": "Luan Pham", "avatar_url": "https://avatars.githubusercontent.com/u/24642166?v=4", "profile": "https://github.com/phamquiluan", "contributions": [ "bug", "doc" ] }, { "login": "TommyDew42", "name": "TommyDew", "avatar_url": "https://avatars.githubusercontent.com/u/46992350?v=4", "profile": "https://github.com/TommyDew42", "contributions": [ "code" ] }, { "login": "anotherjesse", "name": "Jesse Andrews", "avatar_url": "https://avatars.githubusercontent.com/u/27?v=4", "profile": "https://m4ke.org", "contributions": [ "code", "doc", "test" ] }, { "login": "nickstenning", "name": "Nick Stenning", "avatar_url": "https://avatars.githubusercontent.com/u/3602?v=4", "profile": "https://whiteink.com", "contributions": [ "code", "doc", "design", "infra", "test" ] }, { "login": "justinmerrell", "name": "Justin Merrell", "avatar_url": "https://avatars.githubusercontent.com/u/14996837?v=4", "profile": "https://merrell.io/", "contributions": [ "doc" ] }, { "login": "ruriky", "name": "Rurik Ylä-Onnenvuori", "avatar_url": "https://avatars.githubusercontent.com/u/19946546?v=4", "profile": "https://github.com/ruriky", "contributions": [ "bug" ] }, { "login": "youkaclub", "name": "Youka", "avatar_url": "https://avatars.githubusercontent.com/u/59315275?v=4", "profile": "https://www.youka.club/", "contributions": [ "bug" ] }, { "login": "afiaka87", "name": "Clay Mullis", "avatar_url": "https://avatars.githubusercontent.com/u/3994972?v=4", "profile": "https://github.com/afiaka87", "contributions": [ "doc" ] }, { "login": "mattt", "name": "Mattt", "avatar_url": "https://avatars.githubusercontent.com/u/7659?v=4", "profile": "https://github.com/mattt", "contributions": [ "code", "doc", "infra" ] }, { "login": "Juneezee", "name": "Eng Zer Jun", "avatar_url": "https://avatars.githubusercontent.com/u/20135478?v=4", "profile": "https://github.com/Juneezee", "contributions": [ "test" ] }, { "login": "bbedward", "name": "BB", "avatar_url": "https://avatars.githubusercontent.com/u/550752?v=4", "profile": "https://github.com/bbedward", "contributions": [ "code" ] }, { "login": "williamluer", "name": "williamluer", "avatar_url": "https://avatars.githubusercontent.com/u/85975676?v=4", "profile": "https://github.com/williamluer", "contributions": [ "doc" ] }, { "login": "sirupsen", "name": "Simon Eskildsen", "avatar_url": "https://avatars.githubusercontent.com/u/97400?v=4", "profile": "http://sirupsen.com", "contributions": [ "code" ] }, { "login": "erbridge", "name": "F", "avatar_url": "https://avatars.githubusercontent.com/u/1027364?v=4", "profile": "https://erbridge.co.uk", "contributions": [ "bug", "code" ] }, { "login": "philandstuff", "name": "Philip Potter", "avatar_url": "https://avatars.githubusercontent.com/u/581269?v=4", "profile": "https://github.com/philandstuff", "contributions": [ "bug", "code" ] }, { "login": "joannejchen", "name": "Joanne Chen", "avatar_url": "https://avatars.githubusercontent.com/u/33409024?v=4", "profile": "https://github.com/joannejchen", "contributions": [ "doc" ] }, { "login": "technillogue", "name": "technillogue", "avatar_url": "https://avatars.githubusercontent.com/u/945691?v=4", "profile": "http://technillogue.github.io", "contributions": [ "code" ] }, { "login": "aron", "name": "Aron Carroll", "avatar_url": "https://avatars.githubusercontent.com/u/47144?v=4", "profile": "http://aroncarroll.com", "contributions": [ "doc", "code", "ideas" ] }, { "login": "Theodotus1243", "name": "Bohdan Mykhailenko", "avatar_url": "https://avatars.githubusercontent.com/u/32220358?v=4", "profile": "https://github.com/Theodotus1243", "contributions": [ "doc", "bug" ] }, { "login": "one1zero1one", "name": "Daniel Radu", "avatar_url": "https://avatars.githubusercontent.com/u/724604?v=4", "profile": "https://github.com/one1zero1one", "contributions": [ "doc", "bug" ] }, { "login": "Etelis", "name": "Itay Etelis", "avatar_url": "https://avatars.githubusercontent.com/u/92247226?v=4", "profile": "https://github.com/Etelis", "contributions": [ "code" ] }, { "login": "gschian0", "name": "Gennaro Schiano", "avatar_url": "https://avatars.githubusercontent.com/u/54407820?v=4", "profile": "http://www.wavefunction.dev", "contributions": [ "doc" ] }, { "login": "aknoerig", "name": "André Knörig", "avatar_url": "https://avatars.githubusercontent.com/u/481350?v=4", "profile": "http://andreknoerig.de", "contributions": [ "doc" ] }, { "login": "danfairs", "name": "Dan Fairs", "avatar_url": "https://avatars.githubusercontent.com/u/24726?v=4", "profile": "https://condense.live", "contributions": [ "code" ] } ], "contributorsPerLine": 7, "skipCi": true, "commitType": "docs" } ================================================ FILE: .git_archival.txt ================================================ node: $Format:%H$ node-date: $Format:%cI$ describe-name: $Format:%(describe:tags=true,match=*[0-9]*)$ ================================================ FILE: .gitattributes ================================================ .git_archival.txt export-subst Makefile -linguist-detectable docs/llms.txt linguist-generated=true ================================================ FILE: .github/CODEOWNERS ================================================ # Default code owners for the entire repository * @replicate/cog ================================================ FILE: .github/dependabot.yml ================================================ version: 2 updates: - package-ecosystem: "gomod" directory: "/" schedule: interval: "weekly" - package-ecosystem: "pip" directory: "/" schedule: interval: "weekly" allow: - dependency-type: "direct" - package-ecosystem: "cargo" directory: "/crates" schedule: interval: "weekly" - package-ecosystem: "github-actions" directory: "/" schedule: interval: "weekly" ================================================ FILE: .github/workflows/README.md ================================================ # CI Architecture This document describes the CI/CD architecture for the Cog repository. ## Design Principles 1. **Single gate job** - Branch protection uses one required check (`ci-complete`) that depends on all other jobs 2. **Path-based filtering** - Jobs skip when irrelevant files change (Go changes don't trigger Rust tests) 3. **Build once, test many** - Artifacts built once and reused across test jobs 4. **Parallel execution** - Independent jobs run concurrently 5. **Skipped = passing** - Jobs that skip due to path filtering count as passing for the gate ## Workflows ### `ci.yaml` - Main CI Pipeline The primary CI workflow that runs on all PRs and pushes to main. ``` ┌─────────────────────────────────────────────────────────────────────────────┐ │ CHANGES DETECTION │ │ Determines which components changed: go, rust, python, integration-tests │ └─────────────────────────────────────────────────────────────────────────────┘ │ ┌─────────────────┼─────────────────┐ ▼ ▼ ▼ ┌──────────┐ ┌──────────┐ ┌──────────┐ │build-rust│ │ build-sdk│ │ (none) │ │ (wheel) │ │ (wheel) │ │ │ └────┬─────┘ └────┬─────┘ └──────────┘ │ │ ┌─────────────┼────────────────┼─────────────────────┐ │ │ │ │ ▼ ▼ ▼ ▼ ┌─────────┐ ┌──────────┐ ┌───────────┐ ┌───────────┐ │fmt-rust │ │test-rust │ │ fmt-go │ │fmt-python │ │lint-rust│ │coglet-py │ │ lint-go │ │lint-python│ │ deny │ │ (matrix) │ │ test-go │ │test-python│ └─────────┘ └────┬─────┘ └───────────┘ └───────────┘ │ │ │ └────────────────┼─────────────────────┘ ▼ ┌────────────────┐ │test-integration│ │ (matrix) │ └───────┬────────┘ ▼ ┌───────────────┐ │ ci-complete │ ← Branch protection requires this └───────────────┘ ``` #### Jobs | Job | Runs when | Depends on | Purpose | |-----|-----------|------------|---------| | `changes` | Always | - | Detect which components changed | | `build-sdk` | python changed | changes | Build cog SDK wheel | | `build-rust` | rust changed | changes | Build coglet ABI3 wheel | | `fmt-go` | go changed | changes | Check Go formatting | | `fmt-rust` | rust changed | changes | Check Rust formatting | | `fmt-python` | python changed | changes | Check Python formatting | | `lint-go` | go changed | changes | Lint Go code | | `lint-rust` | rust changed | changes | Run clippy | | `lint-rust-deny` | rust changed | changes | Check licenses/advisories | | `lint-python` | python changed | build-sdk | Lint Python code | | `test-go` | go changed | build-sdk | Run Go tests (matrix: ubuntu, macos) | | `test-rust` | rust changed | changes | Run Rust tests | | `test-python` | python changed | build-sdk | Run Python tests (matrix: 3.10-3.13) | | `test-coglet-python` | rust or python changed | build-rust | Test coglet bindings (matrix: 3.10-3.13) | | `test-integration` | any changed | build-sdk, build-rust | Integration tests (matrix: cog, cog-rust) | | `ci-complete` | Always | all jobs | Gate job for branch protection | #### Python Version Matrix Python versions are defined once at the workflow level: ```yaml env: SUPPORTED_PYTHONS: '["3.10", "3.11", "3.12", "3.13"]' ``` Jobs that need the matrix reference it via `fromJson(env.SUPPORTED_PYTHONS)`. ### `codeql.yml` - Security Analysis Runs CodeQL security scanning for Go, Python, and Rust. - **Triggers**: Push to main, PRs to main, weekly schedule - **Languages**: go, python, rust ### Deleted Workflows - `rust.yaml` - Consolidated into `ci.yaml`. The separate workflow was redundant. - `pypi-package.yaml` - Replaced by `release-build.yaml` + `release-publish.yaml`. - `version-bump.yaml` - Removed. Just edit `crates/Cargo.toml` directly. ## Caching Strategy ### Rust Cache - **Save**: Only on `main` branch pushes (to avoid PR cache pollution) - **Restore**: On all runs (PRs restore from main's cache) - Uses `Swatinem/rust-cache@v2` with workspace path `crates -> target` ### Go Cache - Built into `actions/setup-go` via `cache-dependency-path` ### Python/uv Cache - Built into `jdx/mise-action` and `astral-sh/setup-uv` ## Artifacts | Artifact | Contents | Retention | |----------|----------|-----------| | `CogPackage` | cog-*.whl, cog-*.tar.gz | Default (90 days) | | `CogletRustWheel` | coglet-*-cp310-abi3-*.whl | Default (90 days) | The ABI3 wheel is built with Python 3.10 minimum but works on all 3.10+ versions. ## Local Development Use mise tasks to run the same checks locally: ```bash # Format (check) mise run fmt # Format (fix) mise run fmt:fix # Lint mise run lint # Test mise run test:go mise run test:rust mise run test:python # Build mise run build:cog mise run build:coglet mise run build:sdk ``` ## Adding New Checks 1. Add a mise task in `mise.toml` 2. Add a job in `ci.yaml` with appropriate `needs` and path filtering 3. Add the job to `ci-complete`'s needs list 4. Update this README ## Branch Protection Configure branch protection to require only `ci-complete`: ``` Settings > Branches > main > Require status checks: ✓ ci-complete ``` Skipped jobs (from path filtering) are treated as passing by the gate job. ## Release Workflow Releases use a two-workflow system. There are three release types: | Type | Example tag | Branch rule | Draft? | PyPI/crates.io? | |------|-------------|-------------|--------|-----------------| | **Stable** | `v0.17.0` | Must be on main | Yes (manual publish) | Yes | | **Pre-release** | `v0.17.0-alpha3` | Must be on main | Yes (manual publish) | Yes | | **Dev** | `v0.17.0-dev1` | Any branch | No (immediate) | No | ### Stable / Pre-release Flow ``` Developer pushes tag on main (e.g. v0.17.0, v0.17.0-rc1) │ ▼ release-build.yaml (automatic) ┌──────────────────────────────────────────────┐ │ verify-tag ──▶ build-sdk ──┐ │ │ (must be build-coglet ┼──▶ create- │ │ main) build-CLI ──┘ release │ │ (DRAFT) │ └──────────────────────────────────────────────┘ │ Maintainer publishes draft in GitHub UI │ ▼ release-publish.yaml (automatic) ┌──────────────────────────────────────────────┐ │ coglet → PyPI ──▶ SDK → PyPI │ │ coglet → crates.io │ └──────────────────────────────────────────────┘ ``` ### Dev Release Flow ``` Developer pushes tag from any branch (e.g. v0.17.0-dev1) │ ▼ release-build.yaml (automatic) ┌──────────────────────────────────────────────┐ │ verify-tag ──▶ build-sdk ──┐ │ │ (no branch build-coglet ┼──▶ create- │ │ restriction) build-CLI ──┘ release │ │ (PRE- │ │ RELEASE) │ └──────────────────────────────────────────────┘ │ Done. No PyPI/crates.io. Wheels + CLI binaries on GH release. ``` ### Workflows #### `release-build.yaml` Triggered by version tags (`v*.*.*`). Builds all artifacts and creates a GitHub release. | Job | Purpose | |-----|---------| | `verify-tag` | Cargo.toml version match + branch rules (main for stable/pre-release, any for dev) | | `build-sdk` | Build cog SDK wheel and sdist | | `build-coglet-wheels` | Build coglet wheels (3 platforms via zig cross-compile) | | `create-release` | Goreleaser builds CLI + creates release, then appends wheels. Dev releases are immediately published as pre-release; stable/pre-release remain as draft. | **Security**: No secrets needed for dev. Stable/pre-release require maintainer to publish draft. #### `release-publish.yaml` Triggered when a release is published. Publishes to PyPI and crates.io. **Skips entirely for dev releases** (all jobs gated on `is_dev != true`). | Job | Depends on | Purpose | |-----|------------|---------| | `verify-release` | - | Validate tag format, classify release type | | `publish-pypi-coglet` | verify-release | Publish coglet to PyPI (trusted publishing) | | `publish-pypi-sdk` | publish-pypi-coglet | Publish SDK to PyPI (waits for coglet) | | `publish-crates-io` | verify-release | Publish coglet crate (OIDC) | | `update-homebrew-tap` | publish-pypi-sdk, publish-crates-io | Update `replicate/homebrew-tap` cask (stable only, macOS, via GH App) | ### Package Versioning All packages use **lockstep versioning** from `crates/Cargo.toml`. | Package | Registry | Version format | Example | |---------|----------|----------------|---------| | cog SDK | PyPI | PEP 440 | `cog==0.17.0`, `cog==0.17.0a3`, `cog==0.17.0.dev1` | | coglet | PyPI | PEP 440 | `coglet==0.17.0`, `coglet==0.17.0a3` | | coglet | crates.io | semver | `coglet@0.17.0`, `coglet@0.17.0-alpha3` | | CLI | GitHub Release | semver | `cog v0.17.0`, `cog v0.17.0-dev1` | **Version conversion** (semver -> PEP 440): - `0.17.0-alpha3` -> `0.17.0a3` - `0.17.0-beta1` -> `0.17.0b1` - `0.17.0-rc1` -> `0.17.0rc1` - `0.17.0-dev1` -> `0.17.0.dev1` - `0.17.0` -> `0.17.0` ### SDK Wheel Sourcing The CLI installs the cog SDK from PyPI at container build time: | Scenario | COG_SDK_WHEEL env var | Behavior | |----------|-----------------------|----------| | Released CLI | (unset) | Install latest `cog` from PyPI | | Dev CLI (in repo) | (unset) | Auto-detect `dist/cog-*.whl` if present, else PyPI | | Force PyPI | `pypi` | Install latest from PyPI | | Specific version | `pypi:0.12.0` | Install `cog==0.12.0` from PyPI | | Local wheel | `/path/to/cog.whl` | Install from local file | | Force dist | `dist` | Install from `dist/` (error if missing) | Same pattern for `COGLET_WHEEL` (but coglet is optional by default). ### GitHub Environment Setup 1. Create environments in **Settings -> Environments**: - `pypi` - For PyPI publishing (trusted publishing, no secrets) - `crates-io` - For crates.io publishing (trusted publishing, no secrets) 2. Configure protection rules for each environment: - **Deployment branches**: "Selected branches and tags" - **Add pattern**: `v*` (restricts to version tags) - **Required reviewers**: Add maintainers 3. Configure trusted publishers: - **PyPI** (both `cog` and `coglet`): workflow `release-publish.yaml`, environment `pypi` - **crates.io** (`coglet`): workflow `release-publish.yaml`, environment `crates-io` 4. Configure the Homebrew tap GitHub App: - App: `cog-homebrew-tapbot` (ID: 1232932405) - Create environment `homebrew` with secret `COG_HOMEBREW_TAP_PRIVATE_KEY` (app private key) - App must have write access to `replicate/homebrew-tap` ### Performing a Stable / Pre-release ```bash # 1. Update crates/Cargo.toml version (e.g. "0.17.0" or "0.17.0-alpha3") # 2. Merge to main # 3. Tag and push git tag v0.17.0 git push origin v0.17.0 # 4. Wait for release-build.yaml to complete (creates draft release) # 5. Review the draft release in GitHub UI # 6. Click "Publish release" -> triggers release-publish.yaml -> PyPI + crates.io ``` ### Performing a Dev Release ```bash # From any branch: # 1. Update crates/Cargo.toml version (e.g. "0.17.0-dev1") # 2. Commit and push # 3. Tag and push git tag v0.17.0-dev1 git push origin v0.17.0-dev1 # 4. Done. release-build.yaml creates a pre-release with all artifacts. # No PyPI/crates.io publishing. No manual approval needed. ``` ================================================ FILE: .github/workflows/ci.yaml ================================================ name: CI on: merge_group: push: branches: [main] pull_request: workflow_dispatch: # Cancel in-progress runs for PRs, queue for merge group and main concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}-v2 cancel-in-progress: ${{ github.event_name == 'pull_request' }} env: # Single source of truth for supported Python versions SUPPORTED_PYTHONS: '["3.10", "3.11", "3.12", "3.13"]' # Default Python version for non-matrix jobs PYTHON_VERSION: "3.13" # Minimum supported Python — used for ABI3 wheel builds and glob patterns. # Must match the lowest entry in SUPPORTED_PYTHONS. MINIMUM_PYTHON: "3.10" # Number of runners to shard integration tests across (per runtime) # Slow tests ([short] skip) are distributed round-robin first, then fast tests fill in NUM_IT_RUNNER_SHARDS: "4" # Standard environment HYPOTHESIS_PROFILE: ci FORCE_COLOR: "1" PIP_DISABLE_PIP_VERSION_CHECK: "1" PIP_NO_PYTHON_VERSION_WARNING: "1" CARGO_TERM_COLOR: always # CGo required for go-tree-sitter (static Python schema parser) CGO_ENABLED: "1" # Disable tools in mise that CI installs via dedicated GitHub Actions for # better reliability (avoids transient GitHub Releases 502s from aqua downloads), # better caching, and guaranteed tool ordering. # - Rust toolchain: dtolnay/rust-toolchain # - cargo-binstall: taiki-e/install-action # - Python: astral-sh/setup-uv # - golangci-lint: golangci/golangci-lint-action # - gotestsum: go install (uses Go module proxy, not GitHub Releases) # - cargo-deny, cargo-nextest: taiki-e/install-action # - zig, cargo-zigbuild, maturin, cargo-insta: not needed in CI (maturin-action bundles zig) MISE_DISABLE_TOOLS: rust,rustup,rustup-init,cargo-binstall,python,golangci-lint,gotestsum,cargo-deny,cargo-insta,cargo-nextest,cargo:cargo-nextest,zig,cargo-zigbuild,maturin,cargo:maturin permissions: {} # ============================================================================= # Change Detection # ============================================================================= jobs: changes: name: Detect changes runs-on: ubuntu-latest timeout-minutes: 5 outputs: go: ${{ steps.filter.outputs.go }} rust: ${{ steps.filter.outputs.rust }} python: ${{ steps.filter.outputs.python }} integration: ${{ steps.filter.outputs.integration }} docs: ${{ steps.filter.outputs.docs }} version_only: ${{ steps.filter.outputs.version_only }} version_changed: ${{ steps.filter.outputs.version_changed }} # Pass through for matrix jobs (env context unavailable in strategy) supported_pythons: ${{ env.SUPPORTED_PYTHONS }} steps: - uses: actions/checkout@v6 with: fetch-depth: 0 - name: Detect changed paths id: filter run: | # For PRs, compare against base; for pushes, compare against previous commit; # for merge_group, compare against the merge group base. if [ "${{ github.event_name }}" = "pull_request" ]; then BASE="${{ github.event.pull_request.base.sha }}" elif [ "${{ github.event_name }}" = "merge_group" ]; then BASE="${{ github.event.merge_group.base_sha }}" else BASE="${{ github.event.before }}" # Handle initial push (no before) if [ "$BASE" = "0000000000000000000000000000000000000000" ]; then BASE="HEAD~1" fi fi echo "Comparing $BASE..HEAD" # Get changed files CHANGED=$(git diff --name-only "$BASE" HEAD 2>/dev/null || echo "") # Check if coglet version changed VERSION_CHANGED="false" if echo "$CHANGED" | grep -qE '^crates/Cargo\.toml$'; then # Check if only the version line changed if git diff "$BASE" HEAD -- crates/Cargo.toml | grep -qE '^\+version = '; then VERSION_CHANGED="true" echo "Coglet version changed" fi fi echo "version_changed=$VERSION_CHANGED" >> $GITHUB_OUTPUT # Check if ONLY the version changed (version bump PR) # This is true if crates/Cargo.toml is the only file and only version line changed VERSION_ONLY="false" if [ "$VERSION_CHANGED" = "true" ]; then FILE_COUNT=$(echo "$CHANGED" | grep -c . || echo "0") if [ "$FILE_COUNT" = "1" ]; then # Only crates/Cargo.toml changed, check if only version line changed # Get actual diff lines (excluding +++ and --- headers) DIFF_CONTENT=$(git diff "$BASE" HEAD -- crates/Cargo.toml | grep -E '^[+-]' | grep -v '^[+-]{3}') # Should be exactly: -version = "old" and +version = "new" MINUS_LINES=$(echo "$DIFF_CONTENT" | grep -c '^-' || echo "0") PLUS_LINES=$(echo "$DIFF_CONTENT" | grep -c '^\+' || echo "0") VERSION_MINUS=$(echo "$DIFF_CONTENT" | grep -c '^-version = ' || echo "0") VERSION_PLUS=$(echo "$DIFF_CONTENT" | grep -c '^\+version = ' || echo "0") if [ "$MINUS_LINES" = "1" ] && [ "$PLUS_LINES" = "1" ] && \ [ "$VERSION_MINUS" = "1" ] && [ "$VERSION_PLUS" = "1" ]; then VERSION_ONLY="true" echo "Version-only change detected - skipping heavy CI" fi fi fi echo "version_only=$VERSION_ONLY" >> $GITHUB_OUTPUT # CI/tooling changes should run everything (unless version-only) if [ "$VERSION_ONLY" = "true" ]; then echo "go=false" >> $GITHUB_OUTPUT echo "rust=false" >> $GITHUB_OUTPUT echo "python=false" >> $GITHUB_OUTPUT echo "integration=false" >> $GITHUB_OUTPUT echo "docs=false" >> $GITHUB_OUTPUT elif echo "$CHANGED" | grep -qE '^(\.github/workflows/|mise\.toml)'; then echo "CI/tooling changed - running all jobs" echo "go=true" >> $GITHUB_OUTPUT echo "rust=true" >> $GITHUB_OUTPUT echo "python=true" >> $GITHUB_OUTPUT echo "integration=true" >> $GITHUB_OUTPUT echo "docs=true" >> $GITHUB_OUTPUT else # Detect Go changes if echo "$CHANGED" | grep -qE '^(cmd/|pkg/|go\.(mod|sum)|\.golangci\.yml|Makefile)'; then echo "go=true" >> $GITHUB_OUTPUT else echo "go=false" >> $GITHUB_OUTPUT fi # Detect Rust changes if echo "$CHANGED" | grep -qE '^(crates/|Cargo\.(toml|lock))'; then echo "rust=true" >> $GITHUB_OUTPUT else echo "rust=false" >> $GITHUB_OUTPUT fi # Detect Python changes if echo "$CHANGED" | grep -qE '^(python/|pyproject\.toml|uv\.lock|noxfile\.py|\.ruff\.toml)'; then echo "python=true" >> $GITHUB_OUTPUT else echo "python=false" >> $GITHUB_OUTPUT fi # Detect integration test changes (or if any code changed) if echo "$CHANGED" | grep -qE '^(integration-tests/|cmd/|pkg/|python/|crates/|go\.(mod|sum)|uv\.lock|pyproject\.toml)'; then echo "integration=true" >> $GITHUB_OUTPUT else echo "integration=false" >> $GITHUB_OUTPUT fi # Detect docs changes (includes CLI source which generates docs/cli.md) if echo "$CHANGED" | grep -qE '^(docs/|README\.md|cmd/|pkg/cli/)'; then echo "docs=true" >> $GITHUB_OUTPUT else echo "docs=false" >> $GITHUB_OUTPUT fi fi # Debug output echo "Changed files:" echo "$CHANGED" | head -50 # ============================================================================= # Version Check - Validates coglet version changes # ============================================================================= version-check: name: Validate coglet version needs: changes if: needs.changes.outputs.version_changed == 'true' runs-on: ubuntu-latest timeout-minutes: 5 steps: - uses: actions/checkout@v6 with: fetch-depth: 0 - name: Validate version run: | # Get version from Cargo.toml VERSION=$(grep '^version = ' crates/Cargo.toml | head -1 | sed 's/version = "\(.*\)"/\1/') echo "Coglet version: $VERSION" # Validate semver format if [[ ! "$VERSION" =~ ^[0-9]+\.[0-9]+\.[0-9]+(-[a-zA-Z0-9.]+)?$ ]]; then echo "::error::Invalid version format: $VERSION" echo "::error::Expected semver format: MAJOR.MINOR.PATCH or MAJOR.MINOR.PATCH-prerelease" exit 1 fi echo "✓ Valid semver format" # Check version doesn't already exist as a tag if git tag -l "v$VERSION" | grep -q .; then echo "::error::Tag v$VERSION already exists!" echo "::error::Cannot set version to an already-released version." exit 1 fi echo "✓ Version not yet released" # Get the highest existing stable version tag HIGHEST_TAG=$(git tag -l 'v[0-9]*.[0-9]*.[0-9]*' | grep -v '-' | sed 's/^v//' | sort -V | tail -1) if [ -n "$HIGHEST_TAG" ]; then echo "Highest released version: $HIGHEST_TAG" # Check it's not a downgrade (using sort -V for proper semver comparison) BASE_VERSION="${VERSION%%-*}" SORTED_HIGHEST=$(printf '%s\n%s' "$HIGHEST_TAG" "$BASE_VERSION" | sort -V | tail -1) if [ "$SORTED_HIGHEST" = "$HIGHEST_TAG" ] && [ "$HIGHEST_TAG" != "$BASE_VERSION" ]; then echo "::error::Cannot downgrade version from $HIGHEST_TAG to $VERSION" echo "::error::New version must be greater than the highest released version." exit 1 fi echo "✓ Version is not a downgrade" else echo "No existing version tags found" fi echo "" echo "✓ Version $VERSION is valid for release" # ============================================================================= # Build Stage - Produces artifacts for downstream jobs # ============================================================================= build-sdk: name: Build SDK needs: changes if: needs.changes.outputs.python == 'true' || needs.changes.outputs.go == 'true' || needs.changes.outputs.integration == 'true' runs-on: ubuntu-latest timeout-minutes: 15 steps: - uses: actions/checkout@v6 with: fetch-depth: 0 - uses: astral-sh/setup-uv@v6 with: python-version: ${{ env.PYTHON_VERSION }} - uses: dtolnay/rust-toolchain@stable - run: rustup default stable - uses: taiki-e/install-action@cargo-binstall - uses: jdx/mise-action@v4 with: cache: false - name: Build SDK run: mise run ci:build:sdk - name: Upload SDK package uses: actions/upload-artifact@v6 with: name: CogPackage path: dist/cog-* build-rust: name: Build coglet wheel needs: changes if: needs.changes.outputs.rust == 'true' || needs.changes.outputs.integration == 'true' runs-on: ubuntu-latest timeout-minutes: 15 steps: - uses: actions/checkout@v6 with: fetch-depth: 0 - uses: astral-sh/setup-uv@v6 with: python-version: ${{ env.PYTHON_VERSION }} - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 with: workspaces: crates -> target save-if: ${{ github.ref == 'refs/heads/main' }} # No mise needed - maturin-action bundles maturin and zig # Explicitly request MINIMUM_PYTHON inside the manylinux container so # maturin produces an ABI3 wheel (cp310-abi3). Without this, maturin # picks up the container's default Python (3.8), which doesn't support # ABI3, producing a cp38-cp38 wheel that the upload glob won't match. - name: Build coglet wheel (ABI3) uses: PyO3/maturin-action@v1 with: target: x86_64-unknown-linux-gnu args: --release --out dist -m crates/coglet-python/Cargo.toml --interpreter python${{ env.MINIMUM_PYTHON }} manylinux: auto - name: Verify ABI3 wheel exists run: | CPVER="cp${MINIMUM_PYTHON//.}" ls -la dist/coglet-*-${CPVER}-abi3-*.whl - name: Upload coglet wheel uses: actions/upload-artifact@v6 with: name: CogletRustWheel # ABI3 wheels use cpXYZ-abi3 naming; just match any abi3 wheel path: dist/coglet-*-abi3-*.whl build-cog: name: Build cog CLI needs: changes if: needs.changes.outputs.go == 'true' || needs.changes.outputs.integration == 'true' runs-on: ubuntu-latest timeout-minutes: 15 steps: - uses: actions/checkout@v6 with: fetch-depth: 0 - uses: actions/setup-go@v6 with: go-version-file: go.mod cache-dependency-path: go.sum - uses: mlugg/setup-zig@v2 with: version: 0.15.2 - name: Get version from Cargo.toml id: version run: echo "version=$(grep '^version' crates/Cargo.toml | head -1 | sed 's/.*"\(.*\)"/\1/')" >> "$GITHUB_OUTPUT" - name: Build cog binary uses: goreleaser/goreleaser-action@v7 with: version: '~> v2' args: build --clean --snapshot --single-target --id cog --output cog env: GOFLAGS: -buildvcs=false # Use Cargo.toml as version source so snapshot builds match the wheel version COG_VERSION: ${{ steps.version.outputs.version }} - name: Upload cog binary uses: actions/upload-artifact@v6 with: name: CogBinary path: cog # ============================================================================= # Format Checks - Fast, parallel # ============================================================================= fmt-go: name: Format Go needs: changes if: needs.changes.outputs.go == 'true' runs-on: ubuntu-latest timeout-minutes: 10 steps: - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@stable - run: rustup default stable - uses: taiki-e/install-action@cargo-binstall - uses: jdx/mise-action@v4 with: cache: false - name: Check Go formatting run: mise run fmt:go fmt-rust: name: Format Rust needs: changes if: needs.changes.outputs.rust == 'true' runs-on: ubuntu-latest timeout-minutes: 10 steps: - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@stable with: components: rustfmt # No mise needed - rustfmt comes with toolchain - name: Check Rust formatting run: cargo fmt --manifest-path crates/Cargo.toml --all -- --check fmt-python: name: Format Python needs: changes if: needs.changes.outputs.python == 'true' runs-on: ubuntu-latest timeout-minutes: 10 steps: - uses: actions/checkout@v6 - uses: astral-sh/setup-uv@v6 with: python-version: ${{ env.PYTHON_VERSION }} - uses: dtolnay/rust-toolchain@stable - run: rustup default stable - uses: taiki-e/install-action@cargo-binstall - uses: jdx/mise-action@v4 with: cache: false - name: Check Python formatting run: mise run fmt:python check-llm-docs: name: Check LLM docs needs: changes if: needs.changes.outputs.docs == 'true' runs-on: ubuntu-latest timeout-minutes: 5 steps: - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@stable - run: rustup default stable - uses: taiki-e/install-action@cargo-binstall - uses: jdx/mise-action@v4 with: cache: false - name: Check llms.txt is up to date run: mise run docs:llm:check - name: Check CLI docs are up to date run: mise run docs:cli:check # ============================================================================= # Lint Checks - Parallel # ============================================================================= lint-go: name: Lint Go needs: changes if: needs.changes.outputs.go == 'true' runs-on: ubuntu-latest timeout-minutes: 15 steps: - uses: actions/checkout@v6 - uses: actions/setup-go@v6 with: go-version-file: go.mod - uses: golangci/golangci-lint-action@v9 with: version: v2.10.1 lint-rust: name: Lint Rust needs: changes if: needs.changes.outputs.rust == 'true' runs-on: ubuntu-latest timeout-minutes: 15 steps: - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@stable with: components: clippy - uses: Swatinem/rust-cache@v2 with: workspaces: crates -> target save-if: ${{ github.ref == 'refs/heads/main' }} # No mise needed - clippy comes with toolchain - name: Lint Rust (clippy) run: cargo clippy --manifest-path crates/Cargo.toml --workspace -- -D warnings lint-rust-deny: name: Lint Rust (deny) needs: changes if: needs.changes.outputs.rust == 'true' runs-on: ubuntu-latest timeout-minutes: 10 steps: - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@stable - uses: taiki-e/install-action@v2 with: tool: cargo-deny@0.19.0 # No mise needed - cargo-deny installed via taiki-e - name: Check licenses and advisories run: cargo deny --manifest-path crates/Cargo.toml check lint-python: name: Lint Python needs: [changes, build-sdk, build-rust] if: | needs.changes.outputs.python == 'true' && (needs.build-rust.result == 'success' || needs.build-rust.result == 'skipped') runs-on: ubuntu-latest-8-cores timeout-minutes: 15 steps: - name: Download SDK uses: actions/download-artifact@v8 with: name: CogPackage path: dist - name: Download coglet wheel uses: actions/download-artifact@v8 with: name: CogletRustWheel path: dist if: needs.build-rust.result == 'success' - name: Extract source distribution run: tar xf dist/*.tar.gz --strip-components=1 - uses: astral-sh/setup-uv@v6 with: python-version: ${{ env.PYTHON_VERSION }} - uses: dtolnay/rust-toolchain@stable - run: rustup default stable - uses: taiki-e/install-action@cargo-binstall - uses: jdx/mise-action@v4 with: cache: false - name: Lint Python run: mise run lint:python # ============================================================================= # Test Jobs # ============================================================================= test-go: name: "Test Go (${{ matrix.platform }})" needs: changes if: needs.changes.outputs.go == 'true' timeout-minutes: 30 strategy: fail-fast: false matrix: platform: [ubuntu-latest, macos-latest] runs-on: ${{ matrix.platform }} steps: - uses: actions/checkout@v6 - uses: actions/setup-go@v6 with: go-version-file: go.mod # gotestsum via Go module proxy (not GitHub Releases) for reliability - name: Install gotestsum run: go install gotest.tools/gotestsum@v1.13.0 - name: Test Go shell: bash run: | set -euo pipefail set -m # job control, ensures script is in its own process group cleanup() { echo "::warning::Cancelling..." kill -TERM -- -$$ 2>/dev/null || true sleep 5 kill -KILL -- -$$ 2>/dev/null || true } trap cleanup INT TERM gotestsum -- -short -timeout 1200s -parallel 5 ./... & wait $! fuzz-go: name: Fuzz Go needs: changes if: needs.changes.outputs.go == 'true' runs-on: ubuntu-latest timeout-minutes: 10 env: CGO_ENABLED: "1" steps: - uses: actions/checkout@v6 - uses: actions/setup-go@v6 with: go-version-file: go.mod - name: Fuzz schema type resolution run: go test ./pkg/schema/ -run='^$' -fuzz=FuzzResolveSchemaType -fuzztime=30s - name: Fuzz JSON schema generation run: go test ./pkg/schema/ -run='^$' -fuzz=FuzzJSONSchema -fuzztime=30s - name: Fuzz Python parser run: go test ./pkg/schema/python/ -run='^$' -fuzz=FuzzParsePredictor -fuzztime=30s - name: Fuzz type annotation parsing run: go test ./pkg/schema/python/ -run='^$' -fuzz=FuzzParseTypeAnnotation -fuzztime=30s test-rust: name: Test Rust needs: changes if: needs.changes.outputs.rust == 'true' runs-on: ubuntu-latest timeout-minutes: 15 steps: - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@stable - uses: taiki-e/install-action@v2 with: tool: cargo-nextest@0.9.120 - uses: Swatinem/rust-cache@v2 with: workspaces: crates -> target save-if: ${{ github.ref == 'refs/heads/main' }} # No mise needed - cargo-nextest installed via taiki-e - name: Test Rust run: cargo nextest run --manifest-path crates/Cargo.toml --workspace --exclude coglet-python --no-tests=pass test-python: name: "Test Python ${{ matrix.python-version }}" needs: [changes, build-sdk, build-rust] if: | needs.changes.outputs.python == 'true' && needs.build-sdk.result == 'success' && (needs.build-rust.result == 'success' || needs.build-rust.result == 'skipped') runs-on: ubuntu-latest-8-cores timeout-minutes: 30 strategy: fail-fast: false matrix: python-version: ${{ fromJSON(needs.changes.outputs.supported_pythons) }} steps: - name: Download artifacts uses: actions/download-artifact@v8 with: path: dist merge-multiple: true - name: Extract source distribution run: tar xf dist/*.tar.gz --strip-components=1 - uses: astral-sh/setup-uv@v6 with: python-version: ${{ matrix.python-version }} - uses: dtolnay/rust-toolchain@stable - run: rustup default stable - uses: taiki-e/install-action@cargo-binstall - uses: jdx/mise-action@v4 - name: Remove src to ensure tests run against wheel run: rm -rf python/cog - name: Test Python run: uvx nox -s tests -p ${{ matrix.python-version }} test-coglet-python: name: "Test coglet-python (${{ matrix.python-version }})" needs: [changes, build-rust] if: | always() && (needs.changes.outputs.rust == 'true' || needs.changes.outputs.python == 'true') && (needs.build-rust.result == 'success' || needs.build-rust.result == 'skipped') runs-on: ubuntu-latest timeout-minutes: 15 strategy: fail-fast: false matrix: python-version: ${{ fromJSON(needs.changes.outputs.supported_pythons) }} steps: - uses: actions/checkout@v6 with: fetch-depth: 0 - name: Download coglet wheel uses: actions/download-artifact@v8 with: name: CogletRustWheel path: dist if: needs.build-rust.result == 'success' - uses: dtolnay/rust-toolchain@stable - run: rustup default stable # Required for cargo-binstall to find cargo - uses: taiki-e/install-action@cargo-binstall - uses: Swatinem/rust-cache@v2 with: workspaces: crates -> target save-if: ${{ github.ref == 'refs/heads/main' }} - uses: astral-sh/setup-uv@v6 with: python-version: ${{ matrix.python-version }} - name: Test coglet-python bindings run: uvx nox -s coglet -p ${{ matrix.python-version }} # Compute integration test shards dynamically. # Slow tests (tagged with [short] skip) are distributed round-robin first, # then remaining tests fill in. This ensures slow tests don't pile up on one runner. integration-shards: name: Compute test shards needs: changes if: needs.changes.outputs.integration == 'true' runs-on: ubuntu-latest timeout-minutes: 5 outputs: shards: ${{ steps.shard.outputs.shards }} steps: - uses: actions/checkout@v6 - name: Compute shards id: shard run: | NUM_SHARDS=${{ env.NUM_IT_RUNNER_SHARDS }} # Find unconditionally skipped tests (bare "skip" without condition brackets) # These are disabled tests that shouldn't affect shard distribution SKIPPED_TESTS=$(grep -rl '^skip ' integration-tests/tests/*.txtar | \ xargs -I{} basename {} .txtar | sort || echo "") # Identify slow tests (have [short] skip marker), excluding unconditionally skipped SLOW_TESTS=$(grep -rl '\[short\] skip' integration-tests/tests/*.txtar | \ xargs -I{} basename {} .txtar | sort) if [ -n "$SKIPPED_TESTS" ]; then SLOW_TESTS=$(comm -23 <(echo "$SLOW_TESTS") <(echo "$SKIPPED_TESTS")) fi # All tests ALL_TESTS=$(ls integration-tests/tests/*.txtar | \ xargs -I{} basename {} .txtar | sort) # Fast tests = all - slow (skipped tests end up here but run instantly) FAST_TESTS=$(comm -23 <(echo "$ALL_TESTS") <(echo "$SLOW_TESTS")) # Distribute slow tests round-robin across shards declare -a SHARDS for i in $(seq 0 $((NUM_SHARDS - 1))); do SHARDS[$i]="" done idx=0 while IFS= read -r test; do [ -z "$test" ] && continue if [ -n "${SHARDS[$idx]}" ]; then SHARDS[$idx]="${SHARDS[$idx]}|${test}" else SHARDS[$idx]="$test" fi idx=$(( (idx + 1) % NUM_SHARDS )) done <<< "$SLOW_TESTS" # Distribute fast tests round-robin across shards while IFS= read -r test; do [ -z "$test" ] && continue if [ -n "${SHARDS[$idx]}" ]; then SHARDS[$idx]="${SHARDS[$idx]}|${test}" else SHARDS[$idx]="$test" fi idx=$(( (idx + 1) % NUM_SHARDS )) done <<< "$FAST_TESTS" # Build JSON array of shard objects JSON="[" for i in $(seq 0 $((NUM_SHARDS - 1))); do PATTERN="${SHARDS[$i]}" COUNT=$(echo "$PATTERN" | tr '|' '\n' | wc -l | tr -d ' ') [ $i -gt 0 ] && JSON="${JSON}," JSON="${JSON}{\"index\":$i,\"pattern\":\"${PATTERN}\",\"count\":$COUNT}" done JSON="${JSON}]" echo "shards=$JSON" >> "$GITHUB_OUTPUT" # Debug output echo "Shard distribution:" for i in $(seq 0 $((NUM_SHARDS - 1))); do COUNT=$(echo "${SHARDS[$i]}" | tr '|' '\n' | wc -l | tr -d ' ') SLOW_COUNT=$(echo "${SHARDS[$i]}" | tr '|' '\n' | while read t; do echo "$SLOW_TESTS" | grep -q "^${t}$" && echo "$t" done | wc -l | tr -d ' ') echo " Shard $i: $COUNT tests ($SLOW_COUNT slow)" done test-integration: name: "Test integration (shard ${{ matrix.shard.index }})" needs: [changes, build-cog, build-sdk, build-rust, integration-shards] if: | !cancelled() && needs.changes.outputs.integration == 'true' && needs.integration-shards.result == 'success' && (needs.build-cog.result == 'success' || needs.build-cog.result == 'skipped') && (needs.build-sdk.result == 'success' || needs.build-sdk.result == 'skipped') && (needs.build-rust.result == 'success' || needs.build-rust.result == 'skipped') runs-on: ubuntu-latest-16-cores timeout-minutes: 30 strategy: fail-fast: false matrix: shard: ${{ fromJSON(needs.integration-shards.outputs.shards) }} steps: - uses: actions/checkout@v6 with: fetch-depth: 0 - name: Login to Docker Hub uses: docker/login-action@v4 if: github.event.pull_request.head.repo.full_name == github.repository || github.event_name != 'pull_request' with: registry: index.docker.io username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Download artifacts uses: actions/download-artifact@v8 with: path: dist merge-multiple: true - name: Install cog binary run: | cp dist/cog ./cog chmod +x ./cog - uses: actions/setup-go@v6 with: go-version-file: go.mod cache-dependency-path: go.sum # gotestsum via Go module proxy (not GitHub Releases) for reliability - name: Install gotestsum run: go install gotest.tools/gotestsum@v1.13.0 - name: Set wheel environment run: | # Use locally-built wheels, not PyPI (version may not be published yet) # Must use absolute paths — cog subprocess runs from txtar workdir, not checkout root echo "COG_SDK_WHEEL=${{ github.workspace }}/dist" >> $GITHUB_ENV echo "COGLET_WHEEL=${{ github.workspace }}/dist" >> $GITHUB_ENV - name: Run integration tests (shard ${{ matrix.shard.index }}, ${{ matrix.shard.count }} tests) env: COG_BINARY: ./cog TEST_PARALLEL: 4 BUILDKIT_PROGRESS: 'quiet' shell: bash run: | set -euo pipefail set -m # job control, ensures script is in its own process group cleanup() { echo "::warning::Cancelling..." kill -TERM -- -$$ 2>/dev/null || true sleep 5 kill -KILL -- -$$ 2>/dev/null || true } trap cleanup INT TERM # Build -run regex from shard pattern # Pattern is "test1|test2|test3" - wrap each in TestIntegration// RUN_PATTERN="${{ matrix.shard.pattern }}" echo "Running tests matching: $RUN_PATTERN" gotestsum --format github-actions -- \ -tags integration \ -parallel $TEST_PARALLEL \ -timeout 30m \ -run "TestIntegration/($RUN_PATTERN)/" \ ./integration-tests/... & wait $! # ============================================================================= # Gate Job - Single required check for branch protection # ============================================================================= ci-complete: name: CI Complete needs: - changes - version-check - build-cog - build-sdk - build-rust - fmt-go - fmt-rust - fmt-python - check-llm-docs - lint-go - lint-rust - lint-rust-deny - lint-python - test-go - test-rust - test-python - test-coglet-python - integration-shards - test-integration if: always() runs-on: ubuntu-latest timeout-minutes: 5 steps: - name: Check job results run: | echo "Job results:" echo " changes: ${{ needs.changes.result }}" echo " build-sdk: ${{ needs.build-sdk.result }}" echo " build-rust: ${{ needs.build-rust.result }}" echo " fmt-go: ${{ needs.fmt-go.result }}" echo " fmt-rust: ${{ needs.fmt-rust.result }}" echo " fmt-python: ${{ needs.fmt-python.result }}" echo " check-llm-docs: ${{ needs.check-llm-docs.result }}" echo " lint-go: ${{ needs.lint-go.result }}" echo " lint-rust: ${{ needs.lint-rust.result }}" echo " lint-rust-deny: ${{ needs.lint-rust-deny.result }}" echo " lint-python: ${{ needs.lint-python.result }}" echo " test-go: ${{ needs.test-go.result }}" echo " test-rust: ${{ needs.test-rust.result }}" echo " test-python: ${{ needs.test-python.result }}" echo " test-coglet-python: ${{ needs.test-coglet-python.result }}" echo " integration-shards: ${{ needs.integration-shards.result }}" echo " test-integration: ${{ needs.test-integration.result }}" # Fail if any job failed (skipped is OK) FAILED=false for result in \ "${{ needs.changes.result }}" \ "${{ needs.build-sdk.result }}" \ "${{ needs.build-rust.result }}" \ "${{ needs.fmt-go.result }}" \ "${{ needs.fmt-rust.result }}" \ "${{ needs.fmt-python.result }}" \ "${{ needs.check-llm-docs.result }}" \ "${{ needs.lint-go.result }}" \ "${{ needs.lint-rust.result }}" \ "${{ needs.lint-rust-deny.result }}" \ "${{ needs.lint-python.result }}" \ "${{ needs.test-go.result }}" \ "${{ needs.test-rust.result }}" \ "${{ needs.test-python.result }}" \ "${{ needs.test-coglet-python.result }}" \ "${{ needs.integration-shards.result }}" \ "${{ needs.test-integration.result }}" do if [ "$result" = "failure" ] || [ "$result" = "cancelled" ]; then FAILED=true fi done if [ "$FAILED" = "true" ]; then echo "::error::Some jobs failed or were cancelled" exit 1 fi echo "All CI checks passed!" # ============================================================================= # Release Validation - Dry-run checks (PRs and main) # ============================================================================= release-dry-run: name: Release Dry Run needs: ci-complete if: "!startsWith(github.ref, 'refs/tags/')" runs-on: ubuntu-latest timeout-minutes: 15 steps: - uses: actions/checkout@v6 with: fetch-depth: 0 - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 with: workspaces: crates -> target save-if: ${{ github.ref == 'refs/heads/main' }} - name: Check coglet crates.io publish run: cargo publish --dry-run -p coglet --manifest-path crates/Cargo.toml - uses: mlugg/setup-zig@v2 with: version: 0.15.2 - uses: goreleaser/goreleaser-action@v7 with: version: '~> v2' args: check ================================================ FILE: .github/workflows/codeql.yml ================================================ # For most projects, this workflow file will not need changing; you simply need # to commit it to your repository. # # You may wish to alter this file to override the set of languages analyzed, # or to provide custom queries or build logic. # # ******** NOTE ******** # We have attempted to detect the languages in your repository. Please check # the `language` matrix defined below to confirm you have the correct set of # supported CodeQL languages. # name: "CodeQL" on: push: branches: [ "main" ] pull_request: # The branches below must be a subset of the branches above branches: [ "main" ] schedule: - cron: '37 18 * * 5' jobs: analyze: name: Analyze runs-on: ubuntu-latest permissions: actions: read contents: read security-events: write strategy: fail-fast: false matrix: # CodeQL supports: cpp, csharp, go, java, javascript, python, ruby, rust # https://aka.ms/codeql-docs/language-support language: ['go', 'python', 'rust'] steps: - name: Checkout repository uses: actions/checkout@v6 # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL uses: github/codeql-action/init@v4 with: languages: ${{ matrix.language }} # If you wish to specify custom queries, you can do so here or in a config file. # By default, queries listed here will override any specified in a config file. # Prefix the list here with "+" to use these queries and those in the config file. # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs # queries: security-extended,security-and-quality # Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java). # If this step fails, then you should remove it and run the build manually (see below) - name: Autobuild uses: github/codeql-action/autobuild@v4 # ℹ️ Command-line programs to run using the OS shell. # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun # If the Autobuild fails above, remove it and uncomment the following three lines. # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. # - run: | # echo "Run, Build Application using script" # ./location_of_script_within_repo/buildscript.sh - name: Perform CodeQL Analysis uses: github/codeql-action/analyze@v4 with: category: "/language:${{matrix.language}}" ================================================ FILE: .github/workflows/docs.yaml ================================================ name: Deploy docs on: push: branches: - main jobs: deploy: runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - uses: actions/setup-go@v5 with: go-version: '1.23' - uses: actions/setup-python@v6 with: python-version: '3.13' - name: Generate CLI docs run: go run ./tools/gendocs/main.go -o docs/cli.md - name: Copy top-level docs like README and CONTRIBUTING run: | sed 's/docs\///g' README.md > ./docs/README.md cp CONTRIBUTING.md ./docs/ - name: Deploy run: | pip install mkdocs-material mkdocs gh-deploy --force ================================================ FILE: .github/workflows/release-build.yaml ================================================ --- name: Release Build # Triggered on version tags to build release artifacts and create a GitHub release. # # THREE RELEASE TYPES: # # 1. Stable (v0.17.0) - must be on main # - Creates DRAFT release → maintainer publishes → release-publish.yaml → PyPI/crates.io # # 2. Pre-release (v0.17.0-alpha3, v0.17.0-rc1) - must be on main # - Same flow as stable, but marked as pre-release # # 3. Dev (v0.17.0-dev1) - can be tagged from ANY branch # - Creates a published pre-release immediately (no draft, no human approval) # - Does NOT publish to PyPI or crates.io # - Artifacts (CLI binaries, wheels) attached to the GH release # # SECURITY: # - Stable/pre-release tags verified on main branch # - Dev releases have no branch restriction but no registry publishing on: push: tags: ["v[0-9]+.[0-9]+.[0-9]+*"] permissions: contents: write env: CARGO_TERM_COLOR: always jobs: verify-tag: name: Verify tag and version runs-on: ubuntu-latest timeout-minutes: 5 outputs: is_dev: ${{ steps.check.outputs.is_dev }} version: ${{ steps.check.outputs.version }} pep440: ${{ steps.check.outputs.pep440 }} steps: - uses: actions/checkout@v6 with: fetch-depth: 0 - name: Verify tag and version id: check run: | TAG="${{ github.ref_name }}" VERSION="${TAG#v}" TAG_COMMIT="${{ github.sha }}" # Get version from Cargo.toml CARGO_VERSION=$(grep '^version = ' crates/Cargo.toml | head -1 | sed 's/version = "\(.*\)"/\1/') echo "Tag: $TAG" echo "Tag version: $VERSION" echo "Cargo.toml version: $CARGO_VERSION" echo "Commit: $TAG_COMMIT" echo "" # Check Cargo.toml matches tag if [[ "$CARGO_VERSION" != "$VERSION" ]]; then echo "::error::Version mismatch! crates/Cargo.toml has version $CARGO_VERSION but tag is $TAG" echo "::error::" echo "::error::To fix: Update crates/Cargo.toml to match, merge to main, then delete this tag and re-tag." exit 1 fi echo "✓ Cargo.toml version matches tag" # Determine release type IS_DEV="false" if [[ "$VERSION" == *-dev* ]]; then IS_DEV="true" fi echo "is_dev=$IS_DEV" >> "$GITHUB_OUTPUT" echo "version=$VERSION" >> "$GITHUB_OUTPUT" # Compute PEP 440 version: v0.17.0-alpha1 -> 0.17.0a1, v0.17.0-dev1 -> 0.17.0.dev1 PEP440=$(echo "$VERSION" | sed -E 's/-alpha\.?/a/; s/-beta\.?/b/; s/-rc\.?/rc/; s/-dev\.?/.dev/') echo "pep440=$PEP440" >> "$GITHUB_OUTPUT" echo "PEP 440 version: $PEP440" # Branch rules if [[ "$IS_DEV" == "true" ]]; then echo "Dev release - no branch restriction" echo "✓ Dev release, skipping branch check" else # Stable and pre-release tags must be on main echo "Stable/pre-release detected, verifying main branch..." git fetch origin main if ! git merge-base --is-ancestor "$TAG_COMMIT" origin/main; then echo "::error::Release tags must be on the main branch" echo "::error::Tag commit $TAG_COMMIT is not reachable from origin/main" echo "::error::" echo "::error::To fix: Merge to main first, then tag" exit 1 fi echo "✓ Tag is on main branch" fi build-sdk: name: Build SDK wheel needs: verify-tag runs-on: ubuntu-latest timeout-minutes: 10 steps: - uses: actions/checkout@v6 with: fetch-depth: 0 - uses: astral-sh/setup-uv@v6 with: python-version: "3.13" - name: Update coglet version constraint run: | VERSION="${{ needs.verify-tag.outputs.pep440 }}" echo "Setting coglet constraint to >=$VERSION,<1.0" # Update pyproject.toml with lockstep version constraint sed -i "s/coglet>=0\.1\.0,<1\.0/coglet>=$VERSION,<1.0/" pyproject.toml # Verify the change took effect grep "coglet>=$VERSION" pyproject.toml - name: Build SDK wheel run: | echo "Building SDK with version: $SETUPTOOLS_SCM_PRETEND_VERSION" uv build --out-dir dist . env: SETUPTOOLS_SCM_PRETEND_VERSION: ${{ needs.verify-tag.outputs.pep440 }} - name: Upload SDK artifacts uses: actions/upload-artifact@v6 with: name: sdk-dist path: dist/* build-coglet-wheels: name: Build coglet wheel (${{ matrix.target }}) needs: verify-tag runs-on: ${{ matrix.os }} timeout-minutes: 20 strategy: fail-fast: false matrix: include: - os: ubuntu-latest target: x86_64-unknown-linux-gnu artifact-suffix: linux-x64 manylinux: auto zig: true - os: ubuntu-latest target: aarch64-unknown-linux-gnu artifact-suffix: linux-arm64 manylinux: auto zig: true - os: macos-14 target: aarch64-apple-darwin artifact-suffix: macos-arm64 manylinux: "off" zig: false steps: - uses: actions/checkout@v6 with: fetch-depth: 0 - uses: dtolnay/rust-toolchain@stable with: targets: ${{ matrix.target }} - uses: Swatinem/rust-cache@v2 with: workspaces: crates -> target key: release-${{ matrix.target }} - uses: astral-sh/setup-uv@v6 with: python-version: "3.13" enable-cache: false - name: Build coglet wheel uses: PyO3/maturin-action@v1 with: target: ${{ matrix.target }} manylinux: ${{ matrix.manylinux }} args: --release --out dist -m crates/coglet-python/Cargo.toml ${{ matrix.zig && '--zig' || '' }} - name: Upload coglet wheel uses: actions/upload-artifact@v6 with: name: coglet-wheel-${{ matrix.artifact-suffix }} path: dist/*.whl create-release: name: Create release needs: [verify-tag, build-sdk, build-coglet-wheels] # macOS arm64 runner: native clang for darwin targets, zig for linux targets. # CGo required for go-tree-sitter (static Python schema parser). runs-on: macos-14 timeout-minutes: 30 env: CGO_ENABLED: "1" steps: - uses: actions/checkout@v6 with: fetch-depth: 0 - uses: actions/setup-go@v6 with: go-version-file: go.mod - uses: mlugg/setup-zig@v2 with: version: 0.15.2 - name: Check for existing release run: | TAG="${{ github.ref_name }}" EXISTING=$(gh release view "$TAG" --json isDraft,isPrerelease --jq '.' 2>/dev/null || echo "") if [ -n "$EXISTING" ]; then echo "::error::Release for $TAG already exists. Delete it before re-running." echo "::error::Run: gh release delete $TAG --yes" exit 1 fi env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # Goreleaser builds CLI binaries and creates draft release - name: Build CLI and create draft release uses: goreleaser/goreleaser-action@v7 with: version: '~> v2' args: release --clean env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # Append Python wheels to the release - name: Download all wheel artifacts uses: actions/download-artifact@v8 with: path: artifacts - name: Upload wheels to release run: | TAG="${{ github.ref_name }}" # Collect all wheels mkdir -p release-wheels cp artifacts/sdk-dist/* release-wheels/ cp artifacts/coglet-wheel-*/*.whl release-wheels/ # Download goreleaser's checksums.txt and append wheel checksums gh release download "$TAG" -p checksums.txt -D release-wheels cd release-wheels shasum -a 256 *.whl *.tar.gz >> checksums.txt echo "Checksums:" cat checksums.txt cd .. echo "Uploading wheels and updated checksums..." gh release upload "$TAG" release-wheels/* --clobber echo "Release assets:" gh release view "$TAG" --json assets --jq '.assets[].name' env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For dev releases: immediately publish as pre-release (no draft review) # For stable/pre-release: leave as draft for maintainer review - name: Publish dev release if: needs.verify-tag.outputs.is_dev == 'true' run: | TAG="${{ github.ref_name }}" echo "Publishing dev release $TAG as pre-release (no draft phase)..." gh release edit "$TAG" --draft=false --prerelease echo "✓ Dev release $TAG published as pre-release" env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} ================================================ FILE: .github/workflows/release-publish.yaml ================================================ --- name: Release Publish # Publishes packages to PyPI and crates.io when a release is published. # # For stable releases: publishes to PyPI, crates.io, and updates Homebrew tap. # For pre-releases: publishes to PyPI and crates.io only (no Homebrew tap). # For dev releases: ALL jobs are skipped. (Dev releases do trigger this workflow # when release-build.yaml publishes them, but every job gates on is_dev.) # # PUBLISH ORDER: # 1. coglet -> PyPI (must be first, SDK depends on it) # 2. coglet -> crates.io (parallel with coglet PyPI) # 3. SDK -> PyPI (after coglet is on PyPI) # 4. Homebrew cask (stable only, after all publishing completes) # # REQUIRED GITHUB CONFIGURATION: # 1. Create environments in Settings -> Environments: # - "pypi": For PyPI publishing (Trusted Publisher) # - "crates-io": For crates.io publishing # - "homebrew": For Homebrew tap updates # # 2. Configure environment protection rules: # - Deployment branches: "Selected branches and tags" # - Add pattern: v* (to restrict to version tags only) # # 3. crates-io uses Trusted Publishing (OIDC via rust-lang/crates-io-auth-action) # # 4. Homebrew tap uses the cog-homebrew-tapbot GitHub App (ID: 1232932405) # - Secret: COG_HOMEBREW_TAP_PRIVATE_KEY (app's private key) on: release: types: [published] permissions: contents: read id-token: write env: CARGO_TERM_COLOR: always jobs: verify-release: name: Verify release tag runs-on: ubuntu-latest timeout-minutes: 5 outputs: is_dev: ${{ steps.check.outputs.is_dev }} is_prerelease: ${{ steps.check.outputs.is_prerelease }} version: ${{ steps.check.outputs.version }} steps: - name: Verify valid release tag id: check run: | TAG="${{ github.event.release.tag_name }}" # Release must be from a valid version tag if [[ ! "$TAG" =~ ^v[0-9]+\.[0-9]+\.[0-9]+ ]]; then echo "::error::Invalid tag format: $TAG" echo "::error::Tags must match pattern v*.*.* (e.g., v1.0.0)" exit 1 fi echo "✓ Valid release tag: $TAG" VERSION="${TAG#v}" echo "version=$VERSION" >> "$GITHUB_OUTPUT" # Classify release type IS_DEV="false" IS_PRERELEASE="false" if [[ "$VERSION" == *-dev* ]]; then IS_DEV="true" echo "Dev release detected - skipping all publishing" elif [[ "$VERSION" == *-* ]]; then IS_PRERELEASE="true" echo "Pre-release detected - publishing to PyPI/crates.io (no Homebrew tap)" else echo "Stable release detected - full publishing including Homebrew tap" fi echo "is_dev=$IS_DEV" >> "$GITHUB_OUTPUT" echo "is_prerelease=$IS_PRERELEASE" >> "$GITHUB_OUTPUT" publish-pypi-coglet: name: Publish coglet to PyPI needs: verify-release if: needs.verify-release.outputs.is_dev != 'true' runs-on: ubuntu-latest environment: pypi timeout-minutes: 10 steps: - name: Download coglet wheels from release run: | mkdir -p dist gh release download "$TAG" -p "coglet-*.whl" -D dist -R "${{ github.repository }}" env: TAG: ${{ github.event.release.tag_name }} GH_TOKEN: ${{ github.token }} - name: Publish to PyPI uses: pypa/gh-action-pypi-publish@release/v1 publish-crates-io: name: Publish coglet to crates.io needs: verify-release if: needs.verify-release.outputs.is_dev != 'true' runs-on: ubuntu-latest environment: crates-io timeout-minutes: 15 permissions: contents: read id-token: write steps: - uses: actions/checkout@v6 with: ref: ${{ github.event.release.tag_name }} - uses: dtolnay/rust-toolchain@stable - uses: rust-lang/crates-io-auth-action@v1 id: auth - name: Publish to crates.io run: cargo publish -p coglet --manifest-path crates/Cargo.toml env: CARGO_REGISTRY_TOKEN: ${{ steps.auth.outputs.token }} publish-pypi-sdk: name: Publish SDK to PyPI needs: [verify-release, publish-pypi-coglet] if: needs.verify-release.outputs.is_dev != 'true' runs-on: ubuntu-latest environment: pypi timeout-minutes: 10 steps: - name: Download SDK artifacts from release run: | mkdir -p dist gh release download "$TAG" -p "cog-*.whl" -D dist -R "${{ github.repository }}" gh release download "$TAG" -p "cog-*.tar.gz" -D dist -R "${{ github.repository }}" env: TAG: ${{ github.event.release.tag_name }} GH_TOKEN: ${{ github.token }} - name: Publish to PyPI uses: pypa/gh-action-pypi-publish@release/v1 update-homebrew-tap: name: Update Homebrew cask needs: [verify-release, publish-pypi-sdk, publish-crates-io] # Stable releases only — no dev, no pre-release if: >- needs.verify-release.outputs.is_dev != 'true' && needs.verify-release.outputs.is_prerelease != 'true' runs-on: ubuntu-latest environment: homebrew timeout-minutes: 10 steps: - name: Generate GitHub App token id: app-token uses: actions/create-github-app-token@v2 with: app-id: 1232932405 private-key: ${{ secrets.COG_HOMEBREW_TAP_PRIVATE_KEY }} owner: replicate repositories: homebrew-tap - name: Download checksums from release run: gh release download "$TAG" -p checksums.txt -R "${{ github.repository }}" env: TAG: ${{ github.event.release.tag_name }} GH_TOKEN: ${{ github.token }} - name: Generate and push cask env: GH_TOKEN: ${{ steps.app-token.outputs.token }} TAG: ${{ github.event.release.tag_name }} VERSION: ${{ needs.verify-release.outputs.version }} run: | # Extract SHA256s for Darwin binaries from checksums.txt SHA_X86=$(grep 'cog_Darwin_x86_64' checksums.txt | awk '{print $1}') SHA_ARM=$(grep 'cog_Darwin_arm64' checksums.txt | awk '{print $1}') if [ -z "$SHA_X86" ] || [ -z "$SHA_ARM" ]; then echo "::error::Missing Darwin binary checksums in checksums.txt" echo "Darwin x86_64: ${SHA_X86:-MISSING}" echo "Darwin arm64: ${SHA_ARM:-MISSING}" cat checksums.txt exit 1 fi echo "Checksums:" echo " Darwin x86_64: $SHA_X86" echo " Darwin arm64: $SHA_ARM" BASE_URL="https://github.com/replicate/cog/releases/download/${TAG}" # Generate cask file (no indentation to avoid stripping issues) cat > cog.rb <- {{ .Binary }}_ {{- title .Os }}_ {{- if eq .Arch "amd64" }}x86_64 {{- else if eq .Arch "386" }}i386 {{- else }}{{ .Arch }}{{ end -}} checksum: name_template: "checksums.txt" snapshot: version_template: '{{ envOrDefault "COG_VERSION" (printf "%s-dev+g%s" (incpatch .Version) .ShortCommit) }}' changelog: sort: asc filters: exclude: - "^docs:" - "^test:" release: draft: true # If set to auto, will mark the release as not ready for production # in case there is an indicator for this in the tag e.g. v1.0.0-alpha # If set to true, will mark the release as not ready for production. # Default is false. prerelease: auto ================================================ FILE: .mockery.yml ================================================ all: false dir: '{{.InterfaceDir}}' filename: mocks_test.go force-file-write: true formatter: goimports log-level: info structname: '{{.Mock}}{{.InterfaceName}}' pkgname: '{{.SrcPackageName}}' recursive: false require-template-schema-exists: true template: testify template-schema: '{{.Template}}.schema.json' packages: github.com/replicate/cog/pkg/docker/command: config: all: true dir: "pkg/docker/dockertest" filename: "command_mocks.go" pkgname: "dockertest" structname: "{{.Mock}}{{.InterfaceName}}2" ================================================ FILE: .vscode/extensions.json ================================================ { "recommendations": [ "charliermarsh.ruff", "golang.go", "ms-python.python", "ms-python.vscode-pylance" ] } ================================================ FILE: .vscode/settings.json ================================================ { "editor.formatOnSave": true, "editor.formatOnType": true, "editor.formatOnPaste": true, "editor.renderControlCharacters": true, "editor.suggest.localityBonus": true, "files.insertFinalNewline": true, "files.trimFinalNewlines": true, "[go]": { "editor.defaultFormatter": "golang.go" }, "go.coverOnTestPackage": false, "go.lintTool": "golangci-lint", "go.formatTool": "goimports", "go.testOnSave": true, "gopls": { "formatting.local": "github.com/replicate/cog" }, "[json]": { "editor.defaultFormatter": "vscode.json-language-features" }, "[jsonc]": { "editor.defaultFormatter": "vscode.json-language-features" }, "[python]": { "editor.formatOnSave": true, "editor.codeActionsOnSave": { "source.fixAll": "explicit", "source.organizeImports": "explicit" }, "editor.defaultFormatter": "charliermarsh.ruff" }, "python.languageServer": "Pylance", "python.testing.pytestArgs": [ "-vvv", "python" ], "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true, } ================================================ FILE: AGENTS.md ================================================ # AGENTS.md This file provides guidance to coding agents when working with code in this repository. ## Project Overview Cog is a tool that packages machine learning models in production-ready containers. It consists of: - **Cog CLI** (`cmd/cog/`) - Command-line interface for building, running, and deploying models, written in Go - **Python SDK** (`python/cog/`) - Python library for defining model predictors and training in Python - **Coglet** (`crates/`) - Rust-based prediction server that runs inside containers, with Python bindings via PyO3 Documentation for the CLI and SDK is available by reading ./docs/llms.txt. ## Development Commands Development tasks are managed with [mise](https://mise.jdx.dev/). Run `mise tasks` to see all available tasks. ### Quick Reference | Task | Description | |------|-------------| | `mise run fmt` | Check formatting (all languages) | | `mise run fmt:fix` | Fix formatting (all languages) | | `mise run lint` | Run linters (all languages) | | `mise run lint:fix` | Fix lint issues (all languages) | | `mise run test:go` | Run Go tests | | `mise run test:rust` | Run Rust tests | | `mise run test:python` | Run Python tests | | `mise run test:integration` | Run integration tests | | `mise run build:cog` | Build cog CLI binary | | `mise run build:coglet` | Build coglet wheel (dev) | | `mise run build:sdk` | Build SDK wheel | | `mise run install` | Build and symlink cog to /usr/local/bin | | `mise run docs:llm` | **IMPORTANT:** Regenerate `docs/llms.txt` after editing docs | | `mise run docs:cli` | Generate CLI reference docs from Go source code | ### Task Naming Convention Tasks follow a consistent naming pattern: - **Language-based tasks** for fmt/lint/test/typecheck: `task:go`, `task:rust`, `task:python` - **Component-based tasks** for build: `build:cog`, `build:coglet`, `build:sdk` - **Check vs Fix**: `fmt` and `lint` default to check mode (non-destructive); use `:fix` suffix to auto-fix ### All Tasks by Category **Format:** - `mise run fmt` / `mise run fmt:check` - Check all (alias) - `mise run fmt:fix` - Fix all - `mise run fmt:go` / `mise run fmt:rust` / `mise run fmt:python` - Per-language **Lint:** - `mise run lint` / `mise run lint:check` - Check all (alias) - `mise run lint:fix` - Fix all - `mise run lint:go` / `mise run lint:rust` / `mise run lint:python` - Per-language - `mise run lint:rust:deny` - Check Rust licenses/advisories **Test:** - `mise run test:go` - Go unit tests - `mise run test:rust` - Rust unit tests - `mise run test:python` - Python unit tests (via tox) - `mise run test:coglet:python` - Coglet Python binding tests - `mise run test:integration` - Integration tests **Build:** - `mise run build:cog` - Build cog CLI (development) - `mise run build:cog:release` - Build cog CLI (release) - `mise run build:coglet` - Build coglet wheel (dev install) - `mise run build:coglet:wheel` - Build coglet wheel (native platform) - `mise run build:coglet:wheel:linux-x64` - Build for Linux x86_64 - `mise run build:coglet:wheel:linux-arm64` - Build for Linux ARM64 - `mise run build:sdk` - Build SDK wheel **Install:** - `mise run install` - Symlink cog CLI to `/usr/local/bin` (requires `build:cog` first) - `PREFIX=/custom/path mise run install` - Symlink to custom location **Other:** - `mise run typecheck` - Type check all languages - `mise run generate` - Run code generation - `mise run clean` - Clean all build artifacts - `mise run docs` - Build documentation - `mise run docs:serve` - Serve docs locally ## Code Style Guidelines ### Go - **Imports**: Organize in three groups separated by blank lines: (1) Standard library, (2) Third-party packages, (3) Internal packages (`github.com/replicate/cog/pkg/...`) - **Formatting**: Use `mise run fmt:go:fix` - **Linting**: Must pass golangci-lint with: errcheck, gocritic, gosec, govet, ineffassign, misspell, revive, staticcheck, unused - **Error Handling**: Return errors as values; use `pkg/errors.CodedError` for user-facing errors with error codes - **Naming**: CamelCase for exported, camelCase for unexported - **Testing**: Use `testify/require` for assertions; prefer table-driven tests Example import block: ```go import ( "fmt" "github.com/spf13/cobra" "github.com/replicate/cog/pkg/config" ) ``` ### Python - **Imports**: Automatically organized by ruff/isort (stdlib → third-party → local) - **Formatting**: Use `mise run fmt:python:fix` - **Linting**: Must pass ruff checks: E (pycodestyle), F (Pyflakes), I (isort), W (warnings), S (bandit), B (bugbear), ANN (annotations) - **Type Annotations**: Required on all function signatures; use `typing_extensions` for compatibility; avoid `Any` where possible - **Error Handling**: Raise exceptions with descriptive messages; avoid generic exception catching - **Naming**: snake_case for functions/variables/modules, PascalCase for classes - **Testing**: Use pytest with fixtures; async tests with pytest-asyncio - **Compatibility**: Must support Python 3.10-3.13 ### Rust - **Formatting**: Use `mise run fmt:rust:fix` - **Linting**: Must pass `mise run lint:rust` (clippy) - **Dependencies**: Audited with `cargo-deny` (see `crates/deny.toml`); run `mise run lint:rust:deny` - **Error Handling**: Use `thiserror` for typed errors, `anyhow` for application errors - **Naming**: snake_case for functions/variables, PascalCase for types - **Testing**: Use `cargo test`; snapshot tests use `insta` - **Async**: tokio runtime; async/await patterns ## Working on the CLI and support tooling The CLI code is in the `cmd/cog/` and `pkg/` directories. Support tooling is in the `tools/` directory. The main commands for working on the CLI are: - `go run ./cmd/cog` - Runs the Cog CLI directly from source (requires wheel to be built first) - `mise run build:cog` - Builds the Cog CLI binary - `mise run install` - Symlinks the built binary to `/usr/local/bin` (run `build:cog` first), or to a custom path with `PREFIX=/custom/path mise run install` - `mise run test:go` - Runs all Go unit tests - `go test ./pkg/...` - Runs tests directly with `go test` ## Working on the Python SDK The Python SDK is developed in the `python/cog/` directory. It uses `uv` for virtual environments and `tox` for testing across multiple Python versions. The main commands for working on the SDK are: - `mise run build:sdk` - Builds the Python wheel - `mise run test:python` - Runs Python tests across all supported versions ## Working on Coglet (Rust) Coglet is the Rust-based prediction server that runs inside Cog containers, handling HTTP requests, worker process management, and prediction execution. The code is in the `crates/` directory: - `crates/coglet/` - Core Rust library (HTTP server, worker orchestration, IPC) - `crates/coglet-python/` - PyO3 bindings for Python predictor integration (requires Python 3.10+) For detailed architecture documentation, see `crates/README.md` and `crates/coglet/README.md`. The main commands for working on Coglet are: - `mise run build:coglet` - Build and install coglet wheel for development (macOS, for local Rust/Python tests) - `mise run build:coglet:wheel:linux-x64` - Build Linux x86_64 wheel (required to test Rust changes in Docker containers via `cog predict`/`cog train`) - `mise run test:rust` - Run Rust unit tests - `mise run lint:rust` - Run clippy linter - `mise run fmt:rust:fix` - Format code ### Testing Go code is tested using the built-in `go test` framework: - `go test ./pkg/... -run ` - Runs specific Go tests by name - `mise run test:go` - Runs all Go unit tests Python code is tested using `tox`, which allows testing across multiple Python versions and configurations: - `mise run test:python` - Runs all Python unit tests - `uv run tox -e py312-tests -- python/tests/server/test_http.py::test_openapi_specification_with_yield` - Runs a specific Python test The integration test suite in `integration-tests/` tests the end-to-end functionality of the Cog CLI and Python SDK using Go's testscript framework: - `mise run test:integration` - Runs the integration tests - `mise run test:integration string_predictor` - Runs a specific integration test The integration tests require a built Cog binary, which defaults to the first `cog` in `PATH`. Run tests against a specific binary with the `COG_BINARY` environment variable: ```bash mise run build:cog COG_BINARY=dist/go/*/cog mise run test:integration ``` ### Development Workflow 1. Run `mise install` to set up the development environment 2. Run `mise run build:sdk` after making changes to the `./python` directory 3. Run `mise run build:coglet:wheel:linux-x64` after making changes to the `./crates` directory (needed for Docker testing) 4. Run `mise run build:cog` to build the CLI (wheels are picked up from `dist/` at Docker build time, not embedded in the binary) 5. Run `mise run fmt:fix` to format code 6. Run `mise run lint` to check code quality 7. Run `mise run docs:llm` to regenerate `docs/llms.txt` after changing `README.md` or any `docs/*.md` file 8. Read the `./docs` directory and make sure the documentation is up to date **IMPORTANT:** Always run `mise run lint` (or the language-specific variant, e.g. `mise run lint:go`) before committing to catch linter errors early. CI will reject PRs that fail lint checks. ## Architecture ### CLI Architecture (Go) The CLI follows a command pattern with subcommands. The main components are: - `pkg/cli/` - Command definitions (build, run, predict, serve, etc.) - `pkg/docker/` - Docker client and container management - `pkg/dockerfile/` - Dockerfile generation and templating - `pkg/config/` - cog.yaml parsing and validation - `pkg/image/` - Image building and pushing logic ### Python SDK Architecture - `python/cog/` - Core SDK - `base_predictor.py` - Base class for model predictors - `types.py` - Input/output type definitions - `server/` - HTTP/queue server implementation - `command/` - Runner implementations for predict/train ### Coglet Architecture (Rust) The prediction server that runs inside Cog containers. Uses a two-process architecture: a parent process (HTTP server + orchestrator) and a worker subprocess (Python predictor execution). See `crates/README.md` for detailed architecture documentation. - `crates/coglet/` - Core Rust library (HTTP server, worker orchestration, IPC bridge) - `crates/coglet-python/` - PyO3 bindings for Python predictor integration ### Key Design Patterns 1. **Local Wheel Resolution**: The CLI discovers SDK and coglet wheels from `dist/` at Docker build time (not embedded in the binary) 2. **Docker SDK Integration**: Uses Docker Go SDK for container operations 3. **Type Safety**: Dataclasses for Python type validation, strongly typed Go interfaces 4. **Compatibility Matrix**: Automated CUDA/PyTorch/TensorFlow compatibility management For comprehensive architecture documentation, see [`architecture/`](./architecture/00-overview.md). ## Common Tasks ### Adding a new CLI command 1. Create command file in `pkg/cli/` 2. Add command to `pkg/cli/root.go` 3. Implement business logic in appropriate `pkg/` subdirectory 4. Add tests ### Modifying Python SDK behavior 1. Edit files in `python/cog/` 2. Run `mise run build:sdk` to rebuild wheel 3. Test with `mise run test:python` 4. Integration test with `mise run test:integration` ### Updating ML framework compatibility 1. See `tools/compatgen/` for compatibility matrix generation 2. Update framework versions in relevant Dockerfile templates 3. Test with various framework combinations ### Updating the docs - Documentation is in the `docs/` directory, written in Markdown and generated into HTML using `mkdocs`. - **IMPORTANT:** After editing any file in `docs/` or `README.md`, you MUST run `mise run docs:llm` to regenerate `docs/llms.txt`. This file is used by coding agents and should be kept in sync with the documentation. - **IMPORTANT:** CLI reference docs (`docs/cli.md`) are auto-generated from Go source code. After modifying CLI commands in `cmd/` or `pkg/cli/`, run `mise run docs:cli` to regenerate, and ensure `mise run docs:cli:check` passes before committing. ## CI Tool Dependencies Development tools are managed in **two places** that must be kept in sync: 1. **`mise.toml`** — Tool versions for local development (uses aqua backend for prebuilt binaries) 2. **`.github/workflows/ci.yaml`** — Tool installation for CI (uses dedicated GitHub Actions) CI deliberately avoids aqua downloads from GitHub Releases to prevent transient 502 failures. Instead, it uses: | Tool | CI installation method | Why | |------|----------------------|-----| | gotestsum | `go install` | Uses Go module proxy, not GitHub Releases | | cargo-deny | `taiki-e/install-action` | Prebuilt with checksum verification | | cargo-nextest | `taiki-e/install-action` | Prebuilt with checksum verification | | coglet wheel (maturin+zig) | `PyO3/maturin-action` | Bundles maturin and zig | | golangci-lint | `golangci/golangci-lint-action` | Built-in caching | | Rust toolchain | `dtolnay/rust-toolchain` | Guaranteed ordering | Tools disabled in CI are listed in `MISE_DISABLE_TOOLS` in `ci.yaml`. **When updating a tool version**, update both: - The version in `mise.toml` (for local dev) - The corresponding version pin in `.github/workflows/ci.yaml` (for CI) ## Important Files - `cog.yaml` - User-facing model configuration - `pkg/config/config.go` - Go code for parsing and validating `cog.yaml` - `pkg/config/data/config_schema_v1.0.json` - JSON schema for `cog.yaml` - `python/cog/base_predictor.py` - Predictor interface - `crates/Cargo.toml` - Rust workspace configuration - `crates/README.md` - Coglet architecture overview - `mise.toml` - Task definitions for development workflow ## Testing Philosophy - Unit tests for individual components (Go and Python) - Integration tests for end-to-end workflows - Tests use real Docker operations (no mocking Docker API) - Always run `mise run build:sdk` after making Python changes before testing Go code - Python 3.10-3.13 compatibility is required ### Go Test Conventions All Go tests must use [testify](https://github.com/stretchr/testify) for assertions. Do **not** use raw `if` checks with `t.Fatal`/`t.Errorf` — use `require` and `assert` instead. - **`require`** — for fatal assertions that should stop the test (setup failures, preconditions): ```go require.NoError(t, err, "failed to create client") require.Equal(t, expected, actual) require.True(t, condition, "server should be ready") ``` - **`assert`** — for non-fatal checks where the test should continue (e.g. validating multiple fields in a loop): ```go assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Contains(t, output, "expected substring") assert.NoError(t, err, "prediction %d failed", i) ``` - Use `require` for errors in setup/teardown and `assert` for the actual test expectations - Prefer specific assertions (`Equal`, `Contains`, `NoError`, `Len`, `Less`) over generic `True`/`False` — they produce better failure messages - Prefer table-driven tests for testing multiple similar cases ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing guide ## Development environment Development tasks are managed with [mise](https://mise.jdx.dev/). Run `mise tasks` to see all available tasks. ### Prerequisites - [mise](https://mise.jdx.dev/getting-started.html): Manages Go, Rust, Python, and other tools - [Docker](https://docs.docker.com/desktop) or [OrbStack](https://orbstack.dev) ### Setup ```sh # Trust the mise configuration and install tools mise trust mise install # Create Python virtualenv and install dependencies uv venv uv sync --all-groups ``` ### Build & install ```sh mise run build # symlink the binary to /usr/local/bin sudo mise run install ``` After making changes, run `mise run build` to rebuild and it will get picked up by the symlink. ### Common tasks ```sh # Run all tests mise run test:go mise run test:python mise run test:rust # Run specific tests mise run test:go -- ./pkg/config uv run tox -e py312-tests -- python/tests/server/test_http.py -k test_name # Format code (all languages) mise run fmt:fix # Lint code (all languages) mise run lint ``` Run `mise tasks` for the complete list of available tasks. If you encounter any errors, see the troubleshooting section below. ## Project structure As much as possible, this is attempting to follow the [Standard Go Project Layout](https://github.com/golang-standards/project-layout). - `cmd/` - The root `cog` command. - `pkg/cli/` - CLI commands. - `pkg/config` - Everything `cog.yaml` related. - `pkg/docker/` - Low-level interface for Docker commands. - `pkg/dockerfile/` - Creates Dockerfiles. - `pkg/image/` - Creates and manipulates Cog Docker images. - `pkg/predict/` - Runs predictions on models. - `pkg/util/` - Various packages that aren't part of Cog. They could reasonably be separate re-usable projects. - `python/` - The Cog Python library. - `integration-tests/` - Go-based integration tests using testscript. - `tools/compatgen/` - Tool for generating CUDA/PyTorch/TensorFlow compatibility matrices. For deeper architectural understanding, see the [architecture documentation](./architecture/00-overview.md). ## Updating compatibility matrices The CUDA base images and framework compatibility matrices in `pkg/config/` are checked into source control and only need to be regenerated when adding support for new versions of CUDA, PyTorch, or TensorFlow. To regenerate the compatibility matrices, run: ```sh # Regenerate all matrices mise run generate:compat # Or regenerate specific matrices mise run generate:compat cuda mise run generate:compat torch mise run generate:compat tensorflow ``` The generated files are: - `pkg/config/cuda_base_images.json` - Available NVIDIA CUDA base images - `pkg/config/torch_compatibility_matrix.json` - PyTorch/CUDA/Python compatibility - `pkg/config/tf_compatibility_matrix.json` - TensorFlow/CUDA/Python compatibility ## CI tool dependencies Development tools are managed in **two places** that must be kept in sync: 1. **`mise.toml`** — Tool versions for local development (uses aqua backend for prebuilt binaries) 2. **`.github/workflows/ci.yaml`** — Tool installation for CI (uses dedicated GitHub Actions) CI deliberately avoids aqua downloads from GitHub Releases to prevent transient 502 failures. Instead, it uses dedicated actions (`taiki-e/install-action`, `go install`, `PyO3/maturin-action`, etc.) that are more reliable. Tools disabled in CI are listed in `MISE_DISABLE_TOOLS` in `ci.yaml`. **When updating a tool version**, update both: - The version in `mise.toml` (for local dev) - The corresponding version pin in `.github/workflows/ci.yaml` (for CI) See the [CI Tool Dependencies section in AGENTS.md](./AGENTS.md#ci-tool-dependencies) for the full mapping of tools to their CI installation methods. ## Concepts There are a few concepts used throughout Cog that might be helpful to understand. - **Config**: The `cog.yaml` file. - **Image**: Represents a built Docker image that serves the Cog API, containing a **model**. - **Input**: Input from a **prediction**, as key/value JSON object. - **Model**: A user's machine learning model, consisting of code and weights. - **Output**: Output from a **prediction**, as arbitrarily complex JSON object. - **Prediction**: A single run of the model, that takes **input** and produces **output**. - **Predictor**: Defines how Cog runs **predictions** on a **model**. ## Running tests **To run the entire test suite:** ```sh mise run test:go mise run test:python mise run test:rust ``` **To run just the Go unit tests:** ```sh mise run test:go ``` **To run just the Python tests:** ```sh mise run test:python ``` > [!INFO] > This runs the Python test suite across all supported Python versions (3.10-3.13) using tox. ### Integration Tests Integration tests are in `integration-tests/` using [testscript](https://pkg.go.dev/github.com/rogpeppe/go-internal/testscript). Each test is a self-contained `.txtar` file in `integration-tests/tests/`, with some specialized tests as Go test functions in subpackages. ```sh # Run all integration tests mise run test:integration # Run a specific test mise run test:integration string_predictor # Run fast tests only (skip slow GPU/framework tests) cd integration-tests && go test -short -v # Run with a custom cog binary COG_BINARY=/path/to/cog mise run test:integration ``` ### Writing Integration Tests When adding new functionality, add integration tests in `integration-tests/tests/`. They are: - Self-contained (embedded fixtures in `.txtar` files) - Faster to run (parallel execution with automatic cleanup) - Easier to read and write (simple command script format) Example test structure: ```txtar # Test string predictor cog build -t $TEST_IMAGE cog predict $TEST_IMAGE -i s=world stdout 'hello world' -- cog.yaml -- build: python_version: "3.12" predict: "predict.py:Predictor" -- predict.py -- from cog import BasePredictor class Predictor(BasePredictor): def predict(self, s: str) -> str: return "hello " + s ``` For testing `cog serve`, use `cog serve` and the `curl` command: ```txtar cog build -t $TEST_IMAGE cog serve curl POST /predictions '{"input":{"s":"test"}}' stdout '"output":"hello test"' ``` #### Advanced Test Commands For tests that require subprocess initialization or async operations, use `retry-curl`: **`retry-curl` - HTTP request with automatic retries:** ```txtar # Make HTTP request with retry logic (useful for subprocess initialization delays) # retry-curl [method] [path] [body] [max-attempts] [retry-delay] retry-curl POST /predictions '{"input":{"s":"test"}}' 30 1s stdout '"output":"hello test"' ``` **Example: Testing predictor with subprocess in setup** ```txtar cog build -t $TEST_IMAGE cog serve # Use generous retries since setup spawns a background process retry-curl POST /predictions '{"input":{"s":"test"}}' 30 1s stdout '"output":"hello test"' -- predict.py -- class Predictor(BasePredictor): def setup(self): self.process = subprocess.Popen(["./background.sh"]) def predict(self, s: str) -> str: return "hello " + s ``` #### Test Conditions Use conditions to control when tests run based on environment: **`[short]` - Skip slow tests in short mode:** ```txtar [short] skip 'requires GPU or long build time' cog build -t $TEST_IMAGE # ... rest of test ``` Run with `go test -short` to skip these tests. **`[linux]` / `[!linux]` - Platform-specific tests:** ```txtar [!linux] skip 'requires Linux' # Linux-specific test cog build -t $TEST_IMAGE ``` **`[amd64]` / `[!amd64]` - Architecture-specific tests:** ```txtar [!amd64] skip 'requires amd64 architecture' # amd64-specific test cog build -t $TEST_IMAGE ``` **`[linux_amd64]` - Combined platform and architecture:** ```txtar [!linux_amd64] skip 'requires Linux on amd64' # Test that requires both Linux and amd64 cog build -t $TEST_IMAGE ``` **Combining conditions:** Conditions can be negated with `!`. Examples: - `[short]` - True when `go test -short` is used (skip this test in short mode) - `[!short]` - True when NOT running with `-short` flag (only run this in full test mode) - `[!linux]` - True when NOT on Linux - `[linux_amd64]` - True when on Linux AND amd64 See existing tests in `integration-tests/tests/`, especially `setup_subprocess_*.txtar`, for more examples. ## Running the docs server To run the docs website server locally: ```sh mise run docs:serve ``` ## Publishing a release Releases are managed by GitHub Actions workflows. See [`.github/workflows/README.md`](.github/workflows/README.md) for full details. All packages use **lockstep versioning** from `crates/Cargo.toml`. There are three release types: | Type | Example tag | Branch rule | PyPI/crates.io? | |------|-------------|-------------|-----------------| | **Stable** | `v0.17.0` | Must be on main | Yes | | **Pre-release** | `v0.17.0-alpha3` | Must be on main | Yes | | **Dev** | `v0.17.0-dev1` | Any branch | No | ### Stable / Pre-release ```bash # 1. Update crates/Cargo.toml version (e.g. "0.17.0" or "0.17.0-alpha3") # 2. Merge to main # 3. Tag and push git tag v0.17.0 git push origin v0.17.0 # 4. Wait for release-build.yaml to create a draft release # 5. Review the draft in GitHub UI, then click "Publish release" # This triggers release-publish.yaml -> PyPI + crates.io ``` ### Dev release ```bash # From any branch: # 1. Update crates/Cargo.toml version (e.g. "0.17.0-dev1") # 2. Commit and push # 3. Tag and push git tag v0.17.0-dev1 git push origin v0.17.0-dev1 # 4. Done. Artifacts are built and published as a GH pre-release. # No PyPI/crates.io. No manual approval. ``` ## Troubleshooting ### `cog command not found` The compiled `cog` binary will be installed in `$GOPATH/bin/cog`, e.g. `~/go/bin/cog`. Make sure that Golang's bin directory is present on your system PATH by adding it to your shell config (`.bashrc`, `.zshrc`, etc): export PATH=~/go/bin:$PATH --- Still having trouble? Please [open an issue](https://github.com/replicate/cog/issues) on GitHub. ================================================ FILE: DESIGN.md ================================================ # Design ## Background Cog came from Andreas's experience at Spotify and Ben's experience at Docker. At Spotify, Andreas noticed a cluster of related problems: - **It was hard to run open-source machine learning models.** All the advances in machine learning were locked up inside prose in PDFs, scraps of code on GitHub, weights on Google Drive (if you were lucky!). If you wanted to build upon this research, or apply it to real-world problems, you had to implement it all from scratch. - **It was hard to deploy machine learning models to production.** Andreas was the only person on the research team who was also an infrastructure engineer. Typically a researcher would have to sit down with Andreas to decide on an API, get a server written, package up dependencies, battle CUDA, get it running efficiently, get it deployed on the cluster, and so on and so forth. It would take weeks to get something running in production. Ben connected this back to his experience at Docker. What Docker did was define a standard box that software could go in. You could put any kind of server software in there – Python, Java, Ruby on Rails, whatever – and you could then know that you could run it on your local machine or on any cloud, as long as it supported Docker. We wanted to do the same thing for machine learning. ## Vision We want Cog to be a standard artifact for what a model is and how that model is run. (More detail...) ## Design principles There are a few things driving Cog's design: - **Reproducible artifact**: When you've put your model in Cog, it'll run anywhere, and _keep_ on running. This is why it's a Docker image, with all of the model's dependencies. Docker images have a content-addressable SHA256 ID, which is the identifier for that model's behavior, byte-for-byte. - **Weights inside the image**: We encourage users to put model weights in images. If the weights are on cloud storage somewhere, then they might change or disappear, and the image will produce different results. There's nothing magical about Docker images – they're just a bundle of files. Docker moves around that bundle of files quite slowly, though, but we can optimize that process so it's as fast as reading weights directly from blob storage, or wherever. - **Models are just functions**: Models can be lots of things. We are of the opinion that [machine learning is just software](https://replicate.com/blog/machine-learning-needs-better-tools), and a model is just a function. It often needs to be attached to a GPU, but apart from that it's just a normal function that has some input and some output. This is the core difference between Docker's abstraction and Cog's abstraction: Docker packages up an executable, whereas Cog packages up a _function_. - **Standard interface**: When you run the Docker container, it serves an HTTP server, that is a standard API for running that function. You can think of it like a remote procedure call. - **Self-describing artifact**: A Cog model has it's schema (or type signature, if you're thinking of it as a function) attached to the image as a label. This means systems that work with Cog models can know what the model is and what requests to send to it. This is what powers the forms on Replicate, for example. - **Not just the model**: Before Cog, the typical standard packaging formats for machine learning models were at the network level. A way of taking a Tensorflow or PyTorch network and packaging up in a way that would run on lots of different types of accelerators. Things like [ONNX](https://onnx.ai/) or [TVM's IR](https://tvm.apache.org/). We realized that "models" are not just the network, but they are also pre- and post-processing, and are so diverse and the field is so fast-moving that you can't possible squeeze it into some high-level abstraction. It just needs to be code running on a computer. - **It's just Docker**: Cog models need to run anywhere, and they'll only run anywhere if it's vanilla Docker. We might optimize how Docker images get shipped around to make it faster, but we're not going to invent our own image format. - **The API is for software developers**: In the olden days, you have to pass tensors to TFServing and know how to generate a tensor from a JPEG. Cog's API intentionally just speaks JSON, strings, files, etc. It's intended to be the interface between the software developer and the ML engineer. Sort of like Docker was intended to be the interface between the software developer and the infrastructure engineer. - **Cog is the APIs and interfaces, not just the software**: The most important thing about Cog is that it defines a standard for what a model is and how to run it. It doesn't necessarily need to involve Cog the piece of software itself. For example, Replicate could serve a model from OpenAI with a Cog API and schema, but it's not packaged or running with Cog under the hood at all – it's just calling the OpenAI API directly. ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2022, Replicate, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: Makefile ================================================ # Makefile - Shim that delegates to mise tasks # # This Makefile provides backward compatibility for common targets. # All task definitions live in mise.toml. Run `mise tasks` to see available tasks. # # For new development, prefer using mise directly: # mise run build:cog instead of make cog # mise run test:go instead of make test-go # mise run fmt:fix instead of make fmt SHELL := bash # Show deprecation warning (set MAKE_NO_WARN=1 to suppress) ifndef MAKE_NO_WARN $(info ) $(info ┌────────────────────────────────────────────────────────────────────┐) $(info │ NOTE: This Makefile is a compatibility shim. Prefer using mise: │) $(info │ │) $(info │ mise run build:cog mise run test:go mise run fmt:fix │) $(info │ │) $(info │ Run 'mise tasks' to see all available tasks. │) $(info │ Set MAKE_NO_WARN=1 to suppress this message. │) $(info └────────────────────────────────────────────────────────────────────┘) $(info ) endif PREFIX ?= /usr/local GO ?= go COG_BINARIES := cog base-image default: cog # ============================================================================= # Build targets # ============================================================================= .PHONY: cog base-image $(COG_BINARIES): mise run build:cog .PHONY: wheel wheel: mise run build:sdk .PHONY: install install: cog PREFIX=$(PREFIX) mise run install # ============================================================================= # Test targets # ============================================================================= .PHONY: test test: mise run test:go mise run test:python .PHONY: test-go test-go: mise run test:go .PHONY: test-python test-python: mise run test:python .PHONY: test-integration test-integration: mise run test:integration .PHONY: test-coglet test-coglet: test-coglet-rust .PHONY: test-coglet-rust test-coglet-rust: mise run test:rust .PHONY: test-coglet-python test-coglet-python: mise run test:coglet:python # ============================================================================= # Format and lint targets # ============================================================================= .PHONY: fmt fmt: mise run fmt:fix .PHONY: check-fmt check-fmt: mise run fmt .PHONY: lint lint: mise run lint .PHONY: vet vet: $(GO) vet ./... # ============================================================================= # Code generation # ============================================================================= .PHONY: generate generate: mise run generate .PHONY: gen-mocks gen-mocks: mockery # ============================================================================= # Coglet (Rust) targets # ============================================================================= .PHONY: fmt-coglet fmt-coglet: mise run fmt:rust:fix .PHONY: check-fmt-coglet check-fmt-coglet: mise run fmt:rust .PHONY: lint-coglet lint-coglet: mise run lint:rust # ============================================================================= # Documentation # ============================================================================= .PHONY: run-docs-server run-docs-server: mise run docs:serve # ============================================================================= # Clean # ============================================================================= .PHONY: clean clean: mise run clean .PHONY: clean-coglet clean-coglet: mise run clean:rust ================================================ FILE: README.md ================================================ # Cog: Containers for machine learning Cog is an open-source tool that lets you package machine learning models in a standard, production-ready container. You can deploy your packaged model to your own infrastructure, or to [Replicate](https://replicate.com/). ## Highlights - 📦 **Docker containers without the pain.** Writing your own `Dockerfile` can be a bewildering process. With Cog, you define your environment with a [simple configuration file](#how-it-works) and it generates a Docker image with all the best practices: Nvidia base images, efficient caching of dependencies, installing specific Python versions, sensible environment variable defaults, and so on. - 🤬️ **No more CUDA hell.** Cog knows which CUDA/cuDNN/PyTorch/Tensorflow/Python combos are compatible and will set it all up correctly for you. - ✅ **Define the inputs and outputs for your model with standard Python.** Then, Cog generates an OpenAPI schema and validates the inputs and outputs. - 🎁 **Automatic HTTP prediction server**: Your model's types are used to dynamically generate a RESTful HTTP API using a high-performance Rust/Axum server. - 🚀 **Ready for production.** Deploy your model anywhere that Docker images run. Your own infrastructure, or [Replicate](https://replicate.com). ## How it works Define the Docker environment your model runs in with `cog.yaml`: ```yaml build: gpu: true system_packages: - "libgl1-mesa-glx" - "libglib2.0-0" python_version: "3.13" python_requirements: requirements.txt predict: "predict.py:Predictor" ``` Define how predictions are run on your model with `predict.py`: ```python from cog import BasePredictor, Input, Path import torch class Predictor(BasePredictor): def setup(self): """Load the model into memory to make running multiple predictions efficient""" self.model = torch.load("./weights.pth") # The arguments and types the model takes as input def predict(self, image: Path = Input(description="Grayscale input image") ) -> Path: """Run a single prediction on the model""" processed_image = preprocess(image) output = self.model(processed_image) return postprocess(output) ``` In the above we accept a path to the image as an input, and return a path to our transformed image after running it through our model. Now, you can run predictions on this model: ```console $ cog predict -i image=@input.jpg --> Building Docker image... --> Running Prediction... --> Output written to output.jpg ``` Or, build a Docker image for deployment: ```console $ cog build -t my-classification-model --> Building Docker image... --> Built my-classification-model:latest $ docker run -d -p 5000:5000 --gpus all my-classification-model $ curl http://localhost:5000/predictions -X POST \ -H 'Content-Type: application/json' \ -d '{"input": {"image": "https://.../input.jpg"}}' ``` Or, combine build and run via the `serve` command: ```console $ cog serve -p 8080 $ curl http://localhost:8080/predictions -X POST \ -H 'Content-Type: application/json' \ -d '{"input": {"image": "https://.../input.jpg"}}' ``` ## Why are we building this? It's really hard for researchers to ship machine learning models to production. Part of the solution is Docker, but it is so complex to get it to work: Dockerfiles, pre-/post-processing, Flask servers, CUDA versions. More often than not the researcher has to sit down with an engineer to get the damn thing deployed. [Andreas](https://github.com/andreasjansson) and [Ben](https://github.com/bfirsh) created Cog. Andreas used to work at Spotify, where he built tools for building and deploying ML models with Docker. Ben worked at Docker, where he created [Docker Compose](https://github.com/docker/compose). We realized that, in addition to Spotify, other companies were also using Docker to build and deploy machine learning models. [Uber](https://eng.uber.com/michelangelo-pyml/) and others have built similar systems. So, we're making an open source version so other people can do this too. Hit us up if you're interested in using it or want to collaborate with us. [We're on Discord](https://discord.gg/replicate) or email us at [team@replicate.com](mailto:team@replicate.com). ## Prerequisites - **macOS, Linux or Windows 11**. Cog works on macOS, Linux and Windows 11 with [WSL 2](docs/wsl2/wsl2.md) - **Docker**. Cog uses Docker to create a container for your model. You'll need to [install Docker](https://docs.docker.com/get-docker/) before you can run Cog. If you install Docker Engine instead of Docker Desktop, you will need to [install Buildx](https://docs.docker.com/build/architecture/#buildx) as well. ## Install If you're using macOS, you can install Cog using Homebrew: ```console brew install replicate/tap/cog ``` You can also download and install the latest release using our [install script](https://cog.run/install): ```sh # bash, zsh, and other shells sh <(curl -fsSL https://cog.run/install.sh) # fish shell sh (curl -fsSL https://cog.run/install.sh | psub) # download with wget and run in a separate command wget -qO- https://cog.run/install.sh sh ./install.sh ``` You can manually install the latest release of Cog directly from GitHub by running the following commands in a terminal: ```console sudo curl -o /usr/local/bin/cog -L "https://github.com/replicate/cog/releases/latest/download/cog_$(uname -s)_$(uname -m)" sudo chmod +x /usr/local/bin/cog ``` Or if you are on docker: ``` RUN sh -c "INSTALL_DIR=\"/usr/local/bin\" SUDO=\"\" $(curl -fsSL https://cog.run/install.sh)" ``` ## Upgrade If you're using macOS and you previously installed Cog with Homebrew, run the following: ```console brew upgrade replicate/tap/cog ``` Otherwise, you can upgrade to the latest version by running the same commands you used to install it. ## Development See [CONTRIBUTING.md](CONTRIBUTING.md) for how to set up a development environment and build from source. ## Next steps - [Get started with an example model](docs/getting-started.md) - [Get started with your own model](docs/getting-started-own-model.md) - [Using Cog with notebooks](docs/notebooks.md) - [Using Cog with Windows 11](docs/wsl2/wsl2.md) - [Take a look at some examples of using Cog](https://github.com/replicate/cog-examples) - [Deploy models with Cog](docs/deploy.md) - [`cog.yaml` reference](docs/yaml.md) to learn how to define your model's environment - [Prediction interface reference](docs/python.md) to learn how the `Predictor` interface works - [Training interface reference](docs/training.md) to learn how to add a fine-tuning API to your model - [HTTP API reference](docs/http.md) to learn how to use the HTTP API that models serve ## Need help? [Join us in #cog on Discord.](https://discord.gg/replicate) [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/replicate/cog) ## Contributors ✨ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/docs/en/emoji-key)):
Ben Firshman
Ben Firshman

💻 📖
Andreas Jansson
Andreas Jansson

💻 📖 🚧
Zeke Sikelianos
Zeke Sikelianos

💻 📖 🔧
Rory Byrne
Rory Byrne

💻 📖 ⚠️
Michael Floering
Michael Floering

💻 📖 🤔
Ben Evans
Ben Evans

📖
shashank agarwal
shashank agarwal

💻 📖
VictorXLR
VictorXLR

💻 📖 ⚠️
hung anna
hung anna

🐛
Brian Whitman
Brian Whitman

🐛
JimothyJohn
JimothyJohn

🐛
ericguizzo
ericguizzo

🐛
Dominic Baggott
Dominic Baggott

💻 ⚠️
Dashiell Stander
Dashiell Stander

🐛 💻 ⚠️
Shuwei Liang
Shuwei Liang

🐛 💬
Eric Allam
Eric Allam

🤔
Iván Perdomo
Iván Perdomo

🐛
Charles Frye
Charles Frye

📖
Luan Pham
Luan Pham

🐛 📖
TommyDew
TommyDew

💻
Jesse Andrews
Jesse Andrews

💻 📖 ⚠️
Nick Stenning
Nick Stenning

💻 📖 🎨 🚇 ⚠️
Justin Merrell
Justin Merrell

📖
Rurik Ylä-Onnenvuori
Rurik Ylä-Onnenvuori

🐛
Youka
Youka

🐛
Clay Mullis
Clay Mullis

📖
Mattt
Mattt

💻 📖 🚇
Eng Zer Jun
Eng Zer Jun

⚠️
BB
BB

💻
williamluer
williamluer

📖
Simon Eskildsen
Simon Eskildsen

💻
F
F

🐛 💻
Philip Potter
Philip Potter

🐛 💻
Joanne Chen
Joanne Chen

📖
technillogue
technillogue

💻
Aron Carroll
Aron Carroll

📖 💻 🤔
Bohdan Mykhailenko
Bohdan Mykhailenko

📖 🐛
Daniel Radu
Daniel Radu

📖 🐛
Itay Etelis
Itay Etelis

💻
Gennaro Schiano
Gennaro Schiano

📖
André Knörig
André Knörig

📖
Dan Fairs
Dan Fairs

💻
This project follows the [all-contributors](https://github.com/all-contributors/all-contributors) specification. Contributions of any kind welcome! ================================================ FILE: architecture/00-overview.md ================================================ # Cog Architecture Overview Cog packages machine learning models into production-ready OCI images. ## The Big Picture ```mermaid flowchart LR subgraph input["What you write"] model["Model Code
+ cog.yaml"] end subgraph cog["Cog"] cli["CLI"] sdk["Python SDK"] end subgraph output["What you get"] image["Container Image"] api["HTTP API"] end model --> cli cli -->|"builds"| image image -->|"runs"| sdk sdk -->|"serves"| api ``` ## Components ### Model Source What the model author provides: `cog.yaml` for environment config, a Predictor class with `setup()` and `predict()` methods, and optionally model weights. **Deep dive**: [Model Source](./01-model-source.md) --- ### Schema An OpenAPI specification generated from the predictor's type hints. Describes what inputs the model accepts and what outputs it produces. **Deep dive**: [Schema](./02-schema.md) --- ### Prediction API The HTTP interface for running predictions. A fixed envelope format (`PredictionRequest`/`PredictionResponse`) wraps model-specific inputs and outputs. **Deep dives**: - [Legacy API](./legacy/03-prediction-api.md) - FastAPI implementation details - [FFI API](./ffi/03-prediction-api.md) - Rust/Axum implementation details --- ### Container Runtime The runtime that runs inside the container: an HTTP server, worker process isolation, and prediction execution. Cog has two runtime implementations: - **Legacy (Python/FastAPI)**: Current default implementation - [Documentation](./legacy/) - **FFI (Rust/PyO3)**: Next-generation experimental implementation - [Documentation](./ffi/) **Deep dives**: - [Legacy Runtime](./legacy/04-container-runtime.md) - FastAPI/Uvicorn two-process architecture - [FFI Runtime](./ffi/04-container-runtime.md) - Rust/Axum with PyO3 FFI bridge --- ### Build System Transforms `cog.yaml` and user code into a Docker image with the right Python version, CUDA libraries, and dependencies. **Deep dive**: [Build System](./05-build-system.md) --- ### CLI The command-line tool for building, testing, and deploying models. **Deep dive**: [CLI](./06-cli.md) --- ## How It Fits Together ```mermaid flowchart TB subgraph source["Model Source"] yaml["cog.yaml"] code["predict.py"] weights["weights"] end subgraph build["Build Time"] config["Config Parser"] generator["Dockerfile Generator"] schema_gen["Schema Generator"] end subgraph image["Container Image"] layers["Base + Deps + Code"] schema["OpenAPI Schema
(label)"] end subgraph runtime["Runtime"] server["HTTP Server"] worker["Worker Process"] predictor["Predictor"] end yaml --> config config --> generator generator --> layers code --> layers weights --> layers layers --> schema_gen schema_gen --> schema image --> server server --> worker worker --> predictor ``` ## Terminology | Term | Meaning | |------|---------| | **Predictor** | User's model class with `setup()` and `predict()` methods | | **Schema** | OpenAPI spec describing the model's input/output interface | | **Envelope** | Fixed request/response structure wrapping model-specific data | | **Worker** | Isolated subprocess running user code | | **Setup** | One-time model initialization at container start | ## Runtime Implementations Cog supports two runtime implementations: ### Legacy Runtime (Python/FastAPI) - **Status**: Current default - **Use when**: Running standard Cog containers - **Implementation**: `python/cog/server/` - **Documentation**: [legacy/](./legacy/) ### FFI Runtime (Rust/PyO3) - **Status**: Experimental (in development) - **Use when**: Set `USE_COGLET` environment variable - **Implementation**: `crates/coglet/` - **Documentation**: [ffi/](./ffi/) - **Benefits**: Better performance, stability, and resource management Both runtimes expose the same HTTP API and support the same model code. The FFI runtime is a drop-in replacement with improved internals. ## Reading Order For understanding Cog's architecture, we recommend reading in this order: 1. [Model Source](./01-model-source.md) - What users write 2. [Schema](./02-schema.md) - How the interface is described 3. **Choose a runtime path**: - **Legacy**: [Prediction API](./legacy/03-prediction-api.md) → [Container Runtime](./legacy/04-container-runtime.md) - **FFI**: [Prediction API](./ffi/03-prediction-api.md) → [Container Runtime](./ffi/04-container-runtime.md) 4. [Build System](./05-build-system.md) - How images are built 5. [CLI](./06-cli.md) - How users interact with it all ================================================ FILE: architecture/01-model-source.md ================================================ # Model Source This document covers what a model author provides to Cog and the primitives they work with. ## What Users Write A Cog model consists of: ``` my-model/ ├── cog.yaml # Environment configuration ├── predict.py # Predictor class └── weights/ # Model weights (optional, can be downloaded) ``` ## cog.yaml Declares the runtime environment: ```yaml build: python_version: "3.11" gpu: true python_packages: - torch==2.1.0 - transformers==4.35.0 system_packages: - ffmpeg run: - curl -o /src/model.bin https://example.com/model.bin predict: "predict.py:Predictor" concurrency: max: 1 ``` | Field | Purpose | |-------|---------| | `build.python_version` | Python interpreter version (3.10-3.13) | | `build.gpu` | Enable CUDA support | | `build.python_packages` | pip packages to install | | `build.system_packages` | apt packages to install | | `build.run` | Arbitrary shell commands during build | | `predict` | Path to predictor class (`module:ClassName`) | | `train` | Path to training class (optional) | | `concurrency.max` | Max concurrent predictions (requires async) | The [Build System](./05-build-system.md) uses this configuration to produce an image containing all necessary dependencies, libraries, and the correct Python/CUDA versions. ## The Predictor Class A predictor is a Python class with two methods: ```python from cog import BasePredictor, Input, Path class Predictor(BasePredictor): def setup(self): """Load model into memory. Called once at container start.""" self.model = load_model("./weights") def predict(self, prompt: str, steps: int = 50) -> Path: """Run inference. Called for each prediction request.""" output = self.model.generate(prompt, steps=steps) output.save("/tmp/output.png") return Path("/tmp/output.png") ``` ### setup() - Called **once** when the container starts - Used to load model weights, initialize GPU contexts, warm up caches - Runs before the HTTP server accepts requests - Optional: if omitted, Cog proceeds directly to serving ### predict() - Called **for each prediction request** - Signature defines the model's input schema (via type hints) - Return type defines the output schema - Can be sync (`def`) or async (`async def`) ### train() (optional) - Same contract as `predict()` but for fine-tuning workflows - Configured separately in `cog.yaml` with `train:` key ## Input Types The types used in `predict()` parameters become the model's input schema. ### Basic Types ```python def predict( self, text: str, # String input count: int, # Integer temperature: float, # Float verbose: bool, # Boolean ) -> str: ``` ### File Inputs (cog.Path) URLs are automatically downloaded to local files: ```python from cog import Path def predict(self, image: Path) -> Path: # Client sends: {"input": {"image": "https://example.com/photo.jpg"}} # Cog downloads the URL, `image` is a local path like /tmp/inputabc123.jpg img = PIL.Image.open(image) ... ``` `cog.Path` extends `pathlib.Path`. At runtime: - HTTP/HTTPS URLs are downloaded to temp files - Data URLs are decoded - The predictor receives a local filesystem path ### Secrets (cog.Secret) For sensitive values that shouldn't appear in logs: ```python from cog import Secret def predict(self, api_key: Secret) -> str: # Value is masked in logs and webhooks client = SomeAPI(api_key.get_secret_value()) ... ``` ### Input Constraints Use `Input()` to add metadata and validation: ```python from cog import Input def predict( self, prompt: str = Input(description="The text prompt"), steps: int = Input(default=50, ge=1, le=100, description="Inference steps"), style: str = Input(choices=["photo", "art", "sketch"]), ) -> str: ``` | Parameter | Effect | |-----------|--------| | `description` | Shown in UI and schema | | `default` | Default value if not provided | | `ge`, `le` | Numeric bounds (greater/less than or equal) | | `min_length`, `max_length` | String length bounds | | `choices` | Enum values (deprecated: prefer `Literal`) | ### Enums with Literal ```python from typing import Literal def predict( self, size: Literal["small", "medium", "large"] = "medium", ) -> str: ``` ### Lists ```python from typing import List from cog import Path def predict( self, images: List[Path], # Multiple file inputs tags: List[str], # Multiple strings ) -> str: ``` ### Optional Inputs ```python from typing import Optional def predict( self, seed: Optional[int] = None, # Can be omitted or null ) -> str: ``` ## Output Types The return type annotation defines what the model produces. ### Basic Types ```python def predict(self, prompt: str) -> str: return "Generated text..." ``` ### File Outputs Return `cog.Path` pointing to a generated file: ```python from cog import Path def predict(self, prompt: str) -> Path: # Generate file output_path = "/tmp/output.png" self.model.generate(prompt).save(output_path) return Path(output_path) ``` At runtime, Cog uploads the file and returns a URL to the client. ### Multiple Outputs Return a list: ```python from typing import List from cog import Path def predict(self, prompt: str) -> List[Path]: paths = [] for i in range(4): path = f"/tmp/output_{i}.png" self.model.generate(prompt, seed=i).save(path) paths.append(Path(path)) return paths ``` ### Streaming with Iterator Yield values progressively: ```python from typing import Iterator def predict(self, prompt: str) -> Iterator[str]: for token in self.model.generate_stream(prompt): yield token ``` The schema marks this as `x-cog-array-type: iterator`. Clients receive outputs as they're produced via webhooks or streaming responses. ### Streaming Text with ConcatenateIterator For LLM-style token streaming where outputs should be concatenated: ```python from cog import ConcatenateIterator def predict(self, prompt: str) -> ConcatenateIterator[str]: for token in self.model.generate(prompt): yield token # "Hello", " ", "world", "!" # Client sees progressive: "Hello" -> "Hello " -> "Hello world" -> "Hello world!" ``` The schema includes `x-cog-array-display: concatenate` to signal that outputs should be joined rather than displayed as a list. ## Weights Model weights can be loaded in several ways: ### Bundled in the Image Include weights in your source directory - they're copied into the image during build: ``` my-model/ ├── cog.yaml ├── predict.py └── weights/ └── model.safetensors ``` ```python def setup(self): self.model = load("./weights/model.safetensors") ``` ### Downloaded at Runtime Weights can be fetched during `setup()` rather than bundled. Common approaches: **Using the `weights` parameter** (Cog's built-in mechanism): ```python class Predictor(BasePredictor): def setup(self, weights: Path): self.model = load(weights) ``` The `weights` value comes from `COG_WEIGHTS` env var or falls back to `./weights`: ```bash COG_WEIGHTS=https://example.com/model.tar cog predict ... ``` **Using pget** (parallel download tool, included in Cog images): ```python import subprocess def setup(self): subprocess.run(["pget", "https://example.com/model.tar", "./weights"]) self.model = load("./weights/model.safetensors") ``` **Direct download in setup**: ```python def setup(self): # Using requests, huggingface_hub, or any other method snapshot_download(repo_id="meta-llama/Llama-2-7b", local_dir="./weights") self.model = load("./weights") ``` The choice depends on your deployment needs - bundled weights make images larger but start faster; downloaded weights keep images small but require network access at startup. ## Async Predictors For concurrent predictions, use async: ```python class Predictor(BasePredictor): async def setup(self): self.model = await load_model_async() async def predict(self, prompt: str) -> str: return await self.model.generate(prompt) ``` Requires: - Python 3.11+ - `concurrency.max > 1` in cog.yaml See [Container Runtime](./04-container-runtime.md) for concurrency details. ## Code References | File | Purpose | |------|---------| | `python/cog/__init__.py` | Public API exports | | `python/cog/base_predictor.py` | BasePredictor class | | `python/cog/types.py` | Input, Path, Secret, ConcatenateIterator | | `python/cog/predictor.py` | Type introspection, weights handling | | `pkg/config/config.go` | cog.yaml parsing | ================================================ FILE: architecture/02-schema.md ================================================ # Schema The schema is an **OpenAPI 3.0.2 specification** that describes a model's interface. It's the contract between the model and everything that interacts with it. ## Why the Schema Exists Every Cog model uses the same [Prediction API](./ffi/03-prediction-api.md) envelope format, but the `input` and `output` fields are model-specific. The schema captures what each model expects and produces. ``` ┌─────────────────────────────────────────────────┐ │ PredictionRequest (fixed envelope) │ │ ┌─────────────────────────────────────────┐ │ │ │ "input": { ... } <- model-specific │ │ │ └─────────────────────────────────────────┘ │ └─────────────────────────────────────────────────┘ ↑ Schema defines this part ``` Without the schema, consumers would have no way to know: - What inputs the model accepts - What types those inputs should be - What constraints apply (required fields, min/max values, allowed choices) - What the output looks like ### How It's Used Today | Consumer | What They Use the Schema For | |----------|------------------------------| | **Replicate platform** | Generate input forms in the web UI, validate requests before routing to models | | **HTTP server (coglet)** | Validate incoming JSON, reject malformed requests before they reach user code | | **CLI (`cog predict`)** | Parse `-i key=value` flags into correctly-typed Python objects | | **Docker label** | Extract model interface without running the container | | **API clients** | Know what to send and what to expect back without reading source code | ## How It's Generated Cog supports two schema generation paths: ### Legacy Runtime Path (default) The **legacy path** boots the built Docker container and runs `python -m cog.command.openapi_schema` to introspect the model at runtime using pydantic. This is the default for all builds. It works with any Python type that pydantic can serialize, including third-party types, complex inheritance, and dynamically constructed classes. ### Static Path (opt-in) The **static path** parses Python source code at `cog build` time using [tree-sitter](https://tree-sitter.github.io/tree-sitter/) in Go. No Python runtime is invoked. This makes schema generation deterministic, fast, and independent of the model's dependencies. Enable it by setting the `COG_STATIC_SCHEMA` environment variable: ```bash COG_STATIC_SCHEMA=1 cog build -t my-model ``` The static path requires SDK >= 0.17.0. When opted in, if the static parser encounters a type it cannot resolve, it **falls back to the legacy runtime path** automatically with a warning — so builds never fail due to static parser limitations. For local commands (`cog train`, `cog serve`, `cog predict`), the static path is always used regardless of the `COG_STATIC_SCHEMA` flag, because these commands return before the post-build legacy generation step — the CLI needs the schema to parse `-i` input flags. ```mermaid flowchart LR subgraph source["Model Source"] predict["predict.py"] types["output_types.py"] end subgraph parser["Go Static Parser"] ts["tree-sitter Python"] resolve["Type Resolver"] cross["Cross-File Resolver"] end subgraph output["Schema"] spec["OpenAPI 3.0.2 JSON"] end predict --> ts types --> cross ts --> resolve cross --> resolve resolve --> spec ``` ### Static Path Pipeline Steps 1. **Parse** the predictor file with tree-sitter (concrete syntax tree, not AST) 2. **Collect imports** — track where each name came from (`from cog import Path`, `from pydantic import BaseModel`) 3. **Collect module scope** — resolve module-level variable assignments (for default values, choices lists) 4. **Collect BaseModel subclasses** — find all classes that inherit from `BaseModel` (cog or pydantic) in the current file 5. **Resolve cross-file models** — for imported names not found locally, find the `.py` file on disk, parse it, and extract its BaseModel definitions 6. **Extract inputs** — walk the `predict()` / `train()` method parameters, resolve types, defaults, and `Input()` metadata 7. **Resolve output type** — recursively resolve the return type annotation into a `SchemaType` 8. **Generate OpenAPI** — convert the extracted `PredictorInfo` into a full OpenAPI 3.0.2 JSON document If any step fails with an unresolvable type, the build falls back to the legacy runtime path. ### Cross-File Resolution When a predictor imports types from other project files, the schema generator resolves them automatically: ```python # output_types.py from pydantic import BaseModel class Prediction(BaseModel): text: str score: float tags: list[str] ``` ```python # predict.py from cog import BasePredictor from output_types import Prediction class Predictor(BasePredictor): def predict(self, prompt: str) -> Prediction: ... ``` The resolver handles every permutation of local imports: | Import Style | File Resolved | |-------------|---------------| | `from output_types import X` | `/output_types.py` | | `from .output_types import X` | `/output_types.py` | | `from models.output import X` | `/models/output.py` | | `from .models.output import X` | `/models/output.py` | | `from output_types import X as Y` | `/output_types.py` (alias tracked) | **How it distinguishes local from external**: the resolver converts the module path to a filesystem path and checks if the file exists. If `output_types.py` exists in the project directory, it's local. If not (e.g., `from transformers import ...`), it's external. Known external packages (stdlib, torch, numpy, etc.) are skipped without a filesystem check. **Error messages**: when a type can't be resolved, the error includes the import source: ``` cannot resolve output type 'WeirdType' (imported from 'some_package') — external types cannot be statically analyzed. Define it as a BaseModel subclass in your predict file, or provide a .pyi stub ``` ## SchemaType: The Type System Output types are represented as a recursive algebraic data type (`SchemaType`) that composes arbitrarily: ``` SchemaType ├── SchemaPrimitive — str, int, float, bool, Path ├── SchemaAny — untyped (bare dict, Any) ├── SchemaArray — list[T], with Items → SchemaType ├── SchemaDict -- dict[str, V], with ValueType -> SchemaType ├── SchemaObject — BaseModel subclass, with Fields → OrderedMap[name, SchemaField] ├── SchemaIterator — Iterator[T], with Elem → SchemaType └── SchemaConcatIterator — ConcatenateIterator[str] ``` This recursive structure means nested types like `dict[str, list[dict[str, int]]]` are fully representable and produce correct JSON Schema: ```json { "type": "object", "additionalProperties": { "type": "array", "items": { "type": "object", "additionalProperties": { "type": "integer" } } } } ``` ### JSON Schema Generation Each `SchemaType` produces its JSON Schema fragment via `JSONSchema()`: | SchemaType Kind | JSON Schema | |-----------------|-------------| | `SchemaPrimitive(str)` | `{"type": "string"}` | | `SchemaPrimitive(Path)` | `{"type": "string", "format": "uri"}` | | `SchemaAny` | `{"type": "object"}` | | `SchemaArray(items)` | `{"type": "array", "items": items.JSONSchema()}` | | `SchemaDict(valueType)` | `{"type": "object", "additionalProperties": valueType.JSONSchema()}` | | `SchemaObject(fields)` | `{"type": "object", "properties": {...}, "required": [...]}` | | `SchemaIterator(elem)` | `{"type": "array", "items": elem.JSONSchema(), "x-cog-array-type": "iterator"}` | | `SchemaConcatIterator` | `{"type": "array", "items": {"type": "string"}, "x-cog-array-type": "iterator", "x-cog-array-display": "concatenate"}` | ## Type Mappings ### Input Types | Python | JSON Schema | Notes | |--------|-------------|-------| | `str` | `{"type": "string"}` | | | `int` | `{"type": "integer"}` | | | `float` | `{"type": "number"}` | | | `bool` | `{"type": "boolean"}` | | | `cog.Path` | `{"type": "string", "format": "uri"}` | URLs downloaded at runtime | | `cog.File` | `{"type": "string", "format": "uri"}` | File uploads | | `cog.Secret` | `{"type": "string", "format": "password", "x-cog-secret": true}` | Masked in logs | | `list[T]` | `{"type": "array", "items": {...}}` | | | `Optional[T]` | Type T + not in `required` | Input fields only | | `Literal["a", "b"]` / `choices=[...]` | `{"enum": ["a", "b"]}` | | ### Output Types | Python | SchemaType | JSON Schema | |--------|------------|-------------| | `str` | `SchemaPrimitive` | `{"type": "string"}` | | `int` | `SchemaPrimitive` | `{"type": "integer"}` | | `float` | `SchemaPrimitive` | `{"type": "number"}` | | `bool` | `SchemaPrimitive` | `{"type": "boolean"}` | | `Path` | `SchemaPrimitive` | `{"type": "string", "format": "uri"}` | | `dict` (bare) | `SchemaAny` | `{"type": "object"}` | | `dict[str, V]` | `SchemaDict` | `{"type": "object", "additionalProperties": V}` | | `list` (bare) | `SchemaArray(SchemaAny)` | `{"type": "array", "items": {"type": "object"}}` | | `list[T]` | `SchemaArray` | `{"type": "array", "items": T}` | | `BaseModel` subclass | `SchemaObject` | `{"type": "object", "properties": {...}}` | | `Iterator[T]` | `SchemaIterator` | `{"type": "array", "items": T, "x-cog-array-type": "iterator"}` | | `ConcatenateIterator[str]` | `SchemaConcatIterator` | Streaming token output | | Nested types | Recursive | `dict[str, list[dict[str, int]]]` fully supported | ### Unsupported Output Types | Python | Error | |--------|-------| | `Optional[T]` / `T \| None` | Predictions must succeed with a value or fail with an error | | `Union[A, B]` | Ambiguous for downstream consumers | | External package types | Cannot be statically analyzed — define as BaseModel or use .pyi stub | ## Cog-Specific Extensions | Extension | Purpose | |-----------|---------| | `x-order` | Preserves parameter order from function signature | | `x-cog-array-type` | Marks iterators vs regular arrays | | `x-cog-array-display` | Hints for how to display streaming output | | `x-cog-secret` | Marks sensitive inputs | ## Where the Schema Lives ### In the Image Embedded as a Docker label during build: ```bash docker inspect my-model | jq -r '.[0].Config.Labels["run.cog.openapi_schema"]' ``` Also written to `.cog/openapi_schema.json` inside the image for the runtime to serve. ### At Runtime | Endpoint | Format | |----------|--------| | `GET /openapi.json` | Raw OpenAPI spec | ### Override and Configuration | Environment Variable | Purpose | |---------------------|---------| | `COG_STATIC_SCHEMA=1` | Opt in to the static Go tree-sitter schema generator (falls back to legacy on failure) | | `COG_OPENAPI_SCHEMA=path` | Skip generation entirely and use a pre-built schema file | ```bash # Use static schema generation COG_STATIC_SCHEMA=1 cog build -t my-model # Use a pre-built schema file COG_OPENAPI_SCHEMA=my_schema.json cog build ``` ## Schema Structure A simplified example showing a multi-file predictor with structured output: ```json { "openapi": "3.0.2", "info": { "title": "Cog", "version": "0.1.0" }, "paths": { "/predictions": { "post": { "requestBody": { "content": { "application/json": { "schema": { "$ref": "#/components/schemas/PredictionRequest" } } } } } } }, "components": { "schemas": { "Input": { "type": "object", "properties": { "prompt": { "type": "string", "description": "Text prompt", "x-order": 0 }, "steps": { "type": "integer", "default": 50, "minimum": 1, "maximum": 100, "x-order": 1 } }, "required": ["prompt"] }, "Output": { "type": "object", "properties": { "text": { "type": "string", "title": "Text" }, "score": { "type": "number", "title": "Score" } }, "required": ["text", "score"] }, "PredictionRequest": { "..." : "..." }, "PredictionResponse": { "..." : "..." } } } } ``` ## Code References | File | Purpose | |------|---------| | `pkg/schema/schema_type.go` | `SchemaType` ADT, `ResolveSchemaType()`, `JSONSchema()` generation | | `pkg/schema/types.go` | `PredictorInfo`, `PrimitiveType`, `FieldType`, `InputField`, `ImportContext` | | `pkg/schema/python/parser.go` | Tree-sitter Python parser, `ParsePredictor()`, cross-file resolution | | `pkg/schema/openapi.go` | OpenAPI document assembly from `PredictorInfo` | | `pkg/schema/generator.go` | Top-level `Generate()`, `GenerateCombined()`, `Parser` type | | `pkg/schema/errors.go` | Typed error kinds (`ErrUnresolvableType`, `ErrOptionalOutput`, etc.) | | `pkg/image/build.go` | `canUseStaticSchemaGen()` — opt-in gate, `generateStaticSchema()` — entry point, fallback to legacy on `ErrUnresolvableType` | | `pkg/image/openapi_schema.go` | `GenerateOpenAPISchema()` — legacy runtime path (boots container, runs `python -m cog.command.openapi_schema`) | | `python/cog/_adt.py` | Internal ADT types for Python-side predictor introspection | | `python/cog/_inspector.py` | Python-side predictor inspector (runtime introspection for legacy path) | | `python/cog/_schemas.py` | Python-side OpenAPI schema generation from inspected predictor info | | `python/cog/command/openapi_schema.py` | CLI entry point for `python -m cog.command.openapi_schema` (invoked by legacy runtime path) | ================================================ FILE: architecture/05-build-system.md ================================================ # Build System The build system transforms [Model Source](./01-model-source.md) (cog.yaml + predict.py + weights) into a production-ready OCI image containing the [Container Runtime](./04-container-runtime.md). ## Build Flow ```mermaid flowchart TB subgraph input["Inputs"] yaml["cog.yaml"] code["predict.py"] weights["weights"] end subgraph cli["CLI (pkg/cli/build.go)"] parse["Parse Config"] validate["Validate"] end subgraph generate["Dockerfile Generation (pkg/dockerfile/)"] generator["Generator"] baseimage["Base Image Selection"] compat["Compatibility Matrix"] wheel["Embedded Python Wheel"] end subgraph docker["Docker Build"] buildkit["Buildkit"] image["Container Image"] end subgraph post["Post-Build"] schema["Generate OpenAPI Schema"] freeze["pip freeze"] labels["Apply Labels"] end yaml --> parse --> validate validate --> generator compat --> generator baseimage --> generator wheel --> generator generator -->|"Dockerfile"| buildkit code --> buildkit weights --> buildkit buildkit --> image image --> schema image --> freeze schema --> labels freeze --> labels labels -->|"Final Image"| output["Tagged Image"] ``` ## Key Components ### 1. Config Parsing & Validation Reads `cog.yaml` and validates/completes the configuration: - Validates Python version (3.10-3.13) - Auto-detects CUDA version from PyTorch/TensorFlow requirements - Resolves package versions against compatibility matrix ``` cog.yaml (user provides) → Config (completed) ───────────────────────── ───────────────── gpu: true gpu: true python_packages: cuda: "12.1" ← auto-detected - torch==2.1.0 cudnn: "8" ← auto-detected ``` --- ### 2. Dockerfile Generator The generator produces a Dockerfile from the validated config. #### Generated Dockerfile Sections ```dockerfile # 1. Base image (cog-base, CUDA, or python-slim) FROM r8.im/cog-base:cuda12.1-python3.11-torch2.1.0 # 2. System packages RUN apt-get update && apt-get install -y ffmpeg # 3. Python packages RUN pip install -r requirements.txt # 4. Cog wheel (embedded in CLI binary) COPY cog-0.12.0-py3-none-any.whl /tmp/ RUN pip install /tmp/cog-0.12.0-py3-none-any.whl # 5. User run commands RUN echo "custom setup" # 6. Copy source WORKDIR /src COPY . /src # 7. Entrypoint ENTRYPOINT ["/sbin/tini", "--"] CMD ["python", "-m", "cog.server.http"] ``` --- ### 3. Compatibility Matrix PyTorch and TensorFlow releases are built against specific CUDA/cuDNN versions. The compatibility matrix captures these relationships from upstream release notes. ```mermaid flowchart LR subgraph input["User specifies"] torch["torch==2.1.0"] end subgraph matrix["Compatibility Matrix"] lookup["torch_compatibility_matrix.json"] end subgraph output["Cog determines"] cuda["CUDA 12.1"] cudnn["cuDNN 8"] python["Python 3.10-3.13"] end torch --> lookup lookup --> cuda lookup --> cudnn lookup --> python ``` **Data files** (embedded JSON, generated by `tools/compatgen/`): - `pkg/config/torch_compatibility_matrix.json` - PyTorch ↔ CUDA mappings - `pkg/config/tf_compatibility_matrix.json` - TensorFlow ↔ CUDA mappings - `pkg/config/cuda_base_images.json` - Available NVIDIA base image tags These are regenerated when new framework versions are released and embedded into the CLI binary at build time. **What it stores** (for each framework release): - Framework version (e.g., `torch==2.1.0`) - Compatible CUDA versions - Compatible cuDNN versions - Compatible Python versions - Package index URLs (for CUDA-specific wheels) --- ### 4. Base Image Selection Base image selection uses the compatibility matrix to find a pre-built image that matches the required Python/CUDA/PyTorch combination. ```mermaid flowchart TD start["Config has Python + CUDA + Torch versions"] --> lookup{"Matching cog-base
image exists?"} lookup -->|"Yes"| cogbase["Use cog-base image
r8.im/cog-base:cuda12.1-python3.11-torch2.1.0"] lookup -->|"No"| gpu{"GPU enabled?"} gpu -->|"Yes"| cuda["Use NVIDIA CUDA image
nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04
(install Python + Torch in Dockerfile)"] gpu -->|"No"| slim["Use Python slim image
python:3.11-slim"] ``` #### Cog Base Images Pre-built images hosted at `r8.im/cog-base` with Python, CUDA, cuDNN, and PyTorch already installed. - Format: `r8.im/cog-base:cuda-python-torch` - Generated from the compatibility matrix (`BaseImageConfigurations()`) - Includes common system packages (ffmpeg, git, curl, etc.) - Faster builds since heavy dependencies are pre-installed #### Fallback: NVIDIA CUDA Images When no matching cog-base exists (e.g., unusual version combination): - Uses official `nvidia/cuda` images - Dockerfile installs Python via pyenv - Dockerfile installs PyTorch and other packages via pip - Slower builds but supports any valid combination --- ### 5. Embedded Python Wheel The Cog Python SDK is embedded in the Go binary at compile time and injected into images during build. During build, the wheel is: 1. Selected based on configuration (see below) 2. Copied into the Docker build context 3. Installed via pip #### Wheel Selection The `COG_SDK_WHEEL` environment variable controls which cog SDK wheel is installed: | Value | Source | |-------|--------| | (unset) | Latest `cog` from PyPI (auto-detects local wheel in dev builds) | | `pypi` | Latest `cog` from PyPI | | `pypi:0.18.0` | Specific version from PyPI | | `https://...` | Download wheel from URL | | `/path/to/file.whl` | Use local wheel file | This allows testing development versions of the SDK or pinning to specific releases. The `build.sdk_version` field in `cog.yaml` provides the same version-pinning capability without requiring environment variables. --- ### 6. Post-Build: Labels & Schema After the main build, Cog: 1. **Runs the container** to generate OpenAPI schema 2. **Runs pip freeze** to capture installed packages 3. **Applies labels** with metadata #### Image Labels | Label | Content | |-------|---------| | `run.cog.version` | Cog CLI version | | `run.cog.config` | Serialized cog.yaml | | `run.cog.openapi_schema` | OpenAPI spec from type hints | | `run.cog.pip_freeze` | Installed package versions | These labels can be fetched from a remote registry or local image store (like containerd) without pulling the full image. This allows tooling - both the Cog CLI during development and production infrastructure - to inspect model metadata and make decisions about how to run a model before booting it. --- ## Image Layer Structure A built Cog image has layers in this order (bottom to top): ``` ┌─────────────────────────────────────────────────┐ │ COPY . /src │ ← User code + weights ├─────────────────────────────────────────────────┤ │ RUN commands (from cog.yaml) │ ← Custom build steps ├─────────────────────────────────────────────────┤ │ pip install (python_packages) │ ← Python dependencies ├─────────────────────────────────────────────────┤ │ Cog wheel install │ ← Cog runtime ├─────────────────────────────────────────────────┤ │ apt-get install (system_packages) │ ← System dependencies ├─────────────────────────────────────────────────┤ │ tini init │ ← Process manager ├─────────────────────────────────────────────────┤ │ │ │ Base image │ ← Largest layer │ (OS, Python, CUDA, cuDNN, PyTorch) │ ~5-15 GB for GPU images │ │ └─────────────────────────────────────────────────┘ ``` The base image is by far the largest layer. Using a matching `cog-base` image means this layer is shared across builds and doesn't need to be re-downloaded or rebuilt. --- ## Code Reference | Component | Location | |-----------|----------| | CLI command | `pkg/cli/build.go` | | Build orchestration | `pkg/image/build.go` | | Dockerfile generator | `pkg/dockerfile/standard_generator.go` | | Base image selection | `pkg/dockerfile/base.go` | | Compatibility matrix | `pkg/config/compatibility.go` | | Embedded wheels | `pkg/wheels/wheels.go` | | Label definitions | `pkg/docker/command/manifest.go` | ================================================ FILE: architecture/06-cli.md ================================================ # CLI The Cog CLI is a Go binary that provides commands for the full model lifecycle: development, building, testing, and deployment. This document covers what each command does and how it connects to the systems described in previous docs. **Important**: Model code always runs inside a container, never on the host machine. Commands like `cog predict`, `cog train`, and `cog serve` build an image, start a container, and interact with it via the [Prediction API](./03-prediction-api.md). The CLI orchestrates this, but the model execution happens in the containerized [Container Runtime](./04-container-runtime.md). ## Commands Overview | Command | Job To Be Done | |---------|----------------| | `cog init` | Bootstrap a new model project | | `cog build` | Create a container image | | `cog predict` | Run a prediction in a container | | `cog train` | Run training in a container | | `cog run` | Run arbitrary commands in a container | | `cog serve` | Start HTTP server in a container | | `cog push` | Deploy to Replicate | | `cog login` | Authenticate with Replicate | ## Development Commands ### cog init **Job**: Create a starter `cog.yaml` and `predict.py` for a new model. ```bash cog init ``` Creates: - `cog.yaml` with sensible defaults - `predict.py` with a skeleton Predictor class **Code**: `pkg/cli/init.go` --- ### cog predict **Job**: Run a prediction in a container. ```bash cog predict -i prompt="A photo of a cat" -i steps=50 ``` What happens: 1. Builds the image (if needed) 2. Starts a container running the [Container Runtime](./04-container-runtime.md) 3. Parses `-i` flags against the [Schema](./02-schema.md) 4. Sends a [PredictionRequest](./03-prediction-api.md) to the container's HTTP API 5. Streams output back to terminal Input types are inferred from the schema: - Strings: `-i prompt="hello"` - Numbers: `-i steps=50` - Files: `-i image=@photo.jpg` (uploaded to container) - URLs: `-i image=https://example.com/photo.jpg` **Code**: `pkg/cli/predict.go` --- ### cog train **Job**: Run training in a container. ```bash cog train -i data=@dataset.zip -i epochs=10 ``` Same as `cog predict` but calls the `train()` method instead of `predict()`. **Code**: `pkg/cli/train.go` --- ### cog run **Job**: Run arbitrary commands in a container. ```bash cog run python -c "import torch; print(torch.cuda.is_available())" cog run bash ``` Builds the image (if needed), starts a container, and runs the specified command inside it. Useful for: - Debugging the container environment - Running one-off scripts - Interactive exploration **Code**: `pkg/cli/run.go` --- ### cog serve **Job**: Start the HTTP server in a container for testing. ```bash cog serve # Server running at http://localhost:5000 ``` Builds the image (if needed) and starts a container running the [Container Runtime](./04-container-runtime.md). The container's port 5000 is exposed to the host. You can then: - Send requests to `POST /predictions` - View Swagger UI at `/docs` - Test webhooks **Code**: `pkg/cli/serve.go` ## Build Commands ### cog build **Job**: Build a container image from [Model Source](./01-model-source.md). ```bash cog build -t my-model ``` What happens (see [Build System](./05-build-system.md) for details): 1. **Parse** `cog.yaml` 2. **Resolve** CUDA/cuDNN versions from compatibility matrix 3. **Generate** Dockerfile 4. **Build** image via Docker/Buildkit 5. **Run** container to extract [Schema](./02-schema.md) 6. **Apply** labels (schema, config, pip freeze) Key flags: - `-t, --tag`: Image tag - `--no-cache`: Disable Docker cache - `--separate-weights`: Exclude weights from image (for separate upload) **Code**: `pkg/cli/build.go`, `pkg/image/build.go` ## Deployment Commands ### cog push **Job**: Build and push to Replicate. ```bash cog push r8.im/username/model-name ``` What happens: 1. Builds image (like `cog build`) 2. Pushes to Replicate's registry 3. Registers model with Replicate API The image tag must be a Replicate model reference (`r8.im/owner/name`). **Code**: `pkg/cli/push.go`, `pkg/api/client.go` --- ### cog login **Job**: Authenticate with Replicate. ```bash cog login # or cog login --token-stdin < token.txt ``` Stores credentials for `cog push`. **Code**: `pkg/cli/login.go` ## How CLI Commands Interact with Containers Commands like `predict`, `train`, and `serve` follow the same pattern: build an image, start a container, communicate via HTTP. The CLI never runs model code directly. ```mermaid sequenceDiagram participant CLI as cog CLI (host) participant Docker participant Container as Container (runtime) CLI->>CLI: Parse -i flags, load cog.yaml CLI->>Docker: Build image (if needed) Docker-->>CLI: Image ready CLI->>Docker: Start container Docker->>Container: python -m cog.server.http Container->>Container: Run setup() loop Until READY CLI->>Container: GET /health-check Container-->>CLI: Status (STARTING/READY) end CLI->>Container: POST /predictions Container->>Container: Run predict() Container-->>CLI: Response JSON CLI->>Docker: Stop container ``` For what happens inside the container (setup, predict, IPC), see [Container Runtime](./04-container-runtime.md). ## CLI Architecture The CLI is built with [Cobra](https://github.com/spf13/cobra) (Go CLI framework). ``` cmd/cog/ └── cog.go # Entry point pkg/cli/ ├── root.go # Root command, subcommand registration ├── build.go # cog build ├── predict.go # cog predict ├── train.go # cog train ├── run.go # cog run ├── serve.go # cog serve ├── push.go # cog push ├── login.go # cog login └── init.go # cog init ``` Commands delegate to packages: - `pkg/image/` - Image building - `pkg/dockerfile/` - Dockerfile generation - `pkg/docker/` - Docker client operations - `pkg/config/` - cog.yaml parsing - `pkg/api/` - Replicate API client - `pkg/predict/` - Local prediction execution ## Code References | File | Purpose | |------|---------| | `pkg/cli/root.go` | Command registration | | `pkg/cli/build.go` | Build command | | `pkg/cli/predict.go` | Predict command, input parsing | | `pkg/cli/push.go` | Push command | | `pkg/image/build.go` | Build orchestration | | `pkg/predict/predictor.go` | Local prediction client | ================================================ FILE: architecture/ffi/03-prediction-api.md ================================================ # Prediction API (FFI/Rust) The FFI runtime implements the same Prediction API as the legacy runtime, using the same envelope format and endpoints. This document highlights FFI-specific behavior and implementation details. > **Note**: The API surface is identical to the [legacy implementation](../legacy/03-prediction-api.md). Clients don't need to change code when switching runtimes. ## Endpoints | Endpoint | Method | Purpose | FFI Notes | |----------|--------|---------|-----------| | `POST /predictions` | Create | Start a new prediction | Uses `SyncPredictionGuard` for automatic cancellation | | `PUT /predictions/{id}` | Create (idempotent) | Start or retrieve existing prediction | Concurrent-safe with DashMap | | `POST /predictions/{id}/cancel` | Cancel | Cancel a running prediction | Uses cancel tokens propagated to worker | | `GET /health-check` | Health | Check server status | Returns health state machine status | | `GET /` | Index | List available endpoints | Static route | | `GET /openapi.json` | Schema | OpenAPI specification | Cached from worker `Ready` event | ## FFI-Specific Behaviors ### Connection Drop Handling **Key difference from legacy**: Synchronous predictions automatically cancel when the client connection drops. ```rust // SyncPredictionGuard is RAII - drops when connection closes let guard = handle.sync_guard(); let result = service.predict(slot, input).await; // If connection drops here, guard.drop() cancels the prediction ``` This prevents wasted computation on predictions where the client is no longer listening. ### Health States The FFI runtime uses a more detailed health state machine. The `/health-check` endpoint always returns HTTP 200 with the status in the JSON body: | State | JSON `status` | Condition | |-------|---------------|-----------| | `STARTING` | `"STARTING"` | Worker subprocess initializing | | `READY` | `"READY"` | Worker ready, slots available | | `BUSY` | `"BUSY"` | All slots occupied (backpressure) | | `SETUP_FAILED` | `"SETUP_FAILED"` | `setup()` threw exception | | `DEFUNCT` | `"DEFUNCT"` | Fatal error, worker crashed | **New behavior**: When all concurrency slots are occupied, new predictions receive `409 Conflict` instead of queuing. Clients should implement retry with backoff. > **Note**: Prediction endpoints return 503 when health is not `READY`. ### Idempotent PUT Behavior The FFI runtime uses a concurrent-safe DashMap for prediction state: ```rust // Atomic check-or-insert match service.get_prediction_response(id) { Some(response) => return 202 + response, // Already exists None => { service.submit_prediction(id, input, webhook); // Create new return 202 + starting_state; } } ``` This is fully thread-safe without locks, unlike the legacy runtime which uses Python's asyncio locks. ## Request Flow Differences ### Sync Prediction (POST /predictions) **Legacy**: ```python # Connection drop has no effect result = await runner.predict(input) return result ``` **FFI**: ```rust // Connection drop triggers guard.drop() → cancellation let guard = handle.sync_guard(); // RAII guard let result = service.predict(slot, input).await; drop(guard); // Or automatic on scope exit return result; ``` ### Async Prediction (Prefer: respond-async) Behavior is identical to legacy, but implemented differently: **Legacy**: Uses asyncio tasks **FFI**: Uses tokio tasks with cancel tokens ```rust tokio::spawn(async move { let _result = service.predict(slot, input).await; // Prediction state is already updated by predict() internally // Webhooks fire automatically from Prediction mutation methods service.remove_prediction(id); }); ``` ### Cancellation Propagation **Legacy**: Sends `SIGUSR1` signal to child process **FFI**: Uses IPC message + different strategies for sync vs async predictors: ``` Parent: ControlRequest::Cancel { slot } │ └─▶ Worker: handler.cancel(slot) ``` **Sync Predictors:** ``` handler.cancel(slot) │ ├─▶ Set CANCEL_REQUESTED flag for slot │ ├─▶ Send SIGUSR1 to self │ └─▶ Signal handler: raise KeyboardInterrupt (if in cancelable region) Prediction code: with CancelableGuard(): # Sets CANCELABLE=true predictor.predict() # Can be interrupted # CANCELABLE=false on exit ``` **Async Predictors:** ``` handler.cancel(slot) │ ├─▶ Get future from slot state └─▶ future.cancel() │ └─▶ Python raises asyncio.CancelledError ``` This provides more reliable cancellation with proper handling for both sync and async execution models. ## Concurrency Model ### Slot-Based Permits The FFI runtime uses explicit permit tokens instead of async task limits: ```rust // Acquire permit (blocks if all slots busy) let permit = permit_pool.acquire().await?; // Permit is held during prediction let slot_id = permit.slot_id(); let result = orchestrator.predict(slot_id, input).await; // Permit automatically returned on drop drop(permit); // Or automatic on scope exit ``` **Advantages**: - Fixed, predictable concurrency - Fair queuing (FIFO permit acquisition) - Observable slot usage in metrics - No task explosion ### Configuration ```yaml # cog.yaml concurrency: max: 5 ``` This creates 5 slots in the PermitPool. Each slot corresponds to one Unix socket connection to the worker subprocess. ## File Handling File handling is identical to legacy (URLs downloaded to temp files, outputs uploaded), but the implementation differs: **Legacy**: Uses Python `aiohttp` + `requests` **FFI**: Uses Rust `reqwest` with connection pooling: ```rust // Download input files let client = reqwest::Client::builder() .connection_pool_idle_timeout(Duration::from_secs(30)) .build()?; let bytes = client.get(url).send().await?.bytes().await?; tokio::fs::write(temp_path, bytes).await?; ``` This provides better performance for large file downloads. ## Webhooks Webhook delivery is similar but with improvements: ### Retry Logic **Legacy**: Simple exponential backoff **FFI**: Structured retry with observability: ```rust let retry_policy = ExponentialBackoff::builder() .max_elapsed_time(Some(Duration::from_secs(60))) .build(); let webhook_sender = WebhookSender::new(client, retry_policy); webhook_sender.send_with_retry(url, payload).await?; ``` ### Trace Context Propagation The FFI runtime automatically propagates OpenTelemetry trace context in webhook headers: ```rust // Automatic trace propagation headers.insert("traceparent", trace_id); headers.insert("tracestate", trace_state); ``` This enables distributed tracing across prediction → webhook → downstream services. ## Status Lifecycle The status lifecycle is identical to legacy: ```mermaid stateDiagram-v2 [*] --> starting: Request received starting --> processing: predict() called processing --> succeeded: predict() returns processing --> failed: predict() raises exception processing --> canceled: Cancel requested succeeded --> [*] failed --> [*] canceled --> [*] ``` State transitions happen on the `Prediction` struct directly, which fires webhooks as a side effect: ```rust // State transitions fire webhooks automatically pred.set_processing(); // fires Start webhook // ... prediction runs, logs/outputs append ... pred.set_succeeded(output); // fires terminal Completed webhook ``` ## Dynamic Payload Handling Input validation and output serialization work the same as legacy: 1. **Parse JSON** → Extract `input` from request body 2. **Validate against schema** → Pydantic checks types (in worker subprocess) 3. **Download files** → Rust HTTP client fetches URLs 4. **Send to worker** → JSON-framed message via Unix socket 5. **Call predict()** → Python worker executes user code 6. **Capture output** → Worker sends back via slot channel 7. **Upload files** → Rust uploads to storage 8. **Serialize** → Return JSON response The key difference is that steps 1, 3, 7, 8 happen in Rust (faster), while steps 2, 5, 6 happen in Python (same as legacy). ## Error Handling ### Worker Crashes **Legacy**: Parent process becomes unstable, may need restart **FFI**: Server marks health as `DEFUNCT` but continues serving other endpoints: ```rust // Worker process died match worker.wait().await { Ok(status) if !status.success() => { health.set(Health::Defunct); // HTTP server still runs, returns 503 for predictions } } ``` ### Setup Failures Both runtimes mark the container as unhealthy, but FFI provides more detail: ```rust // Detailed setup failure match control_rx.recv().await? { ControlResponse::Failed { error } => { health.set(Health::SetupFailed { reason: error }); // Include error in health-check response } } ``` ## Performance Characteristics | Operation | Legacy | FFI | Improvement | |-----------|--------|-----|-------------| | Request parsing | Pydantic (Python) | serde (Rust) | ~3x faster | | File download | aiohttp | reqwest | ~2x faster | | Concurrency overhead | asyncio tasks | Tokio + permits | ~50% less memory | | Webhook delivery | Sequential retries | Concurrent + backoff | Better throughput | | State management | asyncio locks | DashMap (lock-free) | No contention | ## Environment Variables FFI-specific variables: | Variable | Default | Purpose | |----------|---------|---------| | `USE_COGLET` | unset | Enable FFI runtime (set to any value) | | `COG_CONCURRENCY_SLOTS` | 1 | Number of prediction slots | | `COG_WORKER_TIMEOUT` | 300s | Worker subprocess timeout | | `RUST_LOG` | info | Rust logging (tracing crate) | Legacy variables like `COG_MAX_CONCURRENCY` are ignored when using FFI. ## Code References | File | Purpose | |------|---------| | `crates/coglet/src/transport/http/routes.rs` | HTTP endpoint handlers | | `crates/coglet/src/prediction.rs` | Prediction state + webhook firing | | `crates/coglet/src/webhook.rs` | Webhook delivery with retries | | `crates/coglet/src/bridge/protocol.rs` | IPC message types | | `crates/coglet/src/permit/pool.rs` | Slot-based concurrency | ## Migration Notes When switching from legacy to FFI runtime: ✅ **No changes needed**: - HTTP API endpoints - Request/response format - Predictor code - Client code ⚠️ **Behavioral differences**: - Sync predictions cancel on connection drop - 409 responses when at capacity (not queuing) - Different health state granularity - Different environment variables 📈 **Improvements**: - ~2x faster HTTP layer - Better resource management - More reliable cancellation - Worker crash resilience ================================================ FILE: architecture/ffi/04-container-runtime.md ================================================ # Container Runtime (FFI/Rust) This document covers the FFI runtime implementation using Rust with PyO3 bindings. This is a complete rewrite of the HTTP server, moving from Python/FastAPI to Rust/Axum with a PyO3 ABI3 wheel. ## Overview The FFI runtime provides significant improvements over the legacy Python runtime: - **Rust HTTP server (Axum)**: Faster request handling, better backpressure management - **Worker isolation**: Python predictor crashes don't kill the server - **Slot-based concurrency**: Predictable resource control with permit pools - **Same API surface**: Drop-in replacement for the legacy runtime - **Subprocess reuse**: Predictor stays loaded between requests ## High-Level Architecture ``` ┌─────────────────────────────────────────────────────────────────────────────────┐ │ HTTP Transport (axum) │ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────────┐ │ │ │ POST │ │ PUT │ │ POST │ │ GET │ │ │ │ /predictions│ │ /predictions│ │ /cancel │ │ /health-check │ │ │ │ │ │ /{id} │ │ │ │ /openapi.json │ │ │ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ └───────────┬─────────────┘ │ └─────────┼────────────────┼────────────────┼─────────────────────┼───────────────┘ │ │ │ │ ▼ ▼ ▼ ▼ ┌─────────────────────────────────────────────────────────────────────────────────┐ │ PredictionService │ │ ┌────────────────────────────────────────────────────────────────────────────┐ │ │ │ Active Predictions (DashMap) │ │ │ │ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ │ │ │ │ �� PredictionEntry │ │ PredictionEntry │ │ PredictionEntry │ ... │ │ │ │ │ ─────────────── │ │ ─────────────── │ │ ─────────────── │ │ │ │ │ │ prediction (Arc)│ │ prediction (Arc)│ │ prediction (Arc)│ │ │ │ │ │ cancel_token │ │ cancel_token │ │ cancel_token │ │ │ │ │ │ input │ │ input │ │ input │ │ │ │ │ └─────────────────┘ └─────────────────┘ └─────────────────┘ │ │ │ └────────────────────────────────────────────────────────────────────────────┘ │ │ │ │ ┌────────────────────────────────────────────────────────────────────────────┐ │ │ │ PermitPool │ │ │ │ ┌────────┐ ┌────────┐ ┌────────┐ │ │ │ │ │ Permit │ │ Permit │ │ Permit │ (concurrency control) │ │ │ │ │ slot_0 │ │ slot_1 │ │ slot_2 │ │ │ │ │ └────────┘ └────────┘ └────────┘ │ │ │ └────────────────────────────────────────────────────────────────────────────┘ │ │ │ │ ┌────────────────────────────────────────────────────────────────────────────┐ │ │ │ OrchestratorHandle │ │ │ │ (slot_ids, control_tx for worker comms) │ │ │ └────────────────────────────────────────────────────────────────────────────┘ │ └──────────────────────────────────┬──────────────────────────────────────────────┘ │ Unix Socket (slot) + stdin/stdout (control) │ ▼ ┌─────────────────────────────────────────────────────────────────────────────────┐ │ Worker Subprocess (Python) │ │ ┌────────────────────────────────────────────────────────────────────────────┐ │ │ │ Predictor │ │ │ │ ┌─────────────────────────────────────────────────────────────────────┐ │ │ │ │ │ setup() → runs once at startup │ │ │ │ │ │ predict() → handles SlotRequest::Predict │ │ │ │ │ └─────────────────────────────────────────────────────────────────────┘ │ │ │ └────────────────────────────────────────────────────────────────────────────┘ │ └─────────────────────────────────────────────────────────────────────────────────┘ ``` ## Component Ownership The FFI runtime uses clear ownership patterns to manage prediction lifecycle: ``` ═══════════════════════════════════════════════════════════════════════════════════ COMPONENT OWNERSHIP ═══════════════════════════════════════════════════════════════════════════════════ PredictionService (single owner of prediction state) ├── owns: DashMap (active predictions) ├── owns: OrchestratorState (pool + orchestrator handle) ├── owns: health, setup_result, schema └── method: cancel() fires token + delegates to orchestrator PredictionEntry (in DashMap) ├── has: Arc> (the real state — single source of truth) ├── has: CancellationToken └── has: input (for API responses) Prediction (state machine — webhooks fire from mutation methods) ├── owns: status, logs, outputs, output, error, metrics ├── owns: WebhookSender (fires on set_processing, append_log, etc.) └── owns: completion notifier (for waiting on result) PredictionSlot (RAII container) ├── owns: Arc> (shared with DashMap entry) ├── owns: Permit (concurrency token, returns to pool on drop) └── Drop: marks permit idle, releases back to pool PredictionHandle (returned to route handler) ├── has: CancellationToken clone └── method: sync_guard(service) → SyncPredictionGuard (cancels on drop) Cancellation (via OrchestratorHandle) ├── Sync predictors: ControlRequest::Cancel → SIGUSR1 → KeyboardInterrupt └── Async predictors: ControlRequest::Cancel → future.cancel() → CancelledError ═══════════════════════════════════════════════════════════════════════════════════ ``` ## Worker Subprocess Protocol Communication between the Rust server and Python worker uses two channels: ### Control Channel (stdin/stdout - JSON framed) | Parent → Child | Child → Parent | |----------------|----------------| | `Init { predictor_ref, num_slots, ... }` | `Ready { slots, schema }` | | `Cancel { slot }` | `Log { source, data }` | | `Shutdown` | `Idle { slot }` | | | `Failed { slot, error }` | | | `ShuttingDown` | ### Slot Channel (Unix socket per slot - JSON framed) | Parent → Child | Child → Parent | |----------------|----------------| | `Predict { id, input }` | `Log { data }` | | | `Output { value }` (streaming) | | | `Done { output }` | | | `Failed { error }` | | | `Cancelled` | ## Health State Machine ```mermaid stateDiagram-v2 [*] --> STARTING: Container start note right of STARTING: Predictions return 503 STARTING --> READY: setup() succeeds STARTING --> SETUP_FAILED: setup() raises exception READY --> BUSY: All slots occupied note right of BUSY: New predictions get 409 BUSY --> READY: Slot freed READY --> DEFUNCT: Fatal error / worker crash BUSY --> DEFUNCT: Fatal error / worker crash note right of DEFUNCT: Predictions return 503 SETUP_FAILED --> [*] DEFUNCT --> [*] ``` ### Health States The health-check endpoint always returns HTTP 200 with the status in the JSON body. This allows load balancers and orchestrators to distinguish between "server is running but not ready" vs "server is down". | State | JSON `status` | Meaning | |-------|---------------|---------| | `STARTING` | `"STARTING"` | Worker subprocess initializing, `setup()` running | | `READY` | `"READY"` | Worker ready, at least one slot available | | `BUSY` | `"BUSY"` | All slots occupied, no capacity for new predictions | | `SETUP_FAILED` | `"SETUP_FAILED"` | `setup()` threw exception, cannot serve predictions | | `DEFUNCT` | `"DEFUNCT"` | Fatal error or worker crash, server unusable | > **Note**: Prediction endpoints (`/predictions`) return 503 when health is not `READY`. ## Prediction Flow ### Sync Request (POST /predictions) ```mermaid sequenceDiagram participant Client participant Routes participant Service participant Worker Client->>Routes: POST /predictions Routes->>Service: submit_prediction(id, input, webhook) Service-->>Routes: PredictionHandle + slot Note over Routes: SyncPredictionGuard held
(cancels on connection drop) Routes->>Service: predict(slot, input) Service->>Worker: predict(slot, input) Worker-->>Service: result Note over Service: Prediction.set_succeeded() fires webhook Routes-->>Client: 200 {output} ``` **Key behavior**: The `SyncPredictionGuard` is held for the duration of the request. If the client connection drops, the guard is dropped and the prediction is automatically cancelled. ### Async Request (Prefer: respond-async) ```mermaid sequenceDiagram participant Client participant Routes participant Service participant Worker Client->>Routes: POST + respond-async Routes->>Service: submit_prediction(id, input, webhook) Service-->>Routes: PredictionHandle + slot Routes-->>Client: 202 {status: "starting"} Note over Routes,Worker: spawned task continues independently par Background Task Service->>Worker: predict(slot, input) Worker-->>Service: result Note over Service: Prediction mutations fire webhooks automatically end Service-->>Client: webhook (completed) ``` **Key behavior**: No guard is held. The prediction continues even if the client disconnects. ### Idempotent PUT (PUT /predictions/{id}) ```mermaid sequenceDiagram participant Client participant Routes participant Service Client->>Routes: PUT /predictions/X Routes->>Service: get_prediction_response("X") alt Prediction exists Service-->>Routes: existing state Routes-->>Client: 202 + full state else Prediction doesn't exist Routes->>Service: submit_prediction + predict Routes-->>Client: 202 + starting state end ``` ### Connection Drop (Sync Mode) ```mermaid sequenceDiagram participant Client participant Routes participant Service participant Worker Client->>Routes: POST /predictions Note over Routes: SyncPredictionGuard armed Routes->>Worker: predict(slot) Client-xRoutes: ✕ connection drops Note over Routes: guard.drop() Routes->>Service: cancel(id) Service->>Worker: Cancel Worker-->>Service: Cancelled ``` ## File Structure ``` crates/coglet/src/ ├── lib.rs # Public API exports ├── service.rs # PredictionService (single owner of prediction state) ├── prediction.rs # Prediction state (logs, outputs, status) ├── health.rs # Health enum + SetupResult ├── orchestrator.rs # Worker subprocess management ├── permit/ │ ├── mod.rs │ ├── pool.rs # PermitPool (concurrency control) │ └── slot.rs # PredictionSlot (Prediction + Permit RAII) ├── bridge/ │ ├── mod.rs │ ├── protocol.rs # Control/Slot request/response types │ ├── codec.rs # JSON length-delimited framing │ └── transport.rs # Unix socket transport ├── transport/ │ └── http/ │ ├── mod.rs │ ├── server.rs # Axum server setup │ └── routes.rs # HTTP handlers (uses service) ├── webhook.rs # WebhookSender (retry logic, trace context) ├── worker.rs # run_worker, PredictHandler trait, SetupError └── version.rs # VersionInfo crates/coglet-python/src/ └── lib.rs # PyO3 bindings (coglet.server.serve()) ``` ## Invocation Path How coglet gets invoked when running a Cog container: ``` ┌─────────────────────────────────────────────────────────────────────────────┐ │ cog predict / cog run │ │ (CLI) │ └─────────────────────────────────┬───────────────────────────────────────────┘ │ ▼ ┌─────────────────────────────────────────────────────────────────────────────┐ │ python -m cog.server.http │ │ │ │ if USE_COGLET env var: │ │ import coglet │ │ coglet.server.serve(predictor_ref, port=5000) ──────────────────┐ │ │ else: │ │ │ # original Python FastAPI server │ │ │ uvicorn.run(app, port=5000) │ │ └─────────────────────────────────────────────────────────────────────────┼───┘ │ ▼ ┌─────────────────────────────────────────────────────────────────────────────┐ │ coglet (Rust) │ │ │ │ ┌───────────────────────────────────────────────────────────────────┐ │ │ │ HTTP Server (axum) :5000 │ │ │ │ /predictions, /health-check, etc. │ │ │ └───────────────────────────────────────────────────────────────────┘ │ │ │ │ │ ▼ │ │ ┌───────────────────────────────────────────────────────────────────┐ │ │ │ PredictionService (state, webhooks, permits) │ │ │ └───────────────────────────────────────────────────────────────────┘ │ │ │ │ │ Unix socket + pipes │ │ │ │ │ ▼ │ │ ┌───────────────────────────────────────────────────────────────────┐ │ │ │ Worker subprocess (Python) │ │ │ │ - loads predictor_ref │ │ │ │ - runs setup() │ │ │ │ - handles predict() requests │ │ │ └───────────────────────────────────────────────────────────────────┘ │ └─────────────────────────────────────────────────────────────────────────────┘ ``` ## Key Design Decisions ### Why Rust? - **Performance**: Axum is faster than Uvicorn/FastAPI for HTTP handling - **Stability**: Server doesn't crash when user code fails - **Resource management**: Better backpressure and concurrency control - **Memory safety**: No Python GIL contention in HTTP layer ### Why PyO3 FFI? - **ABI3 wheel**: Single wheel works across Python 3.10-3.13 - **Native performance**: Direct C API calls, no serialization overhead - **Same predictor code**: Users don't change anything - **Drop-in replacement**: Same HTTP API, same behavior ### Why Subprocess (not in-process)? - **Isolation**: Python crashes/segfaults don't kill server - **CUDA context**: Clean GPU initialization per worker - **Memory**: Fresh address space for model loading - **Restart potential**: Architecture enables future worker restart on fatal errors (not yet implemented) ### Why Slots (not async tasks)? - **Predictable**: Fixed number of concurrent predictions - **Fair**: Permits prevent starvation - **Observable**: Easy to monitor slot usage - **Simple**: No async complexity in worker subprocess ## Environment Variables | Variable | Default | Purpose | |----------|---------|---------| | `USE_COGLET` | unset | Enable FFI runtime (set to any value) | | `PORT` | 5000 | HTTP server port | | `COG_LOG_LEVEL` | INFO | Logging verbosity | | `COG_CONCURRENCY_SLOTS` | 1 | Number of concurrent prediction slots | ## Comparison to Legacy Runtime | Aspect | Legacy (Python) | FFI (Rust) | |--------|----------------|------------| | HTTP Server | FastAPI/Uvicorn | Axum | | Language | Pure Python | Rust + PyO3 | | IPC | multiprocessing.Pipe (pickled) | Unix socket + pipes (JSON) | | Concurrency | async tasks | Slot-based permits | | Cancellation | SIGUSR1 signal | IPC message + SIGUSR1 (sync) / future.cancel() (async) | | Connection drop | No effect on prediction | Cancels sync predictions | | Worker crash | Server unstable | Server stays up, marks DEFUNCT | | Performance | Baseline | ~2x faster HTTP layer | ## Code References | File | Purpose | |------|---------| | `crates/coglet/src/service.rs` | Main orchestrator: PredictionService | | `crates/coglet/src/prediction.rs` | Prediction state machine + webhook firing | | `crates/coglet/src/transport/http/routes.rs` | HTTP endpoint handlers | | `crates/coglet/src/permit/pool.rs` | Slot-based concurrency control | | `crates/coglet/src/orchestrator.rs` | Worker subprocess spawn/management | | `crates/coglet/src/bridge/protocol.rs` | IPC message definitions | | `crates/coglet-python/src/lib.rs` | PyO3 Python bindings | | `python/cog/server/http.py` | Entry point (checks USE_COGLET) | ================================================ FILE: architecture/ffi/README.md ================================================ # FFI Runtime (Rust + PyO3) This directory documents the next-generation Cog runtime implementation using Rust with PyO3 FFI bindings. ## Status This is an **experimental** runtime implementation currently in development. It provides significant improvements in: - Performance and resource management - Worker process isolation and stability - Concurrency control with slot-based permits - Graceful cancellation and connection drop handling ## When to Use Enable this implementation by setting the `USE_COGLET` environment variable when running Cog containers. ## Key Improvements - **Rust HTTP server (Axum)**: Faster, better backpressure handling - **Worker isolation**: Python crashes don't kill the server - **Slot-based concurrency**: Predictable resource management with permit pool - **Subprocess reuse**: Predictor stays loaded between requests - **Better cancellation**: Sync predictions cancel on connection drop via RAII guards ## Architecture Overview ``` HTTP Server (Rust/Axum) ↓ PredictionService (state, webhooks, DashMap) ↓ PermitPool (slot-based concurrency) ↓ Orchestrator → Worker Subprocess (Python) ↓ (Unix socket + pipes) Predictor (setup/predict) ``` ## Documentation - [Prediction API](./03-prediction-api.md) - HTTP endpoints with coglet-specific behavior - [Container Runtime](./04-container-runtime.md) - Complete FFI architecture and flow ## Implementation Primary code location: `crates/coglet/` - `src/transport/http/` - Axum HTTP server - `src/service.rs` - PredictionService (single owner of prediction state) - `src/permit/` - Slot-based concurrency control - `src/orchestrator.rs` - Worker subprocess management - `src/bridge/` - IPC protocol and transport - `src/worker/` - Worker implementation Python bindings: `crates/coglet-python/src/lib.rs` ================================================ FILE: architecture/legacy/03-prediction-api.md ================================================ # Prediction API The Prediction API is the HTTP interface for running model inference. It uses a fixed **envelope format** that wraps model-specific inputs and outputs, allowing a uniform API across all Cog models. ## Endpoints | Endpoint | Method | Purpose | |----------|--------|---------| | `POST /predictions` | Create | Start a new prediction | | `PUT /predictions/{id}` | Create (idempotent) | Start or retrieve existing prediction | | `POST /predictions/{id}/cancel` | Cancel | Cancel a running prediction | | `GET /health-check` | Health | Check server status | | `GET /` | Index | List available endpoints | | `GET /openapi.json` | Schema | OpenAPI specification | By default, `POST /predictions` blocks until completion. For long-running predictions, use async mode with `Prefer: respond-async` header - the response returns immediately with status `processing`, and progress updates are delivered via webhook. ## The Envelope Pattern Every Cog model exposes the same endpoints with the same request/response structure. The model-specific parts (input fields, output type) are defined by the [Schema](./02-schema.md) and validated at runtime. ``` ┌────────────────────────────────────────────────────────┐ │ Fixed Envelope (same for all models) │ │ ┌──────────────────────────────────────────────────┐ │ │ │ id, status, created_at, logs, metrics, ... │ │ │ └──────────────────────────────────────────────────┘ │ │ ┌──────────────────────────────────────────────────┐ │ │ │ input: { ... } ← model-specific (from schema)│ │ │ └──────────────────────────────────────────────────┘ │ │ ┌──────────────────────────────────────────────────┐ │ │ │ output: ... ← model-specific (from schema)│ │ │ └──────────────────────────────────────────────────┘ │ └────────────────────────────────────────────────────────┘ ``` This pattern means: - Clients use the same code to call any Cog model - Platforms can route requests without understanding model internals - Input validation is schema-driven, not hardcoded ## PredictionRequest What clients send to start a prediction: ```json { "id": "abc-123", "input": { "prompt": "A photo of a cat", "steps": 50 }, "webhook": "https://example.com/webhook", "webhook_events_filter": ["start", "output", "logs", "completed"] } ``` | Field | Type | Purpose | |-------|------|---------| | `id` | string (optional) | Client-provided ID for idempotency | | `input` | object | **Model-specific** - validated against schema | | `webhook` | URL (optional) | Where to send progress updates | | `webhook_events_filter` | array (optional) | Which events to send | | `created_at` | datetime (optional) | Client-provided timestamp | The `input` object is validated against the `Input` schema generated from the predictor's `predict()` signature. Unknown fields are rejected; missing required fields raise validation errors. ## PredictionResponse What comes back from the API: ```json { "id": "abc-123", "status": "succeeded", "input": { "prompt": "A photo of a cat", "steps": 50 }, "output": "https://storage.example.com/output.png", "logs": "Loading model...\nGenerating image...\nDone.", "error": null, "metrics": { "predict_time": 4.52 }, "created_at": "2024-01-15T10:30:00Z", "started_at": "2024-01-15T10:30:01Z", "completed_at": "2024-01-15T10:30:05Z" } ``` | Field | Type | Purpose | |-------|------|---------| | `id` | string | Prediction identifier | | `status` | enum | `starting`, `processing`, `succeeded`, `canceled`, `failed` | | `input` | object | Echo of the input (for reference) | | `output` | any | **Model-specific** - type defined by schema | | `logs` | string | Captured stdout/stderr from predict() | | `error` | string | Error message if status is `failed` | | `metrics` | object | Timing and other metrics | | `created_at` | datetime | When request was received | | `started_at` | datetime | When prediction began | | `completed_at` | datetime | When prediction finished | ## Status Lifecycle ```mermaid stateDiagram-v2 [*] --> starting: Request received starting --> processing: predict() called processing --> succeeded: predict() returns processing --> failed: predict() raises exception processing --> canceled: Cancel requested succeeded --> [*] failed --> [*] canceled --> [*] ``` ## Dynamic Payload Handling The magic of the envelope pattern is that the `input` and `output` fields are dynamically typed based on the schema. ### Input Validation Flow ```mermaid flowchart LR subgraph request["Incoming Request"] json["JSON body"] end subgraph validation["Validation"] schema["Schema (Input type)"] pydantic["Pydantic Model"] end subgraph transform["Transformation"] download["Download URLs → Files"] coerce["Type Coercion"] end subgraph predict["predict()"] kwargs["**kwargs"] end json --> pydantic schema --> pydantic pydantic --> download download --> coerce coerce --> kwargs ``` 1. **Parse JSON** - Extract `input` from request body 2. **Validate against schema** - Pydantic checks types, required fields, constraints 3. **Download files** - URLs in `cog.Path` fields are fetched to local temp files 4. **Coerce types** - Strings become Paths, etc. 5. **Call predict()** - Validated input passed as `**kwargs` ### Output Handling Flow ```mermaid flowchart LR subgraph predict["predict()"] result["Return value / yields"] end subgraph transform["Transformation"] upload["Upload files → URLs"] serialize["JSON serialization"] end subgraph response["Response"] output["output field"] end result --> upload upload --> serialize serialize --> output ``` 1. **Capture output** - Return value or yielded values from predict() 2. **Upload files** - `cog.Path` outputs are uploaded, replaced with URLs 3. **Serialize** - Convert to JSON-compatible format 4. **Return** - Place in `output` field of response ### File Handling Input files (cog.Path): ``` Client sends: {"input": {"image": "https://example.com/photo.jpg"}} Server downloads: /tmp/inputabc123.jpg predict() sees: image = Path("/tmp/inputabc123.jpg") ``` Output files (cog.Path): ``` predict() returns: Path("/tmp/output.png") Server uploads: https://storage.example.com/output-xyz.png Client receives: {"output": "https://storage.example.com/output-xyz.png"} ``` ## Webhooks For async predictions, progress is delivered via webhooks: ```mermaid sequenceDiagram participant Client participant Cog participant Webhook Client->>Cog: POST /predictions (Prefer: respond-async) Cog-->>Client: 202 {status: "starting"} Cog->>Webhook: {status: "starting"} Note over Cog: predict() starts Cog->>Webhook: {status: "processing"} loop Output yields Cog->>Webhook: {output: "partial...", logs: "..."} end Cog->>Webhook: {status: "succeeded", output: "final"} ``` ### Webhook Events | Event | When | Payload Contains | |-------|------|------------------| | `start` | Prediction begins | `status: starting` | | `output` | Each yield from iterator | Partial `output` | | `logs` | Log lines captured | Updated `logs` | | `completed` | Prediction finishes | Final `status`, `output`, `metrics` | Filter events with `webhook_events_filter`: ```json { "input": {...}, "webhook": "https://...", "webhook_events_filter": ["completed"] } ``` ## Streaming Output For models that yield output progressively: ```python def predict(self, prompt: str) -> Iterator[str]: for token in generate(prompt): yield token ``` The API can deliver these as: 1. **Webhooks** - Each yield triggers an `output` webhook 2. **Server-Sent Events** - Stream via `Accept: text/event-stream` 3. **Final array** - Sync response collects all yields into `output: ["a", "b", "c"]` ## Training API The training API (`/trainings`) uses the same envelope pattern: - `TrainingRequest` extends `PredictionRequest` - `TrainingResponse` extends `PredictionResponse` - Calls `train()` method instead of `predict()` ## Code References | File | Purpose | |------|---------| | `python/cog/schema.py` | `PredictionRequest`, `PredictionResponse`, `Status` | | `python/cog/server/http.py` | HTTP endpoints, request handling | | `python/cog/server/runner.py` | Prediction orchestration | | `python/cog/server/webhook.py` | Webhook delivery | ================================================ FILE: architecture/legacy/04-container-runtime.md ================================================ # Container Runtime This document covers what happens when a Cog container runs. It's where the [Model Source](./01-model-source.md), [Schema](./02-schema.md), and [Prediction API](./03-prediction-api.md) come together. ## Overview When a Cog container runs, it executes a **two-process architecture** with a minimal init system. The design isolates user model code from the HTTP server for stability, resource management, and clean shutdown handling. ## High-Level Architecture ```mermaid flowchart TB subgraph container["Cog Container"] subgraph init["tini (PID 1)"] tini["Signal forwarding & zombie reaping"] end subgraph parent["Parent Process (HTTP Server)"] direction TB subgraph components["Components"] direction LR fastapi["FastAPI/Uvicorn
port 5000"] worker["Worker
(parent-side)"] runner["PredictionRunner
(orchestrator)"] end subgraph threads["Thread Pools"] direction LR t1["Event consumer"] t2["Prediction start"] t3["Input download
(8 threads)"] end end pipe[["multiprocessing.Pipe
(bidirectional IPC)"]] subgraph child["Child Process (_ChildWorker)"] direction TB subgraph child_components["Components"] direction LR predictor["User Predictor
(predict.py)
---
setup()
predict()
train()"] redirector["StreamRedirector
stdout/stderr
capture"] eventloop["Event Loop
(sync/async)"] end end end init --> parent parent <--> pipe pipe <--> child ``` ## Process Roles ### tini (PID 1) - **What**: Minimal init system (~30KB binary) - **Why**: Proper signal forwarding to children, zombie process reaping - **Entry**: `ENTRYPOINT ["/sbin/tini", "--"]` ### Parent Process (HTTP Server) - **What**: Python process running FastAPI/Uvicorn - **Entry**: `CMD ["python", "-m", "cog.server.http"]` - **Responsibilities**: - HTTP API on port 5000 - Request validation (Pydantic) - Input file downloading (from URLs) - Webhook delivery - Output file uploads - Health state management - Child process lifecycle ### Child Process (_ChildWorker) - **What**: Isolated Python process for user code - **Spawned via**: `multiprocessing.get_context("spawn").Process` - **Responsibilities**: - Load user's predictor module - Run `setup()` once at startup - Execute `predict()` / `train()` methods - Capture stdout/stderr - Send events back to parent ## Why Two Processes? 1. **Isolation**: User code crashes don't bring down the HTTP server 2. **Memory**: Fresh address space for each model load (spawn vs fork) 3. **CUDA**: Clean GPU context initialization in child 4. **Cleanup**: Parent can restart child if it dies 5. **Monitoring**: Parent tracks child health independently ## Inter-Process Communication ```mermaid flowchart LR subgraph parent["Parent Process"] Worker end subgraph child["Child Process"] ChildWorker["_ChildWorker"] end Worker -->|"PredictionInput
Cancel
Shutdown"| ChildWorker ChildWorker -->|"Log
PredictionOutput
PredictionOutputType
PredictionMetric
Done"| Worker ``` Communication uses Python's `multiprocessing.Pipe()` with pickled `Envelope` objects: ```python @define class Envelope: event: Union[Cancel, PredictionInput, Shutdown, Log, ...] tag: Optional[str] = None # Routes concurrent predictions ``` ### Event Types | Event | Direction | Purpose | |-------|-----------|---------| | `PredictionInput` | Parent → Child | Start prediction with input payload | | `Cancel` | Parent → Child | Abort the current prediction | | `Shutdown` | Parent → Child | Graceful termination signal | | `PredictionOutputType` | Child → Parent | Declares the output type (once per prediction) | | `PredictionOutput` | Child → Parent | Output value (multiple for generators) | | `Log` | Child → Parent | Captured stdout/stderr line | | `PredictionMetric` | Child → Parent | Timing/performance metrics | | `Done` | Child → Parent | Prediction complete (success or failure) | ## Request Flow: Prediction Lifecycle ```mermaid sequenceDiagram participant Client participant FastAPI participant Runner as PredictionRunner participant Worker as Worker (parent) participant Pool as ThreadPool participant Child as _ChildWorker participant Predictor as User predict() Client->>FastAPI: POST /predictions
{"input": {"prompt": "..."}} FastAPI->>Runner: predict(request) Runner->>Worker: predict(payload, tag) Worker->>Pool: Download input URLs Pool-->>Worker: Local file paths Worker->>Child: PredictionInput event Child->>Predictor: predict(**payload) loop Generator yields / prints Predictor-->>Child: yield output / print() Child-->>Worker: PredictionOutput / Log events Worker-->>Runner: handle_event() Runner-->>Client: Webhook (if configured) end Predictor-->>Child: return Child-->>Worker: Done event Worker-->>Runner: handle_event() Runner->>Runner: Upload output files Runner->>Client: Send final webhook Runner-->>FastAPI: PredictTask complete FastAPI-->>Client: Response JSON
{"output": "...", "status": "succeeded"} ``` ## Key Components Deep Dive ### HTTP Server (`http.py`) | Endpoint | Method | Purpose | |----------|--------|---------| | `/` | GET | API index | | `/health-check` | GET | Health status | | `/predictions` | POST | New prediction | | `/predictions/{id}` | PUT | Idempotent create | | `/predictions/{id}/cancel` | POST | Cancel running | | `/shutdown` | POST | Graceful shutdown | ### Health States ```mermaid stateDiagram-v2 [*] --> STARTING: Container start STARTING --> READY: setup() succeeds STARTING --> SETUP_FAILED: setup() raises exception READY --> BUSY: prediction starts BUSY --> READY: prediction completes READY --> DEFUNCT: child dies unexpectedly BUSY --> DEFUNCT: child dies unexpectedly SETUP_FAILED --> [*] DEFUNCT --> [*] ``` ### StreamRedirector (Output Capture) The child process captures stdout/stderr including native library output (CUDA, etc.): ```mermaid flowchart LR subgraph child["Child Process"] subgraph usercode["User Code"] predict["predict()"] end subgraph redirector["StreamRedirector"] original["Original fd 1/2
(saved)"] pipewrite["Pipe write end
(replaces fd 1/2)"] reader["Reader Thread"] end predict -->|"print()
CUDA logs"| pipewrite pipewrite --> reader end reader -->|"Log events"| parent["To Parent Process"] ``` ## Concurrency Model ### Default: Sequential (`max_concurrency=1`) - One prediction at a time - Sync `def predict()` supported - Cancellation via `SIGUSR1` signal ### Concurrent (`max_concurrency > 1`) - Requires `async def predict()` - Python 3.11+ for `asyncio.TaskGroup` - Configure in `cog.yaml`: ```yaml concurrency: max: 5 ``` ```mermaid gantt title max_concurrency=1 (Sequential) dateFormat X axisFormat %s section Predictions Prediction 1 :0, 3 Prediction 2 :3, 6 Prediction 3 :6, 9 ``` ```mermaid gantt title max_concurrency=5 (Concurrent) dateFormat X axisFormat %s section Predictions Prediction 1 :0, 4 Prediction 2 :1, 4 Prediction 3 :2, 6 Prediction 4 :0, 5 Prediction 5 :3, 5 ``` ## Environment Variables | Variable | Default | Purpose | |----------|---------|---------| | `PORT` | 5000 | HTTP server port | | `COG_LOG_LEVEL` | INFO | Logging verbosity | | `COG_MAX_CONCURRENCY` | 1 | Max concurrent predictions | | `COG_THROTTLE_RESPONSE_INTERVAL` | 0.5s | Webhook rate limit | ## File Locations | Path | Purpose | |------|---------| | `/var/run/cog/ready` | K8s readiness probe touch file | | `/src` | User code (WORKDIR) | | `/src/weights` | Common weights location | ## Code References | File | Purpose | |------|---------| | `python/cog/server/http.py` | FastAPI app, endpoints | | `python/cog/server/worker.py` | Worker, _ChildWorker | | `python/cog/server/runner.py` | PredictionRunner | | `python/cog/server/webhook.py` | Webhook delivery | | `python/cog/server/stream_redirector.py` | Output capture | ================================================ FILE: architecture/legacy/README.md ================================================ # Legacy Python Runtime (FastAPI) This directory documents the original Cog runtime implementation using Python's FastAPI/Uvicorn HTTP server. ## Status This is the **current default** runtime implementation. It uses a two-process architecture with: - Parent process: FastAPI/Uvicorn HTTP server - Child process: User predictor code in isolated subprocess - IPC: Python `multiprocessing.Pipe` with pickled events ## When to Use This implementation is used by default when running Cog containers unless the `USE_COGLET` environment variable is set. ## Documentation - [Prediction API](./03-prediction-api.md) - HTTP endpoints and request/response format - [Container Runtime](./04-container-runtime.md) - Two-process architecture and execution flow ## Implementation Primary code location: `python/cog/server/` - `http.py` - FastAPI application and endpoints - `worker.py` - Worker process management - `runner.py` - Prediction orchestration - `webhook.py` - Webhook delivery - `stream_redirector.py` - Output capture ================================================ FILE: cmd/cog/cog.go ================================================ package main import ( "github.com/replicate/cog/pkg/cli" "github.com/replicate/cog/pkg/util/console" ) func main() { cmd, err := cli.NewRootCommand() if err != nil { console.Fatalf("%f", err) } if err = cmd.Execute(); err != nil { console.Fatalf("%s", err) } } ================================================ FILE: crates/.gitignore ================================================ /target/ ================================================ FILE: crates/Cargo.toml ================================================ [workspace] resolver = "2" members = ["coglet", "coglet-python"] [workspace.package] version = "0.17.0-rc.2" edition = "2024" license = "Apache-2.0" repository = "https://github.com/replicate/cog" homepage = "https://cog.run" documentation = "https://cog.run/docs" keywords = ["machine-learning", "inference", "containers", "prediction"] categories = ["development-tools", "web-programming"] [workspace.dependencies] # Async runtime tokio = { version = "1", features = ["full"] } tokio-util = "0.7" futures = "0.3" # HTTP server axum = "0.8" # HTTP client reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls-native-roots"] } # Serialization serde = { version = "1", features = ["derive"] } serde_json = "1" # Identifiers uuid = { version = "1", features = ["v4", "serde"] } # Observability tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] } # Error handling thiserror = "2" anyhow = "1" # Python bindings pyo3 = { version = "0.27", features = ["abi3-py310"] } pyo3-async-runtimes = { version = "0.27", features = ["tokio-runtime"] } pyo3-stub-gen = "0.18" # Testing insta = { version = "1", features = ["json"] } ================================================ FILE: crates/README.md ================================================ # Coglet: Rust Runtime for Cog Coglet is the Rust-based prediction server that powers Cog's subprocess isolation model. It provides process isolation, concurrent slot management, and high-performance IPC for running ML predictions. ## Architecture Overview ``` ┌─────────────────────────────────────────────────────────────────────────────┐ │ Parent Process │ │ ┌─────────────┐ ┌──────────────┐ ┌─────────────────────────────────┐ │ │ │ HTTP Server │───▶│ Prediction │───▶│ Orchestrator │ │ │ │ (axum) │ │ Service │ │ - Spawns worker subprocess │ │ │ └─────────────┘ └──────────────┘ │ - Routes predictions to slots │ │ │ │ - Handles worker lifecycle │ │ │ └───────────────┬─────────────────┘ │ │ │ │ │ ┌───────────────────────────────┼───────────────┐ │ │ │ Control Channel (stdin/stdout - JSON lines) │ │ │ │ - Init, Ready, Cancel, Shutdown │ │ │ └───────────────────────────────┼───────────────┘ │ │ │ │ │ ┌───────────────────────────────┼───────────────┐ │ │ │ Slot Sockets (Unix domain - per slot) │ │ │ │ - Predict requests │ │ │ │ - Streaming logs, outputs │ │ │ │ - Done/Failed/Cancelled responses │ │ │ └───────────────────────────────┼───────────────┘ │ └──────────────────────────────────────────────────────────┼───────────────────┘ │ ┌──────────────────────────────────────────────────────────┼───────────────────┐ │ Worker Subprocess │ │ │ ┌─────────────────────────────────────────────────────────────────────────┐ │ │ │ Python Runtime (GIL) │ │ │ │ ┌─────────────────┐ ┌─────────────────┐ ┌───────────────────────┐ │ │ │ │ │ PythonPredictor │ │ SlotLogWriter │ │ Audit Hook │ │ │ │ │ │ - load() │ │ (sys.stdout/err)│ │ - Protects streams │ │ │ │ │ │ - setup() │ │ Routes via │ │ - Tee pattern for │ │ │ │ │ │ - predict() │ │ ContextVar │ │ user overrides │ │ │ │ │ └─────────────────┘ └─────────────────┘ └───────────────────────┘ │ │ │ └─────────────────────────────────────────────────────────────────────────┘ │ │ │ │ ┌──────────────────────────────────────────────────────────────────────┐ │ │ │ Tokio Runtime │ │ │ │ - Async event loop for slot socket I/O │ │ │ │ - Releases GIL during I/O (py.detach) │ │ │ │ - Single async executor for async predictors │ │ │ └──────────────────────────────────────────────────────────────────────┘ │ └──────────────────────────────────────────────────────────────────────────────┘ ``` ## Prediction Flow ``` HTTP Request Parent Process Worker Subprocess │ │ │ │ POST /predictions │ │ ├───────────────────────────────▶│ │ │ │ │ │ ┌───────────┴───────────┐ │ │ │ 1. Acquire slot permit│ │ │ │ 2. Register prediction│ │ │ └───────────┬───────────┘ │ │ │ │ │ │ SlotRequest::Predict │ │ │ {id, input} │ │ ├─────────────────────────────────▶│ │ │ (slot socket) │ │ │ │ │ │ ┌───────────┴───────────┐ │ │ │ 3. Set ContextVar │ │ │ │ 4. Call predict() │ │ │ └───────────┬───────────┘ │ │ │ │ │ SlotResponse::Log │ │ │◀─────────────────────────────────┤ (streaming) │ │ │ │ │ SlotResponse::Output │ │ │◀─────────────────────────────────┤ (generators) │ │ │ │ │ SlotResponse::Done │ │ │◀─────────────────────────────────┤ │ │ {id, output, predict_time} │ │ │ │ │ ┌───────────┴───────────┐ │ │ │ 5. Update prediction │ │ │ │ 6. Release permit │ │ │ │ 7. Send webhook │ │ │ └───────────┬───────────┘ │ │ │ │ │ 200 OK {output} │ │ │◀───────────────────────────────┤ │ │ │ │ ``` ## Startup Sequence ``` 1. coglet.server.serve() called from Python │ ├─▶ Start HTTP server immediately (health returns STARTING until ready) │ └─▶ Spawn orchestrator task │ ├─▶ Create slot transport (Unix sockets) │ ├─▶ Spawn worker: python -c "import coglet; coglet.server._run_worker()" │ ├─▶ Send Init message (predictor_ref, num_slots, transport_info) │ │ │ │ ┌────────────────────────────────────────────────┐ │ └──▶│ Worker: connect sockets, install log writers, │ │ │ install audit hook, load predictor, run setup │ │ └────────────────────────────────────────────────┘ │ ├─▶ Wait for Ready {slots, schema} or Failed {error} │ ├─▶ Populate PermitPool with slot sockets │ ├─▶ Start event loop (routes responses to predictions) │ └─▶ Set health = READY, start accepting predictions ``` ## Components ### coglet (core library) Pure Rust library with no Python dependencies. Provides: - **orchestrator.rs** - Spawns worker, manages lifecycle, routes messages - **worker.rs** - Child-side event loop, prediction execution - **service.rs** - Transport-agnostic prediction service - **permit/** - Slot-based concurrency control (PermitPool) - **bridge/** - IPC protocol and transport (Unix sockets + JSON codec) - **transport/http/** - Axum-based HTTP server and routes ### coglet-python (PyO3 bindings) Bridges coglet to Python via PyO3. Provides: - **lib.rs** - Python module with `serve()`, `active()`, `_run_worker()` - **predictor.rs** - Wraps Python predictor class (sync/async detection) - **worker_bridge.rs** - Implements `PredictHandler` trait for Python - **log_writer.rs** - ContextVar-based stdout/stderr routing - **audit.rs** - Protects runtime streams from user code - **cancel.rs** - SIGUSR1-based cancellation for sync predictors ## Directory Structure ``` crates/ ├── Cargo.toml # Workspace manifest ├── Cargo.lock ├── deny.toml # cargo-deny configuration │ ├── coglet/ # Core Rust library │ ├── Cargo.toml │ └── src/ │ ├── lib.rs # Public API exports │ ├── health.rs # Health/SetupStatus types │ ├── prediction.rs # Prediction state machine │ ├── predictor.rs # PredictionResult, PredictionError │ ├── service.rs # PredictionService │ ├── webhook.rs # WebhookSender (retry, trace context) │ ├── version.rs # Version info │ ├── webhook.rs # Webhook sender │ ├── orchestrator.rs # Worker lifecycle, event loop (parent) │ ├── worker.rs # Worker event loop (child) │ ├── bridge/ │ │ ├── mod.rs │ │ ├── codec.rs # JSON line codec │ │ ├── protocol.rs # Message types (ControlRequest, SlotResponse, etc.) │ │ └── transport.rs # Unix socket transport │ ├── permit/ │ │ ├── mod.rs │ │ ├── pool.rs # PermitPool (concurrency control) │ │ └── slot.rs # PredictionSlot (permit + prediction) │ └── transport/ │ ├── mod.rs │ └── http/ │ ├── mod.rs │ ├── server.rs # Axum server setup │ └── routes.rs # HTTP handlers │ └── coglet-python/ # PyO3 bindings ├── Cargo.toml ├── coglet.pyi # Type stubs for Python └── src/ ├── lib.rs # Python module definition ├── predictor.rs # PythonPredictor wrapper ├── worker_bridge.rs # PredictHandler impl ├── input.rs # Input processing (Pydantic/ADT) ├── output.rs # Output serialization ├── log_writer.rs # SlotLogWriter, ContextVar routing ├── audit.rs # Audit hook, TeeWriter └── cancel.rs # Cancellation support ``` ## Bridge Protocol Two communication channels between parent and worker: ### Control Channel (stdin/stdout) Used for lifecycle messages. JSON lines, one message per line. **Parent → Worker:** ```json {"type": "init", "predictor_ref": "predict.py:Predictor", "num_slots": 2, ...} {"type": "cancel", "slot": "uuid"} {"type": "shutdown"} ``` **Worker → Parent:** ```json {"type": "ready", "slots": ["uuid1", "uuid2"], "schema": {...}} {"type": "log", "source": "stdout", "data": "Loading model..."} {"type": "idle", "slot": "uuid"} {"type": "failed", "slot": "uuid", "error": "Setup failed: ..."} {"type": "shutting_down"} ``` ### Slot Sockets (Unix domain) Per-slot bidirectional sockets for prediction data. Avoids head-of-line blocking. **Parent → Worker:** ```json {"type": "predict", "id": "pred_123", "input": {"prompt": "Hello"}} ``` **Worker → Parent:** ```json {"type": "log", "source": "stdout", "data": "Processing..."} {"type": "output", "output": "chunk"} {"type": "done", "id": "pred_123", "output": "Hello, world!", "predict_time": 0.5} {"type": "failed", "id": "pred_123", "error": "ValueError: ..."} {"type": "cancelled", "id": "pred_123"} ``` ## Key Design Decisions ### Subprocess Isolation Worker runs in a separate process. Benefits: - Crash isolation (worker crash → restart, parent survives) - Memory isolation (GPU memory leaks don't accumulate) - Clean shutdown (SIGKILL if needed) ### Single Worker Mode Always exactly one worker subprocess. No dynamic scaling - the parent is lightweight, all the heavy lifting happens in the worker. ### Slot-Based Concurrency Each slot is a Unix socket pair. `max_concurrency` determines slot count. Permits control access - at most one prediction per slot at a time. ### ContextVar-Based Log Routing Async predictions may spawn tasks. ContextVar propagates prediction ID through the call stack, allowing log routing even from spawned tasks. ### Audit Hook Protection User code might replace `sys.stdout`. The audit hook intercepts this and wraps their stream in a TeeWriter, preserving our log routing while allowing their code to work as expected. ================================================ FILE: crates/coglet/Cargo.toml ================================================ [package] name = "coglet" description = "High-performance prediction server for Cog ML models" version.workspace = true edition.workspace = true license.workspace = true repository.workspace = true homepage.workspace = true documentation.workspace = true keywords.workspace = true categories.workspace = true [dependencies] # Async runtime tokio.workspace = true tokio-util = { workspace = true, features = ["codec"] } futures.workspace = true async-trait = "0.1" # Serialization serde.workspace = true serde_json.workspace = true # Encoding base64 = "0.22.1" mime_guess = "2.0.5" # Identifiers uuid.workspace = true # HTTP server axum.workspace = true # HTTP client (webhooks) reqwest.workspace = true ureq = { version = "3", default-features = false, features = ["json", "rustls", "platform-verifier"] } # Time chrono = { version = "0.4", features = ["serde"] } # Error handling thiserror.workspace = true anyhow.workspace = true # Input validation jsonschema = "0.29" # Concurrent collections dashmap = "6" # Observability tracing.workspace = true tracing-subscriber.workspace = true [target.'cfg(unix)'.dependencies] nix = { version = "0.30", features = ["signal", "fs"] } [dev-dependencies] insta.workspace = true tempfile = "3" wiremock = "0.6" tower = { version = "0.5", features = ["util"] } http-body-util = "0.1" ================================================ FILE: crates/coglet/README.md ================================================ # coglet Core Rust library for the coglet prediction server. Pure Rust with no Python dependencies - the Python bindings live in `coglet-python`. ## Architecture ``` coglet ┌─────────────────────────────────────────────────────────────────┐ │ │ │ ┌─────────────────────────────────────────────────────────┐ │ │ │ transport/http │ │ │ │ ┌──────────────┐ ┌─────────────────────────────────┐ │ │ │ │ │ server.rs │ │ routes.rs │ │ │ │ │ │ Axum setup │ │ /health, /predictions, /cancel │ │ │ │ │ └──────────────┘ └─────────────────────────────────┘ │ │ │ └───────────────────────────────┬─────────────────────────┘ │ │ │ │ │ ┌───────────────────────────────▼─────────────────────────┐ │ │ │ service.rs │ │ │ │ PredictionService: health, permits, state, webhooks │ │ │ └───────────────────────────────┬─────────────────────────┘ │ │ │ │ │ ┌────────────────────────┼────────────────┐ │ │ │ │ │ │ │ ▼ ▼ ▼ │ │ ┌─────────────┐ ┌────────────────────┐ ┌──────────┐ │ │ │ permit/ │ │ orchestrator.rs │ │webhook.rs│ │ │ │ PermitPool │ │ Parent-side: │ │ Sender │ │ │ │ Slot alloc │ │ spawn, route │ │ Retry │ │ │ └─────────────┘ └─────────┬──────────┘ └──────────┘ │ │ │ │ │ ┌────────────────────────────▼────────────────────────────┐ │ │ │ bridge/ │ │ │ │ ┌──────────────┐ ┌─────────────┐ ┌────────────────┐ │ │ │ │ │ protocol.rs │ │ codec.rs │ │ transport.rs │ │ │ │ │ │ Message types│ │ JSON lines │ │ Unix sockets │ │ │ │ │ └──────────────┘ └─────────────┘ └────────────────┘ │ │ │ └─────────────────────────────────────────────────────────┘ │ │ │ │ ┌─────────────────────────────────────────────────────────┐ │ │ │ worker.rs │ │ │ │ Child-side: PredictHandler trait, run_worker loop │ │ │ └─────────────────────────────────────────────────────────┘ │ │ │ └─────────────────────────────────────────────────────────────────┘ ``` ## Directory Structure ``` coglet/ └── src/ ├── lib.rs # Public API exports │ │ # Core Types ├── health.rs # Health, SetupStatus, SetupResult ├── prediction.rs # Prediction state machine ├── predictor.rs # PredictionResult, PredictionError, PredictionOutput ├── version.rs # VersionInfo │ │ # Service Layer ├── service.rs # PredictionService - lifecycle, state, webhooks ├── webhook.rs # WebhookSender, webhook types │ │ # Orchestrator (Parent Process) ├── orchestrator.rs # spawn_worker, OrchestratorHandle, event loop │ │ # Worker (Child Process) ├── worker.rs # run_worker, PredictHandler trait, SetupError │ │ # Concurrency Control ├── permit/ │ ├── mod.rs │ ├── pool.rs # PermitPool - slot permit management │ └── slot.rs # PredictionSlot - permit + prediction binding │ │ # IPC Bridge ├── bridge/ │ ├── mod.rs │ ├── protocol.rs # ControlRequest, ControlResponse, SlotRequest, SlotResponse │ ├── codec.rs # JsonCodec - newline-delimited JSON │ └── transport.rs # Unix socket transport, ChildTransportInfo │ │ # HTTP Transport └── transport/ ├── mod.rs └── http/ ├── mod.rs ├── server.rs # ServerConfig, serve() └── routes.rs # Route handlers, request/response types ``` ## Key Components ### PredictionService (`service.rs`) Single owner of prediction state. Manages: - Health state (Unknown → Starting → Ready/SetupFailed) - PermitPool + Orchestrator reference - Active predictions (DashMap — single source of truth) - Cancellation (cancel tokens + orchestrator delegation) - Webhooks fire from Prediction mutation methods (no dual state) ```rust let service = PredictionService::new_no_pool() .with_health(Health::Starting) .with_version(version); // Later, after worker is ready: service.set_orchestrator(pool, handle).await; service.set_health(Health::Ready).await; ``` ### Orchestrator (`orchestrator.rs`) Parent-side worker lifecycle management. ``` spawn_worker(config) │ ├─▶ Create Unix socket transport (N slots) ├─▶ Spawn: python -c "import coglet; coglet.server._run_worker()" ├─▶ Send Init message via stdin ├─▶ Wait for worker to connect sockets ├─▶ Wait for Ready message (with timeout) ├─▶ Populate PermitPool with slot writers ├─▶ Spawn event loop task └─▶ Return OrchestratorReady {pool, schema, handle} ``` Event loop handles: - `ControlResponse::Idle` - Slot ready for next prediction - `ControlResponse::Failed` - Slot poisoned, mark unavailable - `SlotResponse::Log/Output/Done/Failed` - Route to prediction - Worker crash - Fail all in-flight predictions ### Worker (`worker.rs`) Child-side event loop. Implements `PredictHandler` trait. ``` run_worker(handler, config) │ ├─▶ Connect to slot sockets (from env) ├─▶ Setup control channel (stdin/stdout) ├─▶ Run handler.setup() with log routing ├─▶ Send Ready {slots, schema} ├─▶ Enter event loop: │ - ControlRequest::Cancel → handler.cancel(slot) │ - ControlRequest::Shutdown → exit │ - SlotRequest::Predict → spawn prediction task └─▶ Exit on shutdown or all slots poisoned ``` ### PermitPool (`permit/pool.rs`) Slot-based concurrency control. ```rust let pool = PermitPool::new(max_concurrency); // Add slot with its socket writer pool.add_permit(slot_id, writer); // Acquire permit (returns None if at capacity) let permit = pool.try_acquire()?; // Send prediction request permit.send(SlotRequest::Predict { id, input }).await?; // Return permit when done drop(permit); ``` ### Bridge Protocol (`bridge/protocol.rs`) Message types for parent-worker communication. **Control Channel:** - `ControlRequest`: Init, Cancel, Shutdown - `ControlResponse`: Ready, Log, Idle, Failed, Cancelled, ShuttingDown **Slot Channel:** - `SlotRequest`: Predict - `SlotResponse`: Log, Output, Done, Failed, Cancelled All messages are JSON with `{"type": "..."}` discriminator. ## Behaviors ### Health States ``` Unknown ──▶ Starting ──┬──▶ Ready ◀──▶ Busy │ └──▶ SetupFailed ──▶ Defunct ``` - **Unknown**: Initial state, health-check returns status in body - **Starting**: Setup in progress - **Ready**: Accepting predictions - **Busy**: Ready but all slots in use (HTTP 409 on new predictions) - **SetupFailed**: setup() raised exception - **Defunct**: Unrecoverable error ### Prediction States ``` Starting ──▶ Processing ──┬──▶ Succeeded ├──▶ Failed └──▶ Canceled ``` ### Cancellation 1. HTTP DELETE /predictions/{id} or PUT /predictions/{id}/cancel 2. Parent sends `ControlRequest::Cancel { slot }` 3. Worker calls `handler.cancel(slot)` 4. For sync: SIGUSR1 raises KeyboardInterrupt 5. For async: `future.cancel()` on the asyncio task 6. Prediction returns with `SlotResponse::Cancelled` ### Shutdown **Graceful (SIGTERM with await_explicit_shutdown):** 1. Stop accepting new predictions 2. Wait for in-flight to complete 3. Send `ControlRequest::Shutdown` 4. Worker responds `ShuttingDown`, exits 5. Parent exits **Immediate (SIGTERM without flag):** 1. Send `ControlRequest::Shutdown` 2. Cancel in-flight predictions 3. Exit **Worker crash:** 1. Control channel closes 2. Event loop detects, fails all in-flight predictions 3. Health → Defunct ### Slot Poisoning If a slot socket has an error (write fails, etc.), the slot is marked poisoned. It won't receive new predictions. If all slots are poisoned, worker exits. ```rust enum SlotOutcome { Idle(SlotId), // Ready for next prediction Poisoned { slot, error }, // Slot is dead } ``` ================================================ FILE: crates/coglet/src/bridge/codec.rs ================================================ //! Framed codec for worker communication. //! //! Uses LengthDelimitedCodec for framing + serde_json for serialization. //! Works over any AsyncRead/AsyncWrite (pipes, sockets, etc). use std::io; use std::marker::PhantomData; use serde::{Serialize, de::DeserializeOwned}; use tokio_util::bytes::{Bytes, BytesMut}; use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec}; /// Codec that frames messages with length prefix and serializes with JSON. /// /// Wraps LengthDelimitedCodec and adds serde_json serialization. pub struct JsonCodec { inner: LengthDelimitedCodec, _phantom: PhantomData, } impl Default for JsonCodec { fn default() -> Self { Self::new() } } impl JsonCodec { pub fn new() -> Self { Self { inner: LengthDelimitedCodec::builder() .length_field_length(4) .new_codec(), _phantom: PhantomData, } } } impl Decoder for JsonCodec { type Item = T; type Error = io::Error; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { match self.inner.decode(src)? { Some(bytes) => { let item = serde_json::from_slice(&bytes) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; Ok(Some(item)) } None => Ok(None), } } } impl Encoder for JsonCodec { type Error = io::Error; fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> { let json = serde_json::to_vec(&item).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; let json_len = json.len(); // SAFETY: These logs must NOT be shipped over IPC (would create feedback loop). // WorkerTracingLayer filters out coglet::bridge::codec target to prevent encoding // a WorkerLog message from triggering another log that creates another WorkerLog, etc. tracing::trace!(json_size_bytes = json_len, "Encoding frame"); if json_len > 100_000 { tracing::info!( // This log line should be shipped across the IPC to be emitted, unlike the // above trace line. This is a real indicator that we've encoded a large // frame and is generally useful. target: "coglet::bridge::codec::large_frame", json_size_bytes = json_len, json_size_kb = json_len / 1024, "Large frame being encoded" ); } self.inner.encode(Bytes::from(json), dst) } } #[cfg(test)] mod tests { use super::*; use crate::bridge::protocol::{ ControlRequest, ControlResponse, SlotId, SlotRequest, SlotResponse, }; #[test] fn codec_roundtrip_control_request() { let mut codec = JsonCodec::::new(); let mut buf = BytesMut::new(); let slot = SlotId::new(); let req = ControlRequest::Cancel { slot }; codec.encode(req, &mut buf).unwrap(); let decoded = codec.decode(&mut buf).unwrap().unwrap(); assert!(matches!(decoded, ControlRequest::Cancel { .. })); } #[test] fn codec_roundtrip_control_response() { let mut codec = JsonCodec::::new(); let mut buf = BytesMut::new(); let slots = vec![SlotId::new()]; let resp = ControlResponse::Ready { slots, schema: None, }; codec.encode(resp, &mut buf).unwrap(); let decoded = codec.decode(&mut buf).unwrap().unwrap(); assert!(matches!(decoded, ControlResponse::Ready { .. })); } #[test] fn codec_roundtrip_slot_request() { let mut codec = JsonCodec::::new(); let mut buf = BytesMut::new(); let req = SlotRequest::Predict { id: "test".to_string(), input: Some(serde_json::json!({"x": 1})), input_file: None, output_dir: "/tmp/coglet/predictions/test/outputs".to_string(), context: Default::default(), }; codec.encode(req.clone(), &mut buf).unwrap(); let decoded = codec.decode(&mut buf).unwrap().unwrap(); match (req, decoded) { ( SlotRequest::Predict { id: id1, input: input1, input_file: file1, output_dir: dir1, .. }, SlotRequest::Predict { id: id2, input: input2, input_file: file2, output_dir: dir2, .. }, ) => { assert_eq!(id1, id2); assert_eq!(input1, input2); assert_eq!(file1, file2); assert_eq!(dir1, dir2); } } } #[test] fn codec_roundtrip_slot_response() { let mut codec = JsonCodec::::new(); let mut buf = BytesMut::new(); let resp = SlotResponse::Done { id: "test".to_string(), output: Some(serde_json::json!("result")), predict_time: 1.5, is_stream: false, }; codec.encode(resp, &mut buf).unwrap(); let decoded = codec.decode(&mut buf).unwrap().unwrap(); match decoded { SlotResponse::Done { id, output, predict_time, is_stream, } => { assert_eq!(id, "test"); assert_eq!(output, Some(serde_json::json!("result"))); assert!((predict_time - 1.5).abs() < 0.001); assert!(!is_stream); } _ => panic!("wrong variant"), } } } ================================================ FILE: crates/coglet/src/bridge/mod.rs ================================================ //! IPC bridge for coglet parent-worker communication. //! //! This module provides the wire protocol and codec for communication between //! the coglet orchestrator (parent) and worker subprocess. //! //! # Architecture //! //! - **protocol**: Message types (ControlRequest/Response, SlotRequest/Response) //! - **codec**: JSON framing codec for AsyncRead/AsyncWrite pub mod codec; pub mod protocol; pub mod transport; ================================================ FILE: crates/coglet/src/bridge/protocol.rs ================================================ //! Wire protocol types for parent-worker communication. //! //! Two channels: //! - **Control channel** (stdin/stdout): Init, Cancel, Shutdown, Ready, Idle //! - **Slot sockets**: Prediction data, streaming logs (per-slot to avoid HOL blocking) use std::collections::HashMap; use serde::{Deserialize, Serialize}; use super::transport::ChildTransportInfo; /// Unique identifier for a prediction slot. /// /// UUID v4 avoids confusion with array indices and prevents accidental reuse. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(transparent)] pub struct SlotId(uuid::Uuid); impl SlotId { pub fn new() -> Self { Self(uuid::Uuid::new_v4()) } pub fn as_uuid(&self) -> &uuid::Uuid { &self.0 } pub fn parse(s: &str) -> Result { let uuid = uuid::Uuid::parse_str(s)?; Ok(Self(uuid)) } } impl Default for SlotId { fn default() -> Self { Self::new() } } impl std::fmt::Display for SlotId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0) } } /// Maximum payload size (input or output) that can be sent inline over the IPC /// slot socket. Payloads exceeding this threshold are spilled to disk. The /// `LengthDelimitedCodec` default frame limit is 8 MiB, so 6 MiB provides a /// 2 MiB safety margin for framing overhead and other message fields. pub const MAX_INLINE_IPC_SIZE: usize = 1024 * 1024 * 6; // 6MiB const MAX_WORKER_LOG_SIZE: usize = 1024 * 1024 * 4; // 4MIB const WORKER_LOG_TRUNCATE_NOTICE: &str = "[**** LOG LINE TRUNCATED AT 4 MiB ****]"; /// To ensure no panics happen due to oversized log lines, we truncate at 4 MiB. 1 MiB /// let alone 4 MiB log line boarder/exceed usefulness from a readability standpoint. pub fn truncate_worker_log(mut log_message: String) -> String { if log_message.len() > MAX_WORKER_LOG_SIZE { let boundary = log_message.floor_char_boundary(MAX_WORKER_LOG_SIZE - WORKER_LOG_TRUNCATE_NOTICE.len()); log_message.truncate(boundary); log_message.push_str(WORKER_LOG_TRUNCATE_NOTICE); } log_message } /// Control messages from parent to worker. #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ControlRequest { /// Initial configuration sent immediately after spawn (must be first message). Init { predictor_ref: String, num_slots: usize, transport_info: ChildTransportInfo, is_train: bool, is_async: bool, }, Cancel { slot: SlotId, }, /// Request user-defined healthcheck execution. Healthcheck { id: String, }, Shutdown, } /// Control messages from worker to parent. #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ControlResponse { Ready { /// Slot IDs in socket order - parent uses these for all subsequent communication. slots: Vec, #[serde(skip_serializing_if = "Option::is_none")] schema: Option, }, /// Setup-phase logs (before slots are active). Log { source: LogSource, data: String, }, /// Worker tracing log (Rust structured logging). WorkerLog { target: String, level: String, message: String, }, /// Slot completed and is ready for next prediction. Idle { slot: SlotId, }, Cancelled { slot: SlotId, }, /// Slot is poisoned and will not accept more predictions. Failed { slot: SlotId, error: String, }, /// Worker unrecoverable error - parent should poison all slots and fail all /// in-flight predictions. The worker will abort immediately after sending this. /// /// Reason explains *why* (e.g. "slots mutex poisoned: cannot guarantee slot isolation"). Fatal { reason: String, }, /// System diagnostic: logs dropped due to backpressure. DroppedLogs { count: usize, interval_millis: u64, }, /// Result of user-defined healthcheck execution. HealthcheckResult { id: String, status: HealthcheckStatus, #[serde(skip_serializing_if = "Option::is_none")] error: Option, }, ShuttingDown, } /// Status of a user-defined healthcheck. #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum HealthcheckStatus { /// Healthcheck passed (returned True or no healthcheck defined). Healthy, /// Healthcheck failed (returned False, raised exception, or timed out). Unhealthy, } /// Type-safe slot completion - ensures poisoned slots produce Failed, not Idle. #[derive(Debug)] pub enum SlotOutcome { Idle(SlotId), Poisoned { slot: SlotId, error: String }, } impl SlotOutcome { pub fn idle(slot: SlotId) -> Self { Self::Idle(slot) } pub fn poisoned(slot: SlotId, error: impl Into) -> Self { Self::Poisoned { slot, error: error.into(), } } pub fn slot_id(&self) -> SlotId { match self { Self::Idle(slot) => *slot, Self::Poisoned { slot, .. } => *slot, } } pub fn is_poisoned(&self) -> bool { matches!(self, Self::Poisoned { .. }) } pub fn into_control_response(self) -> ControlResponse { match self { Self::Idle(slot) => ControlResponse::Idle { slot }, Self::Poisoned { slot, error } => ControlResponse::Failed { slot, error }, } } } /// Messages from parent to worker on slot socket. #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum SlotRequest { Predict { id: String, /// Inline input payload (present when input fits within the IPC frame limit). #[serde(skip_serializing_if = "Option::is_none")] input: Option, /// Path to a spill file containing the JSON input (present when input exceeds /// `MAX_INLINE_IPC_SIZE`). The worker reads, deserializes, and deletes the file. #[serde(skip_serializing_if = "Option::is_none")] input_file: Option, /// Directory for writing file outputs (created by coglet before dispatch). /// Not included in API responses — internal transport detail. output_dir: String, /// Per-prediction context from the request body (`dict[str, str]`). /// Made available to predictors via `current_scope().context`. #[serde(default)] context: HashMap, }, } impl SlotRequest { /// Returns the prediction ID without consuming the request. pub fn prediction_id(&self) -> &str { match self { SlotRequest::Predict { id, .. } => id, } } /// Rehydrate the input from either inline value or spill file. /// /// Returns `(id, input, output_dir, context)`. If the input was spilled to disk, /// reads the file, deserializes, and deletes it. pub fn rehydrate_input( self, ) -> std::io::Result<(String, serde_json::Value, String, HashMap)> { match self { SlotRequest::Predict { id, input: Some(value), output_dir, context, .. } => Ok((id, value, output_dir, context)), SlotRequest::Predict { id, input: None, input_file: Some(path), output_dir, context, } => { let bytes = std::fs::read(&path)?; // Clean up spill file immediately — bytes are already in memory. // Do this before parsing so the file is removed even if JSON is corrupt. if let Err(e) = std::fs::remove_file(&path) { tracing::warn!(path = %path, error = %e, "Failed to remove input spill file"); } let value: serde_json::Value = serde_json::from_slice(&bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; Ok((id, value, output_dir, context)) } SlotRequest::Predict { .. } => Err(std::io::Error::new( std::io::ErrorKind::InvalidData, "SlotRequest::Predict has neither input nor input_file", )), } } } #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum FileOutputKind { /// Output is a file-like return type (e.g. File, Path) FileType, /// Output exceeds size threshold for bridge codec serialization but is not a file-like return type Oversized, } /// Accumulation mode for user metrics. #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum MetricMode { /// Replace existing value (default). Replace, /// Add to existing numeric value. Increment, /// Append to existing array. Append, } /// Messages from worker to parent on slot socket. #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum SlotResponse { Log { source: LogSource, data: String, }, /// Output for a file/path-like output return type or an output that exceeds the size threshold /// for bridge codec serialization. FileOutput { filename: String, kind: FileOutputKind, /// Explicit MIME type from the predictor. Falls back to mime_guess when None. #[serde(skip_serializing_if = "Option::is_none")] mime_type: Option, }, /// Streaming output chunk (for generators). Output { output: serde_json::Value, }, /// User-emitted metric from the prediction. /// /// Metrics are key-value pairs attached to the prediction response. /// Supports dot-path keys (e.g., "timing.preprocess") that the server /// resolves into nested objects. The mode controls how values are merged: /// - Replace: overwrite existing value /// - Increment: add to existing numeric value /// - Append: push onto existing array Metric { name: String, value: serde_json::Value, mode: MetricMode, }, Done { id: String, #[serde(skip_serializing_if = "Option::is_none")] output: Option, predict_time: f64, /// Predictor signal: true when the output is a list, generator, or /// iterator — used as fallback when the schema Output type is `Any` /// or unavailable. #[serde(default, skip_serializing_if = "std::ops::Not::not")] is_stream: bool, }, Failed { id: String, error: String, }, Cancelled { id: String, }, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum LogSource { Stdout, Stderr, } #[cfg(test)] mod tests { use super::*; use serde_json::json; use std::path::PathBuf; fn test_slot_id() -> SlotId { SlotId(uuid::Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap()) } #[test] fn control_init_serializes() { let req = ControlRequest::Init { predictor_ref: "predict.py:Predictor".to_string(), num_slots: 2, transport_info: ChildTransportInfo::NamedSockets { dir: PathBuf::from("/tmp/coglet-123"), num_slots: 2, }, is_train: false, is_async: true, }; insta::assert_json_snapshot!(req); } #[test] fn control_cancel_serializes() { let req = ControlRequest::Cancel { slot: test_slot_id(), }; insta::assert_json_snapshot!(req); } #[test] fn control_shutdown_serializes() { let req = ControlRequest::Shutdown; insta::assert_json_snapshot!(req); } #[test] fn control_healthcheck_serializes() { let req = ControlRequest::Healthcheck { id: "hc_123".to_string(), }; insta::assert_json_snapshot!(req); } #[test] fn control_healthcheck_result_healthy_serializes() { let resp = ControlResponse::HealthcheckResult { id: "hc_123".to_string(), status: HealthcheckStatus::Healthy, error: None, }; insta::assert_json_snapshot!(resp); } #[test] fn control_healthcheck_result_unhealthy_serializes() { let resp = ControlResponse::HealthcheckResult { id: "hc_123".to_string(), status: HealthcheckStatus::Unhealthy, error: Some("user healthcheck returned False".to_string()), }; insta::assert_json_snapshot!(resp); } #[test] fn control_ready_serializes() { let resp = ControlResponse::Ready { slots: vec![test_slot_id()], schema: None, }; insta::assert_json_snapshot!(resp); } #[test] fn control_ready_with_schema_serializes() { let resp = ControlResponse::Ready { slots: vec![test_slot_id()], schema: Some(json!({ "openapi": "3.0.2", "info": {"title": "Cog", "version": "0.1.0"} })), }; insta::assert_json_snapshot!(resp); } #[test] fn control_idle_serializes() { let resp = ControlResponse::Idle { slot: test_slot_id(), }; insta::assert_json_snapshot!(resp); } #[test] fn control_cancelled_serializes() { let resp = ControlResponse::Cancelled { slot: test_slot_id(), }; insta::assert_json_snapshot!(resp); } #[test] fn control_failed_serializes() { let resp = ControlResponse::Failed { slot: test_slot_id(), error: "segfault".to_string(), }; insta::assert_json_snapshot!(resp); } #[test] fn slot_predict_serializes() { let req = SlotRequest::Predict { id: "pred_123".to_string(), input: Some(json!({"text": "hello"})), input_file: None, output_dir: "/tmp/coglet/predictions/pred_123/outputs".to_string(), context: Default::default(), }; insta::assert_json_snapshot!(req); } #[test] fn slot_predict_file_input_serializes() { let req = SlotRequest::Predict { id: "pred_456".to_string(), input: None, input_file: Some("/tmp/coglet/predictions/pred_456/inputs/spill_abc.json".to_string()), output_dir: "/tmp/coglet/predictions/pred_456/outputs".to_string(), context: Default::default(), }; insta::assert_json_snapshot!(req); } #[test] fn slot_log_serializes() { let resp = SlotResponse::Log { source: LogSource::Stdout, data: "Processing...".to_string(), }; insta::assert_json_snapshot!(resp); } #[test] fn slot_output_serializes() { let resp = SlotResponse::Output { output: json!("chunk 1"), }; insta::assert_json_snapshot!(resp); } #[test] fn slot_done_serializes() { let resp = SlotResponse::Done { id: "pred_123".to_string(), output: Some(json!("final result")), predict_time: 1.234, is_stream: false, }; insta::assert_json_snapshot!(resp); } #[test] fn slot_failed_serializes() { let resp = SlotResponse::Failed { id: "pred_123".to_string(), error: "ValueError: invalid input".to_string(), }; insta::assert_json_snapshot!(resp); } #[test] fn slot_cancelled_serializes() { let resp = SlotResponse::Cancelled { id: "pred_123".to_string(), }; insta::assert_json_snapshot!(resp); } #[test] fn slot_metric_replace_serializes() { let resp = SlotResponse::Metric { name: "temperature".to_string(), value: json!(0.7), mode: MetricMode::Replace, }; insta::assert_json_snapshot!(resp); } #[test] fn slot_metric_increment_serializes() { let resp = SlotResponse::Metric { name: "token_count".to_string(), value: json!(1), mode: MetricMode::Increment, }; insta::assert_json_snapshot!(resp); } #[test] fn slot_metric_append_serializes() { let resp = SlotResponse::Metric { name: "logprobs".to_string(), value: json!(-1.2), mode: MetricMode::Append, }; insta::assert_json_snapshot!(resp); } #[test] fn slot_metric_delete_serializes() { let resp = SlotResponse::Metric { name: "unwanted".to_string(), value: json!(null), mode: MetricMode::Replace, }; insta::assert_json_snapshot!(resp); } #[test] fn slot_metric_complex_value_serializes() { let resp = SlotResponse::Metric { name: "timing".to_string(), value: json!({"preprocess": 0.1, "inference": 0.8}), mode: MetricMode::Replace, }; insta::assert_json_snapshot!(resp); } #[test] fn rehydrate_input_inline() { let req = SlotRequest::Predict { id: "p1".to_string(), input: Some(json!({"text": "hello"})), input_file: None, output_dir: "/tmp/out".to_string(), context: Default::default(), }; let (id, input, output_dir, _context) = req.rehydrate_input().unwrap(); assert_eq!(id, "p1"); assert_eq!(input, json!({"text": "hello"})); assert_eq!(output_dir, "/tmp/out"); } #[test] fn rehydrate_input_from_file() { let dir = tempfile::tempdir().unwrap(); let spill_path = dir.path().join("spill_test.json"); std::fs::write(&spill_path, r#"{"key":"value"}"#).unwrap(); let req = SlotRequest::Predict { id: "p2".to_string(), input: None, input_file: Some(spill_path.to_str().unwrap().to_string()), output_dir: "/tmp/out".to_string(), context: Default::default(), }; let (id, input, output_dir, _context) = req.rehydrate_input().unwrap(); assert_eq!(id, "p2"); assert_eq!(input, json!({"key": "value"})); assert_eq!(output_dir, "/tmp/out"); // Spill file should be deleted assert!(!spill_path.exists()); } #[test] fn rehydrate_input_neither_errors() { let req = SlotRequest::Predict { id: "p3".to_string(), input: None, input_file: None, output_dir: "/tmp/out".to_string(), context: Default::default(), }; let err = req.rehydrate_input().unwrap_err(); assert_eq!(err.kind(), std::io::ErrorKind::InvalidData); } #[test] fn rehydrate_input_corrupt_file_errors() { let dir = tempfile::tempdir().unwrap(); let spill_path = dir.path().join("corrupt.json"); std::fs::write(&spill_path, "not valid json!!!").unwrap(); let req = SlotRequest::Predict { id: "p4".to_string(), input: None, input_file: Some(spill_path.to_str().unwrap().to_string()), output_dir: "/tmp/out".to_string(), context: Default::default(), }; let err = req.rehydrate_input().unwrap_err(); assert_eq!(err.kind(), std::io::ErrorKind::InvalidData); } #[test] fn truncate_worker_log_truncates_long_messages() { let emoji = "🦀"; // 4-byte UTF-8 character // known size of truncate target, add one more character let count = 1024 * 1024 * 1024 * 4 / emoji.len() + 1; let message: String = truncate_worker_log(emoji.repeat(count)); assert!( message.ends_with(WORKER_LOG_TRUNCATE_NOTICE), "log message didn't end with {}", WORKER_LOG_TRUNCATE_NOTICE ); } #[test] fn truncate_worker_log_does_not_truncate_short_messages() { let emoji = "🦀"; // 4-byte UTF-8 character // known size of truncate target, add one more character let count = 10; let message: String = truncate_worker_log(emoji.repeat(count)); assert!( !message.ends_with(WORKER_LOG_TRUNCATE_NOTICE), "short log message was truncated" ); } } ================================================ FILE: crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__control_cancel_serializes.snap ================================================ --- source: coglet/src/bridge/protocol.rs expression: req --- { "type": "cancel", "slot": "550e8400-e29b-41d4-a716-446655440000" } ================================================ FILE: crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__control_cancelled_serializes.snap ================================================ --- source: coglet/src/bridge/protocol.rs expression: resp --- { "type": "cancelled", "slot": "550e8400-e29b-41d4-a716-446655440000" } ================================================ FILE: crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__control_failed_serializes.snap ================================================ --- source: coglet/src/bridge/protocol.rs expression: resp --- { "type": "failed", "slot": "550e8400-e29b-41d4-a716-446655440000", "error": "segfault" } ================================================ FILE: crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__control_healthcheck_result_healthy_serializes.snap ================================================ --- source: coglet/src/bridge/protocol.rs expression: resp --- { "type": "healthcheck_result", "id": "hc_123", "status": "healthy" } ================================================ FILE: crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__control_healthcheck_result_unhealthy_serializes.snap ================================================ --- source: coglet/src/bridge/protocol.rs expression: resp --- { "type": "healthcheck_result", "id": "hc_123", "status": "unhealthy", "error": "user healthcheck returned False" } ================================================ FILE: crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__control_healthcheck_serializes.snap ================================================ --- source: coglet/src/bridge/protocol.rs expression: req --- { "type": "healthcheck", "id": "hc_123" } ================================================ FILE: crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__control_idle_serializes.snap ================================================ --- source: coglet/src/bridge/protocol.rs expression: resp --- { "type": "idle", "slot": "550e8400-e29b-41d4-a716-446655440000" } ================================================ FILE: crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__control_init_serializes.snap ================================================ --- source: coglet/src/bridge/protocol.rs expression: req --- { "type": "init", "predictor_ref": "predict.py:Predictor", "num_slots": 2, "transport_info": { "NamedSockets": { "dir": "/tmp/coglet-123", "num_slots": 2 } }, "is_train": false, "is_async": true } ================================================ FILE: crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__control_ready_serializes.snap ================================================ --- source: coglet/src/bridge/protocol.rs expression: resp --- { "type": "ready", "slots": [ "550e8400-e29b-41d4-a716-446655440000" ] } ================================================ FILE: crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__control_ready_with_schema_serializes.snap ================================================ --- source: coglet/src/bridge/protocol.rs expression: resp --- { "type": "ready", "slots": [ "550e8400-e29b-41d4-a716-446655440000" ], "schema": { "info": { "title": "Cog", "version": "0.1.0" }, "openapi": "3.0.2" } } ================================================ FILE: crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__control_shutdown_serializes.snap ================================================ --- source: coglet/src/bridge/protocol.rs expression: req --- { "type": "shutdown" } ================================================ FILE: crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__slot_cancelled_serializes.snap ================================================ --- source: coglet/src/bridge/protocol.rs expression: resp --- { "type": "cancelled", "id": "pred_123" } ================================================ FILE: crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__slot_done_serializes.snap ================================================ --- source: coglet/src/bridge/protocol.rs expression: resp --- { "type": "done", "id": "pred_123", "output": "final result", "predict_time": 1.234 } ================================================ FILE: crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__slot_failed_serializes.snap ================================================ --- source: coglet/src/bridge/protocol.rs expression: resp --- { "type": "failed", "id": "pred_123", "error": "ValueError: invalid input" } ================================================ FILE: crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__slot_log_serializes.snap ================================================ --- source: coglet/src/bridge/protocol.rs expression: resp --- { "type": "log", "source": "stdout", "data": "Processing..." } ================================================ FILE: crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__slot_metric_append_serializes.snap ================================================ --- source: coglet/src/bridge/protocol.rs expression: resp --- { "type": "metric", "name": "logprobs", "value": -1.2, "mode": "append" } ================================================ FILE: crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__slot_metric_complex_value_serializes.snap ================================================ --- source: coglet/src/bridge/protocol.rs expression: resp --- { "type": "metric", "name": "timing", "value": { "inference": 0.8, "preprocess": 0.1 }, "mode": "replace" } ================================================ FILE: crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__slot_metric_delete_serializes.snap ================================================ --- source: coglet/src/bridge/protocol.rs expression: resp --- { "type": "metric", "name": "unwanted", "value": null, "mode": "replace" } ================================================ FILE: crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__slot_metric_increment_serializes.snap ================================================ --- source: coglet/src/bridge/protocol.rs expression: resp --- { "type": "metric", "name": "token_count", "value": 1, "mode": "increment" } ================================================ FILE: crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__slot_metric_replace_serializes.snap ================================================ --- source: coglet/src/bridge/protocol.rs expression: resp --- { "type": "metric", "name": "temperature", "value": 0.7, "mode": "replace" } ================================================ FILE: crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__slot_output_serializes.snap ================================================ --- source: coglet/src/bridge/protocol.rs expression: resp --- { "type": "output", "output": "chunk 1" } ================================================ FILE: crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__slot_predict_file_input_serializes.snap ================================================ --- source: coglet/src/bridge/protocol.rs expression: req --- { "type": "predict", "id": "pred_456", "input_file": "/tmp/coglet/predictions/pred_456/inputs/spill_abc.json", "output_dir": "/tmp/coglet/predictions/pred_456/outputs", "context": {} } ================================================ FILE: crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__slot_predict_serializes.snap ================================================ --- source: coglet/src/bridge/protocol.rs expression: req --- { "type": "predict", "id": "pred_123", "input": { "text": "hello" }, "output_dir": "/tmp/coglet/predictions/pred_123/outputs", "context": {} } ================================================ FILE: crates/coglet/src/bridge/transport.rs ================================================ //! Slot socket transport for parent-worker IPC. //! //! Platform-specific implementations: //! - **NamedSocketTransport**: Filesystem sockets (macOS, Linux, BSD) //! - **AbstractSocketTransport**: Linux abstract namespace (no filesystem, auto-cleanup) use std::io; use std::path::PathBuf; use serde::{Deserialize, Serialize}; use tokio::net::UnixStream; /// Information passed to child process for connecting to slot sockets. #[derive(Debug, Clone, Serialize, Deserialize)] pub enum ChildTransportInfo { NamedSockets { dir: PathBuf, num_slots: usize, }, #[cfg(target_os = "linux")] AbstractSockets { prefix: String, num_slots: usize, }, } /// Named socket transport using filesystem sockets. /// /// Socket path format: `{temp_dir}/coglet-{pid}/slot-{n}.sock` pub struct NamedSocketTransport { dir: PathBuf, sockets: Vec, listeners: Vec, is_parent: bool, } impl NamedSocketTransport { /// Create transport on parent side, binding listeners for child to connect. pub async fn create(num_slots: usize) -> io::Result<(Self, ChildTransportInfo)> { use std::os::unix::net::UnixListener as StdUnixListener; use tokio::net::UnixListener; let dir = std::env::temp_dir().join(format!("coglet-{}", std::process::id())); std::fs::create_dir_all(&dir)?; tracing::debug!(transport_type = "named", dir = %dir.display(), num_slots, "Creating slot transport"); let mut listeners = Vec::with_capacity(num_slots); for i in 0..num_slots { let path = dir.join(format!("slot-{}.sock", i)); if path.exists() { std::fs::remove_file(&path)?; } let std_listener = StdUnixListener::bind(&path)?; std_listener.set_nonblocking(true)?; let listener = UnixListener::from_std(std_listener)?; tracing::trace!(slot = i, path = %path.display(), "Bound socket"); listeners.push(listener); } let transport = Self { dir: dir.clone(), sockets: Vec::with_capacity(num_slots), listeners, is_parent: true, }; let child_info = ChildTransportInfo::NamedSockets { dir: dir.clone(), num_slots, }; Ok((transport, child_info)) } /// Accept connections from child on all slots. pub async fn accept_connections(&mut self, num_slots: usize) -> io::Result<()> { for i in 0..num_slots { let listener = &self.listeners[i]; tracing::trace!(slot = i, "Waiting for child connection"); let (stream, _) = listener.accept().await?; self.sockets.push(stream); tracing::trace!(slot = i, "Child connected"); } self.listeners.clear(); Ok(()) } /// Connect from child side. pub async fn connect(dir: PathBuf, num_slots: usize) -> io::Result { let mut sockets = Vec::with_capacity(num_slots); for i in 0..num_slots { let path = dir.join(format!("slot-{}.sock", i)); tracing::trace!(slot = i, path = %path.display(), "Connecting to socket"); let stream = UnixStream::connect(&path).await?; sockets.push(stream); tracing::trace!(slot = i, "Connected"); } Ok(Self { dir, sockets, listeners: Vec::new(), is_parent: false, }) } pub fn slot_socket(&mut self, slot: usize) -> Option<&mut UnixStream> { self.sockets.get_mut(slot) } /// Returns owned sockets for splitting into read/write halves. pub fn drain_sockets(&mut self) -> Vec { std::mem::take(&mut self.sockets) } pub fn dir(&self) -> &PathBuf { &self.dir } pub fn num_slots(&self) -> usize { self.sockets.len() } pub fn cleanup(&mut self) -> io::Result<()> { if self.is_parent && self.dir.exists() { tracing::debug!(dir = %self.dir.display(), "Cleaning up socket directory"); std::fs::remove_dir_all(&self.dir)?; } Ok(()) } } impl Drop for NamedSocketTransport { fn drop(&mut self) { if let Err(e) = self.cleanup() { tracing::warn!(error = %e, "Failed to cleanup socket directory"); } } } /// Abstract namespace socket transport (Linux only). /// /// No filesystem entries, auto-cleanup when all references close. #[cfg(target_os = "linux")] pub struct AbstractSocketTransport { #[allow(dead_code)] // Kept for debugging/identification prefix: String, sockets: Vec, listeners: Vec, } #[cfg(target_os = "linux")] impl AbstractSocketTransport { /// Create transport on parent side, binding listeners for child to connect. pub async fn create(num_slots: usize) -> io::Result<(Self, ChildTransportInfo)> { use std::os::linux::net::SocketAddrExt; use std::os::unix::net::{SocketAddr, UnixListener as StdUnixListener}; use tokio::net::UnixListener; let prefix = format!("coglet-{}", std::process::id()); tracing::debug!(transport_type = "abstract", prefix = %prefix, num_slots, "Creating slot transport"); let mut listeners = Vec::with_capacity(num_slots); for i in 0..num_slots { let name = format!("{}-{}", prefix, i); let addr = SocketAddr::from_abstract_name(name.as_bytes())?; let std_listener = StdUnixListener::bind_addr(&addr)?; std_listener.set_nonblocking(true)?; let listener = UnixListener::from_std(std_listener)?; tracing::trace!(slot = i, name = %name, "Bound abstract socket"); listeners.push(listener); } let transport = Self { prefix: prefix.clone(), sockets: Vec::with_capacity(num_slots), listeners, }; let child_info = ChildTransportInfo::AbstractSockets { prefix, num_slots }; Ok((transport, child_info)) } /// Accept connections from child on all slots. pub async fn accept_connections(&mut self, num_slots: usize) -> io::Result<()> { for i in 0..num_slots { let listener = &self.listeners[i]; tracing::trace!(slot = i, "Waiting for child connection"); let (stream, _) = listener.accept().await?; self.sockets.push(stream); tracing::trace!(slot = i, "Child connected"); } self.listeners.clear(); Ok(()) } /// Connect from child side. pub async fn connect(prefix: String, num_slots: usize) -> io::Result { use std::os::linux::net::SocketAddrExt; use std::os::unix::net::SocketAddr; let mut sockets = Vec::with_capacity(num_slots); for i in 0..num_slots { let name = format!("{}-{}", prefix, i); let addr = SocketAddr::from_abstract_name(name.as_bytes())?; tracing::trace!(slot = i, name = %name, "Connecting to abstract socket"); // tokio doesn't support abstract sockets directly let std_stream = std::os::unix::net::UnixStream::connect_addr(&addr)?; std_stream.set_nonblocking(true)?; let stream = UnixStream::from_std(std_stream)?; sockets.push(stream); tracing::trace!(slot = i, "Connected"); } Ok(Self { prefix, sockets, listeners: Vec::new(), }) } pub fn slot_socket(&mut self, slot: usize) -> Option<&mut UnixStream> { self.sockets.get_mut(slot) } pub fn drain_sockets(&mut self) -> Vec { std::mem::take(&mut self.sockets) } pub fn num_slots(&self) -> usize { self.sockets.len() } } pub enum SlotTransport { Named(NamedSocketTransport), #[cfg(target_os = "linux")] Abstract(AbstractSocketTransport), } impl SlotTransport { pub fn slot_socket(&mut self, slot: usize) -> Option<&mut UnixStream> { match self { Self::Named(t) => t.slot_socket(slot), #[cfg(target_os = "linux")] Self::Abstract(t) => t.slot_socket(slot), } } pub fn drain_sockets(&mut self) -> Vec { match self { Self::Named(t) => t.drain_sockets(), #[cfg(target_os = "linux")] Self::Abstract(t) => t.drain_sockets(), } } pub fn num_slots(&self) -> usize { match self { Self::Named(t) => t.num_slots(), #[cfg(target_os = "linux")] Self::Abstract(t) => t.num_slots(), } } pub async fn accept_connections(&mut self, num_slots: usize) -> io::Result<()> { match self { Self::Named(t) => t.accept_connections(num_slots).await, #[cfg(target_os = "linux")] Self::Abstract(t) => t.accept_connections(num_slots).await, } } } /// Create transport using platform default (abstract on Linux, named elsewhere). pub async fn create_transport(num_slots: usize) -> io::Result<(SlotTransport, ChildTransportInfo)> { #[cfg(target_os = "linux")] { let (transport, info) = AbstractSocketTransport::create(num_slots).await?; Ok((SlotTransport::Abstract(transport), info)) } #[cfg(not(target_os = "linux"))] { let (transport, info) = NamedSocketTransport::create(num_slots).await?; Ok((SlotTransport::Named(transport), info)) } } pub async fn connect_transport(info: ChildTransportInfo) -> io::Result { match info { ChildTransportInfo::NamedSockets { dir, num_slots } => { let transport = NamedSocketTransport::connect(dir, num_slots).await?; Ok(SlotTransport::Named(transport)) } #[cfg(target_os = "linux")] ChildTransportInfo::AbstractSockets { prefix, num_slots } => { let transport = AbstractSocketTransport::connect(prefix, num_slots).await?; Ok(SlotTransport::Abstract(transport)) } } } #[cfg(test)] mod tests { use super::*; #[test] fn child_transport_info_roundtrips() { let info = ChildTransportInfo::NamedSockets { dir: PathBuf::from("/tmp/coglet-123"), num_slots: 3, }; let json = serde_json::to_string(&info).unwrap(); let parsed: ChildTransportInfo = serde_json::from_str(&json).unwrap(); match parsed { ChildTransportInfo::NamedSockets { dir, num_slots } => { assert_eq!(dir, PathBuf::from("/tmp/coglet-123")); assert_eq!(num_slots, 3); } #[cfg(target_os = "linux")] _ => panic!("Wrong variant"), } } #[cfg(target_os = "linux")] #[test] fn abstract_socket_info_roundtrips() { let info = ChildTransportInfo::AbstractSockets { prefix: "coglet-456".to_string(), num_slots: 2, }; let json = serde_json::to_string(&info).unwrap(); let parsed: ChildTransportInfo = serde_json::from_str(&json).unwrap(); match parsed { ChildTransportInfo::AbstractSockets { prefix, num_slots } => { assert_eq!(prefix, "coglet-456"); assert_eq!(num_slots, 2); } _ => panic!("Wrong variant"), } } } ================================================ FILE: crates/coglet/src/fd_redirect.rs ================================================ //! File descriptor redirection for subprocess isolation. //! //! Worker uses fd 1 (stdout) as the control channel to the orchestrator. When Python's //! setup() spawns subprocesses (subprocess.Popen), they inherit fd 1 and corrupt the //! control channel by writing directly into it. //! //! We redirect fds early in startup: move control channel to high-numbered fds (99-101), //! replace fd 1/2 with capture pipes, spawn threads to route captured output through //! the log system. //! //! CRITICAL: Must be called before FFI initialization (Python, etc.). //! //! ## Safety contracts //! //! All `unsafe` blocks in this module rely on these guarantees: //! 1. Called early in worker before any predictor/FFI code runs (tokio runtime threads //! exist but aren't accessing fds 0/1/2) //! 2. Standard fds (0, 1, 2) are guaranteed open by the OS at process startup //! 3. High-numbered fds (99-101) won't conflict with application/library usage //! 4. Ownership transfer to threads via `from_raw_fd` + `forget` prevents double-close //! //! Cannot use Miri: This code makes actual syscalls (dup/dup2) which Miri can't execute. #[cfg(unix)] use std::io; #[cfg(unix)] use std::os::fd::{AsRawFd, BorrowedFd, FromRawFd, OwnedFd}; #[cfg(unix)] use nix::unistd::{dup, dup2, pipe}; #[cfg(unix)] use tokio::sync::mpsc; #[cfg(unix)] use crate::bridge::protocol::{ControlResponse, LogSource}; /// Chosen to be above the range typically used by libraries (avoiding conflicts with /// application fds or library-opened files). #[cfg(unix)] const CONTROL_STDIN_FD: i32 = 99; #[cfg(unix)] const CONTROL_STDOUT_FD: i32 = 100; #[cfg(unix)] const WORKER_STDERR_FD: i32 = 101; #[cfg(unix)] pub struct ControlChannelFds { pub stdin_fd: OwnedFd, pub stdout_fd: OwnedFd, } /// Redirect stdout/stderr for subprocess isolation. /// /// CRITICAL: Must be called before FFI initialization. Child processes spawned after /// this will inherit the capture pipes (not the control channel). #[cfg(unix)] pub fn redirect_fds_for_subprocess_isolation( setup_log_tx: mpsc::Sender, ) -> io::Result { // Safety: Called early in worker startup before FFI initialization (tokio runtime threads // exist but aren't accessing fds 0/1/2). dup/dup2 are atomic. BorrowedFd::borrow_raw is // safe because we're borrowing standard fds (0, 1, 2) which are guaranteed to be open. tracing::debug!("Preserving control channel to high fds"); let control_stdin = unsafe { let fd = BorrowedFd::borrow_raw(0); dup(fd) } .map_err(|e| io::Error::other(format!("dup(0) failed: {}", e)))?; let control_stdout = unsafe { let fd = BorrowedFd::borrow_raw(1); dup(fd) } .map_err(|e| io::Error::other(format!("dup(1) failed: {}", e)))?; let worker_stderr = unsafe { let fd = BorrowedFd::borrow_raw(2); dup(fd) } .map_err(|e| io::Error::other(format!("dup(2) failed: {}", e)))?; tracing::trace!( control_stdin = control_stdin.as_raw_fd(), control_stdout = control_stdout.as_raw_fd(), worker_stderr = worker_stderr.as_raw_fd(), "Duped original fds" ); let mut target_stdin = unsafe { OwnedFd::from_raw_fd(CONTROL_STDIN_FD) }; dup2(&control_stdin, &mut target_stdin) .map_err(|e| io::Error::other(format!("dup2 stdin failed: {}", e)))?; std::mem::forget(target_stdin); // Don't close, we'll use it later let mut target_stdout = unsafe { OwnedFd::from_raw_fd(CONTROL_STDOUT_FD) }; dup2(&control_stdout, &mut target_stdout) .map_err(|e| io::Error::other(format!("dup2 stdout failed: {}", e)))?; std::mem::forget(target_stdout); // Don't close, we'll use it later let mut target_stderr = unsafe { OwnedFd::from_raw_fd(WORKER_STDERR_FD) }; dup2(&worker_stderr, &mut target_stderr) .map_err(|e| io::Error::other(format!("dup2 stderr failed: {}", e)))?; std::mem::forget(target_stderr); // Don't close, we'll use it later tracing::trace!( stdin_fd = CONTROL_STDIN_FD, stdout_fd = CONTROL_STDOUT_FD, stderr_fd = WORKER_STDERR_FD, "Moved control channel to high fds" ); // Temps are now duplicated at high positions, safe to close drop(control_stdin); drop(control_stdout); drop(worker_stderr); tracing::debug!("Creating capture pipes for stdout/stderr"); let (stdout_read, stdout_write) = pipe().map_err(|e| io::Error::other(format!("pipe failed: {}", e)))?; let (stderr_read, stderr_write) = pipe().map_err(|e| io::Error::other(format!("pipe failed: {}", e)))?; tracing::trace!( stdout_read = stdout_read.as_raw_fd(), stdout_write = stdout_write.as_raw_fd(), stderr_read = stderr_read.as_raw_fd(), stderr_write = stderr_write.as_raw_fd(), "Created capture pipes" ); let mut target_fd1 = unsafe { OwnedFd::from_raw_fd(1) }; dup2(&stdout_write, &mut target_fd1) .map_err(|e| io::Error::other(format!("dup2(stdout) failed: {}", e)))?; std::mem::forget(target_fd1); // Don't close fd 1 let mut target_fd2 = unsafe { OwnedFd::from_raw_fd(2) }; dup2(&stderr_write, &mut target_fd2) .map_err(|e| io::Error::other(format!("dup2(stderr) failed: {}", e)))?; std::mem::forget(target_fd2); // Don't close fd 2 tracing::trace!("Replaced fd 1/2 with capture pipes"); // Write ends are duped to 1/2, close originals drop(stdout_write); drop(stderr_write); tracing::debug!("Spawning capture threads"); // Capture both stdout and stderr from subprocesses. Rust tracing was initialized before // redirection, so its output also flows through the stderr pipe. All captured output // routes to coglet::user target. Bounded channel (500 messages) provides backpressure // if subprocess output exceeds processing rate. let stdout_tx = setup_log_tx.clone(); let stdout_read_raw = stdout_read.as_raw_fd(); std::thread::spawn(move || { // NOTE: No tracing in capture threads - would create feedback loop (stderr is captured) // Safety: We own stdout_read (moved into this thread) let mut file = unsafe { std::fs::File::from_raw_fd(stdout_read_raw) }; let mut buf = [0u8; 4096]; loop { match std::io::Read::read(&mut file, &mut buf) { Ok(0) => break, Ok(n) => { let data = String::from_utf8_lossy(&buf[..n]).to_string(); if stdout_tx .blocking_send(ControlResponse::Log { source: LogSource::Stdout, data, }) .is_err() { break; } } Err(e) if e.kind() == io::ErrorKind::Interrupted => continue, Err(_) => break, } } }); std::mem::forget(stdout_read); // Ownership transferred to thread let stderr_tx = setup_log_tx; let stderr_read_raw = stderr_read.as_raw_fd(); std::thread::spawn(move || { // NOTE: No tracing in capture threads - would create feedback loop (stderr is captured) // Safety: We own stderr_read (moved into this thread) let mut file = unsafe { std::fs::File::from_raw_fd(stderr_read_raw) }; let mut buf = [0u8; 4096]; loop { match std::io::Read::read(&mut file, &mut buf) { Ok(0) => break, Ok(n) => { let data = String::from_utf8_lossy(&buf[..n]).to_string(); if stderr_tx .blocking_send(ControlResponse::Log { source: LogSource::Stderr, data, }) .is_err() { break; } } Err(e) if e.kind() == io::ErrorKind::Interrupted => continue, Err(_) => break, } } }); std::mem::forget(stderr_read); // Ownership transferred to thread // Note: Both stdout and stderr now point to capture pipes. Rust tracing was initialized // before fd redirection to write to stderr, so its output will be captured along with // subprocess stderr. Both will be routed to coglet::user target. The original stderr // is still available at fd 101 but unused after redirection. tracing::info!("File descriptor redirection complete"); // Safety: We own these fds Ok(ControlChannelFds { stdin_fd: unsafe { OwnedFd::from_raw_fd(CONTROL_STDIN_FD) }, stdout_fd: unsafe { OwnedFd::from_raw_fd(CONTROL_STDOUT_FD) }, }) } #[cfg(not(unix))] pub struct ControlChannelFds { pub stdin_fd: std::io::Stdin, pub stdout_fd: std::io::Stdout, } #[cfg(not(unix))] pub fn redirect_fds_for_subprocess_isolation( _setup_log_tx: tokio::sync::mpsc::Sender, ) -> io::Result { // No fd redirection on non-Unix - subprocesses will pollute control channel Ok(ControlChannelFds { stdin_fd: std::io::stdin(), stdout_fd: std::io::stdout(), }) } ================================================ FILE: crates/coglet/src/health.rs ================================================ //! Health status types for coglet runtime. use serde::{Deserialize, Serialize}; /// Health status of the coglet runtime. #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub enum Health { /// Just started, status unknown #[default] Unknown, /// Running setup() Starting, /// Ready to accept predictions Ready, /// At capacity (all slots in use) Busy, /// setup() failed SetupFailed, /// Unrecoverable error Defunct, } /// Response-only health status (includes transient states like UNHEALTHY). /// Used in HTTP responses but not stored as internal state. #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub enum HealthResponse { Unknown, Starting, Ready, Busy, SetupFailed, Defunct, /// User-defined healthcheck failed (transient - not stored) Unhealthy, } impl From for HealthResponse { fn from(health: Health) -> Self { match health { Health::Unknown => HealthResponse::Unknown, Health::Starting => HealthResponse::Starting, Health::Ready => HealthResponse::Ready, Health::Busy => HealthResponse::Busy, Health::SetupFailed => HealthResponse::SetupFailed, Health::Defunct => HealthResponse::Defunct, } } } /// Status of the setup phase. #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum SetupStatus { Starting, Succeeded, Failed, } /// Result of the setup phase. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SetupResult { /// When setup started (ISO 8601 format). pub started_at: String, /// When setup completed (ISO 8601 format), if finished. #[serde(skip_serializing_if = "Option::is_none")] pub completed_at: Option, /// Status of setup. #[serde(skip_serializing_if = "Option::is_none")] pub status: Option, /// Captured logs during setup. #[serde(default, skip_serializing_if = "String::is_empty")] pub logs: String, } impl SetupResult { /// Create a new SetupResult with the current time as started_at. pub fn starting() -> Self { Self { started_at: chrono::Utc::now().to_rfc3339(), completed_at: None, status: Some(SetupStatus::Starting), logs: String::new(), } } /// Mark setup as succeeded with accumulated logs. pub fn succeeded(mut self, logs: String) -> Self { self.completed_at = Some(chrono::Utc::now().to_rfc3339()); self.status = Some(SetupStatus::Succeeded); self.logs = logs; self } /// Mark setup as failed with error logs. pub fn failed(mut self, logs: String) -> Self { self.completed_at = Some(chrono::Utc::now().to_rfc3339()); self.status = Some(SetupStatus::Failed); self.logs = logs; self } } #[cfg(test)] mod tests { use super::*; #[test] fn health_default_is_unknown() { assert_eq!(Health::default(), Health::Unknown); } #[test] fn health_serializes_screaming_snake_case() { insta::assert_json_snapshot!( "health_all_variants", [ Health::Unknown, Health::Starting, Health::Ready, Health::Busy, Health::SetupFailed, Health::Defunct, ] ); } #[test] fn health_response_serializes_screaming_snake_case() { insta::assert_json_snapshot!( "health_response_all_variants", [ HealthResponse::Unknown, HealthResponse::Starting, HealthResponse::Ready, HealthResponse::Busy, HealthResponse::SetupFailed, HealthResponse::Defunct, HealthResponse::Unhealthy, ] ); } #[test] fn health_deserializes_screaming_snake_case() { assert_eq!( serde_json::from_str::("\"READY\"").unwrap(), Health::Ready ); assert_eq!( serde_json::from_str::("\"SETUP_FAILED\"").unwrap(), Health::SetupFailed ); } #[test] fn setup_status_serializes_lowercase() { insta::assert_json_snapshot!( "setup_status_all_variants", [ SetupStatus::Starting, SetupStatus::Succeeded, SetupStatus::Failed, ] ); } #[test] fn setup_status_deserializes_lowercase() { assert_eq!( serde_json::from_str::("\"succeeded\"").unwrap(), SetupStatus::Succeeded ); } } ================================================ FILE: crates/coglet/src/input_validation.rs ================================================ //! Input validation against the OpenAPI schema. //! //! Validates prediction inputs before dispatching to the Python worker, //! catching missing required fields and unknown fields early with clear //! error messages (matching the format users expect from pydantic). use std::collections::HashSet; use serde_json::Value; /// A single validation error for one field. #[derive(Debug)] pub struct ValidationError { /// Field name (used as loc[2] in the pydantic-compatible response). pub field: String, /// Human-readable error message. pub msg: String, /// Error type string (e.g. "value_error.missing"). pub error_type: String, } /// Compiled input validator built from the OpenAPI schema's Input component. pub struct InputValidator { validator: jsonschema::Validator, /// Known property names from the schema. properties: HashSet, /// Required field names from the schema. required: Vec, } impl InputValidator { /// Build a validator from a full OpenAPI schema document. /// /// Extracts `components.schemas.Input`, injects `additionalProperties: false` /// (for pydantic parity), and compiles a JSON Schema validator. /// /// Returns None if the schema doesn't contain an Input component. pub fn from_openapi_schema(schema: &Value) -> Option { Self::from_openapi_schema_key(schema, "Input") } /// Build a validator from a full OpenAPI schema document using a custom /// schema key (e.g. "TrainingInput" for train endpoints). /// /// Returns None if the schema doesn't contain the specified component. pub fn from_openapi_schema_key(schema: &Value, key: &str) -> Option { let input_schema = schema.get("components")?.get("schemas")?.get(key)?; let properties: HashSet = input_schema .get("properties") .and_then(|p| p.as_object()) .map(|obj| obj.keys().cloned().collect()) .unwrap_or_default(); let required: Vec = input_schema .get("required") .and_then(|r| r.as_array()) .map(|a| { a.iter() .filter_map(|v| v.as_str().map(String::from)) .collect() }) .unwrap_or_default(); // Clone and inject additionalProperties: false for pydantic parity let mut resolved = input_schema.clone(); if let Some(obj) = resolved.as_object_mut() { obj.insert("additionalProperties".to_string(), Value::Bool(false)); } // Inline $ref pointers so the validator can resolve them without // the full OpenAPI document context. cog-schema-gen emits $ref for // enum choices (e.g. "#/components/schemas/Color"). let all_schemas = schema.get("components").and_then(|c| c.get("schemas")); inline_refs(&mut resolved, all_schemas); let validator = jsonschema::validator_for(&resolved) .inspect_err(|e| { tracing::warn!(error = %e, "Failed to compile input schema validator"); }) .ok()?; Some(Self { validator, properties, required, }) } pub fn required_count(&self) -> usize { self.required.len() } /// Validate an input value against the schema. /// /// Returns Ok(()) on success, or a list of per-field validation errors /// formatted for the pydantic-compatible `detail` response. pub fn validate(&self, input: &Value) -> Result<(), Vec> { if self.validator.validate(input).is_ok() { return Ok(()); } let mut errors = Vec::new(); let mut seen_required = false; let mut seen_additional = false; for error in self.validator.iter_errors(input) { let msg = error.to_string(); // "required" errors: emit one entry per missing field if msg.contains("is a required property") && !seen_required { seen_required = true; let input_obj = input.as_object(); for field in &self.required { let present = input_obj .map(|obj| obj.contains_key(field)) .unwrap_or(false); if !present { errors.push(ValidationError { field: field.clone(), msg: "Field required".to_string(), error_type: "value_error.missing".to_string(), }); } } continue; } // "additionalProperties" errors: emit one entry per unknown field if msg.contains("Additional properties") && !seen_additional { seen_additional = true; if let Some(input_obj) = input.as_object() { for key in input_obj.keys() { if !self.properties.contains(key) { errors.push(ValidationError { field: key.clone(), msg: format!("Unexpected field '{key}'"), error_type: "value_error.extra".to_string(), }); } } } continue; } // Skip duplicate required/additional messages if seen_required && msg.contains("is a required property") { continue; } if seen_additional && msg.contains("Additional properties") { continue; } // Type/constraint errors on specific fields let path = error.instance_path.to_string(); let field = path.trim_start_matches('/'); let field_name = if field.is_empty() { "__root__".to_string() } else { field.to_string() }; errors.push(ValidationError { field: field_name, msg, error_type: "value_error".to_string(), }); } if errors.is_empty() { Ok(()) } else { Err(errors) } } } /// Recursively inline `$ref` pointers in a JSON Schema value. /// /// Resolves `{"$ref": "#/components/schemas/Foo"}` by looking up `Foo` in the /// provided schemas map and replacing the `$ref` object with the referenced /// content. This allows the validator to work on an extracted subschema without /// needing the full OpenAPI document. fn inline_refs(value: &mut Value, all_schemas: Option<&Value>) { match value { Value::Object(obj) => { // If this object is a $ref, resolve it if let Some(Value::String(ref_str)) = obj.get("$ref") && let Some(resolved) = resolve_ref(ref_str, all_schemas) { *value = resolved; // Recurse into the resolved value (it may contain more $refs) inline_refs(value, all_schemas); return; } // Recurse into all values for v in obj.values_mut() { inline_refs(v, all_schemas); } } Value::Array(arr) => { for v in arr.iter_mut() { inline_refs(v, all_schemas); } } _ => {} } } /// Resolve a `$ref` string like `#/components/schemas/Foo` against the schemas map. fn resolve_ref(ref_str: &str, all_schemas: Option<&Value>) -> Option { let name = ref_str.strip_prefix("#/components/schemas/")?; all_schemas?.get(name).cloned() } #[cfg(test)] mod tests { use super::*; use serde_json::json; fn make_schema(input_schema: Value) -> Value { json!({ "components": { "schemas": { "Input": input_schema } } }) } #[test] fn validates_required_fields() { let schema = make_schema(json!({ "type": "object", "properties": { "s": {"type": "string", "title": "S"} }, "required": ["s"] })); let validator = InputValidator::from_openapi_schema(&schema).unwrap(); // Valid input assert!(validator.validate(&json!({"s": "hello"})).is_ok()); // Missing required field let errs = validator.validate(&json!({})).unwrap_err(); assert_eq!(errs.len(), 1); assert_eq!(errs[0].field, "s"); assert_eq!(errs[0].msg, "Field required"); } #[test] fn rejects_additional_properties() { let schema = make_schema(json!({ "type": "object", "properties": { "s": {"type": "string", "title": "S"} }, "required": ["s"] })); let validator = InputValidator::from_openapi_schema(&schema).unwrap(); // Extra field should fail let errs = validator .validate(&json!({"s": "hello", "extra": "bad"})) .unwrap_err(); assert_eq!(errs.len(), 1); assert_eq!(errs[0].field, "extra"); assert!(errs[0].msg.contains("Unexpected")); } #[test] fn missing_and_extra_fields() { let schema = make_schema(json!({ "type": "object", "properties": { "s": {"type": "string", "title": "S"} }, "required": ["s"] })); let validator = InputValidator::from_openapi_schema(&schema).unwrap(); // wrong=value with missing s let errs = validator.validate(&json!({"wrong": "value"})).unwrap_err(); assert!(errs.len() >= 2); let fields: Vec<&str> = errs.iter().map(|e| e.field.as_str()).collect(); assert!(fields.contains(&"s"), "should report missing s: {fields:?}"); assert!( fields.contains(&"wrong"), "should report extra wrong: {fields:?}" ); } #[test] fn validates_types() { let schema = make_schema(json!({ "type": "object", "properties": { "count": {"type": "integer", "title": "Count"} }, "required": ["count"] })); let validator = InputValidator::from_openapi_schema(&schema).unwrap(); assert!(validator.validate(&json!({"count": 5})).is_ok()); let errs = validator .validate(&json!({"count": "not_a_number"})) .unwrap_err(); assert_eq!(errs[0].field, "count"); } #[test] fn no_schema_returns_none() { let schema = json!({"components": {"schemas": {}}}); assert!(InputValidator::from_openapi_schema(&schema).is_none()); } #[test] fn resolves_ref_for_choices() { let schema = json!({ "components": { "schemas": { "Input": { "type": "object", "properties": { "color": { "allOf": [{"$ref": "#/components/schemas/Color"}], "x-order": 0 } }, "required": ["color"] }, "Color": { "title": "Color", "description": "An enumeration.", "enum": ["red", "green", "blue"], "type": "string" } } } }); let validator = InputValidator::from_openapi_schema(&schema); assert!(validator.is_some(), "validator should compile with $ref"); let validator = validator.unwrap(); assert!(validator.validate(&json!({"color": "red"})).is_ok()); assert!(validator.validate(&json!({"color": "purple"})).is_err()); } #[test] fn optional_fields_work() { let schema = make_schema(json!({ "type": "object", "properties": { "s": {"type": "string"}, "n": {"type": "integer"} }, "required": ["s"] })); let validator = InputValidator::from_openapi_schema(&schema).unwrap(); assert!(validator.validate(&json!({"s": "hello"})).is_ok()); assert!(validator.validate(&json!({"s": "hello", "n": 42})).is_ok()); } } ================================================ FILE: crates/coglet/src/lib.rs ================================================ //! coglet: Rust execution engine for cog models. mod health; pub mod input_validation; mod prediction; mod predictor; mod version; pub mod bridge; mod fd_redirect; pub mod orchestrator; pub mod permit; pub mod service; mod setup_log_accumulator; pub mod transport; pub mod webhook; pub mod worker; mod worker_tracing_layer; pub use orchestrator::Orchestrator; pub use service::{PredictionHandle, SyncPredictionGuard}; pub use health::{Health, SetupResult, SetupStatus}; pub use input_validation::InputValidator; pub use prediction::{CancellationToken, Prediction, PredictionOutput, PredictionStatus}; pub use predictor::{PredictionError, PredictionGuard, PredictionMetrics, PredictionResult}; pub use service::{CreatePredictionError, HealthSnapshot, PredictionService}; pub use setup_log_accumulator::{SetupLogAccumulator, drain_accumulated_logs}; pub use version::{COGLET_VERSION, VersionInfo}; pub use worker::{ PredictHandler, PredictResult, SetupError, SetupLogHook, SlotSender, WorkerConfig, run_worker, }; ================================================ FILE: crates/coglet/src/orchestrator.rs ================================================ //! Orchestrator - manages worker subprocess lifecycle and event loop. //! //! Flow: //! 1. Spawn worker subprocess //! 2. Send Init message, wait for Ready //! 3. Populate PermitPool with slot sockets //! 4. Run event loop routing responses to predictions //! 5. On worker crash: fail all predictions, shut down use std::collections::HashMap; use std::process::Stdio; use std::sync::Arc; use std::sync::Mutex as StdMutex; use std::time::Duration; use async_trait::async_trait; use futures::{SinkExt, StreamExt}; use tokio::process::{Child, Command}; use tokio::sync::mpsc; use tokio_util::codec::{FramedRead, FramedWrite}; use crate::PredictionOutput; use crate::bridge::codec::JsonCodec; use crate::bridge::protocol::{ ControlRequest, ControlResponse, FileOutputKind, HealthcheckStatus, SlotId, SlotRequest, SlotResponse, }; use crate::bridge::transport::create_transport; use crate::permit::{InactiveSlotIdleToken, PermitPool, SlotIdleToken}; use crate::prediction::Prediction; /// Upload a file to a signed endpoint, returning the final URL. /// /// Matches the behavior of Python cog's `put_file_to_signed_endpoint`: /// PUT to `{endpoint}{filename}` with Content-Type header, then extract /// the final URL from the Location header (falling back to response URL), /// stripping query parameters. Follows redirects automatically. async fn upload_file( endpoint: &str, filename: &str, data: &[u8], content_type: &str, ) -> Result { let url = format!("{endpoint}{filename}"); let client = reqwest::Client::new(); let resp = client .put(&url) .header("Content-Type", content_type) .body(data.to_vec()) .timeout(std::time::Duration::from_secs(25)) .send() .await .map_err(|e| format!("upload request failed: {e}"))?; if !resp.status().is_success() { return Err(format!("upload returned status {}", resp.status())); } // Prefer Location header, fall back to final request URL (after redirects) let final_url = resp .headers() .get("location") .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()) .unwrap_or_else(|| resp.url().to_string()); // Strip query parameters (signing gubbins) match reqwest::Url::parse(&final_url) { Ok(mut parsed) => { parsed.set_query(None); Ok(parsed.to_string()) } Err(_) => Ok(final_url), } } fn ensure_trailing_slash(s: &str) -> String { if s.ends_with('/') { s.to_string() } else { format!("{s}/") } } /// Try to lock a prediction mutex. /// On poison: logs error, recovers to fail the prediction, returns None. /// Caller should remove the prediction from tracking if None is returned. fn try_lock_prediction( pred: &Arc>, ) -> Option> { match pred.lock() { Ok(guard) => Some(guard), Err(poisoned) => { tracing::error!("Prediction mutex poisoned - failing prediction"); let mut guard = poisoned.into_inner(); if !guard.is_terminal() { guard.set_failed("Internal error: mutex poisoned".to_string()); } None } } } /// Wrap collected output items into the correct `PredictionOutput` variant. /// /// Priority: /// 1. Schema says `"type": "array"` (`output_is_array = true`) → always `Stream` /// 2. Predictor signals `is_stream` (list/generator return) → always `Stream` /// 3. Otherwise → `Single` for one item, `Stream` for multiple /// /// This ensures `List[Path]` with a single element returns `["url"]` not `"url"`. fn wrap_outputs( outputs: Vec, output_is_array: bool, is_stream: bool, ) -> PredictionOutput { let should_stream = output_is_array || is_stream; match outputs.as_slice() { [] => { if should_stream { PredictionOutput::Stream(vec![]) } else { PredictionOutput::Single(serde_json::Value::Null) } } _ if should_stream => PredictionOutput::Stream(outputs), [single] => PredictionOutput::Single(single.clone()), _ => PredictionOutput::Stream(outputs), } } fn emit_worker_log(target: &str, level: &str, msg: &str) { use std::collections::HashMap; use std::sync::OnceLock; use tracing::{ Level, Metadata, callsite::{Callsite, Identifier}, field::FieldSet, }; struct DummyCallsite; impl Callsite for DummyCallsite { fn set_interest(&self, _: tracing::subscriber::Interest) {} fn metadata(&self) -> &Metadata<'static> { unreachable!() } } static DUMMY: DummyCallsite = DummyCallsite; static CALLSITES: OnceLock< std::sync::Mutex>>, > = OnceLock::new(); static FIELDS: &[&str] = &["message"]; let lvl = match level { "error" => Level::ERROR, "warn" => Level::WARN, "info" => Level::INFO, "debug" => Level::DEBUG, "trace" => Level::TRACE, _ => Level::INFO, }; let target_static: &'static str = Box::leak(target.to_string().into_boxed_str()); let callsites = CALLSITES.get_or_init(|| std::sync::Mutex::new(HashMap::new())); let mut map = match callsites.lock() { Ok(guard) => guard, Err(_poisoned) => { tracing::error!("Worker log callsite cache poisoned"); return; } }; let meta = map.entry((target_static, lvl)).or_insert_with(|| { Metadata::new( "worker_log", target_static, lvl, Some(file!()), Some(line!()), Some(module_path!()), FieldSet::new(FIELDS, Identifier(&DUMMY)), tracing::metadata::Kind::EVENT, ) }); let meta_ref = meta as *const Metadata<'static>; drop(map); let meta = unsafe { &*meta_ref }; tracing::dispatcher::get_default(|dispatch| { if dispatch.enabled(meta) { let fields = meta.fields(); if let Some(field) = fields.field("message") { let value_array = [(&field, Some(&msg as &dyn tracing::Value))]; let values = fields.value_set(&value_array); dispatch.event(&tracing::Event::new(meta, &values)); } } }); } /// Result of a user-defined healthcheck. #[derive(Debug, Clone)] pub struct HealthcheckResult { pub status: HealthcheckStatus, pub error: Option, } impl HealthcheckResult { pub fn healthy() -> Self { Self { status: HealthcheckStatus::Healthy, error: None, } } pub fn unhealthy(error: impl Into) -> Self { Self { status: HealthcheckStatus::Unhealthy, error: Some(error.into()), } } pub fn is_healthy(&self) -> bool { self.status == HealthcheckStatus::Healthy } } /// Trait for prediction registration with the orchestrator. /// /// This abstraction enables testing the service layer without a real worker subprocess. /// The service only needs to register predictions for response routing - all other /// orchestrator operations happen outside the predict path. #[async_trait] pub trait Orchestrator: Send + Sync { /// Register a prediction for response routing in the event loop. async fn register_prediction( &self, slot_id: SlotId, prediction: Arc>, idle_sender: tokio::sync::oneshot::Sender, ); /// Cancel a prediction by its prediction ID. /// /// The orchestrator resolves the prediction ID to a slot ID and sends /// a cancel request to the worker over the control socket. async fn cancel_by_prediction_id(&self, prediction_id: &str) -> Result<(), OrchestratorError>; /// Run user-defined healthcheck if available. async fn healthcheck(&self) -> Result; /// Shutdown the orchestrator and worker gracefully. async fn shutdown(&self) -> Result<(), OrchestratorError>; } #[derive(Debug, Clone)] pub struct WorkerSpawnConfig { pub num_slots: usize, } #[derive(Debug, thiserror::Error)] pub enum SpawnError { #[error("failed to spawn process: {0}")] Spawn(#[from] std::io::Error), #[error("spawn failed: {0}")] Other(String), } /// Extension point for different worker spawn strategies. pub trait WorkerSpawner: Send + Sync { fn spawn(&self, config: &WorkerSpawnConfig) -> Result; } /// Simple spawner using Python subprocess. pub struct SimpleSpawner; impl WorkerSpawner for SimpleSpawner { fn spawn(&self, _config: &WorkerSpawnConfig) -> Result { let child = Command::new("python") .args(["-c", "import coglet; coglet.server._run_worker()"]) .stdin(Stdio::piped()) .stdout(Stdio::piped()) .stderr(Stdio::inherit()) .spawn()?; Ok(child) } } pub struct OrchestratorConfig { pub predictor_ref: String, pub num_slots: usize, pub is_train: bool, pub is_async: bool, pub setup_timeout: Option, pub spawner: Arc, /// Upload URL prefix for file outputs (from --upload-url CLI arg). pub upload_url: Option, } impl OrchestratorConfig { pub fn new(predictor_ref: impl Into) -> Self { Self { predictor_ref: predictor_ref.into(), num_slots: 1, is_train: false, is_async: false, setup_timeout: None, spawner: Arc::new(SimpleSpawner), upload_url: None, } } pub fn with_upload_url(mut self, upload_url: Option) -> Self { self.upload_url = upload_url; self } pub fn with_num_slots(mut self, n: usize) -> Self { self.num_slots = n; self } pub fn with_train(mut self, is_train: bool) -> Self { self.is_train = is_train; self } pub fn with_async(mut self, is_async: bool) -> Self { self.is_async = is_async; self } pub fn with_setup_timeout(mut self, timeout: Option) -> Self { self.setup_timeout = timeout; self } pub fn with_spawner(mut self, spawner: Arc) -> Self { self.spawner = spawner; self } } pub struct OrchestratorReady { pub pool: Arc, pub schema: Option, pub handle: OrchestratorHandle, pub setup_logs: String, } pub struct OrchestratorHandle { child: Child, ctrl_writer: Arc>>>, register_tx: mpsc::Sender<( SlotId, Arc>, tokio::sync::oneshot::Sender, )>, healthcheck_tx: mpsc::Sender>, cancel_tx: mpsc::Sender, slot_ids: Vec, } #[async_trait] impl Orchestrator for OrchestratorHandle { async fn register_prediction( &self, slot_id: SlotId, prediction: Arc>, idle_sender: tokio::sync::oneshot::Sender, ) { let _ = self .register_tx .send((slot_id, prediction, idle_sender)) .await; } async fn cancel_by_prediction_id(&self, prediction_id: &str) -> Result<(), OrchestratorError> { self.cancel_tx .send(prediction_id.to_string()) .await .map_err(|_| OrchestratorError::Protocol("cancel channel closed".to_string())) } async fn healthcheck(&self) -> Result { tracing::trace!("Healthcheck requested via orchestrator handle"); let (response_tx, response_rx) = tokio::sync::oneshot::channel(); // Send our channel to the event loop. If a healthcheck is already // in-flight, the event loop coalesces — we get the same result as // all other waiters when it comes back. self.healthcheck_tx .send(response_tx) .await .map_err(|_| OrchestratorError::Protocol("healthcheck channel closed".to_string()))?; // Wait for the response with a timeout (worker has 5s, we give 10s total). // If we time out, the healthcheck keeps running — our sender just gets a // silent failure when the event loop eventually broadcasts. match tokio::time::timeout(Duration::from_secs(10), response_rx).await { Ok(Ok(result)) => { tracing::trace!(healthy = result.is_healthy(), "Healthcheck completed"); Ok(result) } Ok(Err(_)) => { tracing::debug!("Healthcheck response channel dropped"); Err(OrchestratorError::Protocol( "healthcheck response channel dropped".to_string(), )) } Err(_) => { tracing::debug!("Healthcheck timed out after 10s"); Ok(HealthcheckResult::unhealthy("healthcheck timed out")) } } } async fn shutdown(&self) -> Result<(), OrchestratorError> { let mut writer = self.ctrl_writer.lock().await; writer .send(ControlRequest::Shutdown) .await .map_err(|e| OrchestratorError::Protocol(format!("failed to send shutdown: {}", e))) } } impl OrchestratorHandle { pub async fn cancel(&self, slot_id: SlotId) -> Result<(), OrchestratorError> { let mut writer = self.ctrl_writer.lock().await; writer .send(ControlRequest::Cancel { slot: slot_id }) .await .map_err(|e| OrchestratorError::Protocol(format!("failed to send cancel: {}", e))) } pub fn slot_ids(&self) -> &[SlotId] { &self.slot_ids } pub async fn wait(&mut self) -> Result<(), OrchestratorError> { self.child.wait().await.map_err(|e| { OrchestratorError::Protocol(format!("failed to wait for worker: {}", e)) })?; Ok(()) } } #[derive(Debug, thiserror::Error)] pub enum OrchestratorError { #[error("failed to spawn worker: {0}")] Spawn(String), #[error("worker setup failed: {0}")] Setup(String), #[error("worker setup timed out")] SetupTimeout, #[error("protocol error: {0}")] Protocol(String), #[error("worker crashed")] WorkerCrashed, } pub async fn spawn_worker( config: OrchestratorConfig, setup_log_rx: &mut tokio::sync::mpsc::UnboundedReceiver, ) -> Result { let num_slots = config.num_slots; tracing::info!(num_slots, "Creating slot transport"); let (mut transport, child_transport_info) = create_transport(num_slots) .await .map_err(|e| OrchestratorError::Spawn(format!("failed to create transport: {}", e)))?; tracing::info!("Spawning worker subprocess"); let spawn_config = WorkerSpawnConfig { num_slots }; let mut child = config .spawner .spawn(&spawn_config) .map_err(|e| OrchestratorError::Spawn(format!("spawner failed: {}", e)))?; let stdin = child .stdin .take() .ok_or_else(|| OrchestratorError::Spawn("stdin not captured".to_string()))?; let stdout = child .stdout .take() .ok_or_else(|| OrchestratorError::Spawn("stdout not captured".to_string()))?; let mut ctrl_writer = FramedWrite::new(stdin, JsonCodec::::new()); let mut ctrl_reader = FramedRead::new(stdout, JsonCodec::::new()); tracing::debug!("Sending Init to worker"); ctrl_writer .send(ControlRequest::Init { predictor_ref: config.predictor_ref.clone(), num_slots, transport_info: child_transport_info, is_train: config.is_train, is_async: config.is_async, }) .await .map_err(|e| OrchestratorError::Protocol(format!("failed to send Init: {}", e)))?; tracing::debug!("Waiting for worker to connect to slot sockets"); transport .accept_connections(num_slots) .await .map_err(|e| OrchestratorError::Spawn(format!("failed to accept connections: {}", e)))?; tracing::debug!("Waiting for Ready from worker"); let setup_fut = async { loop { match ctrl_reader.next().await { Some(Ok(ControlResponse::Ready { slots, schema })) => { return Ok((slots, schema)); } Some(Ok(ControlResponse::Log { source, data })) => { for line in data.lines() { tracing::info!(target: "coglet::setup", source = ?source, "{}", line); } } Some(Ok(ControlResponse::WorkerLog { target, level, message, })) => { emit_worker_log(&target, &level, &message); } Some(Ok(ControlResponse::DroppedLogs { count, interval_millis, })) => { tracing::trace!(count, interval_millis, "Received DroppedLogs during setup"); let interval_secs = interval_millis as f64 / 1000.0; tracing::warn!( "Log production exceeds consumption rate during setup. {} logs dropped in last {:.1}s", count, interval_secs ); } Some(Ok(ControlResponse::Failed { slot, error })) => { return Err(OrchestratorError::Setup(format!( "worker setup failed (slot {}): {}", slot, error ))); } Some(Ok(ControlResponse::Fatal { reason })) => { return Err(OrchestratorError::Setup(format!( "worker fatal: {}", reason ))); } Some(Ok(other)) => { tracing::warn!(?other, "Unexpected message during setup"); } Some(Err(e)) => { return Err(OrchestratorError::Protocol(format!( "control channel error: {}", e ))); } None => { return Err(OrchestratorError::WorkerCrashed); } } } }; let (slot_ids, schema) = match config.setup_timeout { Some(timeout) => { tracing::debug!( timeout_secs = timeout.as_secs(), "Waiting for setup with timeout" ); match tokio::time::timeout(timeout, setup_fut).await { Ok(Ok((slots, schema))) => { tracing::debug!(num_slots = slots.len(), "Setup completed within timeout"); (slots, schema) } Ok(Err(e)) => { tracing::debug!(error = %e, "Setup failed"); return Err(e); } Err(_) => { tracing::debug!(timeout_secs = timeout.as_secs(), "Setup timed out"); return Err(OrchestratorError::SetupTimeout); } } } None => { tracing::debug!("Waiting for setup with no timeout"); setup_fut.await? } }; let setup_logs = crate::setup_log_accumulator::drain_accumulated_logs(setup_log_rx); tracing::debug!( setup_logs_len = setup_logs.len(), "Drained accumulated setup logs" ); tracing::debug!(num_slots = slot_ids.len(), "Worker ready"); if let Some(ref s) = schema && let Ok(json) = serde_json::to_string_pretty(s) { tracing::trace!(target: "coglet::schema", schema = %json, "OpenAPI schema"); } // Determine whether the output type is an array from the schema so the // event loop can correctly wrap single-element list returns as Stream // instead of collapsing them to Single. let output_is_array = schema .as_ref() .and_then(|s| s.get("components")) .and_then(|c| c.get("schemas")) .and_then(|schemas| { let key = if config.is_train { "TrainingOutput" } else { "Output" }; schemas.get(key) }) .and_then(|output| output.get("type")) .and_then(|t| t.as_str()) .is_some_and(|t| t == "array"); let pool = Arc::new(PermitPool::new(num_slots)); let sockets = transport.drain_sockets(); let mut slot_readers = Vec::with_capacity(num_slots); for (slot_id, socket) in slot_ids.iter().zip(sockets) { let (read_half, write_half) = socket.into_split(); let writer = FramedWrite::new(write_half, JsonCodec::::new()); pool.add_permit(*slot_id, writer); let reader = FramedRead::new(read_half, JsonCodec::::new()); slot_readers.push((*slot_id, reader)); } let (register_tx, register_rx) = mpsc::channel(num_slots); let (healthcheck_tx, healthcheck_rx) = mpsc::channel(1); let (cancel_tx, cancel_rx) = mpsc::channel(16); let ctrl_writer = Arc::new(tokio::sync::Mutex::new(ctrl_writer)); let handle = OrchestratorHandle { child, ctrl_writer: Arc::clone(&ctrl_writer), register_tx, healthcheck_tx, cancel_tx, slot_ids: slot_ids.clone(), }; let pool_for_loop = Arc::clone(&pool); let ctrl_writer_for_loop = Arc::clone(&ctrl_writer); let upload_url = config.upload_url.clone(); tokio::spawn(async move { run_event_loop( ctrl_reader, ctrl_writer_for_loop, slot_readers, register_rx, healthcheck_rx, cancel_rx, pool_for_loop, upload_url, output_is_array, ) .await; }); Ok(OrchestratorReady { pool, schema, handle, setup_logs, }) } #[allow(clippy::too_many_arguments)] async fn run_event_loop( mut ctrl_reader: FramedRead>, ctrl_writer: Arc< tokio::sync::Mutex>>, >, slot_readers: Vec<( SlotId, FramedRead>, )>, mut register_rx: mpsc::Receiver<( SlotId, Arc>, tokio::sync::oneshot::Sender, )>, mut healthcheck_rx: mpsc::Receiver>, mut cancel_rx: mpsc::Receiver, pool: Arc, upload_url: Option, // Schema says Output is "type": "array" — always wrap as Stream. // When false, the schema was unavailable or Output type is Any; fall // back to the predictor's is_stream flag on the Done message. output_is_array: bool, ) { let mut predictions: HashMap>> = HashMap::new(); let mut idle_senders: HashMap> = HashMap::new(); let mut pending_healthchecks: Vec> = Vec::new(); let mut healthcheck_counter: u64 = 0; let mut pending_uploads: HashMap>> = HashMap::new(); let (slot_msg_tx, mut slot_msg_rx) = mpsc::channel::<(SlotId, Result)>(100); for (slot_id, mut reader) in slot_readers { let tx = slot_msg_tx.clone(); tokio::spawn(async move { loop { let msg = reader.next().await; match msg { Some(Ok(response)) => { if tx.send((slot_id, Ok(response))).await.is_err() { break; } } Some(Err(e)) => { let _ = tx.send((slot_id, Err(e))).await; break; } None => { break; } } } tracing::debug!(%slot_id, "Slot reader task exiting"); }); } drop(slot_msg_tx); loop { tokio::select! { biased; ctrl_msg = ctrl_reader.next() => { match ctrl_msg { Some(Ok(ControlResponse::Idle { slot })) => { tracing::debug!(%slot, "Slot idle notification received (control channel)"); match idle_senders.remove(&slot) { Some(sender) => { let token = InactiveSlotIdleToken::new(slot); if sender.send(token.activate()).is_err() { tracing::warn!(%slot, "Idle token receiver dropped before idle confirmation"); } } None => { tracing::warn!(%slot, "Received Idle for slot with no pending idle confirmation"); } } } Some(Ok(ControlResponse::Cancelled { slot })) => { tracing::debug!(%slot, "Slot cancelled (control channel)"); } Some(Ok(ControlResponse::Failed { slot, error })) => { tracing::warn!(%slot, %error, "Slot poisoned"); pool.poison(slot); if let Some(pred) = predictions.remove(&slot) && let Some(mut p) = try_lock_prediction(&pred) && !p.is_terminal() { p.set_failed(error); } } Some(Ok(ControlResponse::Fatal { reason })) => { tracing::error!(%reason, "Worker fatal"); for (slot, pred) in predictions.drain() { tracing::warn!(%slot, "Failing prediction due to worker fatal error"); pool.poison(slot); if let Some(mut p) = try_lock_prediction(&pred) && !p.is_terminal() { p.set_failed(reason.clone()); } } let result = HealthcheckResult::unhealthy(&reason); for tx in pending_healthchecks.drain(..) { let _ = tx.send(result.clone()); } break; } Some(Ok(ControlResponse::Ready { .. })) => { tracing::warn!("Unexpected Ready in event loop"); } Some(Ok(ControlResponse::Log { source: _, data })) => { for line in data.lines() { tracing::info!(target: "coglet::user", "{}", line); } } Some(Ok(ControlResponse::WorkerLog { target, level, message })) => { emit_worker_log(&target, &level, &message); } Some(Ok(ControlResponse::DroppedLogs { count, interval_millis })) => { tracing::trace!(count, interval_millis, "Received DroppedLogs message"); let interval_secs = interval_millis as f64 / 1000.0; tracing::warn!( "Log production exceeds consumption rate. {} logs dropped in last {:.1}s", count, interval_secs ); } Some(Ok(ControlResponse::HealthcheckResult { id: _, status, error })) => { tracing::trace!( ?status, ?error, pending_count = pending_healthchecks.len(), "Received healthcheck result from worker" ); if pending_healthchecks.is_empty() { tracing::warn!("Received healthcheck result but no pending requests"); } else { let result = match status { HealthcheckStatus::Healthy => HealthcheckResult::healthy(), HealthcheckStatus::Unhealthy => { HealthcheckResult::unhealthy(error.unwrap_or_else(|| "unhealthy".to_string())) } }; tracing::trace!( pending_count = pending_healthchecks.len(), "Distributing healthcheck result to pending callers" ); for tx in pending_healthchecks.drain(..) { let _ = tx.send(result.clone()); } } } Some(Ok(ControlResponse::ShuttingDown)) => { tracing::info!("Worker shutting down"); break; } Some(Err(e)) => { tracing::error!(error = %e, "Control channel error"); break; } None => { tracing::warn!("Control channel closed (worker crashed?)"); for (slot, pred) in predictions.drain() { tracing::warn!(%slot, "Failing prediction due to worker crash"); if let Some(mut p) = try_lock_prediction(&pred) { p.set_failed("Worker crashed".to_string()); } } // Fail any pending healthchecks for tx in pending_healthchecks.drain(..) { let _ = tx.send(HealthcheckResult::unhealthy("Worker crashed")); } break; } } } Some(response_tx) = healthcheck_rx.recv() => { let in_flight = !pending_healthchecks.is_empty(); pending_healthchecks.push(response_tx); // Only send to worker if no healthcheck is already in-flight. // Otherwise this caller just waits for the same result. if !in_flight { healthcheck_counter += 1; let hc_id = format!("hc_{}", healthcheck_counter); tracing::trace!(%hc_id, "Sending healthcheck request to worker"); let mut writer = ctrl_writer.lock().await; if let Err(e) = writer.send(ControlRequest::Healthcheck { id: hc_id }).await { tracing::error!(error = %e, "Failed to send healthcheck request"); let result = HealthcheckResult::unhealthy(format!("Failed to send: {}", e)); for tx in pending_healthchecks.drain(..) { let _ = tx.send(result.clone()); } } } else { tracing::trace!( pending_count = pending_healthchecks.len(), "Healthcheck already in-flight, coalescing request" ); } } Some(prediction_id) = cancel_rx.recv() => { // Resolve prediction_id → slot_id by iterating (fine for small concurrency) let slot = predictions.iter().find_map(|(sid, pred)| { try_lock_prediction(pred) .filter(|p| p.id() == prediction_id) .map(|_| *sid) }); match slot { Some(slot_id) => { tracing::info!( target: "coglet::prediction", %prediction_id, %slot_id, "Cancelling prediction" ); let mut writer = ctrl_writer.lock().await; if let Err(e) = writer.send(ControlRequest::Cancel { slot: slot_id }).await { tracing::error!( %slot_id, error = %e, "Failed to send cancel request to worker" ); } // Also abort any pending upload tasks for this slot if let Some(handles) = pending_uploads.remove(&slot_id) { for h in handles { h.abort(); } } } None => { tracing::debug!(%prediction_id, "Cancel requested for unknown prediction (may have already completed)"); } } } Some((slot_id, prediction, idle_sender)) = register_rx.recv() => { let prediction_id = match try_lock_prediction(&prediction) { Some(p) => p.id().to_string(), None => { // Mutex poisoned during registration - prediction already failed tracing::error!(%slot_id, "Prediction mutex poisoned during registration"); continue; } }; // NOTE: we insert the idle sender, and idle senders are only removed on consumption of the // `tokio::sync::oneshot::Sender`, this means the only time we'll leak memory here is if the // slot is poisoned or otherwise in a bad state. It is intentional that we don't remove idle // senders in any other case. idle_senders.insert(slot_id, idle_sender); tracing::info!( target: "coglet::prediction", %prediction_id, "Starting prediction" ); tracing::debug!(%slot_id, %prediction_id, "Registered prediction"); predictions.insert(slot_id, prediction); } Some((slot_id, result)) = slot_msg_rx.recv() => { match result { Ok(SlotResponse::Log { source, data }) => { let (prediction_id, poisoned) = if let Some(pred) = predictions.get(&slot_id) { if let Some(mut p) = try_lock_prediction(pred) { p.append_log(&data); (Some(p.id().to_string()), false) } else { (None, true) } } else { (None, false) }; // Remove poisoned predictions outside the borrow if poisoned { predictions.remove(&slot_id); } let trimmed = data.trim(); if !trimmed.is_empty() { if let Some(id) = prediction_id { tracing::info!( target: "coglet::prediction", prediction_id = %id, source = ?source, "{}", trimmed ); } else { tracing::warn!( target: "coglet::prediction", prediction_id = "NO_ACTIVE_PREDICTION", source = ?source, "{}", trimmed ); } } } Ok(SlotResponse::Metric { name, value, mode }) => { let poisoned = if let Some(pred) = predictions.get(&slot_id) { if let Some(mut p) = try_lock_prediction(pred) { p.set_metric(name, value, mode); false } else { true } } else { false }; if poisoned { predictions.remove(&slot_id); } } Ok(SlotResponse::Output { output }) => { let poisoned = if let Some(pred) = predictions.get(&slot_id) { if let Some(mut p) = try_lock_prediction(pred) { p.append_output(output); false } else { true } } else { false }; // Remove poisoned predictions outside the borrow if poisoned { predictions.remove(&slot_id); } } Ok(SlotResponse::FileOutput { filename, kind, mime_type }) => { tracing::debug!(%slot_id, %filename, ?kind, "FileOutput received"); let bytes = match std::fs::read(&filename) { Ok(b) => b, Err(e) => { tracing::error!(%slot_id, %filename, error = %e, "Failed to read FileOutput"); continue; } }; match kind { FileOutputKind::Oversized => { let output: serde_json::Value = match serde_json::from_slice(&bytes) { Ok(val) => val, Err(e) => { tracing::error!(%slot_id, %filename, error = %e, "Failed to parse oversized JSON"); continue; } }; let poisoned = if let Some(pred) = predictions.get(&slot_id) { if let Some(mut p) = try_lock_prediction(pred) { p.append_output(output); false } else { true } } else { false }; if poisoned { predictions.remove(&slot_id); } } FileOutputKind::FileType => { let mime = mime_type.unwrap_or_else(|| { mime_guess::from_path(&filename) .first_or_octet_stream() .to_string() }); if let Some(ref url) = upload_url { // Spawn upload task so we don't block the event loop let pred = predictions.get(&slot_id).cloned(); let endpoint = ensure_trailing_slash(url); let basename = std::path::Path::new(&filename) .file_name() .and_then(|n| n.to_str()) .unwrap_or("output") .to_string(); let handle = tokio::spawn(async move { match upload_file(&endpoint, &basename, &bytes, &mime).await { Ok(url) => { if let Some(pred) = pred && let Some(mut p) = try_lock_prediction(&pred) { p.append_output(serde_json::Value::String(url)); } } Err(e) => { tracing::error!(error = %e, "Failed to upload file output"); } } }); pending_uploads.entry(slot_id).or_default().push(handle); } else { // No upload URL — base64-encode as data URI use base64::Engine; let encoded = base64::engine::general_purpose::STANDARD .encode(&bytes); let output = serde_json::Value::String(format!( "data:{mime};base64,{encoded}" )); let poisoned = if let Some(pred) = predictions.get(&slot_id) { if let Some(mut p) = try_lock_prediction(pred) { p.append_output(output); false } else { true } } else { false }; if poisoned { predictions.remove(&slot_id); } } } } } Ok(SlotResponse::Done { id, output: _, predict_time, is_stream }) => { tracing::info!( target: "coglet::prediction", prediction_id = %id, predict_time, is_stream, output_is_array, "Prediction succeeded" ); let uploads = pending_uploads.remove(&slot_id).unwrap_or_default(); if let Some(pred) = predictions.remove(&slot_id) { if uploads.is_empty() { // No pending uploads — complete synchronously to avoid // a race between tokio::spawn and Notify::notified() in // service.rs. notify_waiters() only wakes already- // registered waiters; spawning a task can fire the // notification before the service registers its waiter. if let Some(mut p) = try_lock_prediction(&pred) { let pred_output = wrap_outputs( p.take_outputs(), output_is_array, is_stream, ); p.set_succeeded(pred_output); } } else { // Has pending uploads — must spawn to await them. // Clone the cancel token so we can abort uploads if // the prediction is cancelled while uploads are in flight. let (cancel_token, upload_pred_id) = match try_lock_prediction(&pred) { Some(p) => (Some(p.cancel_token()), p.id().to_string()), None => (None, id.clone()), }; tokio::spawn(async move { if let Some(token) = cancel_token { let upload_fut = futures::future::join_all(uploads); tokio::pin!(upload_fut); tokio::select! { _ = &mut upload_fut => {} _ = token.cancelled() => { tracing::info!( target: "coglet::prediction", prediction_id = %upload_pred_id, "Aborting in-flight uploads due to cancellation" ); if let Some(mut p) = try_lock_prediction(&pred) { p.set_canceled(); } return; } } } else { for h in uploads { let _ = h.await; } } if let Some(mut p) = try_lock_prediction(&pred) { let pred_output = wrap_outputs( p.take_outputs(), output_is_array, is_stream, ); p.set_succeeded(pred_output); } }); } } else { tracing::warn!(%slot_id, %id, "Prediction not found for Done message"); } } Ok(SlotResponse::Failed { id, error }) => { tracing::info!( target: "coglet::prediction", prediction_id = %id, %error, "Prediction failed" ); // Abort any pending uploads — prediction is terminal if let Some(handles) = pending_uploads.remove(&slot_id) { for h in handles { h.abort(); } } if let Some(pred) = predictions.remove(&slot_id) && let Some(mut p) = try_lock_prediction(&pred) { p.set_failed(error); } } Ok(SlotResponse::Cancelled { id }) => { tracing::info!( target: "coglet::prediction", prediction_id = %id, "Prediction cancelled" ); // Abort any pending uploads — prediction is terminal if let Some(handles) = pending_uploads.remove(&slot_id) { for h in handles { h.abort(); } } if let Some(pred) = predictions.remove(&slot_id) && let Some(mut p) = try_lock_prediction(&pred) { p.set_canceled(); } } Err(e) => { tracing::error!(%slot_id, error = %e, "Slot socket error"); if let Some(handles) = pending_uploads.remove(&slot_id) { for h in handles { h.abort(); } } if let Some(pred) = predictions.remove(&slot_id) && let Some(mut p) = try_lock_prediction(&pred) { p.set_failed(format!("Slot socket error: {}", e)); } } } } } } tracing::info!("Event loop exiting"); } #[cfg(test)] mod tests { use super::*; use serde_json::json; // ── wrap_outputs: schema says array (output_is_array = true) ── #[test] fn wrap_outputs_schema_array_empty() { // List[Path] that returned no items → empty array let result = wrap_outputs(vec![], true, true); assert!(result.is_stream()); assert_eq!(result.into_values(), Vec::::new()); } #[test] fn wrap_outputs_schema_array_single_item() { // List[Path] with num_outputs=1 → ["url"] not "url" let result = wrap_outputs(vec![json!("https://example.com/img.png")], true, true); assert!(result.is_stream()); assert_eq!( result.into_values(), vec![json!("https://example.com/img.png")] ); } #[test] fn wrap_outputs_schema_array_multiple_items() { // List[Path] with num_outputs=4 let items = vec![ json!("https://example.com/1.png"), json!("https://example.com/2.png"), json!("https://example.com/3.png"), json!("https://example.com/4.png"), ]; let result = wrap_outputs(items.clone(), true, true); assert!(result.is_stream()); assert_eq!(result.into_values(), items); } #[test] fn wrap_outputs_schema_array_overrides_is_stream_false() { // Schema says array but predictor didn't set is_stream (shouldn't happen, // but schema is authoritative) let result = wrap_outputs(vec![json!("url")], true, false); assert!(result.is_stream()); } // ── wrap_outputs: predictor signal (is_stream = true, no schema) ── #[test] fn wrap_outputs_predictor_stream_empty() { // Generator that yielded nothing, no schema let result = wrap_outputs(vec![], false, true); assert!(result.is_stream()); assert_eq!(result.into_values(), Vec::::new()); } #[test] fn wrap_outputs_predictor_stream_single_item() { // Any-typed list with one element, no schema let result = wrap_outputs(vec![json!("only_item")], false, true); assert!(result.is_stream()); assert_eq!(result.into_values(), vec![json!("only_item")]); } #[test] fn wrap_outputs_predictor_stream_multiple_items() { // Generator yielding multiple, no schema let items = vec![json!("a"), json!("b"), json!("c")]; let result = wrap_outputs(items.clone(), false, true); assert!(result.is_stream()); assert_eq!(result.into_values(), items); } // ── wrap_outputs: scalar output (neither schema array nor predictor stream) ── #[test] fn wrap_outputs_scalar_empty() { // Single output that was null (e.g. Path sent via FileOutput, not yet resolved?) let result = wrap_outputs(vec![], false, false); assert!(!result.is_stream()); assert_eq!(result.final_value(), &json!(null)); } #[test] fn wrap_outputs_scalar_single() { // return Path("output.png") → single string let result = wrap_outputs(vec![json!("https://example.com/output.png")], false, false); assert!(!result.is_stream()); assert_eq!( result.final_value(), &json!("https://example.com/output.png") ); } #[test] fn wrap_outputs_scalar_multiple_falls_back_to_stream() { // Shouldn't happen for scalar returns, but if multiple items arrive // with neither flag set, Stream is the safe choice let items = vec![json!("a"), json!("b")]; let result = wrap_outputs(items.clone(), false, false); assert!(result.is_stream()); assert_eq!(result.into_values(), items); } // ── Serialization: is_stream field on Done message ── #[test] fn done_is_stream_false_omitted_from_json() { let msg = SlotResponse::Done { id: "p1".into(), output: None, predict_time: 1.0, is_stream: false, }; let json = serde_json::to_value(&msg).unwrap(); assert!( json.get("is_stream").is_none(), "is_stream=false should be omitted" ); } #[test] fn done_is_stream_true_present_in_json() { let msg = SlotResponse::Done { id: "p1".into(), output: None, predict_time: 1.0, is_stream: true, }; let json = serde_json::to_value(&msg).unwrap(); assert_eq!(json.get("is_stream"), Some(&json!(true))); } #[test] fn done_without_is_stream_deserializes_as_false() { // Backward compat: old workers won't send is_stream let json = json!({ "type": "done", "id": "p1", "predict_time": 1.0 }); let msg: SlotResponse = serde_json::from_value(json).unwrap(); match msg { SlotResponse::Done { is_stream, .. } => assert!(!is_stream), _ => panic!("wrong variant"), } } } ================================================ FILE: crates/coglet/src/permit/mod.rs ================================================ //! Permit pool for concurrent slot management. //! //! The permit system uses typestate to enforce valid state transitions at compile time: //! - `PermitInUse` → `PermitIdle` via `into_idle()` (returns to pool on drop) //! - `PermitInUse` → `PermitPoisoned` via `into_poisoned()` (orphaned on drop) //! - `PermitPoisoned` → `PermitIdle`: NOT POSSIBLE (no method exists) //! //! Slot poisoning is a pool-level property (`PermitPool::poison()`). A poisoned slot //! is permanently removed from the pool regardless of whether a prediction was active. //! `PermitIdle::drop` checks the pool-level poison flag and skips returning the permit. mod pool; mod slot; pub use pool::{ AnyPermit, InactiveSlotIdleToken, PermitError, PermitIdle, PermitInUse, PermitPoisoned, PermitPool, SlotIdleToken, }; pub use slot::{PredictionSlot, UnregisteredPredictionSlot}; ================================================ FILE: crates/coglet/src/permit/pool.rs ================================================ //! Permit pool implementation with typestate for compile-time state transition safety. //! //! Slot poisoning is a pool-level property: a poisoned slot is permanently removed //! from the pool regardless of whether a prediction was active on it. use std::sync::Arc; use std::sync::Mutex as StdMutex; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use futures::SinkExt; use tokio::net::unix::OwnedWriteHalf; use tokio::sync::{Mutex, mpsc}; use tokio_util::codec::FramedWrite; use crate::bridge::codec::JsonCodec; use crate::bridge::protocol::{SlotId, SlotRequest}; pub(crate) struct PermitInner { pub slot_id: SlotId, pub writer: FramedWrite>, pub idle_flag: Arc, pub poisoned: Arc, } struct PoolConnection { pool_tx: mpsc::Sender, pool_available: Arc, } impl Clone for PoolConnection { fn clone(&self) -> Self { Self { pool_tx: self.pool_tx.clone(), pool_available: Arc::clone(&self.pool_available), } } } /// A permit actively running a prediction. pub struct PermitInUse { slot_id: SlotId, writer: Option>>, idle_flag: Arc, poisoned: Arc, pool: PoolConnection, } impl PermitInUse { pub(crate) fn new( inner: PermitInner, pool_tx: mpsc::Sender, pool_available: Arc, ) -> Self { inner.idle_flag.store(false, Ordering::Release); Self { slot_id: inner.slot_id, writer: Some(inner.writer), idle_flag: inner.idle_flag, poisoned: inner.poisoned, pool: PoolConnection { pool_tx, pool_available, }, } } pub fn slot_id(&self) -> SlotId { self.slot_id } /// Transition to idle state - permit will return to pool on drop /// (unless the slot has been poisoned at the pool level). pub fn into_idle(mut self) -> PermitIdle { self.idle_flag.store(true, Ordering::Release); PermitIdle { slot_id: self.slot_id, writer: self.writer.take(), idle_flag: Arc::clone(&self.idle_flag), poisoned: Arc::clone(&self.poisoned), pool: self.pool.clone(), } } /// Transition to poisoned state - permit will NOT return to pool. /// /// Also sets the pool-level poison flag so the slot is never reused. pub fn into_poisoned(mut self) -> PermitPoisoned { self.poisoned.store(true, Ordering::Release); PermitPoisoned { slot_id: self.slot_id, _writer: self.writer.take(), } } pub async fn send(&mut self, request: SlotRequest) -> Result<(), PermitError> { let writer = self.writer.as_mut().ok_or(PermitError::Consumed)?; writer .send(request) .await .map_err(|e| PermitError::Send(e.to_string())) } } impl Drop for PermitInUse { fn drop(&mut self) { if self.writer.is_some() && !self.poisoned.load(Ordering::Acquire) { tracing::error!(slot = %self.slot_id, "PermitInUse dropped without state transition"); } } } /// A permit that completed successfully - returns to pool on drop /// (unless the slot has been poisoned at the pool level). pub struct PermitIdle { slot_id: SlotId, writer: Option>>, idle_flag: Arc, poisoned: Arc, pool: PoolConnection, } impl PermitIdle { pub fn slot_id(&self) -> SlotId { self.slot_id } } impl Drop for PermitIdle { fn drop(&mut self) { // If the slot was poisoned at the pool level, don't return it. if self.poisoned.load(Ordering::Acquire) { tracing::warn!(slot = %self.slot_id, "Slot poisoned - not returning to pool"); return; } if let Some(writer) = self.writer.take() { let inner = PermitInner { slot_id: self.slot_id, writer, idle_flag: Arc::clone(&self.idle_flag), poisoned: Arc::clone(&self.poisoned), }; if self.pool.pool_tx.try_send(inner).is_ok() { self.pool.pool_available.fetch_add(1, Ordering::Release); } } } } /// A poisoned permit - slot permanently failed, will NOT return to pool. pub struct PermitPoisoned { slot_id: SlotId, _writer: Option>>, } impl PermitPoisoned { pub fn slot_id(&self) -> SlotId { self.slot_id } } impl Drop for PermitPoisoned { fn drop(&mut self) { tracing::warn!(slot = %self.slot_id, "Slot poisoned - capacity reduced"); } } /// A permit in any state (for containers needing dynamic state). pub enum AnyPermit { InUse(PermitInUse), Idle(PermitIdle), Poisoned(PermitPoisoned), } impl AnyPermit { pub fn slot_id(&self) -> SlotId { match self { AnyPermit::InUse(p) => p.slot_id(), AnyPermit::Idle(p) => p.slot_id(), AnyPermit::Poisoned(p) => p.slot_id(), } } pub fn is_idle(&self) -> bool { matches!(self, AnyPermit::Idle(_)) } pub fn is_poisoned(&self) -> bool { matches!(self, AnyPermit::Poisoned(_)) } pub fn is_in_use(&self) -> bool { matches!(self, AnyPermit::InUse(_)) } } #[must_use = "must be activated to enable slot idle transition"] #[derive(Debug)] pub struct InactiveSlotIdleToken { slot_id: SlotId, } impl InactiveSlotIdleToken { pub fn new(slot_id: SlotId) -> Self { Self { slot_id } } pub fn slot_id(&self) -> SlotId { self.slot_id } pub fn activate(self) -> SlotIdleToken { SlotIdleToken { slot_id: self.slot_id, create_time: std::time::Instant::now(), alarm_handle: tokio::spawn(async move { // This task exists solely to alert if the token isn't consumed within a reasonable time. // If we see this alert, it means the slot won't return to the pool until the process exits. tokio::time::sleep(SlotIdleToken::ALERT_THRESHOLD).await; tracing::error!(slot = %self.slot_id, "IdleToken not consumed after 5s - slot will not return to pool"); }), } } } /// Token confirming the worker has marked the slot as idle, allowing the permit to return to the pool on drop. #[must_use = "IdleToken confirms the worker has marked the slot as idle"] #[derive(Debug)] pub struct SlotIdleToken { pub(crate) slot_id: SlotId, pub(crate) create_time: std::time::Instant, pub(crate) alarm_handle: tokio::task::JoinHandle<()>, } impl SlotIdleToken { const ALERT_THRESHOLD: std::time::Duration = std::time::Duration::from_secs(5); pub fn slot_id(&self) -> SlotId { self.slot_id } pub fn consume(self) { let elapsed = self.create_time.elapsed(); if elapsed > Self::ALERT_THRESHOLD { tracing::warn!(slot = %self.slot_id, latency = ?elapsed, "Delayed IdleToken Consumption"); } tracing::debug!(slot = %self.slot_id, "IdleToken consumed"); } } impl Drop for SlotIdleToken { fn drop(&mut self) { self.alarm_handle.abort(); } } #[derive(Debug, Clone, thiserror::Error)] pub enum PermitError { #[error("Permit already consumed")] Consumed, #[error("Failed to send on slot socket: {0}")] Send(String), } /// Pool of prediction slot permits. /// /// Slot poisoning is tracked here. A poisoned slot is permanently removed /// from the pool — its permit will not be returned or acquired again. pub struct PermitPool { available_rx: Mutex>, available_tx: mpsc::Sender, num_slots: usize, available_count: Arc, /// Per-slot poison flags, shared with permits for fast checking. poison_flags: StdMutex)>>, } impl PermitPool { pub fn new(num_slots: usize) -> Self { let (tx, rx) = mpsc::channel(num_slots); Self { available_rx: Mutex::new(rx), available_tx: tx, num_slots, available_count: Arc::new(AtomicUsize::new(0)), poison_flags: StdMutex::new(Vec::with_capacity(num_slots)), } } pub fn add_permit( &self, slot_id: SlotId, writer: FramedWrite>, ) { let poisoned = Arc::new(AtomicBool::new(false)); // Store the flag for external poisoning. if let Ok(mut flags) = self.poison_flags.lock() { flags.push((slot_id, Arc::clone(&poisoned))); } let inner = PermitInner { slot_id, writer, idle_flag: Arc::new(AtomicBool::new(true)), poisoned, }; if let Err(e) = self.available_tx.try_send(inner) { tracing::error!(slot = %slot_id, error = %e, "Failed to add permit to pool"); } else { self.available_count.fetch_add(1, Ordering::Release); } } /// Poison a slot. The permit will not be returned to the pool. /// /// This works whether the slot is idle (in the pool) or in use (held by a prediction). /// - Idle: the permit will be discarded on next `acquire`/`try_acquire`. /// - In use: `PermitIdle::drop` will see the flag and not return it. pub fn poison(&self, slot_id: SlotId) { if let Ok(flags) = self.poison_flags.lock() { for (id, flag) in flags.iter() { if *id == slot_id { if !flag.swap(true, Ordering::AcqRel) { tracing::warn!(slot = %slot_id, "Slot poisoned - capacity permanently reduced"); } return; } } } tracing::warn!(slot = %slot_id, "Attempted to poison unknown slot"); } /// Check if a slot is poisoned. pub fn is_poisoned(&self, slot_id: SlotId) -> bool { if let Ok(flags) = self.poison_flags.lock() { for (id, flag) in flags.iter() { if *id == slot_id { return flag.load(Ordering::Acquire); } } } false } pub fn try_acquire(&self) -> Option { let mut rx = self.available_rx.try_lock().ok()?; loop { let inner = rx.try_recv().ok()?; self.available_count.fetch_sub(1, Ordering::Release); // Skip poisoned permits — they're permanently dead. if inner.poisoned.load(Ordering::Acquire) { tracing::debug!(slot = %inner.slot_id, "Discarding poisoned permit from pool"); continue; } return Some(PermitInUse::new( inner, self.available_tx.clone(), Arc::clone(&self.available_count), )); } } pub async fn acquire(&self) -> Option { let mut rx = self.available_rx.lock().await; loop { let inner = rx.recv().await?; self.available_count.fetch_sub(1, Ordering::Release); // Skip poisoned permits — they're permanently dead. if inner.poisoned.load(Ordering::Acquire) { tracing::debug!(slot = %inner.slot_id, "Discarding poisoned permit from pool"); continue; } return Some(PermitInUse::new( inner, self.available_tx.clone(), Arc::clone(&self.available_count), )); } } pub fn num_slots(&self) -> usize { self.num_slots } pub fn available(&self) -> usize { self.available_count.load(Ordering::Acquire) } } #[cfg(test)] mod tests { use super::*; use tokio::net::UnixStream; async fn make_socket_pair() -> (OwnedWriteHalf, tokio::net::unix::OwnedReadHalf) { let (a, b) = UnixStream::pair().unwrap(); let (read, write) = a.into_split(); let _ = b; (write, read) } #[tokio::test] async fn pool_add_and_acquire() { let pool = PermitPool::new(2); let (write1, _read1) = make_socket_pair().await; let (write2, _read2) = make_socket_pair().await; let slot1 = SlotId::new(); let slot2 = SlotId::new(); pool.add_permit(slot1, FramedWrite::new(write1, JsonCodec::new())); pool.add_permit(slot2, FramedWrite::new(write2, JsonCodec::new())); let p1 = pool.try_acquire(); assert!(p1.is_some()); let p2 = pool.try_acquire(); assert!(p2.is_some()); let p3 = pool.try_acquire(); assert!(p3.is_none()); } #[tokio::test] async fn permit_returns_to_pool_when_idle() { let pool = PermitPool::new(1); let (write, _read) = make_socket_pair().await; let slot = SlotId::new(); pool.add_permit(slot, FramedWrite::new(write, JsonCodec::new())); { let permit = pool.try_acquire().unwrap(); let _idle_permit = permit.into_idle(); } let permit = pool.try_acquire(); assert!(permit.is_some()); } #[tokio::test] async fn permit_orphaned_when_poisoned() { let pool = PermitPool::new(1); let (write, _read) = make_socket_pair().await; let slot = SlotId::new(); pool.add_permit(slot, FramedWrite::new(write, JsonCodec::new())); { let permit = pool.try_acquire().unwrap(); let _poisoned_permit = permit.into_poisoned(); } let permit = pool.try_acquire(); assert!(permit.is_none()); } #[tokio::test] async fn pool_poison_idle_slot() { // Poison a slot while it's idle in the pool — acquire should skip it. let pool = PermitPool::new(2); let (write1, _read1) = make_socket_pair().await; let (write2, _read2) = make_socket_pair().await; let slot1 = SlotId::new(); let slot2 = SlotId::new(); pool.add_permit(slot1, FramedWrite::new(write1, JsonCodec::new())); pool.add_permit(slot2, FramedWrite::new(write2, JsonCodec::new())); assert!(!pool.is_poisoned(slot1)); pool.poison(slot1); assert!(pool.is_poisoned(slot1)); assert!(!pool.is_poisoned(slot2)); // First acquire should skip poisoned slot1, return slot2. let permit = pool.try_acquire().unwrap(); assert_eq!(permit.slot_id(), slot2); // No more permits available. assert!(pool.try_acquire().is_none()); } #[tokio::test] async fn pool_poison_in_use_slot_prevents_return() { // Poison a slot while a prediction holds it — into_idle + drop should NOT return it. let pool = PermitPool::new(1); let (write, _read) = make_socket_pair().await; let slot = SlotId::new(); pool.add_permit(slot, FramedWrite::new(write, JsonCodec::new())); { let permit = pool.try_acquire().unwrap(); // Poison while in use. pool.poison(slot); // Transition to idle — drop should see the poison flag. let _idle = permit.into_idle(); } // Permit should NOT have returned to the pool. assert!(pool.try_acquire().is_none()); } #[tokio::test] async fn pool_poison_is_idempotent() { let pool = PermitPool::new(1); let (write, _read) = make_socket_pair().await; let slot = SlotId::new(); pool.add_permit(slot, FramedWrite::new(write, JsonCodec::new())); pool.poison(slot); pool.poison(slot); // Should not panic or double-count. assert!(pool.is_poisoned(slot)); } } ================================================ FILE: crates/coglet/src/permit/slot.rs ================================================ //! PredictionSlot - holds Prediction and Permit side-by-side. //! //! This separation allows the prediction to be behind Mutex for concurrent //! updates while the permit's idle_flag can be set without holding the lock. //! //! Slot poisoning is NOT managed here — it's a pool-level property. //! The slot always transitions to idle when done; `PermitIdle::drop` checks //! the pool-level poison flag to decide whether to return the permit. use std::sync::{Arc, Mutex}; use super::{AnyPermit, PermitInUse, SlotIdleToken}; use crate::bridge::protocol::SlotId; use crate::prediction::Prediction; #[derive(Debug, Clone, thiserror::Error)] pub enum SlotError { #[error("receive error while waiting for idle token")] IdleTokenReceiveError(#[from] tokio::sync::oneshot::error::RecvError), #[error("permit already consumed")] PermitAlreadyConsumed, } /// Pre-registration slot state - holds prediction while permit is being /// acquired and slot is being registered with orchestrator. pub struct UnregisteredPredictionSlot { prediction_slot: PredictionSlot, idle_tx: tokio::sync::oneshot::Sender, } impl UnregisteredPredictionSlot { pub fn new( prediction_slot: PredictionSlot, idle_tx: tokio::sync::oneshot::Sender, ) -> Self { Self { prediction_slot, idle_tx, } } /// Consumes the unregistered slot and returns its components for registration. pub fn into_parts(self) -> (tokio::sync::oneshot::Sender, PredictionSlot) { (self.idle_tx, self.prediction_slot) } pub fn prediction(&self) -> Arc> { self.prediction_slot.prediction() } } /// Holds a prediction and its permit side-by-side. /// /// On drop: Permit returns to pool (if idle and not poisoned at pool level). pub struct PredictionSlot { prediction: Arc>, slot_id: SlotId, permit: Option, idle_rx: Option>, } impl PredictionSlot { pub fn new( prediction: Prediction, permit: PermitInUse, idle_rx: tokio::sync::oneshot::Receiver, ) -> Self { let slot_id = permit.slot_id(); Self { prediction: Arc::new(Mutex::new(prediction)), slot_id, permit: Some(AnyPermit::InUse(permit)), idle_rx: Some(idle_rx), } } pub fn prediction(&self) -> Arc> { Arc::clone(&self.prediction) } pub fn permit_mut(&mut self) -> Option<&mut PermitInUse> { match &mut self.permit { Some(AnyPermit::InUse(p)) => Some(p), _ => None, } } pub fn slot_id(&self) -> SlotId { self.slot_id } /// Mark the slot as idle - permit will return to pool on drop (unless the slot has /// been poisoned at the pool level). Awaits until the idle token is received, which /// ensures the slot has been confirmed idle by the worker. If the idle token is not /// received, the permit is not returned to the pool, #[must_use = "into_idle confirms the slot is idle and allows the permit to return to the pool on drop"] pub async fn into_idle(mut self) -> Result<(), SlotError> { if let Some(receiver) = self.idle_rx.take() { let idle_token = receiver.await?; debug_assert_eq!( idle_token.slot_id(), self.slot_id, "IdleToken slot_id mismatch" ); idle_token.consume(); } let permit = self.permit.take(); debug_assert!( permit.is_some(), "Attempted to mark slot as idle but permit was already consumed" ); match permit { Some(AnyPermit::InUse(p)) => { let idle = p.into_idle(); self.permit = Some(AnyPermit::Idle(idle)); Ok(()) } Some(AnyPermit::Idle(p)) => { self.permit = Some(AnyPermit::Idle(p)); Ok(()) } Some(AnyPermit::Poisoned(p)) => { // Permit was explicitly poisoned (legacy path) — keep it. debug_assert!(false, "Cannot mark poisoned slot as idle"); tracing::error!(slot = %p.slot_id(), "Bug: attempted to mark poisoned slot as idle"); self.permit = Some(AnyPermit::Poisoned(p)); Ok(()) } None => { // Permit was already consumed (bug) — log and do nothing. tracing::error!(slot = %self.slot_id(), "Bug: attempted to mark slot as idle but permit was already consumed"); Err(SlotError::PermitAlreadyConsumed) } } } pub fn is_idle(&self) -> bool { self.permit.as_ref().is_some_and(|p| p.is_idle()) } pub fn id(&self) -> String { self.prediction .try_lock() .map(|p| p.id().to_string()) .unwrap_or_default() } } impl Drop for PredictionSlot { fn drop(&mut self) { if let Some(AnyPermit::InUse(_)) = &self.permit && let Ok(mut prediction) = self.prediction.try_lock() && !prediction.is_terminal() { tracing::error!( slot = %self.slot_id(), prediction_id = %prediction.id(), "Slot dropped while InUse with non-terminal prediction" ); prediction.set_failed("Slot dropped unexpectedly".to_string()); } } } #[cfg(test)] mod tests { use super::*; use crate::bridge::codec::JsonCodec; use crate::permit::{InactiveSlotIdleToken, PermitPool}; use tokio::net::UnixStream; use tokio_util::codec::FramedWrite; #[tokio::test] async fn slot_creation() { let pool = PermitPool::new(1); let (a, _b) = UnixStream::pair().unwrap(); let (_, write) = a.into_split(); let slot_id = SlotId::new(); pool.add_permit(slot_id, FramedWrite::new(write, JsonCodec::new())); let permit = pool.try_acquire().unwrap(); let prediction = Prediction::new("test_123".to_string(), None); let (_idle_tx, idle_rx) = tokio::sync::oneshot::channel(); let slot = PredictionSlot::new(prediction, permit, idle_rx); assert_eq!(slot.slot_id(), slot_id); } #[tokio::test] async fn slot_mark_idle_returns_permit() { let pool = PermitPool::new(1); let (a, _b) = UnixStream::pair().unwrap(); let (_, write) = a.into_split(); let slot_id = SlotId::new(); pool.add_permit(slot_id, FramedWrite::new(write, JsonCodec::new())); { let permit = pool.try_acquire().unwrap(); let prediction = Prediction::new("test_123".to_string(), None); let (idle_tx, idle_rx) = tokio::sync::oneshot::channel(); let slot = PredictionSlot::new(prediction, permit, idle_rx); idle_tx .send(InactiveSlotIdleToken::new(slot_id).activate()) .unwrap(); slot.into_idle().await.unwrap(); } assert!(pool.try_acquire().is_some()); } #[tokio::test] async fn slot_not_idle_orphans_permit() { let pool = PermitPool::new(1); let (a, _b) = UnixStream::pair().unwrap(); let (_, write) = a.into_split(); let slot_id = SlotId::new(); pool.add_permit(slot_id, FramedWrite::new(write, JsonCodec::new())); { let permit = pool.try_acquire().unwrap(); let prediction = Prediction::new("test_123".to_string(), None); let (_idle_tx, idle_rx) = tokio::sync::oneshot::channel(); let _slot = PredictionSlot::new(prediction, permit, idle_rx); } assert!(pool.try_acquire().is_none()); } #[tokio::test] async fn slot_idle_channel_closed_does_not_return_permit() { let pool = PermitPool::new(1); let (a, _b) = UnixStream::pair().unwrap(); let (_, write) = a.into_split(); let slot_id = SlotId::new(); pool.add_permit(slot_id, FramedWrite::new(write, JsonCodec::new())); let permit = pool.try_acquire().unwrap(); let prediction = Prediction::new("test_123".to_string(), None); let (idle_tx, idle_rx) = tokio::sync::oneshot::channel::(); let slot = PredictionSlot::new(prediction, permit, idle_rx); drop(idle_tx); let result = slot.into_idle().await; assert!(matches!(result, Err(SlotError::IdleTokenReceiveError(_)))); assert!(pool.try_acquire().is_none()); } } ================================================ FILE: crates/coglet/src/prediction.rs ================================================ //! Prediction state tracking. use std::collections::HashMap; use std::sync::Arc; use std::time::Instant; use tokio::sync::Notify; pub use tokio_util::sync::CancellationToken; use crate::bridge::protocol::MetricMode; use crate::webhook::{WebhookEventType, WebhookSender}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum PredictionStatus { Starting, Processing, Succeeded, Failed, Canceled, } impl PredictionStatus { pub fn is_terminal(&self) -> bool { matches!(self, Self::Succeeded | Self::Failed | Self::Canceled) } pub fn as_str(&self) -> &'static str { match self { Self::Starting => "starting", Self::Processing => "processing", Self::Succeeded => "succeeded", Self::Failed => "failed", Self::Canceled => "canceled", } } } /// Prediction output - single value or streamed chunks. #[derive(Debug, Clone, serde::Serialize)] #[serde(untagged)] pub enum PredictionOutput { Single(serde_json::Value), Stream(Vec), } impl PredictionOutput { pub fn is_stream(&self) -> bool { matches!(self, PredictionOutput::Stream(_)) } pub fn into_values(self) -> Vec { match self { PredictionOutput::Single(v) => vec![v], PredictionOutput::Stream(v) => v, } } /// Get the final/only output value (last for stream, the value for single). pub fn final_value(&self) -> &serde_json::Value { match self { PredictionOutput::Single(v) => v, PredictionOutput::Stream(v) => v.last().unwrap_or(&serde_json::Value::Null), } } } /// Prediction lifecycle state. pub struct Prediction { id: String, cancel_token: CancellationToken, started_at: Instant, status: PredictionStatus, logs: String, outputs: Vec, output: Option, error: Option, webhook: Option, completion: Arc, /// User-emitted metrics. Merged with system metrics (predict_time) in terminal response. metrics: HashMap, } impl Prediction { pub fn new(id: String, webhook: Option) -> Self { Self { id, cancel_token: CancellationToken::new(), started_at: Instant::now(), status: PredictionStatus::Starting, logs: String::new(), outputs: Vec::new(), output: None, error: None, webhook, completion: Arc::new(Notify::new()), metrics: HashMap::new(), } } pub fn id(&self) -> &str { &self.id } pub fn cancel_token(&self) -> CancellationToken { self.cancel_token.clone() } pub fn is_canceled(&self) -> bool { self.cancel_token.is_cancelled() } pub fn status(&self) -> PredictionStatus { self.status } pub fn is_terminal(&self) -> bool { self.status.is_terminal() } pub fn set_processing(&mut self) { self.status = PredictionStatus::Processing; self.fire_webhook(WebhookEventType::Start); } pub fn set_succeeded(&mut self, output: PredictionOutput) { if self.status.is_terminal() { return; } self.status = PredictionStatus::Succeeded; self.output = Some(output); self.fire_terminal_webhook(); // notify_one stores a permit so a future .notified().await will // consume it immediately. notify_waiters only wakes currently- // registered waiters and would race with the service task that // checks is_terminal() then awaits — the notification can fire // in between. There is exactly one waiter per prediction // (service.rs predict()), so notify_one is semantically correct. self.completion.notify_one(); } pub fn set_failed(&mut self, error: String) { if self.status.is_terminal() { return; } self.status = PredictionStatus::Failed; self.error = Some(error); self.fire_terminal_webhook(); self.completion.notify_one(); } pub fn set_canceled(&mut self) { if self.status.is_terminal() { return; } self.status = PredictionStatus::Canceled; self.fire_terminal_webhook(); self.completion.notify_one(); } pub fn elapsed(&self) -> std::time::Duration { self.started_at.elapsed() } pub fn append_log(&mut self, data: &str) { self.logs.push_str(data); self.fire_webhook(WebhookEventType::Logs); } pub fn logs(&self) -> &str { &self.logs } /// Set a user metric with the given accumulation mode. /// /// - `Replace`: overwrites the value (or deletes if null). /// - `Increment`: adds to an existing numeric value. Errors silently if types mismatch. /// - `Append`: pushes onto an existing array, creating one if needed. /// /// Dot-path keys (e.g., "timing.preprocess") are resolved into nested objects. pub fn set_metric(&mut self, name: String, value: serde_json::Value, mode: MetricMode) { // Dot-path resolution: "a.b.c" → nested objects let parts: Vec<&str> = name.split('.').collect(); if parts.len() > 1 { self.set_metric_dotpath(&parts, value, mode); return; } match mode { MetricMode::Replace => { if value.is_null() { self.metrics.remove(&name); } else { self.metrics.insert(name, value); } } MetricMode::Increment => { let entry = self.metrics.entry(name).or_insert(serde_json::json!(0)); if let (Some(existing), Some(delta)) = (entry.as_f64(), value.as_f64()) { // Preserve integer type if both are integers if entry.is_i64() && value.is_i64() { *entry = serde_json::json!(existing as i64 + delta as i64); } else if entry.is_u64() && value.is_u64() { *entry = serde_json::json!(existing as u64 + delta as u64); } else { *entry = serde_json::json!(existing + delta); } } // Non-numeric increment is silently ignored } MetricMode::Append => { let entry = self .metrics .entry(name) .or_insert(serde_json::Value::Array(vec![])); if let Some(arr) = entry.as_array_mut() { arr.push(value); } else { // Existing value is not an array — wrap it and append let existing = entry.take(); *entry = serde_json::json!([existing, value]); } } } } /// Resolve a dot-path key into nested objects and apply the metric. fn set_metric_dotpath(&mut self, parts: &[&str], value: serde_json::Value, mode: MetricMode) { debug_assert!(parts.len() > 1); let root_key = parts[0].to_string(); // Navigate/create nested structure let entry = self .metrics .entry(root_key) .or_insert_with(|| serde_json::json!({})); let mut current = entry; for &part in &parts[1..parts.len() - 1] { // Ensure intermediate nodes are objects if !current.is_object() { *current = serde_json::json!({}); } current = current .as_object_mut() .unwrap() .entry(part) .or_insert_with(|| serde_json::json!({})); } let leaf_key = parts[parts.len() - 1]; // Ensure the parent is an object if !current.is_object() { *current = serde_json::json!({}); } let obj = current.as_object_mut().unwrap(); match mode { MetricMode::Replace => { if value.is_null() { obj.remove(leaf_key); } else { obj.insert(leaf_key.to_string(), value); } } MetricMode::Increment => { let entry = obj.entry(leaf_key).or_insert(serde_json::json!(0)); if let (Some(existing), Some(delta)) = (entry.as_f64(), value.as_f64()) { if entry.is_i64() && value.is_i64() { *entry = serde_json::json!(existing as i64 + delta as i64); } else if entry.is_u64() && value.is_u64() { *entry = serde_json::json!(existing as u64 + delta as u64); } else { *entry = serde_json::json!(existing + delta); } } } MetricMode::Append => { let entry = obj .entry(leaf_key) .or_insert(serde_json::Value::Array(vec![])); if let Some(arr) = entry.as_array_mut() { arr.push(value); } else { let existing = entry.take(); *entry = serde_json::json!([existing, value]); } } } } pub fn metrics(&self) -> &HashMap { &self.metrics } pub fn append_output(&mut self, output: serde_json::Value) { self.outputs.push(output); self.fire_webhook(WebhookEventType::Output); } pub fn outputs(&self) -> &[serde_json::Value] { &self.outputs } pub fn take_outputs(&mut self) -> Vec { std::mem::take(&mut self.outputs) } pub fn output(&self) -> Option<&PredictionOutput> { self.output.as_ref() } pub fn error(&self) -> Option<&str> { self.error.as_deref() } pub async fn wait(&self) { if self.status.is_terminal() { return; } self.completion.notified().await; } pub fn completion(&self) -> Arc { Arc::clone(&self.completion) } /// Take the webhook sender (for sending on drop). pub fn take_webhook(&mut self) -> Option { self.webhook.take() } /// Fire a non-terminal webhook (throttled, fire-and-forget). /// /// Builds the current state as a JSON payload and sends it via the /// stored WebhookSender. Spawns a tokio task — does not block. fn fire_webhook(&self, event: WebhookEventType) { if let Some(ref webhook) = self.webhook { let payload = self.build_webhook_payload(); webhook.send(event, &payload); } } /// Fire the terminal webhook and consume the WebhookSender. /// /// Takes ownership of the webhook sender so it can only fire once. /// Spawns a tokio task with retry logic for reliability. fn fire_terminal_webhook(&mut self) { if let Some(webhook) = self.webhook.take() { let payload = self.build_webhook_payload(); tokio::spawn(async move { webhook .send_terminal(WebhookEventType::Completed, &payload) .await; }); } } /// Build a JSON snapshot of the current prediction state. /// /// This is the single source of truth for prediction JSON. Used by /// webhook payloads, GET responses, and terminal responses. Callers /// can merge additional fields (e.g. `input`) into the result. pub fn build_state_snapshot(&self) -> serde_json::Value { let mut payload = serde_json::json!({ "id": self.id, "status": self.status.as_str(), "logs": self.logs, }); // Include output: use final output if set (terminal), otherwise // include accumulated streaming outputs for intermediate states. if let Some(ref output) = self.output { payload["output"] = serde_json::json!(output); } else if !self.outputs.is_empty() { payload["output"] = serde_json::json!(self.outputs); } if let Some(ref error) = self.error { payload["error"] = serde_json::Value::String(error.clone()); } // Include metrics: always include user metrics, add predict_time on terminal if !self.metrics.is_empty() || self.status.is_terminal() { let mut metrics_obj = serde_json::Map::new(); for (k, v) in &self.metrics { metrics_obj.insert(k.clone(), v.clone()); } if self.status.is_terminal() { let predict_time = self.elapsed().as_secs_f64(); metrics_obj.insert("predict_time".to_string(), serde_json::json!(predict_time)); } payload["metrics"] = serde_json::Value::Object(metrics_obj); } payload } /// Build webhook payload (delegates to build_state_snapshot). fn build_webhook_payload(&self) -> serde_json::Value { self.build_state_snapshot() } pub fn build_terminal_response(&self) -> serde_json::Value { self.build_state_snapshot() } } #[cfg(test)] mod tests { use super::*; #[test] fn status_is_terminal() { assert!(!PredictionStatus::Starting.is_terminal()); assert!(!PredictionStatus::Processing.is_terminal()); assert!(PredictionStatus::Succeeded.is_terminal()); assert!(PredictionStatus::Failed.is_terminal()); assert!(PredictionStatus::Canceled.is_terminal()); } #[test] fn new_starts_in_starting_status() { let pred = Prediction::new("test".to_string(), None); assert_eq!(pred.status(), PredictionStatus::Starting); assert_eq!(pred.id(), "test"); } #[test] fn set_succeeded() { let mut pred = Prediction::new("test".to_string(), None); pred.set_succeeded(PredictionOutput::Single(serde_json::json!("hello"))); assert_eq!(pred.status(), PredictionStatus::Succeeded); } #[test] fn set_failed() { let mut pred = Prediction::new("test".to_string(), None); pred.set_failed("something went wrong".to_string()); assert_eq!(pred.status(), PredictionStatus::Failed); } #[test] fn set_canceled() { let mut pred = Prediction::new("test".to_string(), None); pred.set_canceled(); assert_eq!(pred.status(), PredictionStatus::Canceled); } #[test] fn cancel_token_works() { let pred = Prediction::new("test".to_string(), None); let token = pred.cancel_token(); assert!(!pred.is_canceled()); token.cancel(); assert!(pred.is_canceled()); } #[test] fn elapsed_time_increases() { let pred = Prediction::new("test".to_string(), None); let t1 = pred.elapsed(); std::thread::sleep(std::time::Duration::from_millis(10)); let t2 = pred.elapsed(); assert!(t2 > t1); } #[test] fn append_log() { let mut pred = Prediction::new("test".to_string(), None); pred.append_log("line 1\n"); pred.append_log("line 2\n"); assert_eq!(pred.logs(), "line 1\nline 2\n"); } #[test] fn append_output() { let mut pred = Prediction::new("test".to_string(), None); pred.append_output(serde_json::json!("chunk1")); pred.append_output(serde_json::json!("chunk2")); assert_eq!(pred.outputs().len(), 2); } #[tokio::test] async fn wait_returns_immediately_if_terminal() { let mut pred = Prediction::new("test".to_string(), None); pred.set_succeeded(PredictionOutput::Single(serde_json::json!("done"))); pred.wait().await; assert_eq!(pred.status(), PredictionStatus::Succeeded); } #[test] fn prediction_output_single() { let output = PredictionOutput::Single(serde_json::json!("hello")); assert!(!output.is_stream()); assert_eq!(output.into_values(), vec![serde_json::json!("hello")]); } #[test] fn prediction_output_stream() { let output = PredictionOutput::Stream(vec![serde_json::json!("a"), serde_json::json!("b")]); assert!(output.is_stream()); } // ==================================================================== // Metric tests // ==================================================================== #[test] fn metric_replace_sets_value() { let mut pred = Prediction::new("test".to_string(), None); pred.set_metric("temp".into(), serde_json::json!(0.7), MetricMode::Replace); assert_eq!(pred.metrics()["temp"], serde_json::json!(0.7)); } #[test] fn metric_replace_overwrites() { let mut pred = Prediction::new("test".to_string(), None); pred.set_metric("temp".into(), serde_json::json!(0.7), MetricMode::Replace); pred.set_metric("temp".into(), serde_json::json!(0.9), MetricMode::Replace); assert_eq!(pred.metrics()["temp"], serde_json::json!(0.9)); } #[test] fn metric_replace_null_deletes() { let mut pred = Prediction::new("test".to_string(), None); pred.set_metric("temp".into(), serde_json::json!(0.7), MetricMode::Replace); pred.set_metric("temp".into(), serde_json::Value::Null, MetricMode::Replace); assert!(!pred.metrics().contains_key("temp")); } #[test] fn metric_increment_integers() { let mut pred = Prediction::new("test".to_string(), None); pred.set_metric("count".into(), serde_json::json!(1), MetricMode::Increment); pred.set_metric("count".into(), serde_json::json!(3), MetricMode::Increment); assert_eq!(pred.metrics()["count"], serde_json::json!(4)); } #[test] fn metric_increment_floats() { let mut pred = Prediction::new("test".to_string(), None); pred.set_metric( "score".into(), serde_json::json!(1.5), MetricMode::Increment, ); pred.set_metric( "score".into(), serde_json::json!(2.5), MetricMode::Increment, ); assert_eq!(pred.metrics()["score"], serde_json::json!(4.0)); } #[test] fn metric_increment_creates_from_zero() { let mut pred = Prediction::new("test".to_string(), None); pred.set_metric("count".into(), serde_json::json!(5), MetricMode::Increment); assert_eq!(pred.metrics()["count"], serde_json::json!(5)); } #[test] fn metric_append_creates_array() { let mut pred = Prediction::new("test".to_string(), None); pred.set_metric( "logprobs".into(), serde_json::json!(-1.2), MetricMode::Append, ); pred.set_metric( "logprobs".into(), serde_json::json!(-0.3), MetricMode::Append, ); assert_eq!(pred.metrics()["logprobs"], serde_json::json!([-1.2, -0.3])); } #[test] fn metric_append_to_non_array_wraps() { let mut pred = Prediction::new("test".to_string(), None); pred.set_metric("val".into(), serde_json::json!(1), MetricMode::Replace); pred.set_metric("val".into(), serde_json::json!(2), MetricMode::Append); assert_eq!(pred.metrics()["val"], serde_json::json!([1, 2])); } #[test] fn metric_dotpath_creates_nested() { let mut pred = Prediction::new("test".to_string(), None); pred.set_metric( "timing.preprocess".into(), serde_json::json!(0.1), MetricMode::Replace, ); assert_eq!( pred.metrics()["timing"], serde_json::json!({"preprocess": 0.1}) ); } #[test] fn metric_dotpath_deep() { let mut pred = Prediction::new("test".to_string(), None); pred.set_metric("a.b.c".into(), serde_json::json!(42), MetricMode::Replace); assert_eq!(pred.metrics()["a"], serde_json::json!({"b": {"c": 42}})); } #[test] fn metric_dotpath_multiple_leaves() { let mut pred = Prediction::new("test".to_string(), None); pred.set_metric( "timing.preprocess".into(), serde_json::json!(0.1), MetricMode::Replace, ); pred.set_metric( "timing.inference".into(), serde_json::json!(0.8), MetricMode::Replace, ); assert_eq!( pred.metrics()["timing"], serde_json::json!({"preprocess": 0.1, "inference": 0.8}) ); } #[test] fn metric_dotpath_delete_leaf() { let mut pred = Prediction::new("test".to_string(), None); pred.set_metric( "timing.preprocess".into(), serde_json::json!(0.1), MetricMode::Replace, ); pred.set_metric( "timing.preprocess".into(), serde_json::Value::Null, MetricMode::Replace, ); // Parent object should still exist but be empty assert_eq!(pred.metrics()["timing"], serde_json::json!({})); } #[test] fn metric_dotpath_increment() { let mut pred = Prediction::new("test".to_string(), None); pred.set_metric( "stats.tokens".into(), serde_json::json!(10), MetricMode::Increment, ); pred.set_metric( "stats.tokens".into(), serde_json::json!(5), MetricMode::Increment, ); assert_eq!(pred.metrics()["stats"], serde_json::json!({"tokens": 15})); } #[test] fn metric_complex_values() { let mut pred = Prediction::new("test".to_string(), None); pred.set_metric( "config".into(), serde_json::json!({"layers": 12, "heads": 8}), MetricMode::Replace, ); pred.set_metric( "scores".into(), serde_json::json!([0.9, 0.8, 0.7]), MetricMode::Replace, ); assert_eq!( pred.metrics()["config"], serde_json::json!({"layers": 12, "heads": 8}) ); assert_eq!(pred.metrics()["scores"], serde_json::json!([0.9, 0.8, 0.7])); } #[test] fn terminal_snapshot_merges_metrics_with_predict_time() { let mut pred = Prediction::new("test".to_string(), None); pred.set_metric("temp".into(), serde_json::json!(0.7), MetricMode::Replace); pred.set_metric("count".into(), serde_json::json!(42), MetricMode::Replace); pred.set_succeeded(PredictionOutput::Single(serde_json::json!("ok"))); let snapshot = pred.build_state_snapshot(); let metrics = snapshot["metrics"].as_object().unwrap(); assert_eq!(metrics["temp"], serde_json::json!(0.7)); assert_eq!(metrics["count"], serde_json::json!(42)); assert!(metrics.contains_key("predict_time")); } #[test] fn terminal_snapshot_predict_time_overrides_user() { let mut pred = Prediction::new("test".to_string(), None); // User tries to set predict_time - system should override pred.set_metric( "predict_time".into(), serde_json::json!(999.0), MetricMode::Replace, ); pred.set_succeeded(PredictionOutput::Single(serde_json::json!("ok"))); let snapshot = pred.build_state_snapshot(); let metrics = snapshot["metrics"].as_object().unwrap(); // predict_time should be the actual elapsed, not 999.0 assert_ne!(metrics["predict_time"], serde_json::json!(999.0)); } #[test] fn terminal_state_guard_set_failed_after_succeeded() { let mut pred = Prediction::new("test".to_string(), None); pred.set_succeeded(PredictionOutput::Single(serde_json::json!("ok"))); pred.set_failed("Slot dropped unexpectedly".to_string()); // Must stay succeeded, not overwritten to failed assert_eq!(pred.status(), PredictionStatus::Succeeded); assert!(pred.error().is_none()); } #[test] fn terminal_state_guard_set_succeeded_after_failed() { let mut pred = Prediction::new("test".to_string(), None); pred.set_failed("error".to_string()); pred.set_succeeded(PredictionOutput::Single(serde_json::json!("late"))); assert_eq!(pred.status(), PredictionStatus::Failed); assert_eq!(pred.error(), Some("error")); } #[test] fn terminal_state_guard_set_canceled_after_succeeded() { let mut pred = Prediction::new("test".to_string(), None); pred.set_succeeded(PredictionOutput::Single(serde_json::json!("done"))); pred.set_canceled(); assert_eq!(pred.status(), PredictionStatus::Succeeded); } } ================================================ FILE: crates/coglet/src/predictor.rs ================================================ //! Predictor traits and prediction lifecycle types. use std::collections::HashMap; use std::time::{Duration, Instant}; pub use crate::prediction::{CancellationToken, PredictionOutput}; /// Result of a completed prediction. #[derive(Debug, Clone)] pub struct PredictionResult { pub output: PredictionOutput, pub predict_time: Option, pub logs: String, /// User-emitted metrics from the prediction. pub metrics: HashMap, } /// Metrics collected during prediction. #[derive(Debug, Clone, Default)] pub struct PredictionMetrics { pub predict_time: Option, } /// RAII guard for prediction lifecycle timing. pub struct PredictionGuard { start_time: Instant, metrics: PredictionMetrics, cancel_token: CancellationToken, } impl PredictionGuard { pub fn new() -> Self { Self { start_time: Instant::now(), metrics: PredictionMetrics::default(), cancel_token: CancellationToken::new(), } } pub fn cancel_token(&self) -> CancellationToken { self.cancel_token.clone() } pub fn is_cancelled(&self) -> bool { self.cancel_token.is_cancelled() } pub fn cancel(&self) { self.cancel_token.cancel(); } pub fn finish(mut self) -> PredictionMetrics { self.metrics.predict_time = Some(self.start_time.elapsed()); self.metrics } } impl Default for PredictionGuard { fn default() -> Self { Self::new() } } #[derive(Debug, thiserror::Error)] pub enum PredictionError { #[error("Prediction failed: {0}")] Failed(String), #[error("Input validation error: {0}")] InvalidInput(String), #[error( "Setup has not finished yet. Wait until it has finished, or GET /health-check for status." )] NotReady, #[error("Prediction was cancelled")] Cancelled, } #[cfg(test)] mod tests { use super::*; use serde_json::json; #[test] fn prediction_output_single_is_not_stream() { let output = PredictionOutput::Single(json!("hello")); assert!(!output.is_stream()); } #[test] fn prediction_output_stream_is_stream() { let output = PredictionOutput::Stream(vec![json!("a"), json!("b")]); assert!(output.is_stream()); } #[test] fn prediction_output_serializes_untagged() { let single = PredictionOutput::Single(json!("hello")); insta::assert_json_snapshot!("output_single", single); let stream = PredictionOutput::Stream(vec![json!(1), json!(2)]); insta::assert_json_snapshot!("output_stream", stream); } #[test] fn prediction_guard_tracks_time() { let guard = PredictionGuard::new(); std::thread::sleep(std::time::Duration::from_millis(10)); let metrics = guard.finish(); assert!(metrics.predict_time.is_some()); let time = metrics.predict_time.unwrap(); assert!(time.as_millis() >= 10); assert!(time.as_secs() < 1); } #[test] fn prediction_error_display() { let err = PredictionError::Failed("something broke".to_string()); assert_eq!(format!("{}", err), "Prediction failed: something broke"); let err = PredictionError::InvalidInput("bad json".to_string()); assert_eq!(format!("{}", err), "Input validation error: bad json"); let err = PredictionError::NotReady; assert_eq!( format!("{}", err), "Setup has not finished yet. Wait until it has finished, or GET /health-check for status." ); } } ================================================ FILE: crates/coglet/src/service.rs ================================================ //! PredictionService: Transport-agnostic prediction lifecycle management. //! //! This service is the single owner of prediction state. It manages: //! - Slot management (PermitPool for concurrency control) //! - Prediction lifecycle (DashMap of active predictions) //! - Cancellation (cancel tokens + orchestrator delegation) //! - Health tracking (state, setup result) //! - Shutdown coordination (bidirectional) //! //! Webhooks are fired from Prediction mutation methods (set_processing, //! set_succeeded, etc.) where the real state lives — no dual state tracking. use std::sync::{Arc, Mutex as StdMutex}; use dashmap::DashMap; use tokio::sync::{RwLock, watch}; use crate::bridge::protocol::{MAX_INLINE_IPC_SIZE, SlotRequest}; use crate::health::{Health, SetupResult}; use crate::input_validation::InputValidator; use crate::orchestrator::{HealthcheckResult, Orchestrator}; use crate::permit::{PermitPool, PredictionSlot, UnregisteredPredictionSlot}; use crate::prediction::{CancellationToken, Prediction, PredictionStatus}; use crate::predictor::{PredictionError, PredictionOutput, PredictionResult}; use crate::version::VersionInfo; use crate::webhook::WebhookSender; /// Try to lock a prediction mutex. On poison, fail the prediction and return None. fn try_lock_prediction( pred: &Arc>, ) -> Option> { match pred.lock() { Ok(guard) => Some(guard), Err(poisoned) => { tracing::error!("Prediction mutex poisoned - failing prediction"); let mut guard = poisoned.into_inner(); if !guard.is_terminal() { guard.set_failed("Internal error: mutex poisoned".to_string()); } None } } } #[derive(Debug, thiserror::Error)] pub enum CreatePredictionError { #[error("Service not ready")] NotReady, #[error("At capacity (no slots available)")] AtCapacity, } /// Snapshot of service health for transports to query. #[derive(Debug, Clone)] pub struct HealthSnapshot { pub state: Health, pub available_slots: usize, pub total_slots: usize, pub setup_result: Option, pub version: VersionInfo, } impl HealthSnapshot { pub fn is_ready(&self) -> bool { self.state == Health::Ready } /// BUSY state: ready but all slots in use. pub fn is_busy(&self) -> bool { self.state == Health::Ready && self.available_slots == 0 } } /// Entry in the predictions DashMap. /// /// Holds the real prediction (via Arc), cancel token, and input /// (for API responses — Prediction doesn't store input). struct PredictionEntry { prediction: Arc>, cancel_token: CancellationToken, input: serde_json::Value, } /// Handle to a submitted prediction for cancellation on disconnect. pub struct PredictionHandle { id: String, cancel_token: CancellationToken, } impl PredictionHandle { pub fn id(&self) -> &str { &self.id } pub fn cancel_token(&self) -> CancellationToken { self.cancel_token.clone() } /// Create a guard that cancels on drop (for sync predictions). /// /// On drop (e.g. HTTP connection closed), the guard calls /// `service.cancel(id)` which fires the CancellationToken AND /// delegates to the orchestrator to cancel the worker subprocess. pub fn sync_guard(&self, service: Arc) -> SyncPredictionGuard { SyncPredictionGuard::new(self.id.clone(), service) } } /// Guard for sync predictions - cancels on drop unless disarmed. /// /// When the HTTP connection drops (client disconnect), axum drops the /// response future which drops this guard. The guard calls /// `service.cancel(id)` to trigger both the CancellationToken /// (Rust-side observers) and the orchestrator (worker subprocess cancel). pub struct SyncPredictionGuard { prediction_id: Option, service: Arc, } impl SyncPredictionGuard { pub fn new(prediction_id: String, service: Arc) -> Self { Self { prediction_id: Some(prediction_id), service, } } pub fn disarm(&mut self) { self.prediction_id = None; } } impl Drop for SyncPredictionGuard { fn drop(&mut self) { if let Some(ref id) = self.prediction_id { self.service.cancel(id); } } } /// Orchestrator runtime state - pool and orchestrator together. /// /// Ensures pool and orchestrator are always set atomically. pub struct OrchestratorState { pub pool: Arc, pub orchestrator: Arc, } impl Clone for OrchestratorState { fn clone(&self) -> Self { Self { pool: Arc::clone(&self.pool), orchestrator: Arc::clone(&self.orchestrator), } } } /// Transport-agnostic prediction service. /// /// Created with `new_no_pool()`, then configured with `set_orchestrator()` once /// the worker subprocess is ready. pub struct PredictionService { /// Orchestrator state (pool + handle together). orchestrator: RwLock>, health: RwLock, setup_result: RwLock>, /// Active predictions — single source of truth for prediction state. predictions: DashMap, shutdown_tx: watch::Sender, shutdown_rx: watch::Receiver, version: VersionInfo, schema: RwLock>, input_validator: RwLock>, train_validator: RwLock>, } impl PredictionService { /// Create without configuration (for early HTTP start). /// /// Health check returns STARTING until `set_orchestrator()` is called. pub fn new_no_pool() -> Self { let (shutdown_tx, shutdown_rx) = watch::channel(false); Self { orchestrator: RwLock::new(None), health: RwLock::new(Health::Unknown), setup_result: RwLock::new(None), predictions: DashMap::new(), shutdown_tx, shutdown_rx, version: VersionInfo::new(), schema: RwLock::new(None), input_validator: RwLock::new(None), train_validator: RwLock::new(None), } } /// Configure orchestrator mode atomically. pub async fn set_orchestrator( &self, pool: Arc, orchestrator: Arc, ) { *self.orchestrator.write().await = Some(OrchestratorState { pool, orchestrator }); } pub async fn has_orchestrator(&self) -> bool { self.orchestrator.read().await.is_some() } /// Shutdown the orchestrator gracefully. /// /// Sends a shutdown message to the worker process and waits for it to exit. /// If no orchestrator is configured, this is a no-op. pub async fn shutdown(&self) { if let Some(ref state) = *self.orchestrator.read().await && let Err(e) = state.orchestrator.shutdown().await { tracing::warn!(error = %e, "Error during orchestrator shutdown"); } } /// Set initial health state (for non-Ready states only). /// /// READY requires an orchestrator, so use `set_health()` after `set_orchestrator()`. /// Silently ignores attempts to set READY here. pub fn with_health(mut self, health: Health) -> Self { if health != Health::Ready { self.health = RwLock::new(health); } self } pub fn with_version(mut self, version: VersionInfo) -> Self { self.version = version; self } /// Get the runtime version info. pub fn version(&self) -> &VersionInfo { &self.version } /// Whether the model supports training (has a TrainingInput schema). pub async fn supports_training(&self) -> bool { self.train_validator.read().await.is_some() } /// Get the permit pool from orchestrator. pub async fn pool(&self) -> Option> { if let Some(ref state) = *self.orchestrator.read().await { Some(Arc::clone(&state.pool)) } else { None } } pub async fn health(&self) -> HealthSnapshot { let state = *self.health.read().await; let setup_result = self.setup_result.read().await.clone(); let pool = self.pool().await; let (available_slots, total_slots) = match pool.as_ref() { Some(p) => (p.available(), p.num_slots()), None => (0, 0), }; tracing::trace!( ?state, available_slots, total_slots, setup_status = ?setup_result.as_ref().map(|r| r.status), "Building health snapshot" ); HealthSnapshot { state, available_slots, total_slots, setup_result, version: self.version.clone(), } } /// Set health state. Setting READY requires orchestrator to be configured. /// /// Silently ignores attempts to set READY without orchestrator. pub async fn set_health(&self, health: Health) { if health == Health::Ready && self.orchestrator.read().await.is_none() { tracing::warn!("Attempted to set READY without orchestrator, ignoring"); return; } let previous = *self.health.read().await; tracing::debug!(from = ?previous, to = ?health, "Health state transition"); *self.health.write().await = health; } pub async fn set_setup_result(&self, result: SetupResult) { tracing::debug!( status = ?result.status, started_at = %result.started_at, completed_at = ?result.completed_at, logs_len = result.logs.len(), "Setting setup result" ); *self.setup_result.write().await = Some(result); } pub async fn set_schema(&self, schema: serde_json::Value) { // Compile input validators from the schema components let validator = InputValidator::from_openapi_schema(&schema); if let Some(v) = &validator { tracing::info!( "Input validation enabled ({} required fields)", v.required_count() ); } *self.input_validator.write().await = validator; // Compile a separate validator for training inputs (TrainingInput) let train_val = InputValidator::from_openapi_schema_key(&schema, "TrainingInput"); if let Some(v) = &train_val { tracing::info!( "Training input validation enabled ({} required fields)", v.required_count() ); } *self.train_validator.write().await = train_val; *self.schema.write().await = Some(schema); } pub async fn schema(&self) -> Option { self.schema.read().await.clone() } /// Validate prediction input against the OpenAPI schema. /// /// Returns Ok(()) if no schema is loaded or if validation passes. /// Returns Err with per-field validation errors on failure. pub async fn validate_input( &self, input: &serde_json::Value, ) -> Result<(), Vec> { let guard = self.input_validator.read().await; if let Some(ref validator) = *guard { validator.validate(input) } else { Ok(()) } } /// Validate training input against the TrainingInput schema. /// /// Falls back to the predict validator if no training schema is present. pub async fn validate_train_input( &self, input: &serde_json::Value, ) -> Result<(), Vec> { let guard = self.train_validator.read().await; if let Some(ref validator) = *guard { return validator.validate(input); } drop(guard); // Fallback: no TrainingInput schema — use predict validator (legacy compat) self.validate_input(input).await } /// Run user-defined healthcheck via orchestrator. /// /// Returns healthy if no orchestrator is configured (not ready yet). pub async fn healthcheck( &self, ) -> Result { if let Some(ref state) = *self.orchestrator.read().await { tracing::trace!("Dispatching healthcheck to orchestrator"); let result = state.orchestrator.healthcheck().await; tracing::trace!( healthy = result.as_ref().map(|r| r.is_healthy()).unwrap_or(false), error = ?result.as_ref().ok().and_then(|r| r.error.as_ref()), "Healthcheck result from orchestrator" ); result } else { tracing::debug!("No orchestrator configured, returning default healthy"); Ok(HealthcheckResult::healthy()) } } /// Submit a new prediction: create Prediction, acquire slot, register in DashMap. /// /// Returns a PredictionHandle (for cancel-on-disconnect) and the /// UnregisteredPredictionSlot (for running the prediction). pub async fn submit_prediction( &self, id: String, input: serde_json::Value, webhook: Option, ) -> Result<(PredictionHandle, UnregisteredPredictionSlot), CreatePredictionError> { let health = *self.health.read().await; if health != Health::Ready { return Err(CreatePredictionError::NotReady); } // Pool must exist if health is Ready let pool = self.pool().await; let pool = pool.as_ref().ok_or(CreatePredictionError::NotReady)?; let permit = pool .try_acquire() .ok_or(CreatePredictionError::AtCapacity)?; let prediction = Prediction::new(id.clone(), webhook); let cancel_token = prediction.cancel_token(); let (idle_tx, idle_rx) = tokio::sync::oneshot::channel(); let slot = PredictionSlot::new(prediction, permit, idle_rx); let prediction_arc = slot.prediction(); // Register in DashMap — this is the single source of truth self.predictions.insert( id.clone(), PredictionEntry { prediction: prediction_arc, cancel_token: cancel_token.clone(), input, }, ); let handle = PredictionHandle { id, cancel_token }; Ok((handle, UnregisteredPredictionSlot::new(slot, idle_tx))) } /// Check if a prediction with this ID is already in-flight. pub fn prediction_exists(&self, id: &str) -> bool { self.predictions.contains_key(id) } /// Get a snapshot of prediction state for API responses. /// /// Locks the real Prediction to read current state — no stale copies. /// Adds `input` from the PredictionEntry on top of the shared snapshot. pub fn get_prediction_response(&self, id: &str) -> Option { let entry = self.predictions.get(id)?; let pred = entry.prediction.lock().ok()?; let mut response = pred.build_state_snapshot(); response["input"] = entry.input.clone(); Some(response) } /// Run a prediction to completion via orchestrator. pub async fn predict( &self, unregistered_slot: UnregisteredPredictionSlot, input: serde_json::Value, context: std::collections::HashMap, ) -> Result { let state = self.orchestrator.read().await.clone(); let state = state .ok_or_else(|| PredictionError::Failed("No orchestrator configured".to_string()))?; let (idle_tx, mut slot) = unregistered_slot.into_parts(); let prediction_id = slot.id(); let slot_id = slot.slot_id(); { let prediction = slot.prediction(); let Some(mut pred) = try_lock_prediction(&prediction) else { return Err(PredictionError::Failed( "Prediction mutex poisoned".to_string(), )); }; pred.set_processing(); } // Register for response routing in event loop let prediction_arc = slot.prediction(); state .orchestrator .register_prediction(slot_id, Arc::clone(&prediction_arc), idle_tx) .await; // Create per-prediction dirs for file-based inputs/outputs let prediction_dir = std::path::PathBuf::from("/tmp/coglet/predictions").join(&prediction_id); let output_dir = prediction_dir.join("outputs"); let input_dir = prediction_dir.join("inputs"); std::fs::create_dir_all(&output_dir) .map_err(|e| PredictionError::Failed(format!("Failed to create output dir: {}", e)))?; std::fs::create_dir_all(&input_dir) .map_err(|e| PredictionError::Failed(format!("Failed to create input dir: {}", e)))?; let request = build_slot_request( prediction_id.clone(), input, output_dir .to_str() .expect("output dir path is valid UTF-8") .to_string(), &input_dir, context, ) .map_err(|e| PredictionError::Failed(format!("Failed to build slot request: {}", e)))?; // permit_mut returns None if permit isn't InUse (shouldn't happen here) let permit = slot .permit_mut() .ok_or_else(|| PredictionError::Failed("Permit not in use".to_string()))?; if let Err(e) = permit.send(request).await { tracing::error!(%slot_id, error = %e, "Failed to send prediction request"); // Broken socket means the slot is dead — poison it at the pool level. state.pool.poison(slot_id); if let Some(mut pred) = try_lock_prediction(&prediction_arc) { pred.set_failed(format!("Failed to send request: {}", e)); } return Err(PredictionError::Failed(format!( "Failed to send request: {}", e ))); } // Wait for prediction to complete // Check if already terminal first to avoid race with fast completions let (already_terminal, completion) = { let Some(pred) = try_lock_prediction(&prediction_arc) else { return Err(PredictionError::Failed( "Prediction mutex poisoned".to_string(), )); }; (pred.is_terminal(), pred.completion()) }; if !already_terminal { completion.notified().await; } let (status, output, error, logs, predict_time, metrics) = { let Some(pred) = try_lock_prediction(&prediction_arc) else { return Err(PredictionError::Failed( "Prediction mutex poisoned".to_string(), )); }; ( pred.status(), pred.output().cloned(), pred.error().map(|s| s.to_string()), pred.logs().to_string(), pred.elapsed(), pred.metrics().clone(), ) }; // If `into_idle()` fails, it does not necessarily mean the prediction failed, // so we return the result if available, but log the error and poison the slot to prevent reuse. // This is performed asynchronously to avoid blocking the prediction response to the caller. tokio::spawn(async move { if let Err(e) = slot.into_idle().await { tracing::error!(%slot_id, error = %e, "Failed to transition slot to idle, poisoning slot"); state.pool.poison(slot_id); } }); match status { PredictionStatus::Succeeded => Ok(PredictionResult { output: output.unwrap_or(PredictionOutput::Single(serde_json::Value::Null)), predict_time: Some(predict_time), logs, metrics, }), PredictionStatus::Failed => Err(PredictionError::Failed( error.unwrap_or_else(|| "Unknown error".to_string()), )), PredictionStatus::Canceled => Err(PredictionError::Cancelled), _ => Err(PredictionError::Failed(format!( "Prediction ended in unexpected state: {:?}", status ))), } } /// Cancel a prediction by ID. Returns true if found and cancelled. /// /// Fires the CancellationToken (for Rust-side observers like upload tasks) /// and delegates to the orchestrator to send `ControlRequest::Cancel` to the worker. pub fn cancel(&self, id: &str) -> bool { if let Some(entry) = self.predictions.get(id) { entry.cancel_token.cancel(); // Delegate to orchestrator to actually cancel the worker-side prediction. // This must be non-blocking since cancel() is sync, so we spawn a task. let id_owned = id.to_string(); let orchestrator = self .orchestrator .try_read() .ok() .and_then(|guard| guard.as_ref().map(|s| Arc::clone(&s.orchestrator))); if let Some(orch) = orchestrator { tokio::spawn(async move { if let Err(e) = orch.cancel_by_prediction_id(&id_owned).await { tracing::error!( prediction_id = %id_owned, error = %e, "Failed to send cancel to orchestrator" ); } }); } true } else { false } } /// Remove a prediction from the DashMap after completion. pub fn remove_prediction(&self, id: &str) { self.predictions.remove(id); } pub fn trigger_shutdown(&self) { let _ = self.shutdown_tx.send(true); } pub fn shutdown_rx(&self) -> watch::Receiver { self.shutdown_rx.clone() } } /// Build a `SlotRequest::Predict`, spilling the input to disk if it exceeds /// `MAX_INLINE_IPC_SIZE`. This prevents IPC frame overflow on the slot socket. /// /// NOTE: The input is serialized here to check its size against the threshold. /// For the inline path the original `Value` is kept and will be serialized again /// by `JsonCodec` — a double-serialize trade-off we accept to keep the codec /// generic. The spill path writes the pre-serialized bytes directly. fn build_slot_request( id: String, input: serde_json::Value, output_dir: String, input_dir: &std::path::Path, context: std::collections::HashMap, ) -> std::io::Result { let serialized = serde_json::to_vec(&input) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; if serialized.len() > MAX_INLINE_IPC_SIZE { let path = input_dir.join(format!("spill_{}.json", uuid::Uuid::new_v4())); std::fs::write(&path, &serialized)?; let input_file = path .to_str() .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::InvalidData, "non-UTF-8 path"))? .to_string(); Ok(SlotRequest::Predict { id, input: None, input_file: Some(input_file), output_dir, context, }) } else { Ok(SlotRequest::Predict { id, input: Some(input), input_file: None, output_dir, context, }) } } #[cfg(test)] mod tests { use super::*; use crate::bridge::protocol::SlotId; use crate::permit::{InactiveSlotIdleToken, SlotIdleToken}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; /// Mock orchestrator that immediately completes predictions. struct MockOrchestrator { register_count: AtomicUsize, complete_immediately: bool, send_idle_ack: bool, } impl MockOrchestrator { fn new() -> Self { Self { register_count: AtomicUsize::new(0), complete_immediately: true, send_idle_ack: false, } } fn register_count(&self) -> usize { self.register_count.load(Ordering::SeqCst) } fn with_idle_ack(mut self) -> Self { self.send_idle_ack = true; self } } #[async_trait::async_trait] impl Orchestrator for MockOrchestrator { async fn register_prediction( &self, slot_id: SlotId, prediction: Arc>, idle_sender: tokio::sync::oneshot::Sender, ) { self.register_count.fetch_add(1, Ordering::SeqCst); if self.complete_immediately { let mut pred = prediction.lock().unwrap(); pred.set_succeeded(crate::PredictionOutput::Single(serde_json::json!( "mock result" ))); } if self.send_idle_ack { let _ = idle_sender.send(InactiveSlotIdleToken::new(slot_id).activate()); } } async fn cancel_by_prediction_id( &self, _prediction_id: &str, ) -> Result<(), crate::orchestrator::OrchestratorError> { Ok(()) } async fn healthcheck( &self, ) -> Result { Ok(HealthcheckResult::healthy()) } async fn shutdown(&self) -> Result<(), crate::orchestrator::OrchestratorError> { Ok(()) } } async fn create_test_pool(num_slots: usize) -> Arc { use crate::bridge::codec::JsonCodec; use crate::bridge::protocol::SlotRequest; use futures::StreamExt; use tokio::net::UnixStream; let pool = Arc::new(PermitPool::new(num_slots)); for _ in 0..num_slots { let (a, b) = UnixStream::pair().unwrap(); let (_read_a, write_a) = a.into_split(); let (read_b, _write_b) = b.into_split(); // Spawn a task to consume messages from the socket (prevents broken pipe) let mut reader = tokio_util::codec::FramedRead::new(read_b, JsonCodec::::new()); tokio::spawn(async move { while reader.next().await.is_some() {} }); let writer = tokio_util::codec::FramedWrite::new(write_a, JsonCodec::::new()); pool.add_permit(SlotId::new(), writer); } pool } async fn create_test_pool_with_slots(num_slots: usize) -> (Arc, Vec) { use crate::bridge::codec::JsonCodec; use crate::bridge::protocol::SlotRequest; use futures::StreamExt; use tokio::net::UnixStream; let pool = Arc::new(PermitPool::new(num_slots)); let mut slot_ids = Vec::with_capacity(num_slots); for _ in 0..num_slots { let (a, b) = UnixStream::pair().unwrap(); let (_read_a, write_a) = a.into_split(); let (read_b, _write_b) = b.into_split(); let mut reader = tokio_util::codec::FramedRead::new(read_b, JsonCodec::::new()); tokio::spawn(async move { while reader.next().await.is_some() {} }); let writer = tokio_util::codec::FramedWrite::new(write_a, JsonCodec::::new()); let slot_id = SlotId::new(); pool.add_permit(slot_id, writer); slot_ids.push(slot_id); } (pool, slot_ids) } async fn create_broken_test_pool() -> (Arc, SlotId) { use crate::bridge::codec::JsonCodec; use crate::bridge::protocol::SlotRequest; use tokio::net::UnixStream; let pool = Arc::new(PermitPool::new(1)); let (a, b) = UnixStream::pair().unwrap(); let (_read_a, write_a) = a.into_split(); drop(b); let writer = tokio_util::codec::FramedWrite::new(write_a, JsonCodec::::new()); let slot_id = SlotId::new(); pool.add_permit(slot_id, writer); (pool, slot_id) } #[tokio::test] async fn service_new_no_pool_works() { let svc = PredictionService::new_no_pool(); let health = svc.health().await; assert_eq!(health.state, Health::Unknown); assert_eq!(health.total_slots, 0); assert_eq!(health.available_slots, 0); assert!(svc.pool().await.is_none()); } #[tokio::test] async fn service_no_pool_initially() { let svc = PredictionService::new_no_pool(); assert!(svc.pool().await.is_none()); assert!(!svc.has_orchestrator().await); } #[tokio::test] async fn shutdown_signal_works() { let svc = PredictionService::new_no_pool(); let mut rx = svc.shutdown_rx(); assert!(!*rx.borrow()); svc.trigger_shutdown(); rx.changed().await.unwrap(); assert!(*rx.borrow()); } #[tokio::test] async fn submit_fails_when_not_ready() { let svc = PredictionService::new_no_pool(); let result = svc .submit_prediction("test".to_string(), serde_json::json!({}), None) .await; assert!(matches!(result, Err(CreatePredictionError::NotReady))); } #[tokio::test] async fn cannot_set_ready_without_orchestrator() { let svc = PredictionService::new_no_pool(); // with_health silently ignores READY let svc2 = PredictionService::new_no_pool().with_health(Health::Ready); assert_eq!(svc2.health().await.state, Health::Unknown); // set_health also ignores READY without orchestrator svc.set_health(Health::Ready).await; assert_eq!(svc.health().await.state, Health::Unknown); } #[tokio::test] async fn set_orchestrator_enables_ready_health() { let svc = PredictionService::new_no_pool(); let pool = create_test_pool(2).await; let orchestrator = Arc::new(MockOrchestrator::new()); svc.set_orchestrator(pool, orchestrator).await; assert!(svc.has_orchestrator().await); // Now we can set READY svc.set_health(Health::Ready).await; let health = svc.health().await; assert_eq!(health.state, Health::Ready); assert_eq!(health.total_slots, 2); assert_eq!(health.available_slots, 2); } #[tokio::test] async fn submit_prediction_succeeds_when_ready() { let svc = PredictionService::new_no_pool(); let pool = create_test_pool(1).await; let orchestrator = Arc::new(MockOrchestrator::new()); svc.set_orchestrator(pool, orchestrator).await; svc.set_health(Health::Ready).await; let (handle, _slot) = svc .submit_prediction("test-1".to_string(), serde_json::json!({}), None) .await .unwrap(); assert_eq!(handle.id(), "test-1"); assert!(svc.prediction_exists("test-1")); } #[tokio::test] async fn submit_returns_at_capacity_when_no_slots() { let svc = PredictionService::new_no_pool(); let pool = create_test_pool(1).await; let orchestrator = Arc::new(MockOrchestrator::new()); svc.set_orchestrator(pool, orchestrator).await; svc.set_health(Health::Ready).await; // First prediction takes the only slot let (_handle1, _slot1) = svc .submit_prediction("test-1".to_string(), serde_json::json!({}), None) .await .unwrap(); // Second should fail with AtCapacity let result = svc .submit_prediction("test-2".to_string(), serde_json::json!({}), None) .await; assert!(matches!(result, Err(CreatePredictionError::AtCapacity))); } #[tokio::test] async fn predict_calls_orchestrator_register() { let svc = PredictionService::new_no_pool(); let pool = create_test_pool(1).await; let orchestrator = Arc::new(MockOrchestrator::new()); let orch_ref = Arc::clone(&orchestrator); svc.set_orchestrator(pool, orchestrator).await; svc.set_health(Health::Ready).await; let (_handle, slot) = svc .submit_prediction( "test-1".to_string(), serde_json::json!({"prompt": "hello"}), None, ) .await .unwrap(); let result = svc .predict( slot, serde_json::json!({"prompt": "hello"}), Default::default(), ) .await; // MockOrchestrator completes immediately with success assert!(result.is_ok(), "predict failed: {:?}", result.err()); assert_eq!(orch_ref.register_count(), 1); } #[tokio::test] async fn health_shows_busy_when_all_slots_used() { let svc = PredictionService::new_no_pool(); let pool = create_test_pool(1).await; let orchestrator = Arc::new(MockOrchestrator::new()); svc.set_orchestrator(pool, orchestrator).await; svc.set_health(Health::Ready).await; // Before acquiring slot let health = svc.health().await; assert!(!health.is_busy()); assert_eq!(health.available_slots, 1); // After acquiring slot let (_handle, _slot) = svc .submit_prediction("test-1".to_string(), serde_json::json!({}), None) .await .unwrap(); let health = svc.health().await; assert!(health.is_busy()); assert_eq!(health.available_slots, 0); } #[tokio::test] async fn predict_idle_channel_closed_poison_slot_async() { let svc = PredictionService::new_no_pool(); let (pool, slot_ids) = create_test_pool_with_slots(1).await; let orchestrator = Arc::new(MockOrchestrator::new()); let slot_id = slot_ids[0]; svc.set_orchestrator(Arc::clone(&pool), orchestrator).await; svc.set_health(Health::Ready).await; let (_handle, slot) = svc .submit_prediction( "test-1".to_string(), serde_json::json!({"prompt": "hello"}), None, ) .await .unwrap(); let result = svc .predict( slot, serde_json::json!({"prompt": "hello"}), Default::default(), ) .await; assert!(result.is_ok(), "predict failed: {:?}", result.err()); tokio::time::timeout(Duration::from_secs(1), async { loop { if pool.is_poisoned(slot_id) { break; } tokio::time::sleep(Duration::from_millis(10)).await; } }) .await .expect("slot was not poisoned after idle token channel closed"); } #[tokio::test] async fn predict_idle_ack_returns_capacity_async() { let svc = PredictionService::new_no_pool(); let pool = create_test_pool(1).await; let orchestrator = Arc::new(MockOrchestrator::new().with_idle_ack()); svc.set_orchestrator(Arc::clone(&pool), orchestrator).await; svc.set_health(Health::Ready).await; let (_handle, slot) = svc .submit_prediction( "test-1".to_string(), serde_json::json!({"prompt": "hello"}), None, ) .await .unwrap(); let result = svc .predict( slot, serde_json::json!({"prompt": "hello"}), Default::default(), ) .await; assert!(result.is_ok(), "predict failed: {:?}", result.err()); tokio::time::timeout(Duration::from_secs(1), async { loop { if pool.available() == 1 { break; } tokio::time::sleep(Duration::from_millis(10)).await; } }) .await .expect("slot capacity was not returned after idle acknowledgement"); } #[tokio::test] async fn predict_send_failure_poison_slot() { let svc = PredictionService::new_no_pool(); let (pool, slot_id) = create_broken_test_pool().await; let orchestrator = Arc::new(MockOrchestrator::new()); svc.set_orchestrator(Arc::clone(&pool), orchestrator).await; svc.set_health(Health::Ready).await; let (_handle, slot) = svc .submit_prediction( "test-1".to_string(), serde_json::json!({"prompt": "hello"}), None, ) .await .unwrap(); let result = svc .predict( slot, serde_json::json!({"prompt": "hello"}), Default::default(), ) .await; assert!(matches!(result, Err(PredictionError::Failed(_)))); assert!(pool.is_poisoned(slot_id)); assert!(pool.try_acquire().is_none()); } #[tokio::test] async fn cancel_prediction_works() { let svc = PredictionService::new_no_pool(); let pool = create_test_pool(1).await; let orchestrator = Arc::new(MockOrchestrator::new()); svc.set_orchestrator(pool, orchestrator).await; svc.set_health(Health::Ready).await; let (handle, _slot) = svc .submit_prediction("test-cancel".to_string(), serde_json::json!({}), None) .await .unwrap(); let cancel_token = handle.cancel_token(); let cancelled = svc.cancel("test-cancel"); assert!(cancelled); assert!(cancel_token.is_cancelled()); } #[tokio::test] async fn cancel_nonexistent_returns_false() { let svc = PredictionService::new_no_pool(); assert!(!svc.cancel("nonexistent")); } #[tokio::test] async fn sync_guard_cancels_on_drop() { let svc = Arc::new(PredictionService::new_no_pool()); let pool = create_test_pool(1).await; let orchestrator = Arc::new(MockOrchestrator::new()); svc.set_orchestrator(pool, orchestrator).await; svc.set_health(Health::Ready).await; let (handle, _slot) = svc .submit_prediction("test-guard".to_string(), serde_json::json!({}), None) .await .unwrap(); let cancel_token = handle.cancel_token(); { let _guard = handle.sync_guard(Arc::clone(&svc)); assert!(!cancel_token.is_cancelled()); } assert!(cancel_token.is_cancelled()); } #[tokio::test] async fn sync_guard_disarm_prevents_cancel() { let svc = Arc::new(PredictionService::new_no_pool()); let pool = create_test_pool(1).await; let orchestrator = Arc::new(MockOrchestrator::new()); svc.set_orchestrator(pool, orchestrator).await; svc.set_health(Health::Ready).await; let (handle, _slot) = svc .submit_prediction("test-disarm".to_string(), serde_json::json!({}), None) .await .unwrap(); let cancel_token = handle.cancel_token(); { let mut guard = handle.sync_guard(Arc::clone(&svc)); guard.disarm(); } assert!(!cancel_token.is_cancelled()); } #[tokio::test] async fn remove_prediction_cleans_up() { let svc = PredictionService::new_no_pool(); let pool = create_test_pool(1).await; let orchestrator = Arc::new(MockOrchestrator::new()); svc.set_orchestrator(pool, orchestrator).await; svc.set_health(Health::Ready).await; let (_handle, _slot) = svc .submit_prediction("test-remove".to_string(), serde_json::json!({}), None) .await .unwrap(); assert!(svc.prediction_exists("test-remove")); svc.remove_prediction("test-remove"); assert!(!svc.prediction_exists("test-remove")); } #[test] fn build_slot_request_small_input_inline() { let dir = tempfile::tempdir().unwrap(); let input = serde_json::json!({"text": "hello"}); let req = build_slot_request( "p1".into(), input.clone(), "/tmp/out".into(), dir.path(), Default::default(), ) .unwrap(); match req { SlotRequest::Predict { id, input: Some(v), input_file: None, output_dir, .. } => { assert_eq!(id, "p1"); assert_eq!(v, input); assert_eq!(output_dir, "/tmp/out"); } _ => panic!("expected inline input"), } } #[test] fn build_slot_request_large_input_spills() { let dir = tempfile::tempdir().unwrap(); // Create an input larger than 6 MiB let big = "x".repeat(7 * 1024 * 1024); let input = serde_json::json!({"data": big}); let req = build_slot_request( "p2".into(), input.clone(), "/tmp/out".into(), dir.path(), Default::default(), ) .unwrap(); match req { SlotRequest::Predict { id, input: None, input_file: Some(ref path), output_dir, .. } => { assert_eq!(id, "p2"); assert_eq!(output_dir, "/tmp/out"); // Spill file should exist on disk assert!(std::path::Path::new(path).exists()); // Content should be valid JSON matching the original input let bytes = std::fs::read(path).unwrap(); let roundtrip: serde_json::Value = serde_json::from_slice(&bytes).unwrap(); assert_eq!(roundtrip, input); } _ => panic!("expected file-backed input"), } } #[test] fn build_slot_request_roundtrip() { let dir = tempfile::tempdir().unwrap(); let big = "y".repeat(7 * 1024 * 1024); let input = serde_json::json!({"payload": big}); let req = build_slot_request( "p3".into(), input.clone(), "/tmp/out".into(), dir.path(), Default::default(), ) .unwrap(); // Rehydrate and verify we get back the same value let (id, rehydrated, output_dir, _context) = req.rehydrate_input().unwrap(); assert_eq!(id, "p3"); assert_eq!(rehydrated, input); assert_eq!(output_dir, "/tmp/out"); } } ================================================ FILE: crates/coglet/src/setup_log_accumulator.rs ================================================ //! Tracing layer that accumulates all logs from coglet server during setup. //! //! Captures every tracing event from the moment the server starts until setup completes. //! This includes: //! - Initial server startup logs ("coglet ") //! - Orchestrator logs ("Spawning worker subprocess") //! - Re-emitted worker logs (via emit_worker_log) //! - Transport logs, codec warnings, everything //! //! Uses unbounded mpsc channel for lock-free accumulation. use tokio::sync::mpsc; use tracing::Subscriber; use tracing_subscriber::layer::{Context, Layer}; pub struct SetupLogAccumulator { tx: mpsc::UnboundedSender, } impl SetupLogAccumulator { pub fn new(tx: mpsc::UnboundedSender) -> Self { Self { tx } } } impl Layer for SetupLogAccumulator where S: Subscriber, { fn on_event(&self, event: &tracing::Event<'_>, _ctx: Context<'_, S>) { if self.tx.is_closed() { return; } let metadata = event.metadata(); let target = metadata.target(); let mut visitor = MessageVisitor::default(); event.record(&mut visitor); let log_line = format!("[{}] {}", target, visitor.message); let _ = self.tx.send(log_line); } } #[derive(Default)] struct MessageVisitor { message: String, } impl tracing::field::Visit for MessageVisitor { fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) { if field.name() == "message" { self.message = format!("{:?}", value); if self.message.starts_with('"') && self.message.ends_with('"') { self.message = self.message[1..self.message.len() - 1].to_string(); } } } fn record_str(&mut self, field: &tracing::field::Field, value: &str) { if field.name() == "message" { self.message = value.to_string(); } } } pub fn drain_accumulated_logs(rx: &mut mpsc::UnboundedReceiver) -> String { let mut lines = Vec::new(); while let Ok(line) = rx.try_recv() { lines.push(line); } if lines.is_empty() { String::new() } else { let mut result = lines.join("\n"); result.push('\n'); result } } ================================================ FILE: crates/coglet/src/snapshots/coglet__health__tests__health_all_variants.snap ================================================ --- source: coglet/src/health.rs expression: "[Health::Unknown, Health::Starting, Health::Ready, Health::Busy,\nHealth::SetupFailed, Health::Defunct,]" --- [ "UNKNOWN", "STARTING", "READY", "BUSY", "SETUP_FAILED", "DEFUNCT" ] ================================================ FILE: crates/coglet/src/snapshots/coglet__health__tests__health_response_all_variants.snap ================================================ --- source: coglet/src/health.rs expression: "[HealthResponse::Unknown, HealthResponse::Starting, HealthResponse::Ready,\nHealthResponse::Busy, HealthResponse::SetupFailed, HealthResponse::Defunct,\nHealthResponse::Unhealthy,]" --- [ "UNKNOWN", "STARTING", "READY", "BUSY", "SETUP_FAILED", "DEFUNCT", "UNHEALTHY" ] ================================================ FILE: crates/coglet/src/snapshots/coglet__health__tests__setup_status_all_variants.snap ================================================ --- source: coglet/src/health.rs expression: "[SetupStatus::Starting, SetupStatus::Succeeded, SetupStatus::Failed,]" --- [ "starting", "succeeded", "failed" ] ================================================ FILE: crates/coglet/src/snapshots/coglet__predictor__tests__output_single.snap ================================================ --- source: coglet/src/predictor.rs expression: single --- "hello" ================================================ FILE: crates/coglet/src/snapshots/coglet__predictor__tests__output_stream.snap ================================================ --- source: coglet/src/predictor.rs expression: stream --- [ 1, 2 ] ================================================ FILE: crates/coglet/src/snapshots/coglet__version__tests__version_full.snap ================================================ --- source: coglet/src/version.rs expression: info --- { "coglet": "0.1.0", "git_sha": "abc1234-dirty", "build_time": "2026-03-12T18:00:00Z", "python_sdk": "0.9.0", "python": "3.11.0" } ================================================ FILE: crates/coglet/src/snapshots/coglet__version__tests__version_minimal.snap ================================================ --- source: coglet/src/version.rs expression: info --- { "coglet": "0.1.0" } ================================================ FILE: crates/coglet/src/transport/http/mod.rs ================================================ //! HTTP transport for coglet using axum. mod routes; mod server; pub use server::{ServerConfig, serve}; ================================================ FILE: crates/coglet/src/transport/http/routes.rs ================================================ //! HTTP route handlers. use std::sync::Arc; use axum::{ Router, extract::{DefaultBodyLimit, Path, State}, http::{HeaderMap, StatusCode}, response::{IntoResponse, Json}, routing::{get, post, put}, }; use serde::{Deserialize, Serialize}; #[cfg(test)] use crate::health::Health; use crate::health::{HealthResponse, SetupResult}; use crate::predictor::PredictionError; use crate::service::{CreatePredictionError, HealthSnapshot, PredictionService}; use crate::version::VersionInfo; use crate::webhook::{TraceContext, WebhookConfig, WebhookEventType, WebhookSender}; #[derive(Debug, Serialize)] pub struct HealthCheckResponse { pub status: HealthResponse, #[serde(skip_serializing_if = "Option::is_none")] pub setup: Option, pub version: VersionInfo, #[serde(skip_serializing_if = "Option::is_none")] pub user_healthcheck_error: Option, } impl HealthCheckResponse { pub fn from_snapshot(snapshot: HealthSnapshot, user_healthcheck_error: Option) -> Self { // Determine response status let status = if user_healthcheck_error.is_some() { HealthResponse::Unhealthy } else if snapshot.is_busy() { HealthResponse::Busy } else { snapshot.state.into() }; Self { status, setup: snapshot.setup_result, version: snapshot.version, user_healthcheck_error, } } } #[derive(Debug, Deserialize)] pub struct PredictionRequest { pub id: Option, #[serde( default = "default_empty_input", deserialize_with = "deserialize_input" )] pub input: serde_json::Value, /// Per-prediction context made available to predictors via `current_scope().context`. #[serde(default)] pub context: std::collections::HashMap, pub webhook: Option, #[serde(default = "default_webhook_events_filter")] pub webhook_events_filter: Vec, } fn default_empty_input() -> serde_json::Value { serde_json::json!({}) } fn deserialize_input<'de, D>(deserializer: D) -> Result where D: serde::Deserializer<'de>, { let value = serde_json::Value::deserialize(deserializer)?; Ok(if value.is_null() { serde_json::json!({}) } else { value }) } fn default_webhook_events_filter() -> Vec { vec![ WebhookEventType::Start, WebhookEventType::Output, WebhookEventType::Logs, WebhookEventType::Completed, ] } fn generate_prediction_id() -> String { use std::time::{SystemTime, UNIX_EPOCH}; // SAFETY: SystemTime::now() is always after UNIX_EPOCH on any reasonable system. // This cannot fail unless the system clock is set before 1970. let timestamp = SystemTime::now() .duration_since(UNIX_EPOCH) .expect("system clock is before 1970") .as_nanos(); format!("pred_{:x}", timestamp) } /// Root discovery endpoint — returns a map of available API endpoints. /// /// Restores the `GET /` endpoint from cog <= 0.16.x for service discovery. /// `cog_version` reports the Python SDK version when available (matching the /// old Python server behaviour), falling back to the coglet runtime version. async fn root(State(service): State>) -> Json { let version = service.version(); let cog_version = version.python_sdk.as_deref().unwrap_or(version.coglet); let mut doc = serde_json::json!({ "cog_version": cog_version, "docs_url": "/docs", "openapi_url": "/openapi.json", "shutdown_url": "/shutdown", "healthcheck_url": "/health-check", "predictions_url": "/predictions", "predictions_idempotent_url": "/predictions/{prediction_id}", "predictions_cancel_url": "/predictions/{prediction_id}/cancel", }); if service.supports_training().await { let obj = doc.as_object_mut().expect("doc is an object"); obj.insert("trainings_url".to_string(), serde_json::json!("/trainings")); obj.insert( "trainings_idempotent_url".to_string(), serde_json::json!("/trainings/{training_id}"), ); obj.insert( "trainings_cancel_url".to_string(), serde_json::json!("/trainings/{training_id}/cancel"), ); } Json(doc) } async fn health_check(State(service): State>) -> Json { tracing::trace!("Health check endpoint called"); let snapshot = service.health().await; tracing::trace!( state = ?snapshot.state, available_slots = snapshot.available_slots, total_slots = snapshot.total_slots, has_setup_result = snapshot.setup_result.is_some(), "Health snapshot retrieved" ); // Run user healthcheck if ready (even when busy — healthcheck health // and slot availability are orthogonal concerns). let user_healthcheck_error = if snapshot.is_ready() { write_readiness_file(); // Run user-defined healthcheck tracing::trace!("Running user-defined healthcheck"); match service.healthcheck().await { Ok(result) if result.is_healthy() => { tracing::trace!("User healthcheck passed"); None } Ok(result) => { tracing::debug!(error = ?result.error, "User healthcheck reported unhealthy"); result.error } Err(e) => { tracing::debug!(error = %e, "User healthcheck returned error"); Some(format!("Healthcheck error: {}", e)) } } } else { tracing::trace!(state = ?snapshot.state, "Skipping user healthcheck (not ready)"); None }; let response = HealthCheckResponse::from_snapshot(snapshot, user_healthcheck_error); tracing::trace!(status = ?response.status, "Health check response"); Json(response) } /// Write /var/run/cog/ready for K8s readiness probe. fn write_readiness_file() { if std::env::var("KUBERNETES_SERVICE_HOST").is_err() { return; } let dir = std::path::Path::new("/var/run/cog"); let file = dir.join("ready"); if file.exists() { return; } if let Err(e) = std::fs::create_dir_all(dir) { tracing::warn!(error = %e, "Failed to create /var/run/cog directory"); return; } if let Err(e) = std::fs::write(&file, b"") { tracing::warn!(error = %e, "Failed to write readiness file"); } } fn should_respond_async(headers: &HeaderMap) -> bool { headers .get("prefer") .and_then(|v| v.to_str().ok()) .map(|v| v == "respond-async") .unwrap_or(false) } fn extract_trace_context(headers: &HeaderMap) -> TraceContext { TraceContext { traceparent: headers .get("traceparent") .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()), tracestate: headers .get("tracestate") .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()), } } async fn create_prediction( State(service): State>, headers: HeaderMap, body: Option>, ) -> impl IntoResponse { let request = body.map(|Json(r)| r).unwrap_or_else(|| PredictionRequest { id: None, input: serde_json::json!({}), context: Default::default(), webhook: None, webhook_events_filter: default_webhook_events_filter(), }); let prediction_id = request.id.unwrap_or_else(generate_prediction_id); let respond_async = should_respond_async(&headers); let trace_context = extract_trace_context(&headers); create_prediction_with_id( service, prediction_id, request.input, request.context, request.webhook, request.webhook_events_filter, respond_async, trace_context, false, ) .await } async fn create_prediction_idempotent( State(service): State>, Path(prediction_id): Path, headers: HeaderMap, body: Option>, ) -> impl IntoResponse { let request = body.map(|Json(r)| r).unwrap_or_else(|| PredictionRequest { id: None, input: serde_json::json!({}), context: Default::default(), webhook: None, webhook_events_filter: default_webhook_events_filter(), }); if let Some(ref req_id) = request.id && req_id != &prediction_id { return ( StatusCode::UNPROCESSABLE_ENTITY, Json(serde_json::json!({ "detail": [{ "loc": ["body", "id"], "msg": "prediction ID must match the ID supplied in the URL", "type": "value_error" }] })), ); } // Check if prediction with this ID is already in-flight if let Some(response) = service.get_prediction_response(&prediction_id) { return (StatusCode::ACCEPTED, Json(response)); } let respond_async = should_respond_async(&headers); let trace_context = extract_trace_context(&headers); create_prediction_with_id( service, prediction_id, request.input, request.context, request.webhook, request.webhook_events_filter, respond_async, trace_context, false, ) .await } fn build_webhook_sender( webhook: Option, events_filter: Vec, trace_context: TraceContext, ) -> Option { let webhook_url = webhook?; let events: std::collections::HashSet<_> = events_filter.into_iter().collect(); match WebhookSender::with_trace_context( webhook_url.clone(), WebhookConfig { events_filter: events, ..Default::default() }, trace_context, ) { Ok(sender) => Some(sender), Err(e) => { tracing::error!(url = %webhook_url, error = %e, "Failed to create webhook sender"); None } } } #[allow(clippy::too_many_arguments)] async fn create_prediction_with_id( service: Arc, prediction_id: String, input: serde_json::Value, context: std::collections::HashMap, webhook: Option, webhook_events_filter: Vec, respond_async: bool, trace_context: TraceContext, is_training: bool, ) -> (StatusCode, Json) { // Validate input against the appropriate schema let validation_result = if is_training { service.validate_train_input(&input).await } else { service.validate_input(&input).await }; if let Err(errors) = validation_result { let detail: Vec = errors .into_iter() .map(|e| { serde_json::json!({ "loc": ["body", "input", e.field], "msg": e.msg, "type": e.error_type }) }) .collect(); return ( StatusCode::UNPROCESSABLE_ENTITY, Json(serde_json::json!({ "detail": detail })), ); } let webhook_sender = build_webhook_sender( webhook.clone(), webhook_events_filter.clone(), trace_context.clone(), ); // Submit prediction: creates Prediction, acquires slot, registers in service let (handle, unregistered_slot) = match service .submit_prediction(prediction_id.clone(), input.clone(), webhook_sender) .await { Ok(r) => r, Err(CreatePredictionError::NotReady) => { let msg = PredictionError::NotReady.to_string(); return ( StatusCode::SERVICE_UNAVAILABLE, Json(serde_json::json!({ "error": msg, "status": "failed" })), ); } Err(CreatePredictionError::AtCapacity) => { return ( StatusCode::CONFLICT, Json(serde_json::json!({ "error": "At capacity - all prediction slots busy", "status": "failed" })), ); } }; let prediction = unregistered_slot.prediction(); // Async mode: spawn background task, return immediately if respond_async { let service_clone = Arc::clone(&service); let id_for_cleanup = prediction_id.clone(); let context_async = context.clone(); tokio::spawn(async move { let _result = service_clone .predict(unregistered_slot, input, context_async) .await; // Prediction state is already updated by predict() internally // (set_succeeded/set_failed/set_canceled fire webhooks automatically) service_clone.remove_prediction(&id_for_cleanup); }); return ( StatusCode::ACCEPTED, Json(serde_json::json!({ "id": prediction_id, "status": "starting" })), ); } // Sync mode: spawn prediction into a background task so the slot lifetime // is NOT tied to the HTTP connection. If the client disconnects, the // SyncPredictionGuard fires cancel, but the slot/permit stays alive in the // spawned task until the worker acknowledges the cancel. let mut sync_guard = handle.sync_guard(Arc::clone(&service)); let service_bg = Arc::clone(&service); let id_bg = prediction_id.clone(); let result_rx = { let (tx, rx) = tokio::sync::oneshot::channel(); tokio::spawn(async move { let result = service_bg.predict(unregistered_slot, input, context).await; // Prediction state is already updated by predict() internally service_bg.remove_prediction(&id_bg); let _ = tx.send(result); }); rx }; // Wait for the prediction to complete. If the connection drops, axum // cancels this future, dropping sync_guard which fires cancel. let result = match result_rx.await { Ok(r) => r, Err(_) => { // Background task panicked or was cancelled Err(PredictionError::Failed("prediction task lost".to_string())) } }; let predict_time = prediction .try_lock() .map(|p| p.elapsed()) .unwrap_or(std::time::Duration::ZERO) .as_secs_f64(); // Disarm guard - prediction completed normally (connection still alive) sync_guard.disarm(); // Build metrics object: user metrics + predict_time let build_metrics = |user_metrics: &std::collections::HashMap| { let mut m = serde_json::Map::new(); for (k, v) in user_metrics { m.insert(k.clone(), v.clone()); } m.insert("predict_time".to_string(), serde_json::json!(predict_time)); serde_json::Value::Object(m) }; match result { Ok(r) => { let metrics = build_metrics(&r.metrics); ( StatusCode::OK, Json(serde_json::json!({ "id": prediction_id, "output": r.output, "logs": r.logs, "status": "succeeded", "metrics": metrics })), ) } Err(PredictionError::InvalidInput(msg)) => ( StatusCode::UNPROCESSABLE_ENTITY, Json(serde_json::json!({ "id": prediction_id, "error": msg, "logs": "", "status": "failed", "metrics": { "predict_time": predict_time } })), ), Err(PredictionError::NotReady) => { let msg = PredictionError::NotReady.to_string(); ( StatusCode::SERVICE_UNAVAILABLE, Json(serde_json::json!({ "id": prediction_id, "error": msg, "logs": "", "status": "failed" })), ) } Err(PredictionError::Failed(msg)) => ( // 200 for parity with Python - prediction failure is data, not HTTP error StatusCode::OK, Json(serde_json::json!({ "id": prediction_id, "error": msg, "logs": "", "status": "failed", "metrics": { "predict_time": predict_time } })), ), Err(PredictionError::Cancelled) => ( StatusCode::OK, Json(serde_json::json!({ "id": prediction_id, "logs": "", "status": "canceled", "metrics": { "predict_time": predict_time } })), ), } } async fn cancel_prediction( State(service): State>, Path(prediction_id): Path, ) -> impl IntoResponse { let cancelled = service.cancel(&prediction_id); if cancelled { (StatusCode::OK, Json(serde_json::json!({}))) } else { (StatusCode::NOT_FOUND, Json(serde_json::json!({}))) } } async fn shutdown(State(service): State>) -> impl IntoResponse { tracing::info!("Shutdown requested via HTTP"); service.trigger_shutdown(); (StatusCode::OK, Json(serde_json::json!({}))) } async fn openapi_schema(State(service): State>) -> impl IntoResponse { match service.schema().await { Some(schema) => (StatusCode::OK, Json(schema)), None => ( StatusCode::SERVICE_UNAVAILABLE, Json(serde_json::json!({ "error": "OpenAPI schema not available" })), ), } } // Training routes — same dispatch as predictions but validated against // TrainingInput schema instead of Input. async fn create_training( State(service): State>, headers: HeaderMap, body: Option>, ) -> impl IntoResponse { let request = body.map(|Json(r)| r).unwrap_or_else(|| PredictionRequest { id: None, input: serde_json::json!({}), context: Default::default(), webhook: None, webhook_events_filter: default_webhook_events_filter(), }); let prediction_id = request.id.unwrap_or_else(generate_prediction_id); let respond_async = should_respond_async(&headers); let trace_context = extract_trace_context(&headers); create_prediction_with_id( service, prediction_id, request.input, request.context, request.webhook, request.webhook_events_filter, respond_async, trace_context, true, ) .await } async fn create_training_idempotent( State(service): State>, Path(training_id): Path, headers: HeaderMap, body: Option>, ) -> impl IntoResponse { let request = body.map(|Json(r)| r).unwrap_or_else(|| PredictionRequest { id: None, input: serde_json::json!({}), context: Default::default(), webhook: None, webhook_events_filter: default_webhook_events_filter(), }); if let Some(ref req_id) = request.id && req_id != &training_id { return ( StatusCode::UNPROCESSABLE_ENTITY, Json(serde_json::json!({ "detail": [{ "loc": ["body", "id"], "msg": "training ID must match the ID supplied in the URL", "type": "value_error" }] })), ); } // Idempotent: return existing state if already submitted if let Some(response) = service.get_prediction_response(&training_id) { return (StatusCode::ACCEPTED, Json(response)); } let respond_async = should_respond_async(&headers); let trace_context = extract_trace_context(&headers); create_prediction_with_id( service, training_id, request.input, request.context, request.webhook, request.webhook_events_filter, respond_async, trace_context, true, ) .await } async fn cancel_training( State(service): State>, Path(training_id): Path, ) -> impl IntoResponse { cancel_prediction(State(service), Path(training_id)).await } /// Maximum HTTP request body size (100 MiB). /// /// Axum defaults to 2 MiB which is too small for models that accept large /// inline inputs (e.g. base64-encoded images). Inputs that exceed the IPC /// frame limit are automatically spilled to disk by `build_slot_request`. const MAX_HTTP_BODY_SIZE: usize = 100 * 1024 * 1024; pub fn routes(service: Arc) -> Router { Router::new() .route("/", get(root)) .route("/health-check", get(health_check)) .route("/openapi.json", get(openapi_schema)) .route("/shutdown", post(shutdown)) .route("/predictions", post(create_prediction)) .route("/predictions/{id}", put(create_prediction_idempotent)) .route("/predictions/{id}/cancel", post(cancel_prediction)) .route("/trainings", post(create_training)) .route("/trainings/{id}", put(create_training_idempotent)) .route("/trainings/{id}/cancel", post(cancel_training)) .layer(DefaultBodyLimit::max(MAX_HTTP_BODY_SIZE)) .with_state(service) } #[cfg(test)] mod tests { use super::*; use axum::body::Body; use axum::http::{Request, StatusCode}; use http_body_util::BodyExt; use tower::ServiceExt; async fn response_json(response: axum::response::Response) -> serde_json::Value { let body = response.into_body(); let bytes = body.collect().await.unwrap().to_bytes(); serde_json::from_slice(&bytes).unwrap() } #[tokio::test] async fn health_check_returns_status_and_version() { let service = Arc::new(PredictionService::new_no_pool().with_health(Health::Starting)); let app = routes(service); let response = app .oneshot(Request::get("/health-check").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let json = response_json(response).await; assert_eq!(json["status"], "STARTING"); assert!(json["version"]["coglet"].is_string()); } #[tokio::test] async fn health_check_unknown_when_no_predictor() { let service = Arc::new(PredictionService::new_no_pool()); let app = routes(service); let response = app .oneshot(Request::get("/health-check").body(Body::empty()).unwrap()) .await .unwrap(); let json = response_json(response).await; assert_eq!(json["status"], "UNKNOWN"); } #[tokio::test] async fn predictions_returns_503_when_not_ready() { let service = Arc::new(PredictionService::new_no_pool()); let app = routes(service); let response = app .oneshot( Request::post("/predictions") .header("content-type", "application/json") .body(Body::from(r#"{"input":{}}"#)) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE); let json = response_json(response).await; assert_eq!(json["status"], "failed"); assert!( json["error"] .as_str() .unwrap() .contains("Setup has not finished yet") ); } #[tokio::test] async fn openapi_returns_503_when_schema_not_available() { let service = Arc::new(PredictionService::new_no_pool()); let app = routes(service); let response = app .oneshot(Request::get("/openapi.json").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE); let json = response_json(response).await; assert!(json["error"].as_str().unwrap().contains("not available")); } #[tokio::test] async fn openapi_returns_schema_when_available() { let service = Arc::new(PredictionService::new_no_pool()); service .set_schema(serde_json::json!({ "openapi": "3.0.2", "info": {"title": "Cog", "version": "0.1.0"} })) .await; let app = routes(service); let response = app .oneshot(Request::get("/openapi.json").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let json = response_json(response).await; assert_eq!(json["openapi"], "3.0.2"); assert_eq!(json["info"]["title"], "Cog"); } // --- Tests with MockOrchestrator for full prediction flow --- use crate::PredictionOutput; use crate::bridge::protocol::SlotId; use crate::orchestrator::Orchestrator; use crate::permit::PermitPool; use std::sync::Mutex as StdMutex; use std::sync::atomic::{AtomicUsize, Ordering}; /// Mock orchestrator that immediately completes predictions. struct MockOrchestrator { register_count: AtomicUsize, complete_immediately: bool, } impl MockOrchestrator { fn new() -> Self { Self { register_count: AtomicUsize::new(0), complete_immediately: true, } } /// Create a mock that never completes predictions (for capacity tests). fn never_complete() -> Self { Self { register_count: AtomicUsize::new(0), complete_immediately: false, } } } #[async_trait::async_trait] impl Orchestrator for MockOrchestrator { async fn register_prediction( &self, _slot_id: SlotId, prediction: Arc>, _idle_sender: tokio::sync::oneshot::Sender, ) { self.register_count.fetch_add(1, Ordering::SeqCst); if self.complete_immediately { let mut pred = prediction.lock().unwrap(); pred.set_succeeded(PredictionOutput::Single(serde_json::json!("mock output"))); } } async fn cancel_by_prediction_id( &self, _prediction_id: &str, ) -> Result<(), crate::orchestrator::OrchestratorError> { Ok(()) } async fn healthcheck( &self, ) -> Result { Ok(crate::orchestrator::HealthcheckResult::healthy()) } async fn shutdown(&self) -> Result<(), crate::orchestrator::OrchestratorError> { Ok(()) } } async fn create_test_pool(num_slots: usize) -> Arc { use crate::bridge::codec::JsonCodec; use crate::bridge::protocol::SlotRequest; use futures::StreamExt; use tokio::net::UnixStream; let pool = Arc::new(PermitPool::new(num_slots)); for _ in 0..num_slots { let (a, b) = UnixStream::pair().unwrap(); let (_read_a, write_a) = a.into_split(); let (read_b, _write_b) = b.into_split(); // Spawn a task to consume messages from the socket (prevents broken pipe) let mut reader = tokio_util::codec::FramedRead::new(read_b, JsonCodec::::new()); tokio::spawn(async move { while reader.next().await.is_some() {} }); let writer = tokio_util::codec::FramedWrite::new(write_a, JsonCodec::::new()); pool.add_permit(SlotId::new(), writer); } pool } async fn create_ready_service() -> Arc { let service = Arc::new(PredictionService::new_no_pool()); let pool = create_test_pool(2).await; let orchestrator = Arc::new(MockOrchestrator::new()); service.set_orchestrator(pool, orchestrator).await; service.set_health(Health::Ready).await; service } #[tokio::test] async fn health_check_ready_with_orchestrator() { let service = create_ready_service().await; let app = routes(service); let response = app .oneshot(Request::get("/health-check").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let json = response_json(response).await; assert_eq!(json["status"], "READY"); } #[tokio::test] async fn prediction_sync_success() { let service = create_ready_service().await; let app = routes(service); let response = app .oneshot( Request::post("/predictions") .header("content-type", "application/json") .body(Body::from(r#"{"input":{"prompt":"hello"}}"#)) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let json = response_json(response).await; assert_eq!(json["status"], "succeeded"); assert_eq!(json["output"], "mock output"); assert!(json["id"].is_string()); } #[tokio::test] async fn prediction_async_returns_accepted() { let service = create_ready_service().await; let app = routes(service); let response = app .oneshot( Request::post("/predictions") .header("content-type", "application/json") .header("prefer", "respond-async") .body(Body::from(r#"{"input":{}}"#)) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::ACCEPTED); let json = response_json(response).await; assert_eq!(json["status"], "starting"); } #[tokio::test] async fn prediction_with_custom_id() { let service = create_ready_service().await; let app = routes(service); let response = app .oneshot( Request::post("/predictions") .header("content-type", "application/json") .body(Body::from(r#"{"id":"my-pred-123","input":{}}"#)) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let json = response_json(response).await; assert_eq!(json["id"], "my-pred-123"); assert_eq!(json["status"], "succeeded"); } #[tokio::test] async fn prediction_idempotent_put() { let service = create_ready_service().await; let app = routes(service); let response = app .oneshot( Request::put("/predictions/idempotent-123") .header("content-type", "application/json") .body(Body::from(r#"{"input":{}}"#)) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let json = response_json(response).await; assert_eq!(json["id"], "idempotent-123"); assert_eq!(json["status"], "succeeded"); } #[tokio::test] async fn prediction_idempotent_id_mismatch() { let service = create_ready_service().await; let app = routes(service); let response = app .oneshot( Request::put("/predictions/url-id") .header("content-type", "application/json") .body(Body::from(r#"{"id":"body-id","input":{}}"#)) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); let json = response_json(response).await; assert!( json["detail"][0]["msg"] .as_str() .unwrap() .contains("must match") ); } #[tokio::test] async fn prediction_at_capacity() { let service = Arc::new(PredictionService::new_no_pool()); let pool = create_test_pool(1).await; // Only 1 slot // Use never_complete so the first prediction holds the slot let orchestrator = Arc::new(MockOrchestrator::never_complete()); service.set_orchestrator(pool, orchestrator).await; service.set_health(Health::Ready).await; // Use async mode so first request doesn't block let app = routes(Arc::clone(&service)); let _resp1 = app .oneshot( Request::post("/predictions") .header("content-type", "application/json") .header("prefer", "respond-async") .body(Body::from(r#"{"input":{}}"#)) .unwrap(), ) .await .unwrap(); // Small delay to let async task acquire the slot tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; // Second request should get 409 Conflict (at capacity) let app2 = routes(service); let response = app2 .oneshot( Request::post("/predictions") .header("content-type", "application/json") .body(Body::from(r#"{"input":{}}"#)) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::CONFLICT); let json = response_json(response).await; assert!(json["error"].as_str().unwrap().contains("capacity")); } #[tokio::test] async fn health_check_busy_when_at_capacity() { let service = Arc::new(PredictionService::new_no_pool()); let pool = create_test_pool(1).await; // Use never_complete so the prediction holds the slot let orchestrator = Arc::new(MockOrchestrator::never_complete()); service.set_orchestrator(pool, orchestrator).await; service.set_health(Health::Ready).await; // Use async to hold the slot let app = routes(Arc::clone(&service)); let _resp = app .oneshot( Request::post("/predictions") .header("content-type", "application/json") .header("prefer", "respond-async") .body(Body::from(r#"{"input":{}}"#)) .unwrap(), ) .await .unwrap(); tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; // Health should show BUSY let app2 = routes(service); let response = app2 .oneshot(Request::get("/health-check").body(Body::empty()).unwrap()) .await .unwrap(); let json = response_json(response).await; assert_eq!(json["status"], "BUSY"); } #[tokio::test] async fn training_routes_work() { let service = create_ready_service().await; let app = routes(service); let response = app .oneshot( Request::post("/trainings") .header("content-type", "application/json") .body(Body::from(r#"{"input":{}}"#)) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let json = response_json(response).await; assert_eq!(json["status"], "succeeded"); } #[tokio::test] async fn training_idempotent_put() { let service = create_ready_service().await; let app = routes(service); let response = app .oneshot( Request::put("/trainings/train-123") .header("content-type", "application/json") .body(Body::from(r#"{"input":{}}"#)) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let json = response_json(response).await; assert_eq!(json["id"], "train-123"); assert_eq!(json["status"], "succeeded"); } #[tokio::test] async fn training_idempotent_id_mismatch() { let service = create_ready_service().await; let app = routes(service); let response = app .oneshot( Request::put("/trainings/url-id") .header("content-type", "application/json") .body(Body::from(r#"{"id":"body-id","input":{}}"#)) .unwrap(), ) .await .unwrap(); assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY); let json = response_json(response).await; assert!( json["detail"][0]["msg"] .as_str() .unwrap() .contains("must match") ); } #[tokio::test] async fn shutdown_triggers_service_shutdown() { let service = create_ready_service().await; let mut rx = service.shutdown_rx(); let app = routes(service); assert!(!*rx.borrow()); let response = app .oneshot(Request::post("/shutdown").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); rx.changed().await.unwrap(); assert!(*rx.borrow()); } #[tokio::test] async fn root_returns_discovery_document() { let service = Arc::new(PredictionService::new_no_pool()); let app = routes(service); let response = app .oneshot(Request::get("/").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); assert_eq!( response.headers().get("content-type").unwrap(), "application/json" ); let json = response_json(response).await; // Without a python_sdk version set, falls back to coglet version assert_eq!(json["cog_version"], crate::version::COGLET_VERSION); assert_eq!(json["docs_url"], "/docs"); assert_eq!(json["openapi_url"], "/openapi.json"); assert_eq!(json["shutdown_url"], "/shutdown"); assert_eq!(json["healthcheck_url"], "/health-check"); assert_eq!(json["predictions_url"], "/predictions"); assert_eq!( json["predictions_idempotent_url"], "/predictions/{prediction_id}" ); assert_eq!( json["predictions_cancel_url"], "/predictions/{prediction_id}/cancel" ); // No training URLs without a TrainingInput schema assert!(json.get("trainings_url").is_none()); assert!(json.get("trainings_idempotent_url").is_none()); assert!(json.get("trainings_cancel_url").is_none()); } #[tokio::test] async fn root_includes_training_urls_when_schema_has_training() { let service = Arc::new(PredictionService::new_no_pool()); // Set a schema that includes a TrainingInput component service .set_schema(serde_json::json!({ "openapi": "3.0.2", "info": {"title": "Cog", "version": "0.1.0"}, "components": { "schemas": { "TrainingInput": { "type": "object", "properties": { "data": {"type": "string"} } } } } })) .await; let app = routes(service); let response = app .oneshot(Request::get("/").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(response.status(), StatusCode::OK); let json = response_json(response).await; // Base fields still present assert_eq!(json["predictions_url"], "/predictions"); // Training URLs included assert_eq!(json["trainings_url"], "/trainings"); assert_eq!(json["trainings_idempotent_url"], "/trainings/{training_id}"); assert_eq!( json["trainings_cancel_url"], "/trainings/{training_id}/cancel" ); } #[tokio::test] async fn root_cog_version_prefers_python_sdk() { let version = VersionInfo::new().with_python_sdk("0.14.0".to_string()); let service = Arc::new(PredictionService::new_no_pool().with_version(version)); let app = routes(service); let response = app .oneshot(Request::get("/").body(Body::empty()).unwrap()) .await .unwrap(); let json = response_json(response).await; assert_eq!(json["cog_version"], "0.14.0"); } } ================================================ FILE: crates/coglet/src/transport/http/server.rs ================================================ //! HTTP server implementation. use std::net::SocketAddr; use std::sync::Arc; use tokio::net::TcpListener; use tokio::sync::watch; use tracing::info; use crate::service::PredictionService; use super::routes::routes; #[derive(Debug, Clone)] pub struct ServerConfig { pub host: String, pub port: u16, /// If true, ignore SIGTERM and wait for explicit /shutdown or SIGINT. /// Used in Kubernetes to allow graceful draining. pub await_explicit_shutdown: bool, } impl Default for ServerConfig { fn default() -> Self { Self { host: "0.0.0.0".to_string(), port: 5000, await_explicit_shutdown: false, } } } /// Start the HTTP server with provided service. pub async fn serve(config: ServerConfig, service: Arc) -> anyhow::Result<()> { let shutdown_rx = service.shutdown_rx(); let app = routes(service.clone()); let addr: SocketAddr = format!("{}:{}", config.host, config.port).parse()?; let listener = TcpListener::bind(addr).await?; let actual_addr = listener.local_addr()?; info!("Starting coglet server on {}", actual_addr); axum::serve(listener, app) .with_graceful_shutdown(shutdown_signal(config.await_explicit_shutdown, shutdown_rx)) .await?; info!("Server shutdown complete"); // Gracefully shutdown the orchestrator worker service.shutdown().await; Ok(()) } /// Wait for shutdown signal (SIGTERM, SIGINT, or /shutdown endpoint). /// /// # Panics /// /// Panics if signal handlers cannot be installed. This can only happen if: /// - Called from a non-main thread without the runtime being properly configured /// - The tokio runtime is not properly initialized /// /// These are unrecoverable configuration errors that should fail fast at startup. async fn shutdown_signal(await_explicit_shutdown: bool, mut shutdown_rx: watch::Receiver) { let ctrl_c = async { tokio::signal::ctrl_c() .await .expect("failed to install Ctrl+C handler - is tokio runtime configured correctly?"); }; #[cfg(unix)] let terminate = async { if await_explicit_shutdown { // Ignore SIGTERM - wait forever (until SIGINT or explicit shutdown) tracing::info!("await_explicit_shutdown enabled, ignoring SIGTERM"); std::future::pending::<()>().await } else { tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) .expect( "failed to install SIGTERM handler - is tokio runtime configured correctly?", ) .recv() .await; } }; #[cfg(not(unix))] let terminate = std::future::pending::<()>(); let explicit_shutdown = async { while !*shutdown_rx.borrow() { if shutdown_rx.changed().await.is_err() { std::future::pending::<()>().await; } } }; tokio::select! { _ = ctrl_c => { info!("Received SIGINT, shutting down..."); } _ = terminate => { info!("Received SIGTERM, shutting down..."); } _ = explicit_shutdown => { info!("Shutdown requested via /shutdown endpoint..."); } } } #[cfg(test)] mod tests { use super::*; #[test] fn server_config_default() { let config = ServerConfig::default(); assert_eq!(config.host, "0.0.0.0"); assert_eq!(config.port, 5000); assert!(!config.await_explicit_shutdown); } } ================================================ FILE: crates/coglet/src/transport/mod.rs ================================================ //! Transport layer for coglet. //! //! Currently provides HTTP transport via axum. Future transports //! (gRPC, bnet) will be added as separate submodules. pub mod http; pub use http::{ServerConfig, serve}; ================================================ FILE: crates/coglet/src/version.rs ================================================ //! Version information for coglet. /// Coglet version from Cargo.toml pub const COGLET_VERSION: &str = env!("CARGO_PKG_VERSION"); /// Version information for the runtime. #[derive(Debug, Clone, serde::Serialize)] pub struct VersionInfo { /// Coglet runtime version. pub coglet: &'static str, /// Git SHA (with optional `-dirty` suffix). #[serde(skip_serializing_if = "Option::is_none")] pub git_sha: Option, /// Build timestamp (UTC, ISO 8601). #[serde(skip_serializing_if = "Option::is_none")] pub build_time: Option, /// Python SDK version (if available). #[serde(skip_serializing_if = "Option::is_none")] pub python_sdk: Option, /// Python version. #[serde(skip_serializing_if = "Option::is_none")] pub python: Option, } impl Default for VersionInfo { fn default() -> Self { Self { coglet: COGLET_VERSION, git_sha: None, build_time: None, python_sdk: None, python: None, } } } impl VersionInfo { /// Create version info with coglet version only. pub fn new() -> Self { Self::default() } /// Set git SHA (with optional `-dirty` suffix). pub fn with_git_sha(mut self, sha: String) -> Self { self.git_sha = Some(sha); self } /// Set build timestamp. pub fn with_build_time(mut self, time: String) -> Self { self.build_time = Some(time); self } /// Set Python SDK version. pub fn with_python_sdk(mut self, version: String) -> Self { self.python_sdk = Some(version); self } /// Set Python version. pub fn with_python(mut self, version: String) -> Self { self.python = Some(version); self } } #[cfg(test)] mod tests { use super::*; #[test] fn version_info_has_coglet_version() { let info = VersionInfo::new(); assert_eq!(info.coglet, COGLET_VERSION); assert!(info.python_sdk.is_none()); assert!(info.python.is_none()); } #[test] fn version_info_builder_pattern() { let info = VersionInfo::new() .with_git_sha("abc1234".to_string()) .with_build_time("2026-03-12T18:00:00Z".to_string()) .with_python_sdk("0.9.0".to_string()) .with_python("3.11.0".to_string()); assert_eq!(info.git_sha, Some("abc1234".to_string())); assert_eq!(info.build_time, Some("2026-03-12T18:00:00Z".to_string())); assert_eq!(info.python_sdk, Some("0.9.0".to_string())); assert_eq!(info.python, Some("3.11.0".to_string())); } #[test] fn version_info_serializes_minimal() { // Only coglet when no optional fields set let info = VersionInfo { coglet: "0.1.0", git_sha: None, build_time: None, python_sdk: None, python: None, }; insta::assert_json_snapshot!("version_minimal", info); } #[test] fn version_info_serializes_full() { let info = VersionInfo { coglet: "0.1.0", git_sha: Some("abc1234-dirty".to_string()), build_time: Some("2026-03-12T18:00:00Z".to_string()), python_sdk: Some("0.9.0".to_string()), python: Some("3.11.0".to_string()), }; insta::assert_json_snapshot!("version_full", info); } } ================================================ FILE: crates/coglet/src/webhook.rs ================================================ //! Webhook sender for async predictions. //! //! Implements the cog webhook protocol: //! - Throttling (default 500ms between non-terminal updates) //! - Terminal webhooks retried with exponential backoff //! - WEBHOOK_AUTH_TOKEN bearer authentication //! - Events filtering (start, output, logs, completed) //! //! # Panic Safety //! //! This module avoids panics: //! - `WebhookSender::new()` returns `Result` for HTTP client creation //! - Mutex locks use `lock().unwrap_or_else(|e| e.into_inner())` to recover from //! poison - worst case is we lose throttle tracking, which is acceptable. use std::collections::HashSet; use std::sync::Mutex; use std::time::{Duration, Instant}; use thiserror::Error; use serde::{Deserialize, Serialize}; use crate::version::COGLET_VERSION; #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Deserialize, Serialize)] #[serde(rename_all = "lowercase")] pub enum WebhookEventType { Start, Output, Logs, #[default] Completed, } impl WebhookEventType { pub fn is_terminal(&self) -> bool { matches!(self, Self::Completed) } pub fn all() -> HashSet { [Self::Start, Self::Output, Self::Logs, Self::Completed] .into_iter() .collect() } } /// Error creating a WebhookSender. #[derive(Debug, Error)] pub enum WebhookSenderError { #[error("failed to create HTTP client: {0}")] HttpClient(#[from] reqwest::Error), } #[derive(Debug, Clone)] pub struct WebhookConfig { pub response_interval: Duration, pub events_filter: HashSet, pub max_retries: u32, pub backoff_base: Duration, pub retry_status_codes: Vec, } impl Default for WebhookConfig { fn default() -> Self { Self { response_interval: Duration::from_millis( std::env::var("COG_THROTTLE_RESPONSE_INTERVAL") .ok() .and_then(|s| s.parse::().ok()) .map(|s| (s * 1000.0) as u64) .unwrap_or(500), ), events_filter: WebhookEventType::all(), max_retries: 12, backoff_base: Duration::from_millis(100), retry_status_codes: vec![429, 500, 502, 503, 504], } } } /// W3C Trace Context for distributed tracing. #[derive(Debug, Clone, Default)] pub struct TraceContext { pub traceparent: Option, pub tracestate: Option, } pub struct WebhookSender { url: String, config: WebhookConfig, client: reqwest::Client, last_sent: Mutex, trace_context: TraceContext, } impl WebhookSender { pub fn new(url: String, config: WebhookConfig) -> Result { Self::with_trace_context(url, config, TraceContext::default()) } pub fn with_trace_context( url: String, config: WebhookConfig, trace_context: TraceContext, ) -> Result { let mut headers = reqwest::header::HeaderMap::new(); if let Ok(token) = std::env::var("WEBHOOK_AUTH_TOKEN") && let Ok(value) = reqwest::header::HeaderValue::from_str(&format!("Bearer {}", token)) { headers.insert(reqwest::header::AUTHORIZATION, value); } let user_agent = format!("coglet/{}", COGLET_VERSION); if let Ok(value) = reqwest::header::HeaderValue::from_str(&user_agent) { headers.insert(reqwest::header::USER_AGENT, value); } let client = reqwest::Client::builder() .default_headers(headers) .timeout(Duration::from_secs(30)) .build()?; Ok(Self { url, config, client, // Allow immediate first send last_sent: Mutex::new(Instant::now() - Duration::from_secs(10)), trace_context, }) } pub fn url(&self) -> &str { &self.url } fn should_send(&self, event: WebhookEventType) -> bool { if !self.config.events_filter.contains(&event) { return false; } if event.is_terminal() { return true; } // Output events are never throttled: they are high-value (contain actual // prediction results), relatively infrequent (one per output chunk/file), // and in the old Python runtime were effectively unthrottled because file // uploads were synchronous. Throttling them causes the director to miss // intermediate output data. if matches!(event, WebhookEventType::Output) { return true; } // Recover from poison - losing throttle state is acceptable let last = self.last_sent.lock().unwrap_or_else(|e| e.into_inner()); last.elapsed() >= self.config.response_interval } fn update_last_sent(&self) { // Recover from poison - losing throttle state is acceptable let mut last = self.last_sent.lock().unwrap_or_else(|e| e.into_inner()); *last = Instant::now(); } fn build_request(&self, payload: &serde_json::Value) -> reqwest::RequestBuilder { let mut request = self.client.post(&self.url).json(payload); if let Some(ref traceparent) = self.trace_context.traceparent { request = request.header("traceparent", traceparent); } if let Some(ref tracestate) = self.trace_context.tracestate { request = request.header("tracestate", tracestate); } request } /// Send a non-terminal webhook (fire and forget, no retry). pub fn send(&self, event: WebhookEventType, payload: &serde_json::Value) { if !self.should_send(event) { return; } let request = self.build_request(payload); self.update_last_sent(); tokio::spawn(async move { if let Err(e) = request.send().await { tracing::warn!(error = %e, "Failed to send webhook (non-terminal)"); } }); } /// Send a terminal webhook with exponential backoff retries. pub async fn send_terminal(&self, event: WebhookEventType, payload: &serde_json::Value) { if !self.config.events_filter.contains(&event) { return; } let mut attempt = 0; loop { match self.build_request(payload).send().await { Ok(response) => { let status = response.status().as_u16(); if response.status().is_success() { tracing::debug!(status = %status, "Terminal webhook sent successfully"); return; } if self.config.retry_status_codes.contains(&status) { attempt += 1; if attempt > self.config.max_retries { tracing::error!( status = %status, attempts = attempt, "Terminal webhook failed after max retries" ); return; } let backoff = self.config.backoff_base * (1 << attempt.min(10)); tracing::warn!( status = %status, attempt = attempt, backoff_ms = backoff.as_millis(), "Terminal webhook failed, retrying" ); tokio::time::sleep(backoff).await; continue; } tracing::error!( status = %status, "Terminal webhook failed with non-retryable status" ); return; } Err(e) => { attempt += 1; if attempt > self.config.max_retries { tracing::error!( error = %e, attempts = attempt, "Terminal webhook failed after max retries" ); return; } let backoff = self.config.backoff_base * (1 << attempt.min(10)); tracing::warn!( error = %e, attempt = attempt, backoff_ms = backoff.as_millis(), "Terminal webhook request error, retrying" ); tokio::time::sleep(backoff).await; } } } } /// Send a terminal webhook synchronously (for Drop contexts). /// /// Uses ureq (blocking HTTP) instead of reqwest for non-async contexts. pub fn send_terminal_sync(&self, payload: &serde_json::Value) { if !self .config .events_filter .contains(&WebhookEventType::Completed) { return; } let agent = ureq::Agent::config_builder() .timeout_global(Some(Duration::from_secs(30))) .tls_config( ureq::tls::TlsConfig::builder() .root_certs(ureq::tls::RootCerts::PlatformVerifier) .build(), ) .build() .new_agent(); let auth_header = std::env::var("WEBHOOK_AUTH_TOKEN") .ok() .map(|token| format!("Bearer {}", token)); let user_agent = format!("coglet/{}", COGLET_VERSION); let mut attempt = 0; loop { let mut request = agent .post(&self.url) .header("Content-Type", "application/json") .header("User-Agent", &user_agent); if let Some(ref auth) = auth_header { request = request.header("Authorization", auth); } if let Some(ref traceparent) = self.trace_context.traceparent { request = request.header("traceparent", traceparent); } if let Some(ref tracestate) = self.trace_context.tracestate { request = request.header("tracestate", tracestate); } let result = request.send_json(payload); match result { Ok(response) => { let status = response.status().as_u16(); if (200..300).contains(&status) { tracing::debug!(status = %status, "Terminal webhook (sync) sent successfully"); return; } if self.config.retry_status_codes.contains(&status) { attempt += 1; if attempt > self.config.max_retries { tracing::error!( status = %status, attempts = attempt, "Terminal webhook (sync) failed after max retries" ); return; } let backoff = self.config.backoff_base * (1 << attempt.min(10)); tracing::warn!( status = %status, attempt = attempt, backoff_ms = backoff.as_millis(), "Terminal webhook (sync) failed, retrying" ); std::thread::sleep(backoff); continue; } tracing::error!( status = %status, "Terminal webhook (sync) failed with non-retryable status" ); return; } Err(e) => { attempt += 1; if attempt > self.config.max_retries { tracing::error!( error = %e, attempts = attempt, "Terminal webhook (sync) failed after max retries" ); return; } let backoff = self.config.backoff_base * (1 << attempt.min(10)); tracing::warn!( error = %e, attempt = attempt, backoff_ms = backoff.as_millis(), "Terminal webhook (sync) request error, retrying" ); std::thread::sleep(backoff); } } } } } #[cfg(test)] mod tests { use super::*; use wiremock::matchers::{method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; #[test] fn config_defaults() { let config = WebhookConfig::default(); assert_eq!(config.response_interval, Duration::from_millis(500)); assert_eq!(config.max_retries, 12); assert!(config.events_filter.contains(&WebhookEventType::Start)); assert!(config.events_filter.contains(&WebhookEventType::Completed)); } #[test] fn event_is_terminal() { assert!(!WebhookEventType::Start.is_terminal()); assert!(!WebhookEventType::Output.is_terminal()); assert!(!WebhookEventType::Logs.is_terminal()); assert!(WebhookEventType::Completed.is_terminal()); } fn test_config() -> WebhookConfig { WebhookConfig { response_interval: Duration::ZERO, max_retries: 2, backoff_base: Duration::from_millis(1), ..Default::default() } } #[tokio::test] async fn send_terminal_posts_json() { let server = MockServer::start().await; Mock::given(method("POST")) .and(path("/webhook")) .respond_with(ResponseTemplate::new(200)) .expect(1) .mount(&server) .await; let url = format!("{}/webhook", server.uri()); let sender = WebhookSender::new(url, test_config()).unwrap(); sender .send_terminal( WebhookEventType::Completed, &serde_json::json!({"id": "pred_123", "status": "succeeded"}), ) .await; } #[tokio::test] async fn send_terminal_retries_on_500() { let server = MockServer::start().await; Mock::given(method("POST")) .and(path("/webhook")) .respond_with(ResponseTemplate::new(500)) .up_to_n_times(1) .mount(&server) .await; Mock::given(method("POST")) .and(path("/webhook")) .respond_with(ResponseTemplate::new(200)) .expect(1) .mount(&server) .await; let url = format!("{}/webhook", server.uri()); let sender = WebhookSender::new(url, test_config()).unwrap(); sender .send_terminal( WebhookEventType::Completed, &serde_json::json!({"status": "succeeded"}), ) .await; } #[tokio::test] async fn send_terminal_no_retry_on_400() { let server = MockServer::start().await; Mock::given(method("POST")) .and(path("/webhook")) .respond_with(ResponseTemplate::new(400)) .expect(1) .mount(&server) .await; let url = format!("{}/webhook", server.uri()); let sender = WebhookSender::new(url, test_config()).unwrap(); sender .send_terminal( WebhookEventType::Completed, &serde_json::json!({"status": "succeeded"}), ) .await; } #[tokio::test] async fn send_terminal_respects_filter() { let server = MockServer::start().await; Mock::given(method("POST")) .and(path("/webhook")) .respond_with(ResponseTemplate::new(200)) .expect(0) .mount(&server) .await; let url = format!("{}/webhook", server.uri()); let config = WebhookConfig { events_filter: [WebhookEventType::Start].into_iter().collect(), ..test_config() }; let sender = WebhookSender::new(url, config).unwrap(); sender .send_terminal( WebhookEventType::Completed, &serde_json::json!({"status": "succeeded"}), ) .await; } #[tokio::test] async fn send_non_terminal_fires_and_forgets() { let server = MockServer::start().await; Mock::given(method("POST")) .and(path("/webhook")) .respond_with(ResponseTemplate::new(200)) .expect(1) .mount(&server) .await; let url = format!("{}/webhook", server.uri()); let sender = WebhookSender::new(url, test_config()).unwrap(); sender.send( WebhookEventType::Start, &serde_json::json!({"status": "starting"}), ); tokio::time::sleep(Duration::from_millis(50)).await; } #[tokio::test] async fn send_non_terminal_logs_throttled() { let server = MockServer::start().await; Mock::given(method("POST")) .and(path("/webhook")) .respond_with(ResponseTemplate::new(200)) .expect(1) .mount(&server) .await; let url = format!("{}/webhook", server.uri()); let config = WebhookConfig { response_interval: Duration::from_secs(10), ..test_config() }; let sender = WebhookSender::new(url, config).unwrap(); sender.send( WebhookEventType::Logs, &serde_json::json!({"logs": "line 1"}), ); // Second send should be throttled sender.send( WebhookEventType::Logs, &serde_json::json!({"logs": "line 2"}), ); tokio::time::sleep(Duration::from_millis(50)).await; } #[tokio::test] async fn send_output_not_throttled() { let server = MockServer::start().await; Mock::given(method("POST")) .and(path("/webhook")) .respond_with(ResponseTemplate::new(200)) .expect(2) .mount(&server) .await; let url = format!("{}/webhook", server.uri()); let config = WebhookConfig { response_interval: Duration::from_secs(10), ..test_config() }; let sender = WebhookSender::new(url, config).unwrap(); // Output events bypass throttling — both should be sent sender.send( WebhookEventType::Output, &serde_json::json!({"output": "1"}), ); sender.send( WebhookEventType::Output, &serde_json::json!({"output": "2"}), ); tokio::time::sleep(Duration::from_millis(50)).await; } #[tokio::test] async fn send_terminal_sync_posts_json() { let server = MockServer::start().await; Mock::given(method("POST")) .and(path("/webhook")) .respond_with(ResponseTemplate::new(200)) .expect(1) .mount(&server) .await; let url = format!("{}/webhook", server.uri()); let sender = WebhookSender::new(url, test_config()).unwrap(); sender.send_terminal_sync(&serde_json::json!({"id": "pred_123", "status": "succeeded"})); } #[tokio::test] async fn send_terminal_sync_retries_on_500() { let server = MockServer::start().await; Mock::given(method("POST")) .and(path("/webhook")) .respond_with(ResponseTemplate::new(500)) .up_to_n_times(1) .mount(&server) .await; Mock::given(method("POST")) .and(path("/webhook")) .respond_with(ResponseTemplate::new(200)) .expect(1) .mount(&server) .await; let url = format!("{}/webhook", server.uri()); let sender = WebhookSender::new(url, test_config()).unwrap(); sender.send_terminal_sync(&serde_json::json!({"status": "succeeded"})); } } ================================================ FILE: crates/coglet/src/worker.rs ================================================ //! Worker subprocess - runs inside the Python subprocess. //! //! This module provides the child-side of the worker subprocess protocol. //! The parent side (spawning, message routing) is in orchestrator.rs. //! //! Architecture: //! - Control channel (stdin/stdout): Cancel, Shutdown signals + Ready, Idle responses //! - Slot sockets: Prediction data + streaming logs //! //! Each slot runs predictions independently. use std::collections::HashMap; use std::io; use std::path::PathBuf; use std::sync::Arc; use std::sync::OnceLock; use std::sync::atomic::{AtomicUsize, Ordering}; use futures::{SinkExt, StreamExt}; use tokio::runtime::Handle; use tokio::sync::mpsc; use tokio_util::codec::{FramedRead, FramedWrite}; use crate::bridge::protocol::truncate_worker_log; // ============================================================================ // Dropped log tracking // ============================================================================ /// Counter for logs dropped due to channel backpressure during setup. static DROPPED_SETUP_LOG_COUNT: AtomicUsize = AtomicUsize::new(0); /// Increment the dropped log counter. /// Called by ControlChannelLogSender in coglet-python when try_send fails. pub fn increment_dropped_log_count() { DROPPED_SETUP_LOG_COUNT.fetch_add(1, Ordering::Relaxed); } /// Report and reset dropped log count. /// Returns the number of logs dropped since last call. fn report_dropped_logs(tx: &mpsc::Sender, interval_millis: u64) { let dropped = DROPPED_SETUP_LOG_COUNT.swap(0, Ordering::Relaxed); if dropped > 0 { let _ = tx.try_send(ControlResponse::DroppedLogs { count: dropped, interval_millis, }); } } // ============================================================================ // Fatal worker shutdown // ============================================================================ struct FatalContext { tx: mpsc::Sender, } static FATAL_CONTEXT: OnceLock = OnceLock::new(); fn init_fatal_context(tx: mpsc::Sender) { let _ = FATAL_CONTEXT.set(FatalContext { tx }); } /// Install a panic hook that sends a Fatal IPC message and aborts. /// /// Any panic in the worker is an invariant violation. The hook sends a best-effort /// `ControlResponse::Fatal` so the parent can poison all slots, then aborts. /// This means `.expect()` / `panic!()` at any call site automatically gets /// the correct fatal behavior — no special helpers needed. fn install_panic_hook() { let prev = std::panic::take_hook(); std::panic::set_hook(Box::new(move |info| { // Run the default hook first (prints to stderr). prev(info); let msg = if let Some(s) = info.payload().downcast_ref::<&str>() { (*s).to_string() } else if let Some(s) = info.payload().downcast_ref::() { s.clone() } else { "".to_string() }; let reason = match info.location() { Some(loc) => format!("panic at {}:{}: {}", loc.file(), loc.line(), msg), None => format!("panic: {}", msg), }; if let Some(ctx) = FATAL_CONTEXT.get() { let _ = ctx.tx.try_send(ControlResponse::Fatal { reason }); } // If panic=abort is not set, abort explicitly. std::process::abort(); })); } // ============================================================================ // Tracing initialization // ============================================================================ fn init_worker_tracing(tx: mpsc::Sender) { use tracing_subscriber::{EnvFilter, layer::SubscriberExt, util::SubscriberInitExt}; let filter = if std::env::var("RUST_LOG").is_ok() { EnvFilter::from_default_env() } else { let base_level = match std::env::var("COG_LOG_LEVEL").as_deref() { Ok("debug") => "debug", Ok("warn") | Ok("warning") => "warn", Ok("error") => "error", _ => "info", }; let filter_str = format!( "coglet={level},coglet::setup=info,coglet::user=info,coglet_worker={level},coglet_worker::schema=off,coglet_worker::protocol=off", level = base_level ); EnvFilter::new(filter_str) }; let worker_layer = WorkerTracingLayer::new(tx); let subscriber = tracing_subscriber::registry() .with(filter) .with(worker_layer); let _ = subscriber.try_init(); } use crate::bridge::codec::JsonCodec; use crate::bridge::protocol::{ ControlRequest, ControlResponse, FileOutputKind, LogSource, MAX_INLINE_IPC_SIZE, MetricMode, SlotId, SlotOutcome, SlotRequest, SlotResponse, }; use crate::bridge::transport::{ChildTransportInfo, connect_transport}; use crate::orchestrator::HealthcheckResult; use crate::worker_tracing_layer::WorkerTracingLayer; type SlotWriter = Arc>>>; /// Handle for sending messages on a slot socket. /// /// Used by log writers to stream logs during prediction. Thread-safe via /// tokio mpsc channel - logs are queued and written asynchronously. #[derive(Clone)] pub struct SlotSender { tx: mpsc::UnboundedSender, output_dir: PathBuf, file_counter: Arc, } impl SlotSender { pub fn new(tx: mpsc::UnboundedSender, output_dir: PathBuf) -> Self { Self { tx, output_dir, file_counter: Arc::new(AtomicUsize::new(0)), } } /// Generate a unique filename in the output dir. fn next_output_path(&self, extension: &str) -> PathBuf { let n = self.file_counter.fetch_add(1, Ordering::Relaxed); self.output_dir.join(format!("{n}.{extension}")) } pub fn send_log(&self, source: LogSource, data: &str) -> io::Result<()> { if data.is_empty() { return Ok(()); } let msg = SlotResponse::Log { source, data: truncate_worker_log(data.to_string()), }; self.tx .send(msg) .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "slot channel closed")) } /// Write raw bytes to a file in the output dir and send as FileOutput. /// /// Used by FFI workers (Python, Node, etc.) to hand off file data without /// needing language-specific file I/O — SlotSender owns the write. pub fn write_file_output( &self, data: &[u8], extension: &str, mime_type: Option, ) -> io::Result<()> { let path = self.next_output_path(extension); std::fs::write(&path, data)?; self.send_file_output(path, mime_type) } /// Send a file-typed output (e.g. Path, File return types). /// /// The file is already on disk at `path` — we just send the path reference. /// `mime_type` is an explicit MIME type; when None the parent guesses from extension. pub fn send_file_output(&self, path: PathBuf, mime_type: Option) -> io::Result<()> { let filename = path .to_str() .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "non-UTF-8 path"))? .to_string(); let msg = SlotResponse::FileOutput { filename, kind: FileOutputKind::FileType, mime_type, }; self.tx .send(msg) .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "slot channel closed")) } /// Send a user metric to the parent process. pub fn send_metric( &self, name: String, value: serde_json::Value, mode: MetricMode, ) -> io::Result<()> { let msg = SlotResponse::Metric { name, value, mode }; self.tx .send(msg) .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "slot channel closed")) } /// Send prediction output, either inline or spilled to disk if too large. pub fn send_output(&self, output: serde_json::Value) -> io::Result<()> { let msg = build_output_message(&self.output_dir, output)?; self.tx .send(msg) .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "slot channel closed")) } } /// Build an output message, spilling to disk if larger than the IPC frame limit. fn build_output_message( output_dir: &std::path::Path, output: serde_json::Value, ) -> io::Result { let serialized = serde_json::to_vec(&output).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; if serialized.len() > MAX_INLINE_IPC_SIZE { let path = output_dir.join(format!("spill_{}.json", uuid::Uuid::new_v4())); std::fs::write(&path, &serialized)?; let filename = path .to_str() .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "non-UTF-8 path"))? .to_string(); Ok(SlotResponse::FileOutput { filename, kind: FileOutputKind::Oversized, mime_type: None, }) } else { Ok(SlotResponse::Output { output }) } } /// Setup phase errors. /// /// These errors occur during predictor loading and setup, before predictions /// can run. They affect health status (SETUP_FAILED) rather than HTTP status. #[derive(Debug, thiserror::Error)] pub enum SetupError { /// Failed to import or instantiate the predictor class. #[error("failed to load predictor: {message}")] Load { message: String }, /// The setup() method raised an exception. #[error("setup failed: {message}")] Setup { message: String }, /// Internal error (e.g., GIL acquisition failed). #[error("internal error: {message}")] Internal { message: String }, } impl SetupError { pub fn load(message: impl Into) -> Self { Self::Load { message: message.into(), } } pub fn setup(message: impl Into) -> Self { Self::Setup { message: message.into(), } } pub fn internal(message: impl Into) -> Self { Self::Internal { message: message.into(), } } } /// Trait for the prediction handler - abstracts the Python integration. #[async_trait::async_trait] pub trait PredictHandler: Send + Sync + 'static { /// Initialize the predictor (load model, run setup). async fn setup(&self) -> Result<(), SetupError>; /// Run a prediction. async fn predict( &self, slot: SlotId, id: String, input: serde_json::Value, slot_sender: Arc, context: std::collections::HashMap, ) -> PredictResult; /// Request cancellation of prediction on a slot. fn cancel(&self, slot: SlotId); /// Run user-defined healthcheck. Default: healthy. async fn healthcheck(&self) -> HealthcheckResult { HealthcheckResult::healthy() } } /// Path to the pre-built OpenAPI schema file inside the container. /// Written during `cog build` and COPYed into the image. const BUNDLED_SCHEMA_PATH: &str = ".cog/openapi_schema.json"; /// Load the bundled OpenAPI schema from disk. /// /// Returns `Some(schema)` if the file exists and parses correctly. /// Returns `None` if missing or unparseable — the predictor will accept /// any input without schema validation. fn load_bundled_schema() -> Option { let path = std::path::Path::new(BUNDLED_SCHEMA_PATH); match std::fs::read_to_string(path) { Ok(contents) => match serde_json::from_str(&contents) { Ok(schema) => { tracing::info!("Loaded OpenAPI schema from {}", BUNDLED_SCHEMA_PATH); Some(schema) } Err(e) => { tracing::warn!( "Failed to parse {}: {}. Running without schema — all input types accepted.", BUNDLED_SCHEMA_PATH, e, ); None } }, Err(_) => { tracing::warn!( "No schema file at {}. Running without schema — all input types accepted. \ Rebuild with a recent version of cog to generate the schema.", BUNDLED_SCHEMA_PATH, ); None } } } /// The outcome of a prediction #[derive(Debug, Clone, PartialEq)] pub enum PredictionOutcome { /// Prediction completed successfully Success { output: serde_json::Value, predict_time: f64, /// True when the predictor returned a list, generator, or iterator. is_stream: bool, }, /// Prediction failed with an error Failed { error: String, predict_time: f64 }, /// Prediction was cancelled Cancelled { predict_time: f64 }, } #[derive(Debug)] pub struct PredictResult { pub outcome: PredictionOutcome, } impl PredictResult { pub fn success(output: serde_json::Value, predict_time: f64, is_stream: bool) -> Self { Self { outcome: PredictionOutcome::Success { output, predict_time, is_stream, }, } } pub fn failed(error: String, predict_time: f64) -> Self { Self { outcome: PredictionOutcome::Failed { error, predict_time, }, } } pub fn cancelled(predict_time: f64) -> Self { Self { outcome: PredictionOutcome::Cancelled { predict_time }, } } } /// Callback for setup log routing. /// /// Called before setup() with a sender for routing logs to the control channel. /// Returns a cleanup function that unregisters the sender. pub type SetupLogHook = Box) -> Box + Send>; pub struct WorkerConfig { pub num_slots: usize, /// Hook for setup log routing. Called before setup() to register a log sender. pub setup_log_hook: Option, } impl Default for WorkerConfig { fn default() -> Self { Self { num_slots: 1, setup_log_hook: None, } } } struct SlotCompletion { outcome: SlotOutcome, } impl SlotCompletion { fn idle(slot: SlotId) -> Self { Self { outcome: SlotOutcome::idle(slot), } } fn poisoned(slot: SlotId, error: impl Into) -> Self { Self { outcome: SlotOutcome::poisoned(slot, error), } } } /// Run the worker event loop. /// /// Connects to slot sockets, runs setup, then processes predictions. /// Reads control messages from stdin, prediction requests from slot sockets. pub async fn run_worker( handler: Arc, config: WorkerConfig, transport_info: ChildTransportInfo, ) -> io::Result<()> { let num_slots = config.num_slots; let (setup_log_tx, mut setup_log_rx) = mpsc::channel::(5000); init_worker_tracing(setup_log_tx.clone()); // CRITICAL: Redirect fds BEFORE any FFI initialization to prevent subprocesses // from polluting the control channel let control_fds = crate::fd_redirect::redirect_fds_for_subprocess_isolation(setup_log_tx.clone())?; // Connect to slot sockets (transport info from Init message) tracing::trace!(?transport_info, "Connecting to slot transport"); let mut transport = connect_transport(transport_info).await?; tracing::info!(num_slots, "Connected to slot transport"); // Control channel via redirected fds (not stdin/stdout) let ctrl_stdin = tokio::fs::File::from_std(control_fds.stdin_fd.into()); let ctrl_stdout = tokio::fs::File::from_std(control_fds.stdout_fd.into()); let mut ctrl_reader = FramedRead::new(ctrl_stdin, JsonCodec::::new()); let ctrl_writer = Arc::new(tokio::sync::Mutex::new(FramedWrite::new( ctrl_stdout, JsonCodec::::new(), ))); // Generate unique SlotIds for each socket let slot_ids: Vec = (0..num_slots).map(|_| SlotId::new()).collect(); init_fatal_context(setup_log_tx.clone()); install_panic_hook(); let setup_cleanup = config.setup_log_hook.map(|hook| hook(setup_log_tx.clone())); // Forward logs to control channel (runs for entire worker lifetime) // Receives logs from both Python (during setup) and fd_redirect capture threads (always) let ctrl_writer_for_logs = Arc::clone(&ctrl_writer); let _log_forwarder = tokio::spawn(async move { let mut log_count = 0; let mut total_bytes = 0; while let Some(msg) = setup_log_rx.recv().await { if let ControlResponse::Log { ref data, .. } = msg { let msg_size = data.len(); log_count += 1; total_bytes += msg_size; tracing::trace!( log_number = log_count, msg_size_bytes = msg_size, total_bytes, "Forwarding log" ); } let mut w = ctrl_writer_for_logs.lock().await; if let Err(e) = w.send(msg).await { tracing::warn!( error = %e, log_count, total_bytes, "Failed to forward log" ); break; } } tracing::debug!( total_logs = log_count, total_bytes, total_kb = total_bytes / 1024, "Log forwarder exiting" ); }); // Periodic reporter for dropped logs (runs for entire worker lifetime) let dropped_log_tx = setup_log_tx.clone(); let _dropped_log_reporter = tokio::spawn(async move { let mut interval = tokio::time::interval(std::time::Duration::from_millis(5000)); loop { interval.tick().await; report_dropped_logs(&dropped_log_tx, 5000); } }); // Run setup tracing::info!("Worker starting setup"); let setup_start = std::time::Instant::now(); let setup_result = handler.setup().await; let setup_elapsed = setup_start.elapsed(); tracing::debug!( elapsed_ms = setup_elapsed.as_millis() as u64, success = setup_result.is_ok(), "Setup handler returned" ); // Unregister Python's setup sender, but keep log_forwarder running // The fd_redirect capture threads will continue sending subprocess logs if let Some(cleanup) = setup_cleanup { tracing::debug!("Running cleanup (unregistering Python setup sender)"); cleanup(); } // Note: We DON'T drop setup_log_tx or wait for log_forwarder // The log_forwarder continues running to forward subprocess output throughout worker lifetime // Handle setup failure if let Err(e) = setup_result { tracing::error!( error = %e, elapsed_ms = setup_elapsed.as_millis() as u64, "Setup failed" ); let slot = slot_ids.first().copied().unwrap_or_else(SlotId::new); let mut w = ctrl_writer.lock().await; let _ = w .send(ControlResponse::Failed { slot, error: format!("Setup failed: {}", e), }) .await; return Ok(()); } // Load the pre-built schema from .cog/openapi_schema.json (written during `cog build`). // No runtime generation — if the file doesn't exist, no schema. let schema = load_bundled_schema(); if let Some(ref s) = schema { let schema_json = serde_json::to_string(s).unwrap_or_else(|_| "{}".to_string()); let schema_size = schema_json.len(); tracing::info!( schema_size_bytes = schema_size, schema_size_kb = schema_size / 1024, "Schema loaded" ); if schema_size > 1024 * 1024 { // Log first 500 chars if schema is >1MB tracing::warn!( schema_preview = &schema_json[..500.min(schema_json.len())], "Large schema detected" ); } } tracing::trace!(num_slots, ?slot_ids, "Sending Ready to parent"); { let mut w = ctrl_writer.lock().await; w.send(ControlResponse::Ready { slots: slot_ids.clone(), schema, }) .await?; } // Channel for slot completions let (completion_tx, mut completion_rx) = mpsc::channel::(num_slots); // Track slot state let mut slot_busy: HashMap = slot_ids.iter().map(|id| (*id, false)).collect(); let mut slot_poisoned: HashMap = slot_ids.iter().map(|id| (*id, false)).collect(); // Set up slot socket readers/writers let sockets = transport.drain_sockets(); let mut slot_readers: HashMap< SlotId, FramedRead>, > = HashMap::new(); let mut slot_writers: HashMap< SlotId, FramedWrite>, > = HashMap::new(); for (slot_id, socket) in slot_ids.iter().zip(sockets) { let (read_half, write_half) = socket.into_split(); slot_readers.insert(*slot_id, FramedRead::new(read_half, JsonCodec::new())); slot_writers.insert(*slot_id, FramedWrite::new(write_half, JsonCodec::new())); } // Channel for incoming slot requests let (request_tx, mut request_rx) = mpsc::channel::<(SlotId, SlotRequest)>(num_slots); // Spawn reader task for each slot for (slot_id, reader) in slot_readers { let tx = request_tx.clone(); tokio::spawn(async move { slot_reader_task(slot_id, reader, tx).await; }); } drop(request_tx); // Wrap writers for sharing with prediction tasks let slot_writers: HashMap = slot_writers .into_iter() .map(|(id, w)| (id, Arc::new(tokio::sync::Mutex::new(w)))) .collect(); // Main event loop loop { tokio::select! { biased; ctrl_msg = ctrl_reader.next() => { match ctrl_msg { Some(Ok(ControlRequest::Init { .. })) => { tracing::warn!("Received Init in event loop (should be at startup)"); } Some(Ok(ControlRequest::Cancel { slot })) => { tracing::trace!(%slot, "Cancel requested"); handler.cancel(slot); } Some(Ok(ControlRequest::Shutdown)) => { tracing::info!("Shutdown requested"); let mut w = ctrl_writer.lock().await; let _ = w.send(ControlResponse::ShuttingDown).await; break; } Some(Ok(ControlRequest::Healthcheck { id })) => { tracing::trace!(%id, "Healthcheck requested, invoking handler"); let result = handler.healthcheck().await; tracing::trace!( %id, status = ?result.status, error = ?result.error, "Healthcheck handler returned" ); let mut w = ctrl_writer.lock().await; let _ = w.send(ControlResponse::HealthcheckResult { id, status: result.status, error: result.error, }).await; } Some(Err(e)) => { tracing::error!(error = %e, "Control channel error"); break; } None => { tracing::error!("Control channel closed (parent died?), exiting"); break; } } } Some(completion) = completion_rx.recv() => { let slot = completion.outcome.slot_id(); slot_busy.insert(slot, false); if completion.outcome.is_poisoned() { slot_poisoned.insert(slot, true); } { let mut w = ctrl_writer.lock().await; let _ = w.send(completion.outcome.into_control_response()).await; } if slot_poisoned.values().all(|&p| p) { tracing::error!("All slots poisoned, exiting"); break; } } Some((slot_id, request)) = request_rx.recv() => { if slot_busy.get(&slot_id).copied().unwrap_or(false) { tracing::warn!(%slot_id, "Request received for busy slot, ignoring"); continue; } if slot_poisoned.get(&slot_id).copied().unwrap_or(false) { tracing::warn!(%slot_id, "Request received for poisoned slot, ignoring"); continue; } // Extract the prediction ID before consuming the request, so we // can report a failure even if rehydration itself fails. let prediction_id = request.prediction_id().to_string(); match request.rehydrate_input() { Ok((id, input, output_dir, context)) => { tracing::trace!(%slot_id, %id, "Prediction request received"); slot_busy.insert(slot_id, true); let writer = match slot_writers.get(&slot_id) { Some(w) => Arc::clone(w), None => { tracing::error!(%slot_id, "No writer for slot"); continue; } }; let handler = Arc::clone(&handler); let completion_tx = completion_tx.clone(); tokio::spawn(async move { let completion = run_prediction( slot_id, id, input, PathBuf::from(output_dir), handler, writer, context, ).await; let _ = completion_tx.send(completion).await; }); } Err(e) => { tracing::error!(%slot_id, %prediction_id, error = %e, "Failed to rehydrate input"); // Send a failure response so the prediction doesn't hang forever. if let Some(writer) = slot_writers.get(&slot_id) { let mut w = writer.lock().await; let fail_msg = SlotResponse::Failed { id: prediction_id, error: format!("Failed to rehydrate input: {}", e), }; if let Err(send_err) = w.send(fail_msg).await { tracing::error!(%slot_id, error = %send_err, "Failed to send rehydrate error response"); } } let _ = completion_tx.send(SlotCompletion::idle(slot_id)).await; } } } } } tracing::info!("Worker exiting"); Ok(()) } async fn slot_reader_task( slot_id: SlotId, mut reader: FramedRead>, tx: mpsc::Sender<(SlotId, SlotRequest)>, ) { loop { match reader.next().await { Some(Ok(request)) => { if tx.send((slot_id, request)).await.is_err() { break; } } Some(Err(e)) => { tracing::error!(%slot_id, error = %e, "Slot reader error"); break; } None => { tracing::trace!(%slot_id, "Slot socket closed"); break; } } } } async fn run_prediction( slot_id: SlotId, prediction_id: String, input: serde_json::Value, output_dir: PathBuf, handler: Arc, writer: SlotWriter, context: std::collections::HashMap, ) -> SlotCompletion { tracing::trace!(%slot_id, %prediction_id, "run_prediction starting"); // Create channel for log streaming let (log_tx, mut log_rx) = mpsc::unbounded_channel::(); let slot_sender = Arc::new(SlotSender::new(log_tx, output_dir.clone())); // Forward logs to slot socket let writer_for_logs = Arc::clone(&writer); let log_forwarder = tokio::spawn(async move { while let Some(msg) = log_rx.recv().await { let mut w = writer_for_logs.lock().await; if let Err(e) = w.send(msg).await { tracing::warn!(error = %e, "Failed to forward log"); break; } } tracing::trace!("Prediction log forwarder exiting"); }); // Run prediction — slot_sender is moved in, dropped when predict returns, // which closes the log channel and lets the log forwarder exit. // // block_in_place tells tokio this thread will block (Python GIL acquisition), // allowing the runtime to move other tasks (like log_forwarder) to free // threads. Without this, the log forwarder can be work-stolen onto the // same thread as the prediction and starved until predict returns, causing // all logs to arrive in a single batch at prediction end. let result = tokio::task::block_in_place(|| { Handle::current().block_on(handler.predict( slot_id, prediction_id.clone(), input, slot_sender, context, )) }); tracing::trace!(%slot_id, %prediction_id, "handler.predict returned"); // Wait for log forwarder tracing::trace!(%slot_id, %prediction_id, "Waiting for log forwarder"); let _ = log_forwarder.await; tracing::trace!(%slot_id, %prediction_id, "Log forwarder done"); // Send result on slot socket. // Output is always sent separately from Done so that large values get // spilled to disk and never exceed the IPC frame limit. let mut w = writer.lock().await; let response = match result.outcome { PredictionOutcome::Success { output, predict_time, is_stream, } => { // Send output as a separate message (handles spilling for large values). // Skip if null or empty array — those mean "already streamed" (generators). if !output.is_null() && output != serde_json::Value::Array(vec![]) { let output_msg = match build_output_message(&output_dir, output) { Ok(msg) => msg, Err(e) => { tracing::error!(error = %e, "Failed to build output message"); return SlotCompletion::poisoned( slot_id, format!("Output spill error: {}", e), ); } }; if let Err(e) = w.send(output_msg).await { tracing::error!(error = %e, "Failed to send prediction output"); return SlotCompletion::poisoned(slot_id, format!("Socket write error: {}", e)); } } SlotResponse::Done { id: prediction_id.clone(), output: None, predict_time, is_stream, } } PredictionOutcome::Cancelled { .. } => SlotResponse::Cancelled { id: prediction_id.clone(), }, PredictionOutcome::Failed { error, .. } => SlotResponse::Failed { id: prediction_id.clone(), error, }, }; if let Err(e) = w.send(response).await { tracing::error!(error = %e, "Failed to send prediction response"); return SlotCompletion::poisoned(slot_id, format!("Socket write error: {}", e)); } SlotCompletion::idle(slot_id) } #[cfg(test)] mod tests { use super::*; #[test] fn predict_result_success() { let r = PredictResult::success(serde_json::json!("hello"), 0.5, false); assert!(matches!(r.outcome, PredictionOutcome::Success { .. })); } #[test] fn predict_result_success_stream() { let r = PredictResult::success(serde_json::json!([]), 0.5, true); assert!(matches!( r.outcome, PredictionOutcome::Success { is_stream: true, .. } )); } #[test] fn predict_result_failed() { let r = PredictResult::failed("oops".into(), 0.5); assert!(matches!( r.outcome, PredictionOutcome::Failed { ref error, .. } if error == "oops" )); } #[test] fn predict_result_cancelled() { let r = PredictResult::cancelled(0.5); assert!(matches!(r.outcome, PredictionOutcome::Cancelled { .. })); } #[test] fn worker_config_default() { let config = WorkerConfig::default(); assert_eq!(config.num_slots, 1); } } ================================================ FILE: crates/coglet/src/worker_tracing_layer.rs ================================================ //! Custom tracing layer for worker subprocess. //! //! Ships structured tracing events over IPC to orchestrator, preserving target and level. //! Optionally writes to fd 101 for direct debugging (controlled by RUST_WORKER_DIRECT_LOG=1). use std::io::Write; use std::sync::{Arc, Mutex}; use tokio::sync::mpsc; use tracing::{Level, Subscriber}; use tracing_subscriber::layer::{Context, Layer}; use crate::bridge::protocol::{ControlResponse, truncate_worker_log}; pub struct WorkerTracingLayer { tx: mpsc::Sender, direct_log_fd: Option>>, } impl WorkerTracingLayer { pub fn new(tx: mpsc::Sender) -> Self { let direct_log_fd = if std::env::var("RUST_WORKER_DIRECT_LOG").as_deref() == Ok("1") { // fd 101 is the original stderr preserved during fd_redirect let fd = unsafe { std::fs::File::from_raw_fd(101) }; Some(Arc::new(Mutex::new(fd))) } else { None }; Self { tx, direct_log_fd } } fn level_to_string(level: &Level) -> &'static str { match *level { Level::TRACE => "trace", Level::DEBUG => "debug", Level::INFO => "info", Level::WARN => "warn", Level::ERROR => "error", } } } impl Layer for WorkerTracingLayer where S: Subscriber, { fn on_event(&self, event: &tracing::Event<'_>, _ctx: Context<'_, S>) { let metadata = event.metadata(); let target = metadata.target(); let level = Self::level_to_string(metadata.level()); let mut visitor = MessageVisitor::default(); event.record(&mut visitor); let message = truncate_worker_log(visitor.message); // Targets excluded from IPC: // - coglet::bridge::codec: feedback loop when encoding WorkerLog messages // - coglet::worker_local: diagnostics that must stay on the worker process let is_local_only = target.starts_with("coglet::bridge::codec") || target.starts_with("coglet::worker_local"); if !is_local_only { let _ = self.tx.try_send(ControlResponse::WorkerLog { target: target.to_string(), level: level.to_string(), message: message.clone(), }); } // Write to preserved stderr (fd 101) for: // - worker_local targets (always, these are worker-only diagnostics) // - all targets when RUST_WORKER_DIRECT_LOG=1 is set if let Some(ref fd) = self.direct_log_fd && let Ok(mut file) = fd.lock() { let _ = writeln!(file, "worker::{} [{}] {}", target, level, message); } else if is_local_only { // No direct_log_fd but this is a local-only event — write to fd 101 directly. // Safety: fd 101 is the preserved original stderr from fd_redirect. #[cfg(unix)] { use std::os::unix::io::FromRawFd; let mut file = unsafe { std::fs::File::from_raw_fd(101) }; let _ = writeln!(file, "worker::{} [{}] {}", target, level, message); std::mem::forget(file); // Don't close fd 101 } } } } #[derive(Default)] struct MessageVisitor { message: String, } impl tracing::field::Visit for MessageVisitor { fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) { if field.name() == "message" { self.message = format!("{:?}", value); if self.message.starts_with('"') && self.message.ends_with('"') { self.message = self.message[1..self.message.len() - 1].to_string(); } } } fn record_str(&mut self, field: &tracing::field::Field, value: &str) { if field.name() == "message" { self.message = value.to_string(); } } } #[cfg(unix)] use std::os::unix::io::FromRawFd; ================================================ FILE: crates/coglet-python/Cargo.toml ================================================ [package] name = "coglet-python" version.workspace = true edition.workspace = true license.workspace = true publish = false # Published to PyPI as 'coglet' wheel, not crates.io build = "build.rs" [lib] name = "coglet" crate-type = ["cdylib", "rlib"] [dependencies] async-trait = "0.1.89" base64 = "0.22" coglet_core = { path = "../coglet", package = "coglet" } futures.workspace = true pyo3.workspace = true pyo3-async-runtimes.workspace = true pyo3-stub-gen.workspace = true serde_json.workspace = true tokio.workspace = true tokio-util = { workspace = true, features = ["codec"] } tracing.workspace = true tracing-subscriber.workspace = true [target.'cfg(unix)'.dependencies] libc = "0.2" [dev-dependencies] pyo3 = { workspace = true, features = ["auto-initialize"] } [features] extension-module = ["pyo3/extension-module"] ================================================ FILE: crates/coglet-python/README.md ================================================ # coglet-python PyO3 bindings that bridge the Rust coglet library to Python. This crate implements the `PredictHandler` trait by wrapping Python predictor classes. ## Overview ``` ┌─────────────────────────────────────────────────────────────────────────────┐ │ coglet-python │ │ │ │ ┌──────────────────────────────────────────────────────────────────────┐ │ │ │ lib.rs │ │ │ │ Python module: serve(), active(), _run_worker(), _is_cancelable() │ │ │ └──────────────────────────────────────────────────────────────────────┘ │ │ │ │ │ ┌───────────────────────┼───────────────────────┐ │ │ ▼ ▼ ▼ │ │ ┌─────────────────────┐ ┌─────────────────────┐ ┌──────────────────┐ │ │ │ worker_bridge.rs │ │ predictor.rs │ │ log_writer.rs │ │ │ │ PredictHandler │ │ PythonPredictor │ │ SlotLogWriter │ │ │ │ impl for Python │ │ load/setup/predict │ │ ContextVar │ │ │ └─────────────────────┘ └─────────────────────┘ └──────────────────┘ │ │ │ │ │ │ │ │ │ │ │ │ ▼ ▼ ▼ │ │ ┌─────────────────────┐ ┌─────────────────────┐ ┌──────────────────┐ │ │ │ input.rs │ │ output.rs │ │ audit.rs │ │ │ │ Pydantic/ADT │ │ JSON serialization │ │ TeeWriter │ │ │ │ input processing │ │ make_encodeable │ │ stream protect │ │ │ └─────────────────────┘ └─────────────────────┘ └──────────────────┘ │ │ │ │ ┌──────────────────────────────────────────────────────────────────────┐ │ │ │ cancel.rs │ │ │ │ SIGUSR1 handling, CancelableGuard, KeyboardInterrupt injection │ │ │ └──────────────────────────────────────────────────────────────────────┘ │ │ │ └─────────────────────────────────────────────────────────────────────────────┘ ``` ## Directory Structure ``` coglet-python/ ├── Cargo.toml ├── coglet.pyi # Type stubs for Python IDE support └── src/ ├── lib.rs # Python module definition, serve/active/_run_worker ├── predictor.rs # PythonPredictor: wraps Python Predictor class ├── worker_bridge.rs # PythonPredictHandler: implements PredictHandler ├── input.rs # Input processing (Pydantic validation, ADT) ├── output.rs # Output processing (make_encodeable, upload_files) ├── log_writer.rs # SlotLogWriter, ContextVar routing, SetupLogSender ├── audit.rs # Audit hook, TeeWriter for stream protection └── cancel.rs # Cancellation: SIGUSR1, CancelableGuard ``` ## Critical Concepts ### The `active()` Flag ```python import coglet if coglet.server.active: # Running inside worker subprocess # stdout/stderr are captured, print goes to slot routing else: # Running in parent or standalone # Normal stdout/stderr behavior ``` Set to `True` at the start of `_run_worker()`. Used by user code and cog internals to detect worker context. ### Single Async Event Loop Async predictors run on a **single** Python asyncio event loop, created at worker startup. All slots share this loop. ``` Worker Subprocess ┌─────────────────────────────────────────────────────────────┐ │ Tokio Runtime (Rust) │ │ └─ run_worker event loop │ │ └─ For each SlotRequest::Predict: │ │ └─ tokio::spawn prediction task │ │ └─ Python::attach (acquire GIL) │ │ └─ asyncio.run_coroutine_threadsafe() │ │ └─ Predictor.predict() coroutine │ └─────────────────────────────────────────────────────────────┘ asyncio event loop (Python) ┌─────────────────────────────────────────────────────────────┐ │ Single event loop, started once at worker init │ │ - concurrent.futures.Future per async prediction │ │ - ContextVar propagates prediction_id to spawned tasks │ │ - Cancellation via future.cancel() │ └─────────────────────────────────────────────────────────────┘ ``` **Why single loop?** - Python asyncio has one event loop per thread - We use `run_coroutine_threadsafe` to submit from Rust/Tokio - Multiple slots can have concurrent predictions (up to `max_concurrency`) ### Prediction Execution **Sync Predictors:** ``` SlotRequest::Predict arrives │ ├─▶ Python::attach (acquire GIL) ├─▶ set_sync_prediction_id(id) # For log routing ├─▶ predictor.predict(input) # Blocking call ├─▶ set_sync_prediction_id(None) └─▶ Return PredictResult ``` **Async Predictors:** ``` SlotRequest::Predict arrives │ ├─▶ Python::attach (acquire GIL) ├─▶ Create wrapped coroutine: │ async def _ctx_wrapper(coro, prediction_id, contextvar): │ contextvar.set(prediction_id) # Set in this task's context │ return await coro │ ├─▶ asyncio.run_coroutine_threadsafe(wrapper, loop) ├─▶ py.detach() (release GIL) ├─▶ future.result() (block Rust task, Python runs) └─▶ Return PredictResult ``` ### STDOUT/STDERR Routing All output from user code must be captured and routed through the slot socket. **Architecture:** ``` sys.stdout = SlotLogWriter(stdout) sys.stderr = SlotLogWriter(stderr) SlotLogWriter.write(data) │ ├─▶ Get current prediction_id from: │ 1. SYNC_PREDICTION_ID static (for sync predictors) │ 2. ContextVar (for async predictors/spawned tasks) │ ├─▶ Look up SlotSender in PREDICTION_REGISTRY │ └─▶ Route: Found sender → slot_sender.send_log(source, data) No sender → Check setup sender (during setup) Neither → Log as orphan to stderr ``` **Line Buffering:** SlotLogWriter buffers writes until a newline. This coalesces Python's `print()` which does separate writes for content and `\n`. ### Audit Hook Protection User code might replace `sys.stdout`: ```python sys.stdout = open("mylog.txt", "w") ``` We can't prevent this, but we can intercept it with a Python audit hook. **Strategy: TeeWriter** ``` User replaces sys.stdout │ ├─▶ Audit hook fires on object.__setattr__(sys, "stdout", value) │ ├─▶ Check: is value already SlotLogWriter? → Allow (it's us) │ ├─▶ Check: is value already TeeWriter? → Allow (already wrapped) │ ├─▶ Create TeeWriter(inner=SlotLogWriter, user_stream=value) │ └─▶ Schedule: sys.stdout = tee (via Timer to avoid recursion) TeeWriter.write(data) │ ├─▶ inner.write(data) # Our SlotLogWriter (routing works) └─▶ user_stream.write(data) # User's stream (their code works) ``` **Result:** Both our log routing AND the user's stream receive the data. ### Cancellation **Sync Predictors:** ``` Parent: ControlRequest::Cancel { slot } │ ├─▶ Worker: handler.cancel(slot) │ └─▶ Set CANCEL_REQUESTED flag for slot │ ├─▶ Worker: send SIGUSR1 to self │ └─▶ Signal handler: raise KeyboardInterrupt (if in cancelable region) Prediction code: with CancelableGuard(): # Sets CANCELABLE=true predictor.predict() # Can be interrupted # CANCELABLE=false on exit ``` **Async Predictors:** ``` Parent: ControlRequest::Cancel { slot } │ └─▶ Worker: handler.cancel(slot) │ ├─▶ Get future from slot state └─▶ future.cancel() │ └─▶ Python raises asyncio.CancelledError ``` ### Setup Log Routing During setup (before any prediction), logs go through the control channel: ``` worker_bridge.setup() │ ├─▶ register_setup_sender(tx) # Control channel sender │ ├─▶ predictor.load() + predictor.setup() │ │ │ └─▶ print("Loading model...") │ │ │ └─▶ SlotLogWriter.write() │ │ │ ├─▶ No prediction_id (not in prediction) │ └─▶ get_setup_sender() → ControlResponse::Log │ └─▶ unregister_setup_sender() ``` ### Behaviors **Worker Startup:** 1. `set_active()` - Mark as worker subprocess 2. `init_tracing()` - Configure logging (stderr, COG_LOG_LEVEL env) 3. `install_slot_log_writers()` - Replace sys.stdout/stderr 4. `install_audit_hook()` - Protect streams 5. `install_signal_handler()` - SIGUSR1 for cancellation 6. Read Init message from stdin 7. Connect to slot sockets 8. `handler.setup()` - Load and initialize predictor 9. Send Ready message 10. Enter event loop **Shutdown:** - ControlRequest::Shutdown → Send ShuttingDown, exit - stdin closes (parent died) → Exit immediately - All slots poisoned → Exit **Error Handling:** - SetupError::Load - Failed to import/instantiate predictor - SetupError::Setup - setup() raised exception - PredictionError - Prediction failed, slot stays healthy - Slot write error → Slot poisoned (no more predictions on that slot) ================================================ FILE: crates/coglet-python/build.rs ================================================ //! Build script for coglet-python. //! //! Captures build metadata and converts semver to PEP 440 for Python compatibility. use std::process::Command; fn main() { // Convert CARGO_PKG_VERSION (semver) to PEP 440 let version = env!("CARGO_PKG_VERSION"); let pep440 = semver_to_pep440(version); println!("cargo:rustc-env=COGLET_PEP440_VERSION={pep440}"); // Git SHA (short) let git_sha = Command::new("git") .args(["rev-parse", "--short", "HEAD"]) .output() .ok() .filter(|o| o.status.success()) .map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string()) .unwrap_or_else(|| "unknown".to_string()); println!("cargo:rustc-env=COGLET_GIT_SHA={git_sha}"); // Git dirty flag let git_dirty = Command::new("git") .args(["status", "--porcelain"]) .output() .ok() .filter(|o| o.status.success()) .map(|o| { if String::from_utf8_lossy(&o.stdout).trim().is_empty() { "false" } else { "true" } }) .unwrap_or("unknown"); println!("cargo:rustc-env=COGLET_GIT_DIRTY={git_dirty}"); // Build timestamp (UTC, ISO 8601) let build_time = Command::new("date") .args(["-u", "+%Y-%m-%dT%H:%M:%SZ"]) .output() .ok() .filter(|o| o.status.success()) .map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string()) .unwrap_or_else(|| "unknown".to_string()); println!("cargo:rustc-env=COGLET_BUILD_TIME={build_time}"); // Rustc version let rustc_version = Command::new("rustc") .args(["--version"]) .output() .ok() .filter(|o| o.status.success()) .map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string()) .unwrap_or_else(|| "unknown".to_string()); println!("cargo:rustc-env=COGLET_RUSTC_VERSION={rustc_version}"); // Rebuild if git HEAD changes or files are staged println!("cargo:rerun-if-changed=../../.git/HEAD"); println!("cargo:rerun-if-changed=../../.git/refs"); println!("cargo:rerun-if-changed=../../.git/index"); } /// Convert a semver version string to PEP 440 format. /// /// Mapping: /// 0.17.0 → 0.17.0 /// 0.17.0-alpha.2 → 0.17.0a2 /// 0.17.0-beta.1 → 0.17.0b1 /// 0.17.0-rc.3 → 0.17.0rc3 /// 0.17.0-dev.4 → 0.17.0.dev4 fn semver_to_pep440(version: &str) -> String { let Some((base, pre)) = version.split_once('-') else { return version.to_string(); }; if let Some(n) = pre.strip_prefix("alpha.") { format!("{base}a{n}") } else if let Some(n) = pre.strip_prefix("alpha") { if n.is_empty() { format!("{base}a0") } else { format!("{base}a{n}") } } else if let Some(n) = pre.strip_prefix("beta.") { format!("{base}b{n}") } else if let Some(n) = pre.strip_prefix("beta") { if n.is_empty() { format!("{base}b0") } else { format!("{base}b{n}") } } else if let Some(n) = pre.strip_prefix("rc.") { format!("{base}rc{n}") } else if let Some(n) = pre.strip_prefix("rc") { if n.is_empty() { format!("{base}rc0") } else { format!("{base}rc{n}") } } else if let Some(n) = pre.strip_prefix("dev.") { format!("{base}.dev{n}") } else if let Some(n) = pre.strip_prefix("dev") { if n.is_empty() { format!("{base}.dev0") } else { format!("{base}.dev{n}") } } else { // Unknown pre-release format, pass through version.to_string() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_stable_version() { assert_eq!(semver_to_pep440("0.17.0"), "0.17.0"); assert_eq!(semver_to_pep440("1.0.0"), "1.0.0"); } #[test] fn test_alpha() { assert_eq!(semver_to_pep440("0.17.0-alpha.2"), "0.17.0a2"); assert_eq!(semver_to_pep440("0.17.0-alpha.0"), "0.17.0a0"); assert_eq!(semver_to_pep440("0.17.0-alpha"), "0.17.0a0"); } #[test] fn test_beta() { assert_eq!(semver_to_pep440("0.17.0-beta.1"), "0.17.0b1"); assert_eq!(semver_to_pep440("0.17.0-beta"), "0.17.0b0"); } #[test] fn test_rc() { assert_eq!(semver_to_pep440("0.17.0-rc.3"), "0.17.0rc3"); assert_eq!(semver_to_pep440("0.17.0-rc"), "0.17.0rc0"); } #[test] fn test_dev() { assert_eq!(semver_to_pep440("0.17.0-dev.4"), "0.17.0.dev4"); assert_eq!(semver_to_pep440("0.17.0-dev"), "0.17.0.dev0"); } } ================================================ FILE: crates/coglet-python/coglet/__init__.py ================================================ """coglet — high-performance Rust prediction server for Cog ML models.""" from coglet._impl import CancelationException, __build__, __version__, server from coglet._impl import _sdk as _sdk __all__ = ["__version__", "__build__", "server", "CancelationException"] ================================================ FILE: crates/coglet-python/coglet/__init__.pyi ================================================ # This file is automatically generated by stub_gen # ruff: noqa: E501, F401 from coglet._impl import __build__ as __build__, __version__ as __version__, server as server, CancelationException as CancelationException from . import _sdk as _sdk __all__ = ['__build__', '__version__', 'server', 'CancelationException'] ================================================ FILE: crates/coglet-python/coglet/_impl.pyi ================================================ # This file is automatically generated by pyo3_stub_gen # ruff: noqa: E501, F401, F403, F405 import builtins import typing from . import _sdk __all__ = [ "BuildInfo", "CancelationException", "Server", "server", ] __build__: BuildInfo __version__: builtins.str server: Server @typing.final class BuildInfo: r""" Frozen build metadata exposed as `coglet.__build__`. """ @property def version(self) -> builtins.str: ... @property def git_sha(self) -> builtins.str: ... @property def build_time(self) -> builtins.str: ... @property def rustc_version(self) -> builtins.str: ... def __repr__(self) -> builtins.str: ... class CancelationException(builtins.BaseException): r""" Raised when a running prediction or training is cancelled. Derives from ``BaseException`` (not ``Exception``) so that bare ``except Exception`` blocks do not accidentally swallow cancellation. This matches the semantics of ``KeyboardInterrupt`` and ``asyncio.CancelledError``. """ ... @typing.final class Server: r""" The coglet prediction server. Access via `coglet.server`. Frozen — attributes cannot be set or deleted. - `coglet.server.active` — `True` when running inside a worker subprocess - `coglet.server.serve(...)` — start the HTTP prediction server (blocking) """ @property def active(self) -> builtins.bool: r""" `True` when running inside a coglet worker subprocess. """ def serve(self, predictor_ref: typing.Optional[builtins.str] = None, host: builtins.str = '0.0.0.0', port: builtins.int = 5000, await_explicit_shutdown: builtins.bool = False, is_train: builtins.bool = False, output_temp_dir_base: builtins.str = '/tmp/coglet/output', upload_url: typing.Optional[builtins.str] = None) -> None: r""" Start the HTTP prediction server. Blocks until shutdown. """ def _run_worker(self) -> None: r""" Worker subprocess entry point. Called by the orchestrator. Sets the active flag, installs log writers and audit hooks, then enters the worker event loop. """ def __repr__(self) -> builtins.str: ... ================================================ FILE: crates/coglet-python/coglet/_sdk/__init__.pyi ================================================ # This file is automatically generated by pyo3_stub_gen # ruff: noqa: E501, F401, F403, F405 import builtins import typing __all__ = [ "MetricRecorder", "Scope", "current_scope", ] @typing.final class MetricRecorder: r""" Metric recorder with type invariant enforcement. Accessed via `scope.metrics`. Supports: - `scope.metrics.record(key, value, mode="replace")` — full API - `scope.metrics.delete(key)` — delete (required before type change) - `scope.metrics[key] = value` — dict-style set (replace mode) - `del scope.metrics[key]` — dict-style delete """ def record( self, key: builtins.str, value: typing.Any, mode: typing.Optional[builtins.str] = None, ) -> None: r""" Record a metric value. Args: key: Metric name. Dot-separated keys (e.g. "timing.preprocess") create nested objects in the response. value: Must be bool, int, float, str, list, or dict. Once a key is set with a type, it cannot be changed without calling delete() first. mode: Accumulation mode — "replace" (default), "incr" (increment numeric), or "append" (push to array). """ def delete(self, key: builtins.str) -> None: r""" Delete a metric key. Required before changing a metric's type. """ def __setitem__(self, key: builtins.str, value: typing.Any) -> None: r""" Dict-style set: `scope.metrics["key"] = value` """ def __delitem__(self, key: builtins.str) -> None: r""" Dict-style delete: `del scope.metrics["key"]` """ def __repr__(self) -> builtins.str: ... @typing.final class Scope: r""" Prediction scope, obtained via `current_scope()`. Provides access to `scope.metrics` for recording metrics, and `scope.record_metric()` as a convenience shorthand. """ @property def metrics(self) -> MetricRecorder: r""" The metric recorder for this prediction. """ @property def context(self) -> dict[builtins.str, builtins.str]: r""" Per-prediction context passed in the request body. Returns a `dict[str, str]` (empty dict if no context was provided). """ def record_metric( self, key: builtins.str, value: typing.Any, mode: typing.Optional[builtins.str] = None, ) -> None: r""" Convenience: record a metric value. Equivalent to `scope.metrics.record(key, value, mode)`. """ def __repr__(self) -> builtins.str: ... @typing.final class _SlotLogWriter: r""" A Python file-like object that routes writes via the prediction_id ContextVar. This is installed as sys.stdout/stderr once at worker startup. Each write looks up the current prediction_id from the ContextVar and routes to the appropriate SlotSender. If no prediction_id is set, or the prediction has completed (orphan task), writes go to tracing (logged as orphan). Uses line buffering: accumulates writes until a newline is received, then emits complete lines. This coalesces Python's print() which does separate writes for content and newline. """ @property def closed(self) -> builtins.bool: r""" Whether writes should be ignored (used after errors). """ @property def encoding(self) -> typing.Optional[builtins.str]: r""" Encoding property - needed for compatibility. """ @property def newlines(self) -> typing.Optional[builtins.str]: r""" Newlines property - needed for compatibility. """ @property def buffer(self) -> typing.Any: r""" Buffer property - some code checks for this. """ def write(self, data: builtins.str) -> builtins.int: r""" Write data, routing to the appropriate destination. Uses line buffering: accumulates data until a newline is received, then emits complete lines. This coalesces Python's print() which does separate writes for content and the trailing newline. Priority for routing: 1. If inside a prediction (ContextVar set), route to slot sender 2. If setup sender registered, route to control channel 3. Fall back to stderr (for orphan tasks or unexpected cases) """ def emit_data(self, data: builtins.str) -> None: r""" Emit data to the appropriate destination. """ def flush(self) -> None: r""" Flush the stream. Emits any buffered content that hasn't been terminated with a newline. """ def readable(self) -> builtins.bool: r""" Return whether the stream is readable. """ def writable(self) -> builtins.bool: r""" Return whether the stream is writable. """ def seekable(self) -> builtins.bool: r""" Return whether the stream is seekable. """ def isatty(self) -> builtins.bool: r""" Return whether the stream is a TTY. """ def fileno(self) -> builtins.int: r""" Return the file number. """ def close(self) -> None: r""" Close the stream. """ def __enter__(self) -> _SlotLogWriter: r""" Context manager enter. """ def __exit__( self, _exc_type: typing.Optional[typing.Any], _exc_val: typing.Optional[typing.Any], _exc_tb: typing.Optional[typing.Any], ) -> builtins.bool: r""" Context manager exit. """ @typing.final class _TeeWriter: r""" Tee writer that sends writes to both our slot routing and user's stream. - inner: Our _SlotLogWriter for slot-based log routing - user_stream: The stream user code tried to install """ @property def inner(self) -> typing.Any: r""" Our _SlotLogWriter (does ContextVar-based routing) """ @property def user_stream(self) -> typing.Any: r""" User's replacement stream """ @property def name(self) -> builtins.str: r""" Stream name (stdout or stderr) """ @property def closed(self) -> builtins.bool: r""" Closed flag """ @property def encoding(self) -> typing.Optional[builtins.str]: ... @property def newlines(self) -> typing.Optional[builtins.str]: ... def __new__( cls, inner: typing.Any, user_stream: typing.Any, name: builtins.str ) -> _TeeWriter: ... def write(self, data: builtins.str) -> builtins.int: r""" Write to both streams. """ def flush(self) -> None: r""" Flush both streams. """ def readable(self) -> builtins.bool: ... def writable(self) -> builtins.bool: ... def seekable(self) -> builtins.bool: ... def isatty(self) -> builtins.bool: ... def fileno(self) -> builtins.int: ... def close(self) -> None: ... def __enter__(self) -> _TeeWriter: ... def __exit__( self, _exc_type: typing.Optional[typing.Any], _exc_val: typing.Optional[typing.Any], _exc_tb: typing.Optional[typing.Any], ) -> builtins.bool: ... def current_scope() -> Scope: r""" Python-callable: get the current Scope. Returns the active scope if inside a prediction, or a no-op scope otherwise. """ ================================================ FILE: crates/coglet-python/coglet/py.typed ================================================ ================================================ FILE: crates/coglet-python/pyproject.toml ================================================ [build-system] requires = ["maturin>=1.0,<2.0"] build-backend = "maturin" [project] name = "coglet" description = "High-performance Rust prediction server for Cog ML models" readme = "README.md" license = {text = "Apache-2.0"} requires-python = ">=3.10" authors = [ {name = "Replicate", email = "team@replicate.com"}, ] keywords = ["machine-learning", "inference", "cog", "prediction"] classifiers = [ "Development Status :: 4 - Beta", "Intended Audience :: Developers", "License :: OSI Approved :: Apache Software License", "Operating System :: MacOS", "Operating System :: POSIX :: Linux", "Programming Language :: Rust", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", "Topic :: Scientific/Engineering :: Artificial Intelligence", ] dynamic = ["version"] dependencies = [] [project.urls] Homepage = "https://cog.run" Documentation = "https://cog.run/docs" Repository = "https://github.com/replicate/cog" Issues = "https://github.com/replicate/cog/issues" [project.optional-dependencies] test = [ "pytest>=8.0", "requests>=2.31", ] [tool.maturin] features = ["pyo3/extension-module"] # Mixed layout: coglet/__init__.py is hand-managed, .so is named _impl. module-name = "coglet._impl" # Tell pyo3-stub-gen where the Python source root is (also used by maturin) python-source = "." # Use manylinux2014 (glibc 2.17) for compatibility with Python 3.10+ base images compatibility = "manylinux2014" [tool.pytest.ini_options] testpaths = ["tests"] ================================================ FILE: crates/coglet-python/src/audit.rs ================================================ //! Audit hooks to protect Rust-injected runtime objects. //! //! Uses sys.addaudithook to intercept operations that could interfere with //! our runtime machinery. The hook cannot be removed once added. //! //! ## Protection: sys.stdout/stderr (Tee pattern) //! //! If user code replaces stdout/stderr, we wrap their replacement in a _TeeWriter //! that sends data to BOTH our slot routing AND their stream. User's code works //! as they expect, but we still get our logs. //! //! If they replace again, we unwrap the inner _SlotLogWriter from the current //! _TeeWriter and re-tee with the new stream. No nested _TeeWriters. use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Mutex, OnceLock}; use pyo3::prelude::*; use pyo3_stub_gen::derive::*; /// Whether the audit hook has been installed. static HOOK_INSTALLED: AtomicBool = AtomicBool::new(false); /// Re-entrancy guard for the audit hook. /// Prevents infinite recursion when the hook itself sets sys.stdout/stderr. static IN_HOOK: AtomicBool = AtomicBool::new(false); /// Serializes stream replacement so concurrent threads don't race /// on the read-current → create-tee → set-new sequence. static STREAM_LOCK: Mutex<()> = Mutex::new(()); /// Reference to sys module for identity comparison in hook. static SYS_MODULE: OnceLock> = OnceLock::new(); /// Reference to our _SlotLogWriter class for isinstance checks. static SLOT_LOG_WRITER_TYPE: OnceLock> = OnceLock::new(); /// Install the audit hook. Called once at worker startup. /// /// The hook intercepts object.__setattr__ on sys for stdout/stderr. pub fn install_audit_hook(py: Python<'_>) -> PyResult<()> { if HOOK_INSTALLED.swap(true, Ordering::SeqCst) { return Ok(()); } // Store sys module reference for identity comparison let sys = py.import("sys")?; let _ = SYS_MODULE.set(sys.as_any().clone().unbind()); // Store our _SlotLogWriter type for isinstance checks if let Ok(coglet) = py.import("coglet") && let Ok(writer_type) = coglet.getattr("_SlotLogWriter") { let _ = SLOT_LOG_WRITER_TYPE.set(writer_type.unbind()); } // Register the Rust audit hook callable let hook = wrap_pyfunction!(_coglet_audit_hook, py)?; sys.call_method1("addaudithook", (hook,))?; tracing::debug!("Installed audit hook for runtime protection"); Ok(()) } /// Audit hook implemented in Rust. /// /// Intercepts `object.__setattr__` events on `sys` for stdout/stderr. /// Uses an AtomicBool re-entrancy guard instead of deferred threading.Timer. #[pyfunction] fn _coglet_audit_hook(py: Python<'_>, event: &str, args: &Bound<'_, PyAny>) -> PyResult<()> { if event != "object.__setattr__" { return Ok(()); } // Re-entrancy guard: skip if we're already inside the hook // (because we're setting sys.stdout/stderr ourselves). if IN_HOOK.load(Ordering::SeqCst) { return Ok(()); } // args is (obj, name, value) let obj = args.get_item(0)?; let name: String = args.get_item(1)?.extract()?; if name != "stdout" && name != "stderr" { return Ok(()); } // Check if obj is the sys module (identity comparison) let Some(sys_ref) = SYS_MODULE.get() else { return Ok(()); }; if !obj.is(sys_ref.bind(py)) { return Ok(()); } let value = args.get_item(2)?; handle_stream_replacement(py, &name, &value)?; Ok(()) } /// Handle user code replacing sys.stdout or sys.stderr. /// /// If the new value is already our _SlotLogWriter, this is our own setup — skip. /// Otherwise, find our _SlotLogWriter from the current stream (direct or inside /// a _TeeWriter), and wrap the user's new stream in a fresh _TeeWriter. fn handle_stream_replacement(py: Python<'_>, name: &str, value: &Bound<'_, PyAny>) -> PyResult<()> { // If value is our _SlotLogWriter, this is us installing — skip if is_slot_log_writer(py, value) { return Ok(()); } // Serialize the read-current → create-tee → set-new sequence. // Without this, two threads replacing stdout simultaneously could race // and one tee gets silently dropped. // The lock protects no data (just `()`), so poisoned is safe to recover. let _lock = STREAM_LOCK.lock().unwrap_or_else(|poisoned| { tracing::warn!( target: "coglet::worker_local", "stream lock was poisoned (a thread panicked during stream replacement) — \ recovering, but log routing may be inconsistent" ); poisoned.into_inner() }); // Get current writer from sys let sys = py.import("sys")?; let current = sys.getattr(name)?; // Find our _SlotLogWriter — either it IS current, or it's inside a _TeeWriter let slot_writer = if is_slot_log_writer(py, ¤t) { Some(current.clone().unbind()) } else if is_tee_writer(¤t) { get_inner_writer(py, ¤t).ok() } else { None }; let Some(slot_writer) = slot_writer else { // No _SlotLogWriter installed — nothing to protect return Ok(()); }; // Create new _TeeWriter wrapping our _SlotLogWriter and user's stream let tee = _TeeWriter::new(slot_writer, value.clone().unbind(), name.to_string()); let tee_obj = tee.into_pyobject(py)?; // Set under re-entrancy guard to prevent hook from re-triggering IN_HOOK.store(true, Ordering::SeqCst); let result = sys.setattr(name, tee_obj); IN_HOOK.store(false, Ordering::SeqCst); result } // ============================================================================ // Type checks — pub(crate) only, not exported to Python // ============================================================================ /// Check if a value is a _SlotLogWriter. pub(crate) fn is_slot_log_writer(py: Python<'_>, value: &Bound<'_, PyAny>) -> bool { if let Some(writer_type) = SLOT_LOG_WRITER_TYPE.get() && let Ok(true) = value.is_instance(writer_type.bind(py)) { return true; } // Fallback: check by class name (handles cross-module edge cases) if let Ok(type_name) = value.get_type().name() { return type_name == "_SlotLogWriter"; } false } /// Check if a value is a _TeeWriter. pub(crate) fn is_tee_writer(value: &Bound<'_, PyAny>) -> bool { if value.is_instance_of::<_TeeWriter>() { return true; } if let Ok(type_name) = value.get_type().name() { return type_name == "_TeeWriter"; } false } /// Get the inner _SlotLogWriter from a _TeeWriter. pub(crate) fn get_inner_writer(py: Python<'_>, tee: &Bound<'_, PyAny>) -> PyResult> { if let Ok(tee_writer) = tee.extract::>() { return Ok(tee_writer.inner.clone_ref(py)); } if let Ok(inner) = tee.getattr("inner") { return Ok(inner.unbind()); } Err(pyo3::exceptions::PyTypeError::new_err( "Expected _TeeWriter with inner attribute", )) } // ============================================================================ // _TeeWriter — private pyclass // ============================================================================ /// Tee writer that sends writes to both our slot routing and user's stream. /// /// - inner: Our _SlotLogWriter for slot-based log routing /// - user_stream: The stream user code tried to install #[gen_stub_pyclass] #[pyclass(name = "_TeeWriter", module = "coglet._sdk")] pub struct _TeeWriter { /// Our _SlotLogWriter (does ContextVar-based routing) #[pyo3(get)] inner: Py, /// User's replacement stream #[pyo3(get)] user_stream: Py, /// Stream name (stdout or stderr) #[pyo3(get)] name: String, /// Closed flag #[pyo3(get)] closed: bool, } #[gen_stub_pymethods] #[pymethods] impl _TeeWriter { #[new] fn new(inner: Py, user_stream: Py, name: String) -> Self { Self { inner, user_stream, name, closed: false, } } /// Write to both streams. fn write(&self, py: Python<'_>, data: &str) -> PyResult { if self.closed || data.is_empty() { return Ok(data.len()); } if let Err(e) = self.inner.call_method1(py, "write", (data,)) { tracing::warn!(error = %e, "_TeeWriter: failed to write to inner"); } if let Err(e) = self.user_stream.call_method1(py, "write", (data,)) { tracing::warn!(error = %e, "_TeeWriter: failed to write to user stream"); } Ok(data.len()) } /// Flush both streams. fn flush(&self, py: Python<'_>) -> PyResult<()> { let _ = self.inner.call_method0(py, "flush"); let _ = self.user_stream.call_method0(py, "flush"); Ok(()) } fn readable(&self) -> bool { false } fn writable(&self) -> bool { !self.closed } fn seekable(&self) -> bool { false } fn isatty(&self, py: Python<'_>) -> PyResult { let result = self.user_stream.call_method0(py, "isatty")?; result.extract(py) } fn fileno(&self, py: Python<'_>) -> PyResult { let result = self.user_stream.call_method0(py, "fileno")?; result.extract(py) } fn close(&mut self) { self.closed = true; } fn __enter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { slf } fn __exit__( &mut self, _exc_type: Option<&Bound<'_, PyAny>>, _exc_val: Option<&Bound<'_, PyAny>>, _exc_tb: Option<&Bound<'_, PyAny>>, ) -> bool { false } #[getter] fn encoding(&self, py: Python<'_>) -> PyResult> { match self.user_stream.getattr(py, "encoding") { Ok(enc) => enc.extract(py), Err(_) => Ok(Some("utf-8".to_string())), } } #[getter] fn newlines(&self) -> Option { None } } ================================================ FILE: crates/coglet-python/src/bin/stub_gen.rs ================================================ //! Generate Python stub files for coglet. //! //! Run with: cargo run --bin stub_gen //! //! Custom generate logic: pyo3-stub-gen places classes from the native //! `coglet._impl` module into the `coglet` parent package stub. We redirect //! that output to `coglet/_impl.pyi` so native module types are preserved in //! the right place, then generate `coglet/__init__.pyi` ourselves to //! re-export the public API — matching the hand-maintained `__init__.py`. use pyo3_stub_gen::Result; use std::fs; use std::io::Write; /// Public items re-exported from `coglet._impl` in `coglet/__init__.pyi`. /// Uses the `X as X` pattern to mark explicit re-exports (PEP 484). const PUBLIC_REEXPORTS: &[&str] = &["__build__", "__version__", "server", "CancelationException"]; /// Private submodules re-exported with `from . import X as X`. /// /// These use a relative import (not `from coglet._impl`) because `_sdk` is a /// subpackage that type checkers resolve via the filesystem, not an attribute /// of the native extension module. Not included in `__all__`. const PRIVATE_REEXPORTS: &[&str] = &["_sdk"]; fn main() -> Result<()> { let stub = coglet::stub_info()?; for (name, module) in &stub.modules { let normalized = name.replace('-', "_"); let dest = if normalized == "coglet" { // Native module classes land here — redirect to _impl.pyi stub.python_root.join("coglet").join("_impl.pyi") } else { // Submodules like "coglet._sdk" → coglet/_sdk/__init__.pyi let path = normalized.replace('.', "/"); stub.python_root.join(&path).join("__init__.pyi") }; let dir = dest.parent().expect("cannot get parent directory"); if !dir.exists() { fs::create_dir_all(dir)?; } let mut f = fs::File::create(&dest)?; write!(f, "{module}")?; eprintln!("Generated stub: {}", dest.display()); } // Generate coglet/__init__.pyi — re-exports from _impl let init_pyi = stub.python_root.join("coglet").join("__init__.pyi"); let mut f = fs::File::create(&init_pyi)?; writeln!(f, "# This file is automatically generated by stub_gen")?; writeln!(f, "# ruff: noqa: E501, F401")?; writeln!(f)?; // `from coglet._impl import X as X, Y as Y, ...` let reexports: Vec = PUBLIC_REEXPORTS .iter() .map(|name| format!("{name} as {name}")) .collect(); writeln!(f, "from coglet._impl import {}", reexports.join(", "))?; // `from . import _sdk as _sdk` — relative import so ty resolves the // subpackage via coglet/_sdk/__init__.pyi, not through _impl. let private: Vec = PRIVATE_REEXPORTS .iter() .map(|name| format!("{name} as {name}")) .collect(); writeln!(f, "from . import {}", private.join(", "))?; writeln!(f)?; // __all__ only includes public items (no underscore-prefixed names) let all_items: Vec = PUBLIC_REEXPORTS .iter() .map(|name| format!("'{name}'")) .collect(); writeln!(f, "__all__ = [{}]", all_items.join(", "))?; eprintln!("Generated stub: {}", init_pyi.display()); Ok(()) } ================================================ FILE: crates/coglet-python/src/cancel.rs ================================================ //! Cancellation support for predictions. //! //! Sync predictors use `PyThreadState_SetAsyncExc` to inject a //! `CancelationException` (a `BaseException` subclass) into the Python //! thread running `predict()`. //! //! Async predictors use asyncio task cancellation: //! - Store task reference when prediction starts //! - Call task.cancel() when cancel requested //! - Python raises asyncio.CancelledError //! //! `CancelationException` deliberately derives from `BaseException` (not //! `Exception`) so that bare `except Exception` blocks in user code cannot //! swallow it — matching the semantics of `KeyboardInterrupt` and //! `asyncio.CancelledError`. use pyo3::prelude::*; // Static exception type with automatic stub generation. // Derives from BaseException so `except Exception` does not catch it. pyo3_stub_gen::create_exception!( coglet, CancelationException, pyo3::exceptions::PyBaseException, "Raised when a running prediction or training is cancelled.\n\ \n\ Derives from ``BaseException`` (not ``Exception``) so that bare\n\ ``except Exception`` blocks do not accidentally swallow cancellation.\n\ This matches the semantics of ``KeyboardInterrupt`` and\n\ ``asyncio.CancelledError``." ); /// Inject CancelationException into a specific Python thread. /// /// Uses CPython's `PyThreadState_SetAsyncExc` to raise the exception at the /// next bytecode boundary. This works on any thread (not just the main thread), /// unlike SIGUSR1-based cancellation. /// /// Requires the GIL — `Python::attach` acquires it, blocking briefly if the /// prediction thread currently holds it (CPython releases it every ~5ms). pub fn cancel_sync_thread(py_thread_id: std::ffi::c_long) { Python::attach(|py| { let exc = py.get_type::().as_ptr(); // SAFETY: We hold the GIL. exc is a valid Python type pointer // obtained from the interpreter's type registry. let result = unsafe { pyo3::ffi::PyThreadState_SetAsyncExc(py_thread_id, exc) }; match result { 0 => { tracing::warn!( py_thread_id, "PyThreadState_SetAsyncExc: thread not found (prediction may have completed)" ); } 1 => { tracing::debug!( py_thread_id, "Injected CancelationException into Python thread" ); } _ => { // CPython docs: if > 1, call again with NULL to reset tracing::error!( py_thread_id, count = result, "PyThreadState_SetAsyncExc modified multiple thread states, resetting" ); unsafe { pyo3::ffi::PyThreadState_SetAsyncExc(py_thread_id, std::ptr::null_mut()); } } } }); } /// Get the current Python thread identifier (for later use with `cancel_sync_thread`). /// /// Uses `threading.get_ident()` which returns the same value as /// `PyThreadState_SetAsyncExc` expects for the thread id argument. /// Can be called from any thread (acquires the GIL briefly). pub fn current_py_thread_id() -> std::ffi::c_long { Python::attach(|py| { let threading = py.import("threading").expect("failed to import threading"); threading .call_method0("get_ident") .expect("threading.get_ident() failed") .extract::() .expect("thread ident is not an integer") }) } ================================================ FILE: crates/coglet-python/src/input.rs ================================================ //! Input processing for cog predictors. //! //! This module handles file downloads for cog predictor inputs. //! Input validation is performed at the HTTP edge using the OpenAPI schema; //! the worker only needs to download URLPath inputs and pass them through. use std::collections::HashSet; use pyo3::prelude::*; use pyo3::types::PyDict; /// Type alias for Python object. type PyObject = Py; /// RAII wrapper for prepared input that cleans up temp files on drop. /// /// When URLPath inputs are downloaded, they create temp files. This struct /// ensures those files are cleaned up when the prediction completes (success, /// failure, or cancellation). pub struct PreparedInput { /// The prepared input dict (ready for predict(**kwargs)) dict: Py, /// Paths to cleanup on drop (downloaded temp files) cleanup_paths: Vec, } impl PreparedInput { /// Create a new PreparedInput with the given dict and paths to cleanup. pub fn new(dict: Py, cleanup_paths: Vec) -> Self { Self { dict, cleanup_paths, } } /// Get the input dict bound to the given Python context. pub fn dict<'py>(&self, py: Python<'py>) -> Bound<'py, PyDict> { self.dict.bind(py).clone() } } impl Drop for PreparedInput { fn drop(&mut self) { if self.cleanup_paths.is_empty() { return; } Python::attach(|py| { for path in &self.cleanup_paths { let path_bound = path.bind(py); let kwargs = PyDict::new(py); if kwargs.set_item("missing_ok", true).is_ok() && let Err(e) = path_bound.call_method("unlink", (), Some(&kwargs)) { tracing::warn!(error = %e, "Failed to cleanup temp file"); } } }); } } // Safety: PyObject is Send in PyO3 0.23+, we only access through Python::attach unsafe impl Send for PreparedInput {} /// Prepare input for prediction. /// /// Coerces URL strings to the appropriate cog types based on the function's /// type annotations: `File`-annotated params get `File.validate()` (IO-like), /// `Path`-annotated params get `Path.validate()` (filesystem path + download). /// Returns a PreparedInput that cleans up temp files on drop. /// /// Input validation is handled at the HTTP edge via the OpenAPI schema — /// this function only handles URL->Path/File coercion and file downloads. /// /// `func` is the Python predict/train callable used to introspect type annotations. pub fn prepare_input( py: Python<'_>, input: &Bound<'_, PyDict>, func: &Bound<'_, PyAny>, ) -> PyResult { let file_fields = detect_file_fields(py, func)?; coerce_url_strings(py, input, &file_fields)?; let cleanup_paths = download_url_paths_into_dict(py, input)?; Ok(PreparedInput::new(input.clone().unbind(), cleanup_paths)) } /// Inspect a Python function's type annotations to find parameters typed as /// `cog.File` (or `list[File]`, `Optional[File]`, `File | None`, /// `Optional[list[File]]`, etc.). Returns a set of field names that should use /// `File.validate()` instead of `Path.validate()`. fn detect_file_fields(py: Python<'_>, func: &Bound<'_, PyAny>) -> PyResult> { let mut file_fields = HashSet::new(); let cog_file_class = py.import("cog.types")?.getattr("File")?; // typing.get_type_hints resolves string annotations and handles forward refs let typing = py.import("typing")?; let get_type_hints = typing.getattr("get_type_hints")?; let get_origin = typing.getattr("get_origin")?; let get_args = typing.getattr("get_args")?; let builtins_list = py.eval(c"list", None, None)?; let union_type = typing.getattr("Union")?; let hints = match get_type_hints.call1((func,)) { Ok(h) => h, Err(_) => return Ok(file_fields), // If we can't get hints, don't coerce as File }; // Helper closure: returns true if `ty` is `File` or `list[File]`. let is_file_like = |ty: &Bound<'_, PyAny>| -> PyResult { if ty.is(&cog_file_class) { return Ok(true); } let inner_origin = get_origin.call1((ty,))?; if !inner_origin.is_none() && inner_origin.is(&builtins_list) { let inner_args = get_args.call1((ty,))?; if let Ok(t) = inner_args.cast::() && !t.is_empty() && t.get_item(0)?.is(&cog_file_class) { return Ok(true); } } Ok(false) }; let hints_dict = hints.cast::()?; for (name, annotation) in hints_dict.iter() { let name_str: String = match name.extract() { Ok(s) => s, Err(_) => continue, }; if name_str == "return" { continue; } // Direct File annotation: `param: File` // Also covers `list[File]` via is_file_like. if is_file_like(&annotation)? { file_fields.insert(name_str); continue; } // Union annotation: Optional[File], File | None, Optional[list[File]], etc. // typing.get_origin(Optional[X]) -> typing.Union // typing.get_args(Optional[X]) -> (X, NoneType) let origin = get_origin.call1((&annotation,))?; if !origin.is_none() && origin.is(&union_type) { let args = get_args.call1((&annotation,))?; if let Ok(args_tuple) = args.cast::() { for arg in args_tuple.iter() { // Skip NoneType if arg.is_none() || arg.is(py.None().into_bound(py)) { continue; } // Check if this variant is NoneType by comparing to type(None) let nonetype = py.eval(c"type(None)", None, None)?; if arg.is(&nonetype) { continue; } if is_file_like(&arg)? { file_fields.insert(name_str.clone()); break; } } } } } if !file_fields.is_empty() { tracing::debug!("Detected File-typed fields: {:?}", file_fields); } Ok(file_fields) } /// Coerce URL string values in the input dict to the appropriate cog types. /// /// After `json.loads()`, all values are plain Python types. URL strings /// (http://, https://, data:) that represent file inputs need to be converted: /// - `File`-typed fields -> `File.validate()` -> returns IO-like `URLFile` /// - `Path`-typed fields -> `Path.validate()` -> returns `URLPath` (downloaded later) /// /// This replaces the type coercion that `_adt.py`'s `PrimitiveType.normalize()` /// previously performed. fn coerce_url_strings( py: Python<'_>, payload: &Bound<'_, PyDict>, file_fields: &HashSet, ) -> PyResult<()> { let cog_types = py.import("cog.types")?; let path_validate = cog_types.getattr("Path")?.getattr("validate")?; let file_validate = cog_types.getattr("File")?.getattr("validate")?; for (key, value) in payload.iter() { let key_str: String = key.extract().unwrap_or_default(); let use_file = file_fields.contains(&key_str); let validate = if use_file { &file_validate } else { &path_validate }; // Single string value -- check if it's a URL if let Ok(s) = value.extract::() { if s.starts_with("http://") || s.starts_with("https://") || s.starts_with("data:") { let coerced = validate.call1((&value,))?; payload.set_item(&key, coerced)?; } } // List of strings -- check if any are URLs else if let Ok(list) = value.extract::>() { let mut any_coerced = false; let new_items = pyo3::types::PyList::empty(py); for item in list.iter() { if let Ok(s) = item.extract::() && (s.starts_with("http://") || s.starts_with("https://") || s.starts_with("data:")) { let coerced = validate.call1((&item,))?; new_items.append(coerced)?; any_coerced = true; continue; } new_items.append(item)?; } if any_coerced { payload.set_item(&key, new_items)?; } } } Ok(()) } /// Download URLPath inputs in parallel and replace them in the payload dict. /// /// This replicates the behavior from cog's worker.py: /// - Find all URLPath instances in the payload dict /// - Download them in parallel using ThreadPoolExecutor /// - Replace URLPath values with local Path in the dict /// /// Returns the downloaded Path objects for cleanup on drop. fn download_url_paths_into_dict( py: Python<'_>, payload: &Bound<'_, PyDict>, ) -> PyResult> { let cog_types = py.import("cog.types")?; let url_path_class = cog_types.getattr("URLPath")?; // Collect URLPath fields that need downloading // Structure: (key, value, is_list) let mut url_path_keys: Vec<(String, bool)> = Vec::new(); for (key, value) in payload.iter() { let key_str: String = key.extract()?; if value.is_instance(&url_path_class)? { url_path_keys.push((key_str, false)); } // Check for lists of URLPath else if let Ok(list) = value.extract::>() && !list.is_empty() { let all_url_paths = list .iter() .all(|item| item.is_instance(&url_path_class).unwrap_or(false)); if all_url_paths { url_path_keys.push((key_str, true)); } } } if url_path_keys.is_empty() { return Ok(Vec::new()); } tracing::debug!("Downloading {} URLPath input(s)", url_path_keys.len()); // Use ThreadPoolExecutor to download in parallel (like worker.py) let concurrent_futures = py.import("concurrent.futures")?; let executor_class = concurrent_futures.getattr("ThreadPoolExecutor")?; let executor = executor_class.call1((8,))?; // max_workers=8 // Structure to track futures: (key, future_or_futures, is_list) let mut futs: std::collections::HashMap>, bool)> = std::collections::HashMap::new(); let mut all_futures: Vec> = Vec::new(); for (key, is_list) in &url_path_keys { let value = payload.get_item(key)?.ok_or_else(|| { pyo3::exceptions::PyKeyError::new_err(format!( "Input key '{}' disappeared during processing", key )) })?; if *is_list { let list = value.extract::>()?; let mut futures_for_key = Vec::new(); for item in list.iter() { let convert_method = item.getattr("convert")?; let future = executor.call_method1("submit", (convert_method,))?; futures_for_key.push(future.clone()); all_futures.push(future); } futs.insert(key.clone(), (futures_for_key, true)); } else { let convert_method = value.getattr("convert")?; let future = executor.call_method1("submit", (convert_method,))?; all_futures.push(future.clone()); futs.insert(key.clone(), (vec![future], false)); } } // Wait for all futures let future_list = pyo3::types::PyList::new(py, &all_futures)?; let wait_fn = concurrent_futures.getattr("wait")?; let wait_result = wait_fn.call1((&future_list,))?; let done = wait_result.get_item(0)?; let not_done = wait_result.get_item(1)?; // Check for failures let not_done_len: usize = not_done.len()?; if not_done_len > 0 { // Cancel remaining and find the exception for item in not_done.try_iter()? { let fut = item?; let _ = fut.call_method0("cancel"); } // Find and raise the exception for item in done.try_iter()? { let fut = item?; fut.call_method0("result")?; // raises if future finished with exception } return Err(PyErr::new::( "Download failed", )); } // All downloads complete - replace URLPath with local Path in payload // Collect the Path objects for cleanup let mut cleanup_paths: Vec = Vec::new(); for (key, (futures, is_list)) in futs { if is_list { let mut results = Vec::new(); for fut in futures { let result = fut.call_method0("result")?; cleanup_paths.push(result.clone().unbind()); results.push(result); } let result_list = pyo3::types::PyList::new(py, &results)?; payload.set_item(&key, result_list)?; } else { let result = futures[0].call_method0("result")?; cleanup_paths.push(result.clone().unbind()); payload.set_item(&key, result)?; } } // Shutdown executor executor.call_method0("shutdown")?; tracing::debug!( "URLPath downloads complete, {} paths to cleanup", cleanup_paths.len() ); Ok(cleanup_paths) } #[cfg(test)] mod tests { use super::*; /// Helper: define a Python function with the given parameter annotations and /// return the set of field names that `detect_file_fields` identifies as File-typed. fn file_fields_for(py_func_src: &str) -> HashSet { pyo3::Python::initialize(); Python::attach(|py| { // Ensure cog.types.File is importable py.run(c"import cog.types", None, None) .expect("cog.types must be importable for tests"); let locals = PyDict::new(py); py.run( &std::ffi::CString::new(py_func_src).unwrap(), None, Some(&locals), ) .expect("failed to define test function"); let func = locals.get_item("func").unwrap().unwrap(); detect_file_fields(py, &func).expect("detect_file_fields failed") }) } #[test] #[ignore] // Requires cog Python package in PYTHONPATH fn detect_direct_file() { let fields = file_fields_for("from cog import File\ndef func(a: File, b: str): ..."); assert!(fields.contains("a"), "direct File annotation not detected"); assert!(!fields.contains("b"), "str incorrectly flagged as File"); } #[test] #[ignore] // Requires cog Python package in PYTHONPATH fn detect_list_file() { let fields = file_fields_for("from cog import File\ndef func(a: list[File]): ..."); assert!(fields.contains("a"), "list[File] annotation not detected"); } #[test] #[ignore] // Requires cog Python package in PYTHONPATH fn detect_optional_file() { let fields = file_fields_for( "from typing import Optional\nfrom cog import File\ndef func(a: Optional[File]): ...", ); assert!( fields.contains("a"), "Optional[File] annotation not detected" ); } #[test] #[ignore] // Requires cog Python package in PYTHONPATH fn detect_file_union_none() { let fields = file_fields_for( "from typing import Union\nfrom cog import File\ndef func(a: Union[File, None]): ...", ); assert!( fields.contains("a"), "File | None / Union[File, None] annotation not detected" ); } #[test] #[ignore] // Requires cog Python package in PYTHONPATH fn detect_optional_list_file() { let fields = file_fields_for( "from typing import Optional\nfrom cog import File\ndef func(a: Optional[list[File]]): ...", ); assert!( fields.contains("a"), "Optional[list[File]] annotation not detected" ); } #[test] #[ignore] // Requires cog Python package in PYTHONPATH fn non_file_types_not_detected() { let fields = file_fields_for( "from pathlib import Path\nfrom typing import Optional\ndef func(a: str, b: int, c: Optional[str], d: Path): ...", ); assert!( fields.is_empty(), "non-File types incorrectly detected: {:?}", fields ); } } ================================================ FILE: crates/coglet-python/src/lib.rs ================================================ //! coglet-python: PyO3 bindings for coglet. mod audit; mod cancel; mod input; mod log_writer; mod metric_scope; mod output; mod predictor; mod worker_bridge; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; use pyo3::prelude::*; use pyo3_stub_gen::derive::*; use tracing::{debug, error, info, warn}; use tracing_subscriber::{EnvFilter, fmt, layer::SubscriberExt, util::SubscriberInitExt}; // Define stub info gatherer for generating .pyi files pyo3_stub_gen::define_stub_info_gatherer!(stub_info); // Module-level attributes (pyo3-stub-gen can't see m.add() calls). // Uses "coglet" because that's the module key in StubInfo for the native module. pyo3_stub_gen::module_variable!("coglet", "__version__", &str); pyo3_stub_gen::module_variable!("coglet", "__build__", BuildInfo); pyo3_stub_gen::module_variable!("coglet", "server", CogletServer); use coglet_core::{ Health, PredictionService, SetupResult, VersionInfo, transport::{ServerConfig, serve as http_serve}, }; /// Global flag: true when running inside a worker subprocess. static ACTIVE: AtomicBool = AtomicBool::new(false); /// Frozen build metadata exposed as `coglet.__build__`. #[gen_stub_pyclass] #[pyclass(name = "BuildInfo", module = "coglet", frozen)] pub struct BuildInfo { #[pyo3(get)] version: String, #[pyo3(get)] git_sha: String, #[pyo3(get)] dirty: bool, #[pyo3(get)] build_time: String, #[pyo3(get)] rustc_version: String, } #[gen_stub_pymethods] #[pymethods] impl BuildInfo { fn __repr__(&self) -> String { format!( "BuildInfo(version='{}', git_sha='{}', dirty={}, build_time='{}', rustc_version='{}')", self.version, self.git_sha, if self.dirty { "True" } else { "False" }, self.build_time, self.rustc_version ) } } impl BuildInfo { fn new() -> Self { Self { version: env!("COGLET_PEP440_VERSION").to_string(), git_sha: env!("COGLET_GIT_SHA").to_string(), dirty: env!("COGLET_GIT_DIRTY") == "true", build_time: env!("COGLET_BUILD_TIME").to_string(), rustc_version: env!("COGLET_RUSTC_VERSION").to_string(), } } /// Git SHA with optional `-dirty` suffix. fn sha_display(&self) -> String { if self.dirty { format!("{}-dirty", self.git_sha) } else { self.git_sha.clone() } } } fn set_active() { ACTIVE.store(true, Ordering::SeqCst); } /// Initialize tracing with COG_LOG_LEVEL and LOG_FORMAT support. /// Returns optional receiver for draining setup logs. fn init_tracing( _to_stderr: bool, setup_log_tx: Option>, ) -> Option> { let filter = if std::env::var("RUST_LOG").is_ok() { EnvFilter::from_default_env() } else { let base_level = match std::env::var("COG_LOG_LEVEL").as_deref() { Ok("debug") => "debug", Ok("warn") | Ok("warning") => "warn", Ok("error") => "error", _ => "info", }; let filter_str = format!( "coglet={level},coglet::setup=info,coglet::user=info,coglet_worker={level},coglet_worker::schema=off,coglet_worker::protocol=off", level = base_level ); EnvFilter::new(filter_str) }; let use_json = std::env::var("LOG_FORMAT").as_deref() != Ok("console"); if let Some(tx) = setup_log_tx { let accumulator = coglet_core::SetupLogAccumulator::new(tx); if use_json { let subscriber = tracing_subscriber::registry() .with(filter) .with(accumulator) .with(fmt::layer().json().with_writer(std::io::stderr)); let _ = subscriber.try_init(); } else { let subscriber = tracing_subscriber::registry() .with(filter) .with(accumulator) .with(fmt::layer().with_writer(std::io::stderr)); let _ = subscriber.try_init(); } None } else { if use_json { let subscriber = tracing_subscriber::registry() .with(filter) .with(fmt::layer().json().with_writer(std::io::stderr)); let _ = subscriber.try_init(); } else { let subscriber = tracing_subscriber::registry() .with(filter) .with(fmt::layer().with_writer(std::io::stderr)); let _ = subscriber.try_init(); } None } } fn detect_version(py: Python<'_>, build: &BuildInfo) -> VersionInfo { let mut version = VersionInfo::new() .with_git_sha(build.sha_display()) .with_build_time(build.build_time.clone()); if let Ok(sys) = py.import("sys") && let Ok(py_version) = sys.getattr("version") && let Ok(v) = py_version.extract::() { let short_version = v.split_whitespace().next().unwrap_or(&v); version = version.with_python(short_version.to_string()); } if let Ok(cog) = py.import("cog") && let Ok(cog_version) = cog.getattr("__version__") && let Ok(v) = cog_version.extract::() { version = version.with_python_sdk(v); } version } fn read_max_concurrency() -> usize { match std::env::var("COG_MAX_CONCURRENCY") { Ok(val) => val.parse::().unwrap_or(1), Err(_) => 1, } } fn read_setup_timeout() -> Option { match std::env::var("COG_SETUP_TIMEOUT") { Ok(val) => match val.parse::() { Ok(0) => { warn!("COG_SETUP_TIMEOUT=0 would cause immediate timeout, ignoring"); None } Ok(secs) => Some(std::time::Duration::from_secs(secs)), Err(e) => { warn!( value = %val, error = %e, "Invalid COG_SETUP_TIMEOUT value, ignoring (no timeout will be applied)" ); None } }, Err(_) => None, } } // ============================================================================= // coglet.server — frozen Server object with serve() and active property // ============================================================================= /// The coglet prediction server. /// /// Access via `coglet.server`. Frozen — attributes cannot be set or deleted. /// /// - `coglet.server.active` — `True` when running inside a worker subprocess /// - `coglet.server.serve(...)` — start the HTTP prediction server (blocking) #[gen_stub_pyclass] #[pyclass(name = "Server", module = "coglet", frozen)] pub struct CogletServer {} #[gen_stub_pymethods] #[pymethods] impl CogletServer { /// `True` when running inside a coglet worker subprocess. #[getter] fn active(&self) -> bool { ACTIVE.load(Ordering::SeqCst) } /// Start the HTTP prediction server. Blocks until shutdown. #[allow(clippy::too_many_arguments)] #[pyo3(signature = (predictor_ref=None, host="0.0.0.0".to_string(), port=5000, await_explicit_shutdown=false, is_train=false, output_temp_dir_base="/tmp/coglet/output".to_string(), upload_url=None))] fn serve( &self, py: Python<'_>, predictor_ref: Option, host: String, port: u16, await_explicit_shutdown: bool, is_train: bool, output_temp_dir_base: String, upload_url: Option, ) -> PyResult<()> { serve_impl( py, predictor_ref, host, port, await_explicit_shutdown, is_train, output_temp_dir_base, upload_url, ) } /// Worker subprocess entry point. Called by the orchestrator. /// /// Sets the active flag, installs log writers and audit hooks, /// then enters the worker event loop. #[pyo3(name = "_run_worker", signature = ())] fn run_worker(&self, py: Python<'_>) -> PyResult<()> { set_active(); // Install SlotLogWriters for ContextVar-based log routing log_writer::install_slot_log_writers(py)?; // Install audit hook to protect stdout/stderr from user replacement if let Err(e) = audit::install_audit_hook(py) { warn!(error = %e, "Failed to install audit hook, stdout/stderr protection disabled"); } info!(target: "coglet::worker", "Worker subprocess starting, waiting for Init message"); py.detach(|| { let rt = tokio::runtime::Runtime::new() .map_err(|e| PyErr::new::(e.to_string()))?; rt.block_on(async { run_worker_with_init() .await .map_err(|e| PyErr::new::(e.to_string())) }) }) } fn __repr__(&self) -> &'static str { "coglet.server" } } #[allow(clippy::too_many_arguments)] fn serve_impl( py: Python<'_>, predictor_ref: Option, host: String, port: u16, await_explicit_shutdown: bool, is_train: bool, _output_temp_dir_base: String, upload_url: Option, ) -> PyResult<()> { let (setup_log_tx, setup_log_rx) = tokio::sync::mpsc::unbounded_channel(); init_tracing(false, Some(setup_log_tx)); let build = BuildInfo::new(); info!( "coglet {} ({}, built {}{})", env!("CARGO_PKG_VERSION"), build.sha_display(), build.build_time, if cfg!(debug_assertions) { ", debug" } else { "" }, ); let config = ServerConfig { host, port, await_explicit_shutdown, }; // Install Python SIGTERM handler if await_explicit_shutdown if await_explicit_shutdown { let signal_module = py.import("signal")?; let sigterm = signal_module.getattr("SIGTERM")?; let sig_ign = signal_module.getattr("SIG_IGN")?; signal_module.call_method1("signal", (sigterm, sig_ign))?; info!("await_explicit_shutdown: installed SIGTERM ignore handler"); } let version = detect_version(py, &build); info!( "python sdk {}", version.python_sdk.as_deref().unwrap_or("unknown") ); info!("python {}", version.python.as_deref().unwrap_or("unknown")); let Some(pred_ref) = predictor_ref else { info!("No predictor specified, serving health endpoints only"); let service = Arc::new( PredictionService::new_no_pool() .with_health(Health::Unknown) .with_version(version), ); return py.detach(|| { let rt = tokio::runtime::Runtime::new() .map_err(|e| PyErr::new::(e.to_string()))?; rt.block_on(async { http_serve(config, service) .await .map_err(|e| PyErr::new::(e.to_string())) }) }); }; info!(predictor_ref = %pred_ref, is_train, "Using subprocess isolation"); serve_subprocess( py, pred_ref, config, version, is_train, setup_log_rx, upload_url, ) } fn serve_subprocess( py: Python<'_>, pred_ref: String, config: ServerConfig, version: VersionInfo, is_train: bool, mut setup_log_rx: tokio::sync::mpsc::UnboundedReceiver, upload_url: Option, ) -> PyResult<()> { let max_concurrency = read_max_concurrency(); info!( max_concurrency, "Configuring subprocess worker via orchestrator" ); let setup_timeout = read_setup_timeout(); debug!( setup_timeout_secs = setup_timeout.map(|d| d.as_secs()), is_train, "Orchestrator configuration" ); let orch_config = coglet_core::orchestrator::OrchestratorConfig::new(pred_ref) .with_num_slots(max_concurrency) .with_train(is_train) .with_upload_url(upload_url) .with_setup_timeout(setup_timeout); let service = Arc::new( PredictionService::new_no_pool() .with_health(Health::Starting) .with_version(version), ); let service_clone = Arc::clone(&service); py.detach(|| { let rt = tokio::runtime::Runtime::new() .map_err(|e| PyErr::new::(e.to_string()))?; rt.block_on(async { let setup_result = SetupResult::starting(); service_clone.set_setup_result(setup_result.clone()).await; let setup_service = Arc::clone(&service_clone); tokio::spawn(async move { info!("Spawning worker subprocess"); let spawn_start = std::time::Instant::now(); match coglet_core::orchestrator::spawn_worker(orch_config, &mut setup_log_rx).await { Ok(ready) => { let spawn_elapsed = spawn_start.elapsed(); debug!( elapsed_ms = spawn_elapsed.as_millis() as u64, "Worker ready, configuring service" ); let num_slots = ready.handle.slot_ids().len(); debug!(num_slots, "Setting up orchestrator on service"); setup_service .set_orchestrator(ready.pool, Arc::new(ready.handle)) .await; debug!("Transitioning health to Ready"); setup_service.set_health(Health::Ready).await; if let Some(s) = ready.schema { debug!("Setting OpenAPI schema on service"); setup_service.set_schema(s).await; } else { debug!("No OpenAPI schema provided by worker"); } let mode = if is_train { "train" } else { "predict" }; info!(num_slots, mode, "Server ready"); // Drain final logs (includes "Server ready" above) let final_logs = coglet_core::drain_accumulated_logs(&mut setup_log_rx); debug!( initial_logs_len = ready.setup_logs.len(), final_logs_len = final_logs.len(), "Drained setup logs" ); drop(setup_log_rx); // Combine initial + final logs let complete_logs = ready.setup_logs + &final_logs; setup_service .set_setup_result(setup_result.succeeded(complete_logs)) .await; info!("Setup complete, now accepting requests"); } Err(e) => { let spawn_elapsed = spawn_start.elapsed(); error!( error = %e, elapsed_ms = spawn_elapsed.as_millis() as u64, "Worker initialization failed" ); debug!("Transitioning health to SetupFailed"); setup_service.set_health(Health::SetupFailed).await; setup_service .set_setup_result(setup_result.failed(e.to_string())) .await; } } }); http_serve(config, service_clone) .await .map_err(|e| PyErr::new::(e.to_string())) }) }) } async fn run_worker_with_init() -> Result<(), String> { use coglet_core::bridge::codec::JsonCodec; use coglet_core::bridge::protocol::ControlRequest; use futures::StreamExt; use tokio::io::stdin; use tokio_util::codec::FramedRead; let mut ctrl_reader = FramedRead::new(stdin(), JsonCodec::::new()); let init_msg = ctrl_reader .next() .await .ok_or_else(|| "stdin closed before Init received".to_string())? .map_err(|e| format!("Failed to read Init: {}", e))?; let (predictor_ref, num_slots, transport_info, is_train, _is_async) = match init_msg { ControlRequest::Init { predictor_ref, num_slots, transport_info, is_train, is_async, } => (predictor_ref, num_slots, transport_info, is_train, is_async), other => { return Err(format!("Expected Init message, got: {:?}", other)); } }; info!(predictor_ref = %predictor_ref, num_slots, is_train, "Init received, connecting to transport"); let handler = Arc::new(if is_train { worker_bridge::PythonPredictHandler::new_train(predictor_ref) .map_err(|e| format!("Failed to create handler: {}", e))? } else { worker_bridge::PythonPredictHandler::new(predictor_ref) .map_err(|e| format!("Failed to create handler: {}", e))? }); // Setup log hook: registers a global sender for control channel logs // This lives for the entire worker lifetime (setup + subprocess output) let setup_log_hook: coglet_core::SetupLogHook = Box::new(|tx| { let sender = Arc::new(log_writer::ControlChannelLogSender::new(tx)); log_writer::register_control_channel_sender(sender); // Cleanup is a no-op: sender stays registered for worker lifetime Box::new(|| {}) }); let config = coglet_core::WorkerConfig { num_slots, setup_log_hook: Some(setup_log_hook), }; coglet_core::run_worker(handler, config, transport_info) .await .map_err(|e| format!("Worker error: {}", e)) } // ============================================================================= // Module init // ============================================================================= #[pymodule] #[pyo3(name = "_impl")] fn coglet(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { // Control what `from ._impl import *` exports into coglet/__init__.py m.add("__all__", vec!["__version__", "__build__", "server"])?; // Static metadata m.add("__version__", env!("COGLET_PEP440_VERSION"))?; m.add("__build__", BuildInfo::new())?; // Frozen server object m.add("server", CogletServer {})?; // CancelationException — a BaseException subclass for prediction cancellation. // Re-exported through coglet → cog.exceptions → cog.CancelationException. m.add( "CancelationException", py.get_type::(), )?; // _sdk submodule — internal Python runtime integration classes let sdk = PyModule::new(py, "_sdk")?; sdk.setattr( "__doc__", "Internal SDK runtime integration for coglet.\n\ \n\ This submodule contains Rust-backed classes that integrate coglet with\n\ the Python runtime (I/O routing, audit hooks, log streaming). These are\n\ implementation details used by the cog SDK — not part of the public API.", )?; sdk.add_class::()?; sdk.add_class::()?; sdk.add_class::()?; sdk.add_class::()?; sdk.add_function(wrap_pyfunction!(metric_scope::py_current_scope, &sdk)?)?; m.add_submodule(&sdk)?; Ok(()) } ================================================ FILE: crates/coglet-python/src/log_writer.rs ================================================ //! Log routing via prediction_id ContextVar. //! //! Architecture: //! - Rust owns a ContextVar `_coglet_prediction_id` that holds the current prediction ID //! - Rust maintains a registry mapping prediction_id -> SlotSender //! - SlotLogWriter reads the ContextVar to route logs to the correct sender //! //! This design supports: //! - Async predictions with proper per-task isolation (ContextVar is task-local) //! - Orphan task detection (prediction completed but task still running) //! - Slot reuse safety (new prediction = new ID, old tasks can't pollute) //! - Setup logs routed through control channel before predictions start //! //! The ContextVar is private (`_coglet_` prefix). Users who need the prediction ID //! should use the public API (e.g., `cog.current_prediction_id()`) which we'll //! inject onto the cog namespace later. use std::collections::HashMap; use std::sync::{Arc, Mutex, OnceLock}; use pyo3::prelude::*; use pyo3_stub_gen::derive::*; use coglet_core::bridge::protocol::{ControlResponse, LogSource}; use coglet_core::worker::SlotSender; use tokio::sync::mpsc::Sender; // ============================================================================ // Rust-owned ContextVar for prediction routing // ============================================================================ /// The Rust-owned ContextVar instance. Created once, lives forever. /// Named `cog_prediction_id` - documented as internal, don't modify. static PREDICTION_CONTEXTVAR: OnceLock> = OnceLock::new(); /// Registry mapping prediction_id -> SlotSender. /// When a prediction starts, we register the sender. /// When SlotLogWriter.write() is called, we look up the sender here. static PREDICTION_REGISTRY: OnceLock>>> = OnceLock::new(); /// Current sync prediction ID. /// For sync predictions (single slot, blocking), there's exactly one active prediction. /// ContextVars don't work across separate Python::attach calls, so we use this. /// Protected by mutex since it's accessed from Python callbacks. static SYNC_PREDICTION_ID: OnceLock>> = OnceLock::new(); fn get_sync_prediction_id_slot() -> &'static Mutex> { SYNC_PREDICTION_ID.get_or_init(|| Mutex::new(None)) } /// Control channel log sender - used when outside prediction context. /// Set by worker before setup(), lives for entire worker lifetime. static CONTROL_CHANNEL_LOG_SENDER: OnceLock>>> = OnceLock::new(); fn get_registry() -> &'static Mutex>> { PREDICTION_REGISTRY.get_or_init(|| Mutex::new(HashMap::new())) } fn get_control_channel_sender_slot() -> &'static Mutex>> { CONTROL_CHANNEL_LOG_SENDER.get_or_init(|| Mutex::new(None)) } // ============================================================================ // ControlChannelLogSender - sends logs via control channel // ============================================================================ /// Sender for logs that go through the control channel. /// Used for Python logs during setup and subprocess output throughout worker lifetime. pub struct ControlChannelLogSender { tx: Sender, } impl ControlChannelLogSender { /// Create a new control channel log sender. pub fn new(tx: Sender) -> Self { Self { tx } } /// Try to send a log message. /// /// Uses try_send() to avoid blocking (called from Python code on tokio runtime threads). /// If the channel is full, the log is dropped and counted for periodic reporting. pub fn try_send_log(&self, source: LogSource, data: &str) { if self .tx .try_send(ControlResponse::Log { source, data: data.to_string(), }) .is_err() { coglet_core::worker::increment_dropped_log_count(); } } } // NOTE: All mutex locks in the worker use .expect(). // // If a mutex is poisoned (another thread panicked while holding it), the worker // is in an unrecoverable state. We cannot safely continue because: // - Log routing shares channels with prediction updates // - Prediction→slot mappings could be inconsistent // - Continuing risks cross-prediction data bleed // // The panic hook installed by coglet_core::worker sends a Fatal IPC message // to the parent (which poisons all slots) and aborts the process. /// Register the control channel log sender. /// Called by worker before setup(). pub fn register_control_channel_sender(sender: Arc) { let mut slot = get_control_channel_sender_slot() .lock() .expect("control_channel_sender mutex poisoned"); *slot = Some(sender); } /// Unregister the control channel log sender. /// Called by worker when shutting down (not after setup). #[allow(dead_code)] pub fn unregister_control_channel_sender() { let mut slot = get_control_channel_sender_slot() .lock() .expect("control_channel_sender mutex poisoned"); *slot = None; } /// Get the control channel log sender if registered. fn get_control_channel_sender() -> Option> { let slot = get_control_channel_sender_slot() .lock() .expect("control_channel_sender mutex poisoned"); slot.clone() } /// Get or create the Rust-owned ContextVar. /// /// This returns the same ContextVar instance used by SlotLogWriter for log routing. /// Public so predictor.rs can pass it to async coroutine wrappers. pub fn get_prediction_contextvar(py: Python<'_>) -> PyResult<&'static Py> { if let Some(cv) = PREDICTION_CONTEXTVAR.get() { return Ok(cv); } let contextvars = py.import("contextvars")?; let cv = contextvars.call_method1("ContextVar", ("_coglet_prediction_id",))?; // Try to store it. Race is fine - if another thread won, use their value. match PREDICTION_CONTEXTVAR.set(cv.unbind()) { Ok(()) => {} Err(_already_set) => { // Another thread initialized it first - that's fine } } // This should always succeed now - either we set it or another thread did. PREDICTION_CONTEXTVAR.get().ok_or_else(|| { pyo3::exceptions::PyRuntimeError::new_err( "Failed to initialize prediction context variable", ) }) } /// Register a SlotSender for a prediction ID. /// Called when starting a prediction. pub fn register_prediction(prediction_id: String, sender: Arc) { let mut registry = get_registry() .lock() .expect("prediction_registry mutex poisoned"); tracing::trace!(%prediction_id, "Registering prediction sender"); registry.insert(prediction_id, sender); } /// Unregister a prediction ID. /// Called when prediction completes. pub fn unregister_prediction(prediction_id: &str) { let mut registry = get_registry() .lock() .expect("prediction_registry mutex poisoned"); registry.remove(prediction_id); // Clear sync prediction ID if it matches let mut slot = get_sync_prediction_id_slot() .lock() .expect("sync_prediction_id mutex poisoned"); if slot.as_deref() == Some(prediction_id) { *slot = None; } } /// Get the SlotSender for a prediction ID. fn get_prediction_sender(prediction_id: &str) -> Option> { let registry = get_registry() .lock() .expect("prediction_registry mutex poisoned"); registry.get(prediction_id).cloned() } /// Set the current prediction ID in the ContextVar (for async). /// Returns a token that can be used to reset (for explicit cleanup). pub fn set_current_prediction(py: Python<'_>, prediction_id: &str) -> PyResult> { // Set ContextVar for async predictions let cv = get_prediction_contextvar(py)?; let token = cv.call_method1(py, "set", (prediction_id,))?; Ok(token) } /// Set the current sync prediction ID (for sync predictions only). /// Call this before running a sync prediction, clear after. pub fn set_sync_prediction_id(prediction_id: Option<&str>) { let mut slot = get_sync_prediction_id_slot() .lock() .expect("sync_prediction_id mutex poisoned"); *slot = prediction_id.map(|s| s.to_string()); } /// Get the current prediction ID from sync static or ContextVar. /// Returns None if not set (outside prediction context). fn get_current_prediction_id(py: Python<'_>) -> PyResult> { // First check sync prediction static (works for sync predictions) { let slot = get_sync_prediction_id_slot() .lock() .expect("sync_prediction_id mutex poisoned"); if let Some(ref prediction_id) = *slot { tracing::trace!(%prediction_id, "Sync prediction ID found"); return Ok(Some(prediction_id.clone())); } } // Fall back to ContextVar (works for async predictions) let cv = get_prediction_contextvar(py)?; // Try to get the value - returns the value or raises LookupError match cv.call_method0(py, "get") { Ok(val) => { let prediction_id: String = val.extract(py)?; tracing::trace!(%prediction_id, "ContextVar lookup succeeded"); Ok(Some(prediction_id)) } Err(e) if e.is_instance_of::(py) => { // ContextVar not set - outside prediction context Ok(None) } Err(e) => Err(e), } } // ============================================================================ // SlotLogWriter - routes via ContextVar lookup // ============================================================================ /// A Python file-like object that routes writes via the prediction_id ContextVar. /// /// This is installed as sys.stdout/stderr once at worker startup. /// Each write looks up the current prediction_id from the ContextVar and routes /// to the appropriate SlotSender. /// /// If no prediction_id is set, or the prediction has completed (orphan task), /// writes go to tracing (logged as orphan). /// /// Uses line buffering: accumulates writes until a newline is received, then /// emits complete lines. This coalesces Python's print() which does separate /// writes for content and newline. #[gen_stub_pyclass] #[pyclass(name = "_SlotLogWriter", module = "coglet._sdk")] pub struct SlotLogWriter { /// Which stream this captures (stdout or stderr). source: LogSource, /// Original stream (used for delegation of methods like isatty, fileno). original: Py, /// Whether writes should be ignored (used after errors). #[pyo3(get)] closed: bool, /// Line buffer for coalescing writes into complete lines. line_buffer: Mutex, } #[gen_stub_pymethods] #[pymethods] impl SlotLogWriter { /// Write data, routing to the appropriate destination. /// /// Uses line buffering: accumulates data until a newline is received, then /// emits complete lines. This coalesces Python's print() which does separate /// writes for content and the trailing newline. /// /// Priority for routing: /// 1. If inside a prediction (ContextVar set), route to slot sender /// 2. If setup sender registered, route to control channel /// 3. Fall back to stderr (for orphan tasks or unexpected cases) fn write(&self, py: Python<'_>, data: &str) -> PyResult { if self.closed || data.is_empty() { return Ok(data.len()); } let len = data.len(); // Append to line buffer and extract complete lines let complete = { let mut buffer = self.line_buffer.lock().expect("line_buffer mutex poisoned"); buffer.push_str(data); // Check if we have complete lines to emit if let Some(last_newline) = buffer.rfind('\n') { // Extract complete lines (including the newline) let complete = buffer[..=last_newline].to_string(); // Keep remainder in buffer let remainder = buffer[last_newline + 1..].to_string(); *buffer = remainder; Some(complete) } else { None } }; // Emit complete lines (outside lock) if let Some(complete) = complete { self.emit_data(py, &complete)?; } Ok(len) } /// Emit data to the appropriate destination. fn emit_data(&self, py: Python<'_>, data: &str) -> PyResult<()> { if data.is_empty() { return Ok(()); } // Try to get current prediction from ContextVar match get_current_prediction_id(py)? { Some(prediction_id) => { // Have prediction ID - check if still active if let Some(sender) = get_prediction_sender(&prediction_id) { // Active prediction - route to slot tracing::trace!( prediction_id = %prediction_id, source = ?self.source, bytes = data.len(), "Log routed to slot" ); sender .send_log(self.source, data) .map_err(|e| pyo3::exceptions::PyIOError::new_err(e.to_string()))?; } else { // Orphan task - prediction completed but task still running tracing::trace!( prediction_id = %prediction_id, source = ?self.source, "Orphan log (prediction completed)" ); self.write_outside_prediction(py, data)?; } } None => { // Outside prediction context // Try setup sender (for setup logs), then fallback to stderr self.write_outside_prediction(py, data)?; } } Ok(()) } /// Flush the stream. /// /// Emits any buffered content that hasn't been terminated with a newline. fn flush(&self, py: Python<'_>) -> PyResult<()> { // Emit any buffered content let buffered = { let mut buffer = self.line_buffer.lock().expect("line_buffer mutex poisoned"); std::mem::take(&mut *buffer) }; if !buffered.is_empty() { self.emit_data(py, &buffered)?; } // Flush the original stream self.original.call_method0(py, "flush")?; Ok(()) } /// Return whether the stream is readable. fn readable(&self) -> bool { false } /// Return whether the stream is writable. fn writable(&self) -> bool { !self.closed } /// Return whether the stream is seekable. fn seekable(&self) -> bool { false } /// Return whether the stream is a TTY. fn isatty(&self, py: Python<'_>) -> PyResult { // Delegate to original let result = self.original.call_method0(py, "isatty")?; result.extract(py) } /// Return the file number. fn fileno(&self, py: Python<'_>) -> PyResult { // Delegate to original - needed for some libraries let result = self.original.call_method0(py, "fileno")?; result.extract(py) } /// Close the stream. fn close(&mut self) { self.closed = true; } /// Context manager enter. fn __enter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { slf } /// Context manager exit. fn __exit__( &mut self, _exc_type: Option<&Bound<'_, PyAny>>, _exc_val: Option<&Bound<'_, PyAny>>, _exc_tb: Option<&Bound<'_, PyAny>>, ) -> bool { false // Don't suppress exceptions } /// Encoding property - needed for compatibility. #[getter] fn encoding(&self, py: Python<'_>) -> PyResult> { match self.original.getattr(py, "encoding") { Ok(enc) => enc.extract(py), Err(_) => Ok(Some("utf-8".to_string())), } } /// Newlines property - needed for compatibility. #[getter] fn newlines(&self) -> Option { None } /// Buffer property - some code checks for this. #[getter] fn buffer(&self, py: Python<'_>) -> PyResult> { // Return original's buffer if it has one, otherwise return self match self.original.getattr(py, "buffer") { Ok(buf) => Ok(buf), Err(_) => Ok(self.original.clone_ref(py)), } } } impl SlotLogWriter { /// Create a new stdout writer. pub fn new_stdout(original: Py) -> Self { Self { source: LogSource::Stdout, original, closed: false, line_buffer: Mutex::new(String::new()), } } /// Create a new stderr writer. pub fn new_stderr(original: Py) -> Self { Self { source: LogSource::Stderr, original, closed: false, line_buffer: Mutex::new(String::new()), } } /// Write when outside prediction context. /// /// During setup: routes to control channel (for health-check). /// Otherwise: emits via tracing to stderr locally (not shipped). fn write_outside_prediction(&self, _py: Python<'_>, data: &str) -> PyResult<()> { // Try control channel sender (registered for worker lifetime) if let Some(sender) = get_control_channel_sender() { sender.try_send_log(self.source, data); tracing::trace!( source = ?self.source, bytes = data.len(), "Log routed via control channel" ); return Ok(()); } // Outside setup/prediction context - orphan log // This happens with orphan tasks or edge cases for line in data.lines() { tracing::info!(target: "coglet::user", "{}", line); } Ok(()) } } // ============================================================================ // Installation - called once at worker startup // ============================================================================ /// Install SlotLogWriters as sys.stdout/stderr. /// Called once at worker startup. The writers persist for the lifetime of the process. /// Returns true if installation succeeded. pub fn install_slot_log_writers(py: Python<'_>) -> PyResult { let sys = py.import("sys")?; // Get originals let original_stdout = sys.getattr("stdout")?.unbind(); let original_stderr = sys.getattr("stderr")?.unbind(); // Create writers let stdout_writer = SlotLogWriter::new_stdout(original_stdout); let stderr_writer = SlotLogWriter::new_stderr(original_stderr); // Install sys.setattr("stdout", stdout_writer.into_pyobject(py)?)?; sys.setattr("stderr", stderr_writer.into_pyobject(py)?)?; // Initialize the ContextVar get_prediction_contextvar(py)?; tracing::debug!("Installed SlotLogWriters with prediction_id routing"); Ok(true) } // ============================================================================ // PredictionLogGuard - RAII guard for prediction context // ============================================================================ /// RAII guard that sets the current prediction in the ContextVar. /// /// On creation, registers the SlotSender and sets the ContextVar. /// On drop, unregisters the prediction (but ContextVar reset is automatic for async). pub struct PredictionLogGuard { prediction_id: String, #[allow(dead_code)] token: Py, } impl PredictionLogGuard { /// Enter prediction context. /// /// Registers the sender and sets the ContextVar. pub fn enter(py: Python<'_>, prediction_id: String, sender: Arc) -> PyResult { // Register sender in global registry register_prediction(prediction_id.clone(), sender); // Set ContextVar let token = set_current_prediction(py, &prediction_id)?; tracing::trace!(%prediction_id, "Entered prediction log context"); Ok(Self { prediction_id, token, }) } /// Get the prediction ID. #[allow(dead_code)] pub fn prediction_id(&self) -> &str { &self.prediction_id } } impl Drop for PredictionLogGuard { fn drop(&mut self) { // Unregister prediction - this makes orphan tasks fall back to stderr unregister_prediction(&self.prediction_id); // Note: We don't reset the ContextVar here because: // 1. For sync: the context resets naturally when the function returns // 2. For async: each task has its own ContextVar copy, no reset needed // The token is kept just in case we need explicit reset in the future. } } // ============================================================================ // Tests // ============================================================================ #[cfg(test)] mod tests { use super::*; use coglet_core::bridge::protocol::SlotResponse; use tokio::sync::mpsc; #[test] fn registry_operations() { let prediction_id = "pred_123".to_string(); let (tx, _rx) = mpsc::unbounded_channel(); let sender = Arc::new(SlotSender::new(tx, std::env::temp_dir())); // Register register_prediction(prediction_id.clone(), sender.clone()); assert!(get_prediction_sender(&prediction_id).is_some()); // Unregister unregister_prediction(&prediction_id); assert!(get_prediction_sender(&prediction_id).is_none()); } #[test] fn slot_sender_sends_log() { let (tx, mut rx) = mpsc::unbounded_channel(); let sender = SlotSender::new(tx, std::env::temp_dir()); sender.send_log(LogSource::Stdout, "hello").unwrap(); let msg = rx.try_recv().unwrap(); match msg { SlotResponse::Log { source, data } => { assert_eq!(source, LogSource::Stdout); assert_eq!(data, "hello"); } _ => panic!("expected Log message"), } } #[test] fn slot_sender_ignores_empty() { let (tx, mut rx) = mpsc::unbounded_channel(); let sender = SlotSender::new(tx, std::env::temp_dir()); sender.send_log(LogSource::Stderr, "").unwrap(); // No message should be sent assert!(rx.try_recv().is_err()); } #[test] fn slot_sender_detects_closed_channel() { let (tx, rx) = mpsc::unbounded_channel::(); drop(rx); // Close receiver let sender = SlotSender::new(tx, std::env::temp_dir()); let result = sender.send_log(LogSource::Stdout, "hello"); assert!(result.is_err()); } } ================================================ FILE: crates/coglet-python/src/metric_scope.rs ================================================ //! Metric scope: type-safe metric recording with ContextVar routing. //! //! Two PyO3 classes: //! - `Scope` — the per-prediction context, obtained via `current_scope()` //! - `MetricRecorder` — the `scope.metrics` sub-object with type invariant //! enforcement, dict-style access, and accumulation modes //! //! All validation happens in Rust (PyO3, in-process). IPC sends the validated //! metric to the coglet server via SlotSender. use std::collections::HashMap; use std::sync::{Arc, Mutex, OnceLock}; use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; use pyo3::types::PyDict; use pyo3_stub_gen::derive::*; use coglet_core::bridge::protocol::MetricMode; use coglet_core::worker::SlotSender; // ============================================================================ // Value type tracking for type invariant // ============================================================================ /// Coarse type tag for enforcing the type invariant. /// Once a key is set with a type, it cannot be changed without deleting first. #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum MetricValueType { Bool, Int, Float, Str, List, Dict, } impl MetricValueType { /// Classify a Python object into a type tag. fn from_py(obj: &Bound<'_, PyAny>) -> PyResult { // Order matters: bool before int (bool is a subclass of int in Python) if obj.is_instance_of::() { Ok(Self::Bool) } else if obj.is_instance_of::() { Ok(Self::Int) } else if obj.is_instance_of::() { Ok(Self::Float) } else if obj.is_instance_of::() { Ok(Self::Str) } else if obj.is_instance_of::() { Ok(Self::List) } else if obj.is_instance_of::() { Ok(Self::Dict) } else { let type_name = obj.get_type().name()?.to_string(); Err(PyTypeError::new_err(format!( "Unsupported metric value type: {}. Expected bool, int, float, str, list, or dict.", type_name ))) } } fn as_str(&self) -> &'static str { match self { Self::Bool => "bool", Self::Int => "int", Self::Float => "float", Self::Str => "str", Self::List => "list", Self::Dict => "dict", } } } // ============================================================================ // MetricRecorder — scope.metrics sub-object // ============================================================================ /// Metric recorder with type invariant enforcement. /// /// Accessed via `scope.metrics`. Supports: /// - `scope.metrics.record(key, value, mode="replace")` — full API /// - `scope.metrics.delete(key)` — delete (required before type change) /// - `scope.metrics[key] = value` — dict-style set (replace mode) /// - `del scope.metrics[key]` — dict-style delete #[gen_stub_pyclass] #[pyclass(name = "MetricRecorder", module = "coglet._sdk")] pub struct MetricRecorder { inner: Mutex>, } struct RecorderInner { /// Type tag per metric key — enforces type invariant. types: HashMap, /// IPC sender to the coglet server. sender: Arc, } impl MetricRecorder { pub fn new(sender: Arc) -> Self { Self { inner: Mutex::new(Some(RecorderInner { types: HashMap::new(), sender, })), } } pub fn noop() -> Self { Self { inner: Mutex::new(None), } } } #[gen_stub_pymethods] #[pymethods] impl MetricRecorder { /// Record a metric value. /// /// Args: /// key: Metric name. Dot-separated keys (e.g. "timing.preprocess") create /// nested objects in the response. /// value: Must be bool, int, float, str, list, or dict. Once a key is set /// with a type, it cannot be changed without calling delete() first. /// mode: Accumulation mode — "replace" (default), "incr" (increment numeric), /// or "append" (push to array). #[pyo3(signature = (key, value, mode=None))] fn record( &self, py: Python<'_>, key: &str, value: &Bound<'_, PyAny>, mode: Option<&str>, ) -> PyResult<()> { let mode = parse_mode(mode)?; let mut guard = self.inner.lock().expect("metric_recorder mutex poisoned"); let Some(inner) = guard.as_mut() else { return Ok(()); // no-op outside prediction }; record_impl(py, inner, key, value, mode) } /// Delete a metric key. Required before changing a metric's type. fn delete(&self, key: &str) -> PyResult<()> { let mut guard = self.inner.lock().expect("metric_recorder mutex poisoned"); let Some(inner) = guard.as_mut() else { return Ok(()); }; delete_impl(inner, key) } /// Dict-style set: `scope.metrics["key"] = value` fn __setitem__(&self, py: Python<'_>, key: &str, value: &Bound<'_, PyAny>) -> PyResult<()> { if value.is_none() { return self.delete(key); } let mut guard = self.inner.lock().expect("metric_recorder mutex poisoned"); let Some(inner) = guard.as_mut() else { return Ok(()); }; record_impl(py, inner, key, value, MetricMode::Replace) } /// Dict-style delete: `del scope.metrics["key"]` fn __delitem__(&self, key: &str) -> PyResult<()> { self.delete(key) } fn __repr__(&self) -> String { let guard = self.inner.lock().expect("metric_recorder mutex poisoned"); match guard.as_ref() { Some(inner) => format!("MetricRecorder(keys={})", inner.types.len()), None => "MetricRecorder(inactive)".to_string(), } } } // ============================================================================ // Scope — the per-prediction context // ============================================================================ /// Prediction scope, obtained via `current_scope()`. /// /// Provides access to `scope.metrics` for recording metrics, /// `scope.record_metric()` as a convenience shorthand, and /// `scope.context` for per-prediction context passed in the request. #[gen_stub_pyclass] #[pyclass(name = "Scope", module = "coglet._sdk")] pub struct Scope { metrics_recorder: Py, /// Per-prediction context from the request body (`dict[str, str]`). context: Py, } impl Scope { pub fn new( py: Python<'_>, sender: Arc, context: HashMap, ) -> PyResult { let recorder = Py::new(py, MetricRecorder::new(sender))?; let dict = PyDict::new(py); for (k, v) in &context { dict.set_item(k, v)?; } Ok(Self { metrics_recorder: recorder, context: dict.unbind(), }) } pub fn noop(py: Python<'_>) -> PyResult { let recorder = Py::new(py, MetricRecorder::noop())?; let dict = PyDict::new(py); Ok(Self { metrics_recorder: recorder, context: dict.unbind(), }) } } #[gen_stub_pymethods] #[pymethods] impl Scope { /// The metric recorder for this prediction. #[getter] fn metrics(&self, py: Python<'_>) -> Py { self.metrics_recorder.clone_ref(py) } /// Per-prediction context passed in the request body. /// /// Returns a `dict[str, str]` (empty dict if no context was provided). #[getter] fn context(&self, py: Python<'_>) -> Py { self.context.clone_ref(py) } /// Convenience: record a metric value. /// /// Equivalent to `scope.metrics.record(key, value, mode)`. #[pyo3(signature = (key, value, mode=None))] fn record_metric( &self, py: Python<'_>, key: &str, value: &Bound<'_, PyAny>, mode: Option<&str>, ) -> PyResult<()> { self.metrics_recorder .borrow(py) .record(py, key, value, mode) } fn __repr__(&self, py: Python<'_>) -> String { let recorder = self.metrics_recorder.borrow(py); format!("Scope({})", recorder.__repr__()) } } // ============================================================================ // Shared implementation // ============================================================================ fn parse_mode(mode: Option<&str>) -> PyResult { match mode { None | Some("replace") => Ok(MetricMode::Replace), Some("incr") | Some("increment") => Ok(MetricMode::Increment), Some("append") => Ok(MetricMode::Append), Some(other) => Err(PyTypeError::new_err(format!( "Invalid metric mode: '{}'. Expected 'replace', 'incr', or 'append'.", other ))), } } fn record_impl( _py: Python<'_>, inner: &mut RecorderInner, key: &str, value: &Bound<'_, PyAny>, mode: MetricMode, ) -> PyResult<()> { let value_type = MetricValueType::from_py(value)?; // Type invariant check if let Some(existing_type) = inner.types.get(key) && *existing_type != value_type { return Err(PyTypeError::new_err(format!( "Metric '{}' has type {}, cannot set to {} without deleting first", key, existing_type.as_str(), value_type.as_str(), ))); } // Mode-specific validation if mode == MetricMode::Increment && !matches!(value_type, MetricValueType::Int | MetricValueType::Float) { return Err(PyTypeError::new_err(format!( "Increment mode requires int or float, got {}", value_type.as_str() ))); } let json_value = py_to_json(value)?; inner.types.insert(key.to_string(), value_type); inner .sender .send_metric(key.to_string(), json_value, mode) .map_err(|e| pyo3::exceptions::PyIOError::new_err(format!("Failed to send metric: {}", e))) } fn delete_impl(inner: &mut RecorderInner, key: &str) -> PyResult<()> { inner.types.remove(key); inner .sender .send_metric( key.to_string(), serde_json::Value::Null, MetricMode::Replace, ) .map_err(|e| { pyo3::exceptions::PyIOError::new_err(format!("Failed to send metric delete: {}", e)) }) } // ============================================================================ // ContextVar-based routing (same pattern as log_writer.rs) // ============================================================================ /// Global ContextVar for the current Scope. static SCOPE_CONTEXTVAR: OnceLock> = OnceLock::new(); /// Current sync scope (for sync predictions where ContextVar doesn't work across attach calls). static SYNC_SCOPE: OnceLock>>> = OnceLock::new(); fn get_sync_scope_slot() -> &'static Mutex>> { SYNC_SCOPE.get_or_init(|| Mutex::new(None)) } fn get_scope_contextvar(py: Python<'_>) -> PyResult<&'static Py> { if let Some(cv) = SCOPE_CONTEXTVAR.get() { return Ok(cv); } let contextvars = py.import("contextvars")?; let cv = contextvars.call_method1("ContextVar", ("_coglet_metric_scope",))?; match SCOPE_CONTEXTVAR.set(cv.unbind()) { Ok(()) => {} Err(_already_set) => {} } SCOPE_CONTEXTVAR.get().ok_or_else(|| { pyo3::exceptions::PyRuntimeError::new_err("Failed to initialize scope ContextVar") }) } /// Set the current scope in the ContextVar (for async predictions). pub fn set_current_scope(py: Python<'_>, scope: &Py) -> PyResult> { let cv = get_scope_contextvar(py)?; let token = cv.call_method1(py, "set", (scope,))?; Ok(token) } /// Set the current sync scope (for sync predictions). pub fn set_sync_scope(py: Python<'_>, scope: Option<&Py>) { let mut slot = get_sync_scope_slot() .lock() .expect("sync_scope mutex poisoned"); *slot = scope.map(|s| s.clone_ref(py)); } /// Clear the sync scope. pub fn clear_sync_scope() { let mut slot = get_sync_scope_slot() .lock() .expect("sync_scope mutex poisoned"); *slot = None; } /// Python-callable: get the current Scope. /// /// Returns the active scope if inside a prediction, or a no-op scope otherwise. #[gen_stub_pyfunction(module = "coglet._sdk")] #[pyfunction] #[pyo3(name = "current_scope")] pub fn py_current_scope(py: Python<'_>) -> PyResult> { // Try sync scope first { let slot = get_sync_scope_slot() .lock() .expect("sync_scope mutex poisoned"); if let Some(ref scope) = *slot { return Ok(scope.clone_ref(py)); } } // Try ContextVar (async predictions) if let Some(cv) = SCOPE_CONTEXTVAR.get() { match cv.call_method0(py, "get") { Ok(val) => { let scope: Py = val.extract(py)?; return Ok(scope); } Err(e) if e.is_instance_of::(py) => { // Not set — fall through to no-op } Err(e) => return Err(e), } } // Outside prediction context — return no-op scope Py::new(py, Scope::noop(py)?) } // ============================================================================ // RAII guard for prediction scope lifecycle // ============================================================================ /// RAII guard that manages the Scope for a prediction. /// /// On creation, creates a Scope with a MetricRecorder and sets it in /// ContextVar + sync scope. On drop, clears the scope and releases the /// Arc so the log-forwarder channel can close. pub struct ScopeGuard { scope: Py, #[allow(dead_code)] token: Py, } impl ScopeGuard { /// Enter scope for a prediction. pub fn enter( py: Python<'_>, sender: Arc, context: HashMap, ) -> PyResult { let scope = Py::new(py, Scope::new(py, sender, context)?)?; let token = set_current_scope(py, &scope)?; set_sync_scope(py, Some(&scope)); Ok(Self { scope, token }) } } impl Drop for ScopeGuard { fn drop(&mut self) { clear_sync_scope(); // Acquire the GIL to release the Arc held by the MetricRecorder. // Without this, the Py destructor may not run immediately (PyO3 // defers ref-count decrements when the GIL is not held), keeping the // SlotSender channel alive and blocking the log-forwarder shutdown. Python::attach(|py| { let scope = self.scope.borrow(py); let recorder = scope.metrics_recorder.borrow(py); let mut guard = recorder .inner .lock() .expect("metric_recorder mutex poisoned"); // Drop the RecorderInner (and its Arc) *guard = None; }); } } // ============================================================================ // Python → JSON conversion // ============================================================================ fn py_to_json(obj: &Bound<'_, PyAny>) -> PyResult { if obj.is_none() { Ok(serde_json::Value::Null) } else if obj.is_instance_of::() { Ok(serde_json::Value::Bool(obj.extract::()?)) } else if obj.is_instance_of::() { if let Ok(v) = obj.extract::() { Ok(serde_json::json!(v)) } else { Ok(serde_json::json!(obj.extract::()?)) } } else if obj.is_instance_of::() { Ok(serde_json::json!(obj.extract::()?)) } else if obj.is_instance_of::() { Ok(serde_json::Value::String(obj.extract::()?)) } else if obj.is_instance_of::() { let list = obj.cast::()?; let items: Vec = list .iter() .map(|item| py_to_json(&item)) .collect::>()?; Ok(serde_json::Value::Array(items)) } else if obj.is_instance_of::() { let dict = obj.cast::()?; let mut map = serde_json::Map::new(); for (k, v) in dict.iter() { let key: String = k.extract()?; map.insert(key, py_to_json(&v)?); } Ok(serde_json::Value::Object(map)) } else { let type_name = obj.get_type().name()?.to_string(); Err(PyTypeError::new_err(format!( "Cannot convert {} to JSON metric value", type_name ))) } } ================================================ FILE: crates/coglet-python/src/output.rs ================================================ //! Output processing for prediction results. //! //! Converts Python prediction output to JSON-serializable format: //! - Pydantic models -> dict (via model_dump() / dict()) //! - Dataclasses -> dict (via dataclasses.asdict()) //! - Enums -> .value //! - datetime -> .isoformat() //! - PathLike -> base64 data URL //! - IOBase -> base64 data URL //! - numpy int/float/ndarray -> Python int/float/list //! - dict/list/set/tuple/generator -> recursive descent //! //! This replaces the Python modules cog.json and cog.files. use pyo3::prelude::*; use pyo3::types::{PyDict, PyFrozenSet, PyList, PySet, PyString, PyTuple}; /// Process prediction output for JSON serialization. /// /// Calls make_encodeable() to normalize, then encode_files() to convert any /// remaining Path/IOBase objects to base64 data URLs. pub fn process_output<'py>( py: Python<'py>, output: &Bound<'py, PyAny>, ) -> PyResult> { let encodeable = make_encodeable(py, output)?; encode_files(py, &encodeable) } /// Process a single output item (for generator outputs). pub fn process_output_item<'py>( py: Python<'py>, item: &Bound<'py, PyAny>, ) -> PyResult> { process_output(py, item) } /// Normalize a Python object into a JSON-friendly form. /// /// Handles Pydantic models, dataclasses, enums, datetime, numpy types, /// and collections. PathLike objects are passed through (handled later /// by encode_files). fn make_encodeable<'py>(py: Python<'py>, obj: &Bound<'py, PyAny>) -> PyResult> { // Pydantic v2: model_dump() if let Ok(method) = obj.getattr("model_dump") && method.is_callable() { let dumped = method.call0()?; return make_encodeable(py, &dumped); } // Pydantic v1: dict() // Skip plain dicts -- they also have .dict but we handle those below if !obj.is_instance_of::() && let Ok(method) = obj.getattr("dict") && method.is_callable() { let dumped = method.call0()?; return make_encodeable(py, &dumped); } // dataclass instances (not the class itself) let dataclasses = py.import("dataclasses")?; let is_dataclass = dataclasses.getattr("is_dataclass")?; if is_dataclass.call1((obj,))?.is_truthy()? && !obj.is_instance(py.get_type::().as_any())? { let asdict = dataclasses.getattr("asdict")?; let d = asdict.call1((obj,))?; return make_encodeable(py, &d); } // dict if let Ok(dict) = obj.cast_exact::() { let new_dict = PyDict::new(py); for (key, value) in dict.iter() { new_dict.set_item(&key, make_encodeable(py, &value)?)?; } return Ok(new_dict.into_any()); } // list, set, frozenset, tuple, generator if obj.is_instance_of::() || obj.is_instance_of::() || obj.is_instance_of::() || obj.is_instance_of::() || is_generator(py, obj)? { let iter = obj.try_iter()?; let items: Vec> = iter .map(|item| make_encodeable(py, &item?)) .collect::>()?; let list = PyList::new(py, &items)?; return Ok(list.into_any()); } // Enum -> .value let enum_mod = py.import("enum")?; let enum_cls = enum_mod.getattr("Enum")?; if obj.is_instance(&enum_cls)? { return obj.getattr("value"); } // datetime -> .isoformat() let datetime_mod = py.import("datetime")?; let datetime_cls = datetime_mod.getattr("datetime")?; if obj.is_instance(&datetime_cls)? { return obj.call_method0("isoformat"); } // os.PathLike -> pathlib.Path (will be encoded to base64 later by encode_files) let os_mod = py.import("os")?; let pathlike_cls = os_mod.getattr("PathLike")?; if obj.is_instance(&pathlike_cls)? { let pathlib = py.import("pathlib")?; let path_cls = pathlib.getattr("Path")?; return path_cls.call1((obj,)); } // numpy types (optional) if let Ok(np) = py.import("numpy") && !obj.is_instance(py.get_type::().as_any())? { let np_integer = np.getattr("integer")?; if obj.is_instance(&np_integer)? { let builtins = py.import("builtins")?; return builtins.getattr("int")?.call1((obj,)); } let np_floating = np.getattr("floating")?; if obj.is_instance(&np_floating)? { let builtins = py.import("builtins")?; return builtins.getattr("float")?.call1((obj,)); } let np_ndarray = np.getattr("ndarray")?; if obj.is_instance(&np_ndarray)? { return obj.call_method0("tolist"); } } // Primitive / unknown -- pass through Ok(obj.clone()) } /// Recursively walk the output and encode any Path/IOBase objects to base64 data URLs. fn encode_files<'py>(py: Python<'py>, obj: &Bound<'py, PyAny>) -> PyResult> { // str -- return as-is (don't recurse into characters) if obj.is_instance_of::() { return Ok(obj.clone()); } // dict if let Ok(dict) = obj.cast_exact::() { let new_dict = PyDict::new(py); for (key, value) in dict.iter() { new_dict.set_item(&key, encode_files(py, &value)?)?; } return Ok(new_dict.into_any()); } // list if let Ok(list) = obj.cast_exact::() { let items: Vec> = list .iter() .map(|item| encode_files(py, &item)) .collect::>()?; let new_list = PyList::new(py, &items)?; return Ok(new_list.into_any()); } // os.PathLike -> open and base64 encode let os_mod = py.import("os")?; let pathlike_cls = os_mod.getattr("PathLike")?; if obj.is_instance(&pathlike_cls)? { let builtins = py.import("builtins")?; let fh = builtins.getattr("open")?.call1((obj, "rb"))?; let result = file_to_base64(py, &fh); fh.call_method0("close")?; return result; } // io.IOBase -> base64 encode let io_mod = py.import("io")?; let iobase_cls = io_mod.getattr("IOBase")?; if obj.is_instance(&iobase_cls)? { return file_to_base64(py, obj); } // Primitive -- pass through Ok(obj.clone()) } /// Encode a file handle to a base64 data URL. /// /// Seeks to start if seekable, reads all bytes, guesses MIME type from /// the file name, and returns "data:{mime};base64,{encoded}". fn file_to_base64<'py>(py: Python<'py>, fh: &Bound<'py, PyAny>) -> PyResult> { // Seek to start if possible if let Ok(seekable) = fh.call_method0("seekable") && seekable.is_truthy()? { fh.call_method1("seek", (0,))?; } // Read content let content = fh.call_method0("read")?; let bytes: Vec = if content.is_instance_of::() { let s: String = content.extract()?; s.into_bytes() } else { content.extract()? }; // Guess MIME type from filename let mime_type = if let Ok(name) = fh.getattr("name") && !name.is_none() { let name_str: String = name.extract()?; let mimetypes = py.import("mimetypes")?; let guess = mimetypes.call_method1("guess_type", (&name_str,))?; let first = guess.get_item(0)?; if first.is_none() { "application/octet-stream".to_string() } else { first.extract()? } } else { "application/octet-stream".to_string() }; // Base64 encode use base64::Engine as _; let encoded = base64::engine::general_purpose::STANDARD.encode(&bytes); let data_url = format!("data:{mime_type};base64,{encoded}"); Ok(PyString::new(py, &data_url).into_any()) } /// Check if a Python object is a generator instance. fn is_generator<'py>(py: Python<'py>, obj: &Bound<'py, PyAny>) -> PyResult { let types_mod = py.import("types")?; let gen_type = types_mod.getattr("GeneratorType")?; obj.is_instance(&gen_type) } #[cfg(test)] mod tests { use super::*; use base64::Engine as _; use pyo3::types::PyDict; /// Helper: evaluate a Python expression and run make_encodeable on it. fn encodeable(py_expr: &str) -> String { pyo3::Python::initialize(); Python::attach(|py| { let locals = PyDict::new(py); // Run any setup + expression, storing result in `obj` let code = format!("import json\n{}\nresult = obj", py_expr); py.run(&std::ffi::CString::new(code).unwrap(), None, Some(&locals)) .expect("failed to evaluate test expression"); let obj = locals.get_item("result").unwrap().unwrap(); let encoded = make_encodeable(py, &obj).expect("make_encodeable failed"); // Convert to JSON string for easy assertion let json_mod = py.import("json").unwrap(); let json_str = json_mod .call_method1("dumps", (&encoded,)) .expect("json.dumps failed"); json_str.extract::().unwrap() }) } /// Helper: evaluate a Python expression and run process_output on it. fn processed(py_expr: &str) -> String { pyo3::Python::initialize(); Python::attach(|py| { let locals = PyDict::new(py); let code = format!("import json\n{}\nresult = obj", py_expr); py.run(&std::ffi::CString::new(code).unwrap(), None, Some(&locals)) .expect("failed to evaluate test expression"); let obj = locals.get_item("result").unwrap().unwrap(); let output = process_output(py, &obj).expect("process_output failed"); let json_mod = py.import("json").unwrap(); let json_str = json_mod .call_method1("dumps", (&output,)) .expect("json.dumps failed"); json_str.extract::().unwrap() }) } // ── make_encodeable: primitives ────────────────────────────────── #[test] fn encodeable_string() { assert_eq!(encodeable("obj = 'hello'"), r#""hello""#); } #[test] fn encodeable_int() { assert_eq!(encodeable("obj = 42"), "42"); } #[test] fn encodeable_float() { assert_eq!(encodeable("obj = 3.14"), "3.14"); } #[test] fn encodeable_bool() { assert_eq!(encodeable("obj = True"), "true"); } #[test] fn encodeable_none() { assert_eq!(encodeable("obj = None"), "null"); } // ── make_encodeable: collections ───────────────────────────────── #[test] fn encodeable_list() { assert_eq!(encodeable("obj = [1, 2, 3]"), "[1, 2, 3]"); } #[test] fn encodeable_dict() { assert_eq!( encodeable(r#"obj = {"a": 1, "b": 2}"#), r#"{"a": 1, "b": 2}"# ); } #[test] fn encodeable_tuple_to_list() { assert_eq!(encodeable("obj = (1, 2, 3)"), "[1, 2, 3]"); } #[test] fn encodeable_set_to_list() { // Set with single element to avoid ordering issues assert_eq!(encodeable("obj = {42}"), "[42]"); } #[test] fn encodeable_frozenset_to_list() { assert_eq!(encodeable("obj = frozenset([99])"), "[99]"); } #[test] fn encodeable_nested_dict() { assert_eq!( encodeable(r#"obj = {"outer": {"inner": [1, 2]}}"#), r#"{"outer": {"inner": [1, 2]}}"# ); } // ── make_encodeable: enum ──────────────────────────────────────── #[test] fn encodeable_enum() { assert_eq!( encodeable("import enum\nclass Color(enum.Enum):\n RED = 'red'\nobj = Color.RED"), r#""red""# ); } #[test] fn encodeable_int_enum() { assert_eq!( encodeable( "import enum\nclass Priority(enum.IntEnum):\n HIGH = 1\nobj = Priority.HIGH" ), "1" ); } // ── make_encodeable: datetime ──────────────────────────────────── #[test] fn encodeable_datetime() { let result = encodeable("from datetime import datetime\nobj = datetime(2025, 1, 15, 10, 30, 0)"); assert_eq!(result, r#""2025-01-15T10:30:00""#); } // ── make_encodeable: dataclass ─────────────────────────────────── #[test] fn encodeable_dataclass() { assert_eq!( encodeable( "from dataclasses import dataclass\n\ @dataclass\n\ class Point:\n\ \tx: int\n\ \ty: int\n\ obj = Point(x=1, y=2)" ), r#"{"x": 1, "y": 2}"# ); } #[test] fn encodeable_nested_dataclass() { assert_eq!( encodeable( "from dataclasses import dataclass, asdict\n\ @dataclass\n\ class Inner:\n\ \tval: str\n\ # Build nested via dict so class scoping isn't an issue\n\ obj = {'inner': asdict(Inner(val='hello')), 'name': 'test'}" ), r#"{"inner": {"val": "hello"}, "name": "test"}"# ); } // ── make_encodeable: generator ─────────────────────────────────── #[test] fn encodeable_generator() { assert_eq!(encodeable("obj = (x * 2 for x in range(3))"), "[0, 2, 4]"); } // ── make_encodeable: enum value in collection ──────────────────── #[test] fn encodeable_enum_in_list() { assert_eq!( encodeable( "import enum\n\ class Status(enum.Enum):\n\ \tOK = 'ok'\n\ \tERR = 'err'\n\ obj = [Status.OK, Status.ERR]" ), r#"["ok", "err"]"# ); } // ── encode_files / file_to_base64 ──────────────────────────────── #[test] fn encode_pathlike_to_base64() { let result = processed( "import tempfile, pathlib\n\ f = tempfile.NamedTemporaryFile(suffix='.txt', delete=False)\n\ f.write(b'hello world')\n\ f.close()\n\ obj = pathlib.Path(f.name)", ); assert!( result.starts_with(r#""data:text/plain;base64,"#), "expected data URL, got: {result}" ); // Verify the base64 content decodes correctly let b64_part = result.trim_matches('"').split(",").nth(1).unwrap(); let decoded = base64::engine::general_purpose::STANDARD .decode(b64_part) .unwrap(); assert_eq!(decoded, b"hello world"); } #[test] fn encode_iobase_to_base64() { let result = processed( "import io\n\ obj = io.BytesIO(b'test bytes')", ); assert!( result.starts_with(r#""data:application/octet-stream;base64,"#), "expected data URL, got: {result}" ); let b64_part = result.trim_matches('"').split(",").nth(1).unwrap(); let decoded = base64::engine::general_purpose::STANDARD .decode(b64_part) .unwrap(); assert_eq!(decoded, b"test bytes"); } #[test] fn encode_iobase_seeks_to_start() { let result = processed( "import io\n\ buf = io.BytesIO(b'rewind me')\n\ buf.read() # advance to end\n\ obj = buf", ); let b64_part = result.trim_matches('"').split(",").nth(1).unwrap(); let decoded = base64::engine::general_purpose::STANDARD .decode(b64_part) .unwrap(); assert_eq!(decoded, b"rewind me", "should seek to start before reading"); } #[test] fn encode_file_in_dict() { let result = processed( "import io\n\ obj = {'output': io.BytesIO(b'nested')}", ); // Parse the JSON to verify structure let parsed: serde_json::Value = serde_json::from_str(&result).unwrap(); let url = parsed["output"].as_str().unwrap(); assert!( url.starts_with("data:application/octet-stream;base64,"), "expected data URL in dict value" ); } #[test] fn encode_file_in_list() { let result = processed( "import io\n\ obj = [io.BytesIO(b'item1'), io.BytesIO(b'item2')]", ); let parsed: serde_json::Value = serde_json::from_str(&result).unwrap(); assert!(parsed.as_array().unwrap().len() == 2); for item in parsed.as_array().unwrap() { assert!(item.as_str().unwrap().starts_with("data:")); } } #[test] fn encode_string_passthrough() { // Strings should NOT be recursed into assert_eq!(processed("obj = 'just a string'"), r#""just a string""#); } #[test] fn encode_mime_type_guessing() { let result = processed( "import tempfile, pathlib\n\ f = tempfile.NamedTemporaryFile(suffix='.png', delete=False)\n\ f.write(b'\\x89PNG')\n\ f.close()\n\ obj = pathlib.Path(f.name)", ); assert!( result.contains("image/png"), "expected image/png MIME type, got: {result}" ); } // ── process_output: end-to-end ─────────────────────────────────── #[test] fn process_output_primitives_passthrough() { assert_eq!(processed("obj = 'hello'"), r#""hello""#); assert_eq!(processed("obj = 42"), "42"); assert_eq!(processed("obj = None"), "null"); } #[test] fn process_output_dataclass_with_file() { let result = processed( "from dataclasses import dataclass\n\ import pathlib, tempfile\n\ @dataclass\n\ class Output:\n\ \ttext: str\n\ \tdata: object\n\ f = tempfile.NamedTemporaryFile(suffix='.bin', delete=False)\n\ f.write(b'binary')\n\ f.close()\n\ obj = Output(text='result', data=pathlib.Path(f.name))", ); let parsed: serde_json::Value = serde_json::from_str(&result).unwrap(); assert_eq!(parsed["text"], "result"); assert!(parsed["data"].as_str().unwrap().starts_with("data:")); } } ================================================ FILE: crates/coglet-python/src/predictor.rs ================================================ //! Python predictor loading and invocation. use std::sync::{Arc, OnceLock}; use pyo3::prelude::*; use pyo3::types::PyDict; use coglet_core::worker::SlotSender; use coglet_core::{PredictionError, PredictionOutput, PredictionResult}; use crate::cancel; use crate::input::{self, PreparedInput}; use crate::output; // ============================================================================= // Async helper functions — defined as Python strings, initialized once. // // These must be Python `async def` functions to participate in asyncio's event // loop. They cannot be expressed as pure Rust because they use Python's async // iteration protocol and ContextVar.set() before awaiting a coroutine. // ============================================================================= /// Collects an async generator into a list. Initialized once, reused per-call. static COLLECT_ASYNC_GEN: OnceLock> = OnceLock::new(); /// Sets a ContextVar then awaits a coroutine. Initialized once, reused per-call. static CTX_WRAPPER: OnceLock> = OnceLock::new(); /// Get or initialize the `_collect_async_gen` Python helper. fn get_collect_async_gen(py: Python<'_>) -> Result, PredictionError> { if let Some(f) = COLLECT_ASYNC_GEN.get() { return Ok(f.clone_ref(py)); } let code = c"\ async def _collect_async_gen(agen): results = [] async for item in agen: results.append(item) return results "; let globals = PyDict::new(py); py.run(code, Some(&globals), None).map_err(|e| { PredictionError::Failed(format!("Failed to define _collect_async_gen: {e}")) })?; let f = globals .get_item("_collect_async_gen") .map_err(|e| PredictionError::Failed(format!("Failed to get _collect_async_gen: {e}")))? .ok_or_else(|| PredictionError::Failed("_collect_async_gen not found".to_string()))? .unbind(); let _ = COLLECT_ASYNC_GEN.set(f.clone_ref(py)); Ok(f) } /// Get or initialize the `_ctx_wrapper` Python helper. fn get_ctx_wrapper(py: Python<'_>) -> Result, PredictionError> { if let Some(f) = CTX_WRAPPER.get() { return Ok(f.clone_ref(py)); } let code = c"\ async def _ctx_wrapper(coro, prediction_id, contextvar): contextvar.set(prediction_id) return await coro "; let globals = PyDict::new(py); py.run(code, Some(&globals), None) .map_err(|e| PredictionError::Failed(format!("Failed to define _ctx_wrapper: {e}")))?; let f = globals .get_item("_ctx_wrapper") .map_err(|e| PredictionError::Failed(format!("Failed to get _ctx_wrapper: {e}")))? .ok_or_else(|| PredictionError::Failed("_ctx_wrapper not found".to_string()))? .unbind(); let _ = CTX_WRAPPER.set(f.clone_ref(py)); Ok(f) } /// Check if a PyErr is a CancelationException or asyncio.CancelledError. fn is_cancelation_exception(py: Python<'_>, err: &PyErr) -> bool { // Check for our static CancelationException type if err.is_instance_of::(py) { return true; } // Check for asyncio.CancelledError if let Ok(asyncio) = py.import("asyncio") && let Ok(cancelled_error) = asyncio.getattr("CancelledError") && err.is_instance(py, &cancelled_error) { return true; } false } /// Format a Python validation error. /// /// Cog validation errors are already formatted as "field: message". fn format_validation_error(py: Python<'_>, err: &PyErr) -> String { err.value(py).to_string() } /// Send a single output item over IPC, routing file outputs to disk. /// /// For Path outputs (os.PathLike): sends the existing file path via send_file_output. /// For IOBase outputs: reads bytes, writes to output_dir via write_file_output. /// For everything else: processes through make_encodeable + upload_files, then send_output. fn send_output_item( py: Python<'_>, item: &Bound<'_, PyAny>, json_module: &Bound<'_, PyAny>, slot_sender: &SlotSender, ) -> Result<(), PredictionError> { let os = py .import("os") .map_err(|e| PredictionError::Failed(format!("Failed to import os: {}", e)))?; let io_mod = py .import("io") .map_err(|e| PredictionError::Failed(format!("Failed to import io: {}", e)))?; let pathlike = os .getattr("PathLike") .map_err(|e| PredictionError::Failed(format!("Failed to get os.PathLike: {}", e)))?; let iobase = io_mod .getattr("IOBase") .map_err(|e| PredictionError::Failed(format!("Failed to get io.IOBase: {}", e)))?; if item.is_instance(&pathlike).unwrap_or(false) { // Path output — file already on disk, send path reference let path_str: String = item .call_method0("__fspath__") .and_then(|p| p.extract()) .map_err(|e| PredictionError::Failed(format!("Failed to get fspath: {}", e)))?; slot_sender .send_file_output(std::path::PathBuf::from(path_str), None) .map_err(|e| PredictionError::Failed(format!("Failed to send file output: {}", e)))?; return Ok(()); } if item.is_instance(&iobase).unwrap_or(false) { // IOBase output — read bytes, write to disk via SlotSender // Seek to start if seekable if item .call_method0("seekable") .and_then(|r| r.extract::()) .unwrap_or(false) { let _ = item.call_method1("seek", (0,)); } let data: Vec = item .call_method0("read") .and_then(|d| d.extract()) .map_err(|e| PredictionError::Failed(format!("Failed to read IOBase: {}", e)))?; // Try to guess extension from filename let ext = item .getattr("name") .and_then(|n| n.extract::()) .ok() .and_then(|name| { std::path::Path::new(&name) .extension() .and_then(|e| e.to_str()) .map(|s| s.to_string()) }) .unwrap_or_else(|| "bin".to_string()); slot_sender .write_file_output(&data, &ext, None) .map_err(|e| PredictionError::Failed(format!("Failed to write file output: {}", e)))?; return Ok(()); } // Non-file output - process normally let processed = output::process_output_item(py, item) .map_err(|e| PredictionError::Failed(format!("Failed to process output item: {}", e)))?; let item_str: String = json_module .call_method1("dumps", (&processed,)) .map_err(|e| PredictionError::Failed(format!("Failed to serialize output item: {}", e)))? .extract() .map_err(|e| PredictionError::Failed(format!("Failed to extract output string: {}", e)))?; let item_json: serde_json::Value = serde_json::from_str(&item_str) .map_err(|e| PredictionError::Failed(format!("Failed to parse output JSON: {}", e)))?; slot_sender .send_output(item_json) .map_err(|e| PredictionError::Failed(format!("Failed to send output: {}", e)))?; Ok(()) } /// Type alias for Python object (Py). type PyObject = Py; /// How a predict() method executes #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum PredictKind { /// Synchronous function: def predict(self, **input) -> Output Sync, /// Async coroutine: async def predict(self, **input) -> Output Async, /// Async generator: async def predict(self, **input) -> AsyncIterator[Output] AsyncGen, } /// Whether and how train() exists #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum TrainKind { /// No train() method None, /// Synchronous: def train(self, **input) -> Output Sync, /// Async: async def train(self, **input) -> Output Async, } /// The predictor's structure and invocation target #[derive(Debug, Clone, PartialEq, Eq)] pub enum PredictorKind { /// Class instance with predict() method, optionally train() Class { predict: PredictKind, train: TrainKind, }, /// Standalone function (e.g., train.py:train) /// The PredictKind describes how the function executes (sync/async/async_gen) StandaloneFunction(PredictKind), } /// A loaded Python predictor instance. /// /// Input coercion (URL->Path/File) and FieldInfo default unwrapping are handled /// in Rust. The Python `_adt` and `_inspector` modules are no longer called. pub struct PythonPredictor { instance: PyObject, /// The predictor's kind (class or standalone function) and method execution types kind: PredictorKind, } // PyObject is Send in PyO3 0.23+ // Safety: We only access the instance through Python::attach() unsafe impl Send for PythonPredictor {} unsafe impl Sync for PythonPredictor {} impl PythonPredictor { /// Load a predictor from a reference like "predict.py:Predictor". pub fn load(py: Python<'_>, predictor_ref: &str) -> PyResult { // Import the cog.predictor module to use its loading function let cog_predictor = py.import("cog.predictor")?; let load_fn = cog_predictor.getattr("load_predictor_from_ref")?; // Load the predictor class and instantiate it let instance: PyObject = load_fn.call1((predictor_ref,))?.unbind(); // Check if this is a standalone function (train mode) or a Predictor instance let inspect = py.import("inspect")?; let is_function: bool = inspect .call_method1("isfunction", (instance.bind(py),))? .extract()?; let kind = if is_function { // Standalone function - detect its async nature let (is_async, is_async_gen) = Self::detect_async(py, &instance, "")?; let predict_kind = if is_async_gen { tracing::info!("Detected async generator train()"); PredictKind::AsyncGen } else if is_async { tracing::info!("Detected async train()"); PredictKind::Async } else { tracing::info!("Detected sync train()"); PredictKind::Sync }; PredictorKind::StandaloneFunction(predict_kind) } else { // Class instance - detect predict() and train() methods let (is_async, is_async_gen) = Self::detect_async(py, &instance, "predict")?; let predict_kind = if is_async_gen { tracing::info!("Detected async generator predict()"); PredictKind::AsyncGen } else if is_async { tracing::info!("Detected async predict()"); PredictKind::Async } else { tracing::info!("Detected sync predict()"); PredictKind::Sync }; // Check if train() method exists and if it's async let train_kind = if instance.bind(py).hasattr("train")? { let (train_async, _) = Self::detect_async(py, &instance, "train")?; if train_async { tracing::info!("Detected async train()"); TrainKind::Async } else { tracing::info!("Detected sync train()"); TrainKind::Sync } } else { TrainKind::None }; PredictorKind::Class { predict: predict_kind, train: train_kind, } }; let predictor = Self { instance, kind }; // Patch FieldInfo defaults on predict/train methods so Python uses actual // default values instead of FieldInfo wrapper objects for missing inputs. // Input(default=42, description="...") creates a FieldInfo; without patching, // Python would pass the FieldInfo itself as the default value. if is_function { Self::unwrap_field_info_defaults(py, &predictor.instance, "")?; } else { Self::unwrap_field_info_defaults(py, &predictor.instance, "predict")?; if matches!(predictor.kind, PredictorKind::Class { train, .. } if train != TrainKind::None) { Self::unwrap_field_info_defaults(py, &predictor.instance, "train")?; } } Ok(predictor) } /// Replace FieldInfo defaults with their `.default` values on a method's signature. /// /// When users write `def predict(self, seed: int = Input(default=42, description="..."))`, /// the Python default for `seed` is a `FieldInfo(default=42, ...)` object. If `seed` is /// missing from the input dict, Python would use this FieldInfo as the value — not `42`. /// /// This patches `__defaults__` on the underlying function so Python natively resolves /// to the actual default values. fn unwrap_field_info_defaults( py: Python<'_>, instance: &PyObject, method_name: &str, ) -> PyResult<()> { let field_info_class = py.import("cog.input")?.getattr("FieldInfo")?; // Get the underlying function object let func = if method_name.is_empty() { // Standalone function instance.bind(py).clone() } else { // Bound method — get __func__ for the raw function instance .bind(py) .getattr(method_name)? .getattr("__func__")? }; // Patch __defaults__ (positional parameter defaults) if let Ok(defaults) = func.getattr("__defaults__") && !defaults.is_none() { let defaults_tuple = defaults.cast::()?; let mut new_defaults: Vec> = Vec::new(); let mut changed = false; for item in defaults_tuple.iter() { if item.is_instance(&field_info_class)? { new_defaults.push(item.getattr("default")?); changed = true; } else { new_defaults.push(item); } } if changed { let new_tuple = pyo3::types::PyTuple::new(py, &new_defaults)?; func.setattr("__defaults__", new_tuple)?; tracing::debug!("Patched FieldInfo defaults on {}", method_name); } } // Patch __kwdefaults__ (keyword-only parameter defaults) if let Ok(kwdefaults) = func.getattr("__kwdefaults__") && !kwdefaults.is_none() { let kwdefaults_dict = kwdefaults.cast::()?; let mut changed = false; for (key, value) in kwdefaults_dict.iter() { if value.is_instance(&field_info_class)? { kwdefaults_dict.set_item(&key, value.getattr("default")?)?; changed = true; } } if changed { tracing::debug!("Patched FieldInfo kwdefaults on {}", method_name); } } Ok(()) } /// Detect if a method is an async function. /// Returns (is_async, is_async_gen) tuple. /// /// If method_name is empty, checks the instance itself (for standalone functions). fn detect_async( py: Python<'_>, instance: &PyObject, method_name: &str, ) -> PyResult<(bool, bool)> { let inspect = py.import("inspect")?; // If method_name is empty, check the instance itself (standalone function) let target = if method_name.is_empty() { instance.bind(py).clone() } else { instance.bind(py).getattr(method_name)? }; // Check isasyncgenfunction first (it's more specific) let is_async_gen: bool = inspect .call_method1("isasyncgenfunction", (&target,))? .extract()?; if is_async_gen { return Ok((true, true)); } // Check iscoroutinefunction let is_coro: bool = inspect .call_method1("iscoroutinefunction", (&target,))? .extract()?; Ok((is_coro, false)) } /// Returns true if this predictor has an async predict() method. pub fn is_async(&self) -> bool { match &self.kind { PredictorKind::Class { predict, .. } => { matches!(predict, PredictKind::Async | PredictKind::AsyncGen) } PredictorKind::StandaloneFunction(predict_kind) => { matches!(predict_kind, PredictKind::Async | PredictKind::AsyncGen) } } } /// Returns true if this predictor has a train() method. pub fn has_train(&self) -> bool { match &self.kind { PredictorKind::Class { train, .. } => !matches!(train, TrainKind::None), PredictorKind::StandaloneFunction(_) => true, } } /// Returns true if the train() method is async. pub fn is_train_async(&self) -> bool { match &self.kind { PredictorKind::Class { train, .. } => matches!(train, TrainKind::Async), PredictorKind::StandaloneFunction(predict_kind) => { matches!(predict_kind, PredictKind::Async | PredictKind::AsyncGen) } } } /// Call setup() on the predictor, handling weights parameter if present. /// /// Uses cog.predictor helpers to detect and extract weights: /// - `has_setup_weights()` checks if setup() has a weights parameter /// - `extract_setup_weights()` reads from COG_WEIGHTS env or ./weights path pub fn setup(&self, py: Python<'_>) -> PyResult<()> { let instance = self.instance.bind(py); // Check if setup method exists if !instance.hasattr("setup")? { return Ok(()); } // Import cog.predictor helpers let cog_predictor = py.import("cog.predictor")?; let has_setup_weights = cog_predictor.getattr("has_setup_weights")?; let extract_setup_weights = cog_predictor.getattr("extract_setup_weights")?; // Check if setup() has a weights parameter let needs_weights: bool = has_setup_weights.call1((&instance,))?.extract()?; if needs_weights { // Extract weights from COG_WEIGHTS env or ./weights path let weights = extract_setup_weights.call1((&instance,))?; instance.call_method1("setup", (weights,))?; } else { instance.call_method0("setup")?; } Ok(()) } /// Get the predict function object for type annotation introspection. pub fn predict_func<'py>(&self, py: Python<'py>) -> PyResult> { let instance = self.instance.bind(py); match &self.kind { PredictorKind::Class { .. } => instance.getattr("predict"), PredictorKind::StandaloneFunction(_) => Ok(instance.clone()), } } /// Get the train function object for type annotation introspection. pub fn train_func<'py>(&self, py: Python<'py>) -> PyResult> { let instance = self.instance.bind(py); match &self.kind { PredictorKind::Class { .. } => instance.getattr("train"), PredictorKind::StandaloneFunction(_) => Ok(instance.clone()), } } /// Call predict() with the given input dict, returning raw Python output. /// /// For standalone functions, calls the function directly. pub fn predict_raw(&self, py: Python<'_>, input: &Bound<'_, PyDict>) -> PyResult { let (method_name, is_async) = match &self.kind { PredictorKind::Class { predict, .. } => ( "predict", matches!(predict, PredictKind::Async | PredictKind::AsyncGen), ), PredictorKind::StandaloneFunction(predict_kind) => ( "", matches!(predict_kind, PredictKind::Async | PredictKind::AsyncGen), ), }; self.call_method_raw(py, method_name, is_async, input) } /// Call train() with the given input dict, returning raw Python output. /// /// For standalone train functions, calls the function directly. /// For Predictor classes with a train() method, calls instance.train(). pub fn train_raw(&self, py: Python<'_>, input: &Bound<'_, PyDict>) -> PyResult { let (method_name, is_async) = match &self.kind { PredictorKind::Class { train, .. } => ("train", matches!(train, TrainKind::Async)), PredictorKind::StandaloneFunction(predict_kind) => ( "", matches!(predict_kind, PredictKind::Async | PredictKind::AsyncGen), ), }; self.call_method_raw(py, method_name, is_async, input) } /// Internal helper to call a method (predict or train) on the predictor. fn call_method_raw( &self, py: Python<'_>, method_name: &str, is_async: bool, input: &Bound<'_, PyDict>, ) -> PyResult { let instance = self.instance.bind(py); // Call the method - returns coroutine if async, result if sync // If method_name is empty, call the instance directly (standalone function) let method_result = if method_name.is_empty() { instance.call((), Some(input))? } else { instance.call_method(method_name, (), Some(input))? }; // If async, run the coroutine with asyncio.run() let result = if is_async { let asyncio = py.import("asyncio")?; asyncio.call_method1("run", (&method_result,))? } else { method_result }; Ok(result.unbind()) } /// Worker mode predict - with input processing and output serialization. pub fn predict_worker( &self, input: serde_json::Value, slot_sender: Arc, ) -> Result { Python::attach(|py| { let json_module = py.import("json").map_err(|e| { PredictionError::Failed(format!("Failed to import json module: {}", e)) })?; let types_module = py.import("types").map_err(|e| { PredictionError::Failed(format!("Failed to import types module: {}", e)) })?; let generator_type = types_module.getattr("GeneratorType").map_err(|e| { PredictionError::Failed(format!("Failed to get GeneratorType: {}", e)) })?; let input_str = serde_json::to_string(&input) .map_err(|e| PredictionError::InvalidInput(e.to_string()))?; let py_input = json_module .call_method1("loads", (input_str,)) .map_err(|e| PredictionError::InvalidInput(format!("Invalid JSON input: {}", e)))?; #[allow(deprecated)] let raw_input_dict = py_input.downcast::().map_err(|_| { PredictionError::InvalidInput("Input must be a JSON object".to_string()) })?; // PreparedInput cleans up temp files on drop (RAII) let func = self.predict_func(py).map_err(|e| { PredictionError::Failed(format!("Failed to get predict function: {}", e)) })?; let prepared = input::prepare_input(py, raw_input_dict, &func) .map_err(|e| PredictionError::InvalidInput(format_validation_error(py, &e)))?; let input_dict = prepared.dict(py); // Call predict let result = self.predict_raw(py, &input_dict); // Handle errors (prepared drops here, cleaning up temp files) let result = match result { Ok(r) => r, Err(e) => { drop(prepared); // Explicit cleanup on error path if is_cancelation_exception(py, &e) { return Err(PredictionError::Cancelled); } return Err(PredictionError::Failed(format!("Prediction failed: {}", e))); } }; let result_bound = result.bind(py); let is_generator: bool = result_bound.is_instance(&generator_type).unwrap_or(false); let output = if is_generator { self.process_generator_output(py, result_bound, &json_module, &slot_sender)? } else { self.process_single_output(py, result_bound, &json_module, &slot_sender)? }; // prepared drops here, cleaning up temp files via RAII drop(prepared); Ok(PredictionResult { output, predict_time: None, logs: String::new(), metrics: Default::default(), }) }) } /// Worker mode train - with input processing and output serialization. pub fn train_worker( &self, input: serde_json::Value, slot_sender: Arc, ) -> Result { Python::attach(|py| { let json_module = py.import("json").map_err(|e| { PredictionError::Failed(format!("Failed to import json module: {}", e)) })?; let types_module = py.import("types").map_err(|e| { PredictionError::Failed(format!("Failed to import types module: {}", e)) })?; let generator_type = types_module.getattr("GeneratorType").map_err(|e| { PredictionError::Failed(format!("Failed to get GeneratorType: {}", e)) })?; let input_str = serde_json::to_string(&input) .map_err(|e| PredictionError::InvalidInput(e.to_string()))?; let py_input = json_module .call_method1("loads", (input_str,)) .map_err(|e| PredictionError::InvalidInput(format!("Invalid JSON input: {}", e)))?; #[allow(deprecated)] let raw_input_dict = py_input.downcast::().map_err(|_| { PredictionError::InvalidInput("Input must be a JSON object".to_string()) })?; // PreparedInput cleans up temp files on drop (RAII) let func = self.train_func(py).map_err(|e| { PredictionError::Failed(format!("Failed to get train function: {}", e)) })?; let prepared = input::prepare_input(py, raw_input_dict, &func) .map_err(|e| PredictionError::InvalidInput(format_validation_error(py, &e)))?; let input_dict = prepared.dict(py); // Call train let result = self.train_raw(py, &input_dict); // Handle errors let result = match result { Ok(r) => r, Err(e) => { drop(prepared); if is_cancelation_exception(py, &e) { return Err(PredictionError::Cancelled); } return Err(PredictionError::Failed(format!("Training failed: {}", e))); } }; let result_bound = result.bind(py); let is_generator: bool = result_bound.is_instance(&generator_type).unwrap_or(false); let output = if is_generator { self.process_generator_output(py, result_bound, &json_module, &slot_sender)? } else { self.process_single_output(py, result_bound, &json_module, &slot_sender)? }; drop(prepared); Ok(PredictionResult { output, predict_time: None, logs: String::new(), metrics: Default::default(), }) }) } /// Process generator output by streaming each yield over IPC. fn process_generator_output( &self, py: Python<'_>, result: &Bound<'_, PyAny>, json_module: &Bound<'_, PyAny>, slot_sender: &SlotSender, ) -> Result { let iter = result .try_iter() .map_err(|e| PredictionError::Failed(format!("Failed to iterate generator: {}", e)))?; for item in iter { let item = item.map_err(|e| { if is_cancelation_exception(py, &e) { return PredictionError::Cancelled; } PredictionError::Failed(format!("Generator iteration error: {}", e)) })?; send_output_item(py, &item, json_module, slot_sender)?; } // Outputs already streamed over IPC — return empty stream Ok(PredictionOutput::Stream(vec![])) } /// Process single output into PredictionOutput::Single. /// /// For file outputs (Path/IOBase), the file is sent via slot_sender and /// an empty Single(Null) is returned since the output was already streamed. fn process_single_output( &self, py: Python<'_>, result: &Bound<'_, PyAny>, json_module: &Bound<'_, PyAny>, slot_sender: &SlotSender, ) -> Result { // Check for file-type outputs first let os = py .import("os") .map_err(|e| PredictionError::Failed(format!("Failed to import os: {}", e)))?; let io_mod = py .import("io") .map_err(|e| PredictionError::Failed(format!("Failed to import io: {}", e)))?; let pathlike = os .getattr("PathLike") .map_err(|e| PredictionError::Failed(format!("Failed to get os.PathLike: {}", e)))?; let iobase = io_mod .getattr("IOBase") .map_err(|e| PredictionError::Failed(format!("Failed to get io.IOBase: {}", e)))?; if result.is_instance(&pathlike).unwrap_or(false) { let path_str: String = result .call_method0("__fspath__") .and_then(|p| p.extract()) .map_err(|e| PredictionError::Failed(format!("Failed to get fspath: {}", e)))?; slot_sender .send_file_output(std::path::PathBuf::from(path_str), None) .map_err(|e| { PredictionError::Failed(format!("Failed to send file output: {}", e)) })?; return Ok(PredictionOutput::Single(serde_json::Value::Null)); } if result.is_instance(&iobase).unwrap_or(false) { if result .call_method0("seekable") .and_then(|r| r.extract::()) .unwrap_or(false) { let _ = result.call_method1("seek", (0,)); } let data: Vec = result .call_method0("read") .and_then(|d| d.extract()) .map_err(|e| PredictionError::Failed(format!("Failed to read IOBase: {}", e)))?; let ext = result .getattr("name") .and_then(|n| n.extract::()) .ok() .and_then(|name| { std::path::Path::new(&name) .extension() .and_then(|e| e.to_str()) .map(|s| s.to_string()) }) .unwrap_or_else(|| "bin".to_string()); slot_sender .write_file_output(&data, &ext, None) .map_err(|e| { PredictionError::Failed(format!("Failed to write file output: {}", e)) })?; return Ok(PredictionOutput::Single(serde_json::Value::Null)); } // List/tuple output — iterate items so file outputs (Path, IOBase) // go through the FileOutput IPC path for upload instead of being // base64-encoded inline by process_output. if let Ok(list) = result.cast::() { for item in list.iter() { send_output_item(py, &item, json_module, slot_sender)?; } return Ok(PredictionOutput::Stream(vec![])); } if let Ok(tuple) = result.cast::() { for item in tuple.iter() { send_output_item(py, &item, json_module, slot_sender)?; } return Ok(PredictionOutput::Stream(vec![])); } // Non-file output — process normally let processed = output::process_output(py, result) .map_err(|e| PredictionError::Failed(format!("Failed to process output: {}", e)))?; let result_str: String = json_module .call_method1("dumps", (&processed,)) .map_err(|e| PredictionError::Failed(format!("Failed to serialize output: {}", e)))? .extract() .map_err(|e| { PredictionError::Failed(format!("Failed to extract output string: {}", e)) })?; let output_json: serde_json::Value = serde_json::from_str(&result_str) .map_err(|e| PredictionError::Failed(format!("Failed to parse output JSON: {}", e)))?; Ok(PredictionOutput::Single(output_json)) } /// Worker mode async predict - submits to shared event loop. /// /// Uses run_coroutine_threadsafe to submit the coroutine to the provided event loop. /// Returns the concurrent.futures.Future, is_async_gen flag, and PreparedInput for cleanup. /// Caller should block on future.result() to get the result, then drop PreparedInput. /// /// The prediction_id is used to set up log routing in the event loop thread. pub fn predict_async_worker( &self, input: serde_json::Value, event_loop: &Py, prediction_id: &str, ) -> Result<(Py, bool, PreparedInput), PredictionError> { Python::attach(|py| { let json_module = py.import("json").map_err(|e| { PredictionError::Failed(format!("Failed to import json module: {}", e)) })?; let asyncio = py .import("asyncio") .map_err(|e| PredictionError::Failed(format!("Failed to import asyncio: {}", e)))?; let input_str = serde_json::to_string(&input) .map_err(|e| PredictionError::InvalidInput(e.to_string()))?; let py_input = json_module .call_method1("loads", (input_str,)) .map_err(|e| PredictionError::InvalidInput(format!("Invalid JSON input: {}", e)))?; #[allow(deprecated)] let raw_input_dict = py_input.downcast::().map_err(|_| { PredictionError::InvalidInput("Input must be a JSON object".to_string()) })?; let func = self.predict_func(py).map_err(|e| { PredictionError::Failed(format!("Failed to get predict function: {}", e)) })?; let prepared = input::prepare_input(py, raw_input_dict, &func) .map_err(|e| PredictionError::InvalidInput(format_validation_error(py, &e)))?; let input_dict = prepared.dict(py); // Call predict - returns coroutine let instance = self.instance.bind(py); let coro = instance .call_method("predict", (), Some(&input_dict)) .map_err(|e| PredictionError::Failed(format!("Failed to call predict: {}", e)))?; // For async generators, wrap to collect all values let is_async_gen = matches!( &self.kind, PredictorKind::Class { predict: PredictKind::AsyncGen, .. } | PredictorKind::StandaloneFunction(PredictKind::AsyncGen) ); let coro = if is_async_gen { let collect_fn = get_collect_async_gen(py)?; collect_fn .call1(py, (&coro,)) .map_err(|e| { PredictionError::Failed(format!("Failed to wrap async generator: {}", e)) })? .into_bound(py) } else { coro }; // Wrap coroutine to set up log routing in the event loop thread let ctx_wrapper = get_ctx_wrapper(py)?; // Get the same ContextVar instance used by SlotLogWriter for log routing let contextvar = crate::log_writer::get_prediction_contextvar(py).map_err(|e| { PredictionError::Failed(format!("Failed to get prediction ContextVar: {}", e)) })?; // Wrap the coroutine with context setup let wrapped_coro = ctx_wrapper .call1(py, (&coro, prediction_id, contextvar.bind(py))) .map_err(|e| { PredictionError::Failed(format!("Failed to wrap coroutine with context: {}", e)) })?; // Submit wrapped coroutine to shared event loop via run_coroutine_threadsafe let future = asyncio .call_method1( "run_coroutine_threadsafe", (wrapped_coro.bind(py), event_loop.bind(py)), ) .map_err(|e| { PredictionError::Failed(format!("Failed to submit coroutine: {}", e)) })?; Ok((future.unbind(), is_async_gen, prepared)) }) } /// Process the result from an async prediction future. /// /// Call this after future.result() returns to convert the Python result /// to a PredictionResult. pub fn process_async_result( &self, py: Python<'_>, result: &Bound<'_, PyAny>, is_async_gen: bool, slot_sender: &SlotSender, ) -> Result { let json_module = py .import("json") .map_err(|e| PredictionError::Failed(format!("Failed to import json module: {}", e)))?; let types_module = py.import("types").map_err(|e| { PredictionError::Failed(format!("Failed to import types module: {}", e)) })?; // Process output let output = if is_async_gen { // Result is a pre-collected list — stream each item over IPC if let Ok(list) = result.extract::>>() { for item in list { send_output_item(py, &item, &json_module, slot_sender)?; } } PredictionOutput::Stream(vec![]) } else { // Check if result is a generator (sync generator from async predict) let generator_type = types_module.getattr("GeneratorType").map_err(|e| { PredictionError::Failed(format!("Failed to get GeneratorType: {}", e)) })?; let is_generator: bool = result.is_instance(&generator_type).unwrap_or(false); if is_generator { self.process_generator_output(py, result, &json_module, slot_sender)? } else { self.process_single_output(py, result, &json_module, slot_sender)? } }; Ok(PredictionResult { output, predict_time: None, logs: String::new(), metrics: Default::default(), }) } /// Worker mode async train - submits to shared event loop. pub fn train_async_worker( &self, input: serde_json::Value, event_loop: &Py, prediction_id: &str, ) -> Result<(Py, bool, PreparedInput), PredictionError> { Python::attach(|py| { let json_module = py.import("json").map_err(|e| { PredictionError::Failed(format!("Failed to import json module: {}", e)) })?; let asyncio = py .import("asyncio") .map_err(|e| PredictionError::Failed(format!("Failed to import asyncio: {}", e)))?; let input_str = serde_json::to_string(&input) .map_err(|e| PredictionError::InvalidInput(e.to_string()))?; let py_input = json_module .call_method1("loads", (input_str,)) .map_err(|e| PredictionError::InvalidInput(format!("Invalid JSON input: {}", e)))?; #[allow(deprecated)] let raw_input_dict = py_input.downcast::().map_err(|_| { PredictionError::InvalidInput("Input must be a JSON object".to_string()) })?; let func = self.train_func(py).map_err(|e| { PredictionError::Failed(format!("Failed to get train function: {}", e)) })?; let prepared = input::prepare_input(py, raw_input_dict, &func) .map_err(|e| PredictionError::InvalidInput(format_validation_error(py, &e)))?; let input_dict = prepared.dict(py); // Call train - returns coroutine let instance = self.instance.bind(py); let coro = match &self.kind { PredictorKind::StandaloneFunction(_) => instance.call((), Some(&input_dict)), PredictorKind::Class { .. } => instance.call_method("train", (), Some(&input_dict)), } .map_err(|e| PredictionError::Failed(format!("Failed to call train: {}", e)))?; // Wrap coroutine to set up log routing let ctx_wrapper = get_ctx_wrapper(py)?; // Get the same ContextVar instance used by SlotLogWriter let contextvar = crate::log_writer::get_prediction_contextvar(py).map_err(|e| { PredictionError::Failed(format!("Failed to get prediction ContextVar: {}", e)) })?; // Wrap the coroutine with context setup let wrapped_coro = ctx_wrapper .call1(py, (&coro, prediction_id, contextvar.bind(py))) .map_err(|e| { PredictionError::Failed(format!("Failed to wrap coroutine with context: {}", e)) })?; // Submit wrapped coroutine to shared event loop let future = asyncio .call_method1( "run_coroutine_threadsafe", (wrapped_coro.bind(py), event_loop.bind(py)), ) .map_err(|e| { PredictionError::Failed(format!("Failed to submit coroutine: {}", e)) })?; // Train doesn't typically use async generators, but we return false for consistency Ok((future.unbind(), false, prepared)) }) } // ========================================================================= // Healthcheck methods // ========================================================================= /// Healthcheck timeout in seconds. const HEALTHCHECK_TIMEOUT: f64 = 5.0; /// Check if the predictor has a healthcheck() method. pub fn has_healthcheck(&self, py: Python<'_>) -> bool { match &self.kind { PredictorKind::Class { .. } => { let instance = self.instance.bind(py); instance.hasattr("healthcheck").unwrap_or(false) } PredictorKind::StandaloneFunction(_) => false, } } /// Check if the healthcheck() method is async. pub fn is_healthcheck_async(&self, py: Python<'_>) -> bool { match &self.kind { PredictorKind::Class { .. } => { let instance = self.instance.bind(py); if let Ok(healthcheck) = instance.getattr("healthcheck") { let inspect = py.import("inspect").ok(); if let Some(inspect) = inspect { inspect .call_method1("iscoroutinefunction", (&healthcheck,)) .ok() .and_then(|r| r.extract::().ok()) .unwrap_or(false) } else { false } } else { false } } PredictorKind::StandaloneFunction(_) => false, } } /// Run a synchronous healthcheck with timeout. /// /// Runs the healthcheck in a thread pool executor with a 5 second timeout. pub fn healthcheck_sync(&self, py: Python<'_>) -> coglet_core::orchestrator::HealthcheckResult { use coglet_core::orchestrator::HealthcheckResult; let instance = self.instance.bind(py); // Run healthcheck in executor with timeout, mirroring Python impl let result: PyResult = (|| { let concurrent_futures = py.import("concurrent.futures")?; let thread_pool = concurrent_futures.getattr("ThreadPoolExecutor")?; // Create a small executor just for this healthcheck let executor = thread_pool.call1((1,))?; // Get the healthcheck method let healthcheck_fn = instance.getattr("healthcheck")?; // Submit to executor let future = executor.call_method1("submit", (healthcheck_fn,))?; // Wait with timeout let result = future.call_method1("result", (Self::HEALTHCHECK_TIMEOUT,)); // Shutdown executor let _ = executor.call_method1("shutdown", (false,)); match result { Ok(r) => Ok(r.extract::().unwrap_or(true)), Err(e) => { let err_str = e.to_string(); if err_str.contains("TimeoutError") { Err(pyo3::exceptions::PyTimeoutError::new_err( "Healthcheck timed out", )) } else { Err(e) } } } })(); match result { Ok(true) => HealthcheckResult::healthy(), Ok(false) => HealthcheckResult::unhealthy( "Healthcheck failed: user-defined healthcheck returned False", ), Err(e) => { let err_str = e.to_string(); if err_str.contains("TimeoutError") { HealthcheckResult::unhealthy(format!( "Healthcheck failed: user-defined healthcheck timed out after {:.1} seconds", Self::HEALTHCHECK_TIMEOUT )) } else { HealthcheckResult::unhealthy(format!("Healthcheck failed: {}", e)) } } } } /// Run an async healthcheck with timeout. /// /// Runs the healthcheck in the async event loop with a 5 second timeout. pub fn healthcheck_async( &self, py: Python<'_>, event_loop: &Py, ) -> coglet_core::orchestrator::HealthcheckResult { use coglet_core::orchestrator::HealthcheckResult; let instance = self.instance.bind(py); let result: PyResult = (|| { let asyncio = py.import("asyncio")?; // Get the healthcheck coroutine let healthcheck_fn = instance.getattr("healthcheck")?; let coro = healthcheck_fn.call0()?; // Wrap with timeout let wait_for = asyncio.getattr("wait_for")?; let timeout_coro = wait_for.call1((&coro, Self::HEALTHCHECK_TIMEOUT))?; // Submit to event loop let future = asyncio.call_method1( "run_coroutine_threadsafe", (&timeout_coro, event_loop.bind(py)), )?; // Block on result with extra buffer time for event loop overhead let result = future.call_method1("result", (Self::HEALTHCHECK_TIMEOUT + 1.0,)); match result { Ok(r) => Ok(r.extract::().unwrap_or(true)), Err(e) => { let err_str = e.to_string(); if err_str.contains("TimeoutError") || err_str.contains("timed out") { Err(pyo3::exceptions::PyTimeoutError::new_err( "Healthcheck timed out", )) } else { Err(e) } } } })(); match result { Ok(true) => HealthcheckResult::healthy(), Ok(false) => HealthcheckResult::unhealthy( "Healthcheck failed: user-defined healthcheck returned False", ), Err(e) => { let err_str = e.to_string(); if err_str.contains("TimeoutError") { HealthcheckResult::unhealthy(format!( "Healthcheck failed: user-defined healthcheck timed out after {:.1} seconds", Self::HEALTHCHECK_TIMEOUT )) } else { HealthcheckResult::unhealthy(format!("Healthcheck failed: {}", e)) } } } } } ================================================ FILE: crates/coglet-python/src/worker_bridge.rs ================================================ //! Bridge between coglet-worker's PredictHandler trait and PythonPredictor. use std::collections::HashMap; use std::sync::{Arc, Mutex}; use std::thread::JoinHandle; use pyo3::prelude::*; use coglet_core::bridge::protocol::SlotId; use coglet_core::worker::{PredictHandler, PredictResult, SetupError, SlotSender}; use crate::predictor::PythonPredictor; /// What operation the handler performs #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum HandlerMode { /// Calls predict() method Predict, /// Calls train() method Train, } /// SDK implementation type detected from the Python predictor. /// /// This enum allows for future extensibility if additional SDK /// implementations are needed (e.g., Node.js). #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum SdkImplementation { /// Standard cog Python SDK Cog, /// Unable to detect SDK type Unknown, } impl std::fmt::Display for SdkImplementation { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Cog => write!(f, "cog"), Self::Unknown => write!(f, "unknown"), } } } /// Current state of a prediction slot #[derive(Debug, Default)] pub enum SlotState { /// No prediction running #[default] Idle, /// Sync prediction in progress SyncPrediction { cancelled: bool, /// Python thread identifier (for `PyThreadState_SetAsyncExc`) py_thread_id: std::ffi::c_long, }, /// Async prediction in progress AsyncPrediction { /// Future for cancellation future: Py, cancelled: bool, }, } impl SlotState { pub fn is_cancelled(&self) -> bool { match self { SlotState::SyncPrediction { cancelled, .. } => *cancelled, SlotState::AsyncPrediction { cancelled, .. } => *cancelled, SlotState::Idle => false, } } pub fn mark_cancelled(&mut self) { match self { SlotState::SyncPrediction { cancelled, .. } => *cancelled = true, SlotState::AsyncPrediction { cancelled, .. } => *cancelled = true, SlotState::Idle => { /* no-op */ } } } } /// Wraps PythonPredictor to implement the PredictHandler trait. /// /// The `is_train` flag determines whether predict() calls the Python /// predict() or train() method. This is set at construction time. /// /// BUG-FOR-BUG COMPATIBILITY: In cog mainline, training routes use a worker /// that was created with is_train=false, so training routes actually call /// predict() instead of train(). We replicate this by always creating the /// handler with is_train=false. To fix this bug, pass is_train=true when /// creating a handler for training routes. pub struct PythonPredictHandler { predictor_ref: String, predictor: Mutex>>, /// Per-slot cancellation state (keyed by SlotId). slots: Mutex>, /// What operation this handler performs (predict or train). /// BUG: cog mainline always uses Predict mode, even for training routes. mode: HandlerMode, /// Shared asyncio event loop for async predictions (runs in dedicated thread). async_loop: Mutex>>, /// Handle to the asyncio loop thread for joining on shutdown. async_thread: Mutex>>, } impl PythonPredictHandler { /// Create a handler in prediction mode. pub fn new(predictor_ref: String) -> Result { let (loop_obj, thread) = Self::init_async_loop()?; Ok(Self { predictor_ref, predictor: Mutex::new(None), slots: Mutex::new(HashMap::new()), mode: HandlerMode::Predict, async_loop: Mutex::new(Some(loop_obj)), async_thread: Mutex::new(Some(thread)), }) } /// Create a handler in training mode. /// /// NOTE: For bug-for-bug compatibility with cog mainline, use new() instead. /// Cog mainline's training routes incorrectly use a predict-mode worker. #[allow(dead_code)] pub fn new_train(predictor_ref: String) -> Result { let (loop_obj, thread) = Self::init_async_loop()?; Ok(Self { predictor_ref, predictor: Mutex::new(None), slots: Mutex::new(HashMap::new()), mode: HandlerMode::Train, async_loop: Mutex::new(Some(loop_obj)), async_thread: Mutex::new(Some(thread)), }) } /// Initialize the shared asyncio event loop in a dedicated thread. fn init_async_loop() -> Result<(Py, JoinHandle<()>), SetupError> { Python::attach(|py| { let asyncio = py .import("asyncio") .map_err(|e| SetupError::internal(format!("Failed to import asyncio: {}", e)))?; let loop_obj = asyncio .call_method0("new_event_loop") .map_err(|e| SetupError::internal(format!("Failed to create event loop: {}", e)))?; // Clone for the thread let loop_for_thread = loop_obj.clone().unbind(); let loop_result = loop_obj.unbind(); // Spawn thread running loop.run_forever() let thread = std::thread::spawn(move || { Python::attach(|py| { let loop_ref = loop_for_thread.bind(py); // These errors in the thread are logged but can't be propagated // The thread dying will cause async predictions to fail let Ok(asyncio) = py.import("asyncio") else { tracing::error!("Failed to import asyncio in loop thread"); return; }; if let Err(e) = asyncio.call_method1("set_event_loop", (loop_ref,)) { tracing::error!(error = %e, "Failed to set event loop"); return; } tracing::trace!("Asyncio event loop thread starting"); if let Err(e) = loop_ref.call_method0("run_forever") { tracing::error!(error = %e, "Asyncio event loop error"); } tracing::trace!("Asyncio event loop thread exiting"); }); }); Ok((loop_result, thread)) }) } // NOTE: All mutex locks in this file use .expect(). // See log_writer.rs for the full rationale. Short version: poisoned mutex // means slot isolation is compromised. The panic hook installed by // coglet_core::worker sends a Fatal IPC message and aborts. /// Check the cancelled flag for a slot without clearing it. fn is_cancelled(&self, slot: SlotId) -> bool { let slots = self.slots.lock().expect("slots mutex poisoned"); slots.get(&slot).is_some_and(|s| s.is_cancelled()) } /// Check and clear the cancelled flag for a slot. fn take_cancelled(&self, slot: SlotId) -> bool { let mut slots = self.slots.lock().expect("slots mutex poisoned"); let state = slots.entry(slot).or_default(); let was_cancelled = state.is_cancelled(); // Reset to idle after checking cancellation if was_cancelled { *state = SlotState::Idle; } was_cancelled } /// Mark a slot as having a sync prediction in progress. /// /// `py_thread_id` is the Python thread identifier of the thread that will /// run the prediction, for use with `PyThreadState_SetAsyncExc` on cancel. fn start_sync_prediction(&self, slot: SlotId, py_thread_id: std::ffi::c_long) { let mut slots = self.slots.lock().expect("slots mutex poisoned"); slots.insert( slot, SlotState::SyncPrediction { cancelled: false, py_thread_id, }, ); } /// Mark a slot as having an async prediction in progress. fn start_async_prediction(&self, slot: SlotId, future: Py) { let mut slots = self.slots.lock().expect("slots mutex poisoned"); slots.insert( slot, SlotState::AsyncPrediction { future, cancelled: false, }, ); } /// Clear prediction state for a slot. fn finish_prediction(&self, slot: SlotId) { let mut slots = self.slots.lock().expect("slots mutex poisoned"); slots.insert(slot, SlotState::Idle); } /// Cancel an async prediction using future.cancel(). /// Returns true if cancellation was requested, false if no future found. fn cancel_async_future(&self, slot: SlotId) -> bool { Python::attach(|py| { let future = { let slots = self.slots.lock().expect("slots mutex poisoned"); if let Some(SlotState::AsyncPrediction { future, .. }) = slots.get(&slot) { Some(future.clone_ref(py)) } else { None } }; if let Some(future) = future { match future.call_method0(py, "cancel") { Ok(_) => { tracing::trace!(%slot, "Cancelled async future"); true } Err(e) => { tracing::warn!(%slot, error = %e, "Failed to cancel async future"); false } } } else { tracing::trace!(%slot, "No async future to cancel"); false } }) } /// Get a reference to the shared asyncio event loop. fn get_async_loop(&self) -> Option> { Python::attach(|py| { self.async_loop .lock() .expect("async_loop mutex poisoned") .as_ref() .map(|l| l.clone_ref(py)) }) } } #[async_trait::async_trait] impl PredictHandler for PythonPredictHandler { async fn setup(&self) -> Result<(), SetupError> { Python::attach(|py| { tracing::info!(predictor_ref = %self.predictor_ref, "Loading predictor"); let pred = PythonPredictor::load(py, &self.predictor_ref) .map_err(|e| SetupError::load(e.to_string()))?; // Detect SDK implementation let sdk_impl = match py.import("cog") { Ok(cog) => match cog.getattr("BasePredictor") { Ok(_) => SdkImplementation::Cog, Err(_) => SdkImplementation::Unknown, }, Err(_) => SdkImplementation::Unknown, }; tracing::info!(sdk_implementation = %sdk_impl, "Detected Cog SDK implementation"); tracing::info!("Running setup"); pred.setup(py) .map_err(|e| SetupError::setup(e.to_string()))?; let mut guard = self.predictor.lock().expect("predictor mutex poisoned"); *guard = Some(Arc::new(pred)); tracing::info!("Setup complete"); Ok(()) }) } async fn predict( &self, slot: SlotId, id: String, input: serde_json::Value, slot_sender: Arc, context: HashMap, ) -> PredictResult { tracing::trace!(%slot, %id, "PythonPredictHandler::predict starting"); // Get predictor let pred = { let guard = self.predictor.lock().expect("predictor mutex poisoned"); match guard.as_ref() { Some(p) => Arc::clone(p), None => { return PredictResult::failed("Predictor not initialized".to_string(), 0.0); } } }; let is_async = pred.is_async(); tracing::trace!(%slot, %id, is_async, "Got predictor"); // Track that we're starting a prediction on this slot. // Capture the Python thread ID for this thread (used by // PyThreadState_SetAsyncExc to inject CancelationException on cancel). // For async predictions, the slot state is updated later with the future. let py_thread_id = crate::cancel::current_py_thread_id(); self.start_sync_prediction(slot, py_thread_id); // Check cancellation first (in case cancel was called before we started) if self.take_cancelled(slot) { self.finish_prediction(slot); return PredictResult::cancelled(0.0); } // Enter prediction context - sets cog_prediction_id ContextVar for log routing let prediction_id = id.clone(); let slot_sender_clone = slot_sender.clone(); let log_guard = Python::attach(|py| { crate::log_writer::PredictionLogGuard::enter( py, prediction_id.clone(), slot_sender_clone, ) }); let log_guard = match log_guard { Ok(g) => Some(g), Err(e) => { tracing::warn!(error = %e, "Failed to enter prediction context"); None } }; // Enter metric scope - sets Scope ContextVar for metric recording let slot_sender_for_metrics = slot_sender.clone(); let scope_guard = Python::attach(|py| { crate::metric_scope::ScopeGuard::enter(py, slot_sender_for_metrics, context) }); let scope_guard = match scope_guard { Ok(g) => Some(g), Err(e) => { tracing::warn!(error = %e, "Failed to enter metric scope"); None } }; tracing::trace!(%slot, %id, "Prediction context entered"); // Run prediction or training based on mode. let start = std::time::Instant::now(); let result = match self.mode { HandlerMode::Train => { // Training mode - check if train() exists if !pred.has_train() { self.finish_prediction(slot); return PredictResult::failed( "Training not supported by this predictor".to_string(), 0.0, ); } // Use worker-mode train if pred.is_train_async() { // Async train - submit to shared event loop let loop_obj = match self.get_async_loop() { Some(l) => l, None => { return PredictResult::failed( "Async event loop not initialized".to_string(), start.elapsed().as_secs_f64(), ); } }; // Submit coroutine and get future + prepared input for cleanup let (future, is_async_gen, prepared) = match pred .train_async_worker(input, &loop_obj, &id) { Ok(f) => f, Err(e) => { self.finish_prediction(slot); drop(log_guard); return if matches!(e, coglet_core::PredictionError::Cancelled) { PredictResult::cancelled(start.elapsed().as_secs_f64()) } else { PredictResult::failed(e.to_string(), start.elapsed().as_secs_f64()) }; } }; // Update slot state with future for cancellation Python::attach(|py| { self.start_async_prediction(slot, future.clone_ref(py)); }); // Block on future.result() let sender_for_async = slot_sender.clone(); let result = Python::attach(|py| match future.call_method0(py, "result") { Ok(result) => pred.process_async_result( py, result.bind(py), is_async_gen, &sender_for_async, ), Err(e) => { let err_str = e.to_string(); if err_str.contains("CancelledError") || err_str.contains("cancelled") { Err(coglet_core::PredictionError::Cancelled) } else { Err(coglet_core::PredictionError::Failed(format!( "Async training failed: {}", e ))) } } }); // Cleanup temp files via RAII drop(prepared); result } else { // Sync train - set sync prediction ID for log routing crate::log_writer::set_sync_prediction_id(Some(&id)); let r = pred.train_worker(input, slot_sender.clone()); crate::log_writer::set_sync_prediction_id(None); // Upgrade to Cancelled if the slot was marked cancelled // (same logic as sync predict above) match r { Err(_) if self.is_cancelled(slot) => { Err(coglet_core::PredictionError::Cancelled) } other => other, } } } HandlerMode::Predict => { // Prediction mode tracing::trace!(%slot, %id, is_async = pred.is_async(), "Running prediction"); if pred.is_async() { // Async predict - submit to shared event loop let loop_obj = match self.get_async_loop() { Some(l) => l, None => { return PredictResult::failed( "Async event loop not initialized".to_string(), start.elapsed().as_secs_f64(), ); } }; // Submit coroutine and get future + prepared input for cleanup let (future, is_async_gen, prepared) = match pred .predict_async_worker(input, &loop_obj, &id) { Ok(f) => f, Err(e) => { self.finish_prediction(slot); drop(log_guard); return if matches!(e, coglet_core::PredictionError::Cancelled) { PredictResult::cancelled(start.elapsed().as_secs_f64()) } else { PredictResult::failed(e.to_string(), start.elapsed().as_secs_f64()) }; } }; // Update slot state with future for cancellation Python::attach(|py| { self.start_async_prediction(slot, future.clone_ref(py)); }); // Block on future.result() let sender_for_async = slot_sender.clone(); let result = Python::attach(|py| match future.call_method0(py, "result") { Ok(result) => pred.process_async_result( py, result.bind(py), is_async_gen, &sender_for_async, ), Err(e) => { let err_str = e.to_string(); if err_str.contains("CancelledError") || err_str.contains("cancelled") { Err(coglet_core::PredictionError::Cancelled) } else { Err(coglet_core::PredictionError::Failed(format!( "Async prediction failed: {}", e ))) } } }); // Cleanup temp files via RAII drop(prepared); result } else { // Sync predict - set sync prediction ID for log routing crate::log_writer::set_sync_prediction_id(Some(&id)); tracing::trace!(%slot, %id, "Calling predict_worker"); let r = pred.predict_worker(input, slot_sender.clone()); tracing::trace!(%slot, %id, "predict_worker returned"); crate::log_writer::set_sync_prediction_id(None); // If the prediction failed AND the slot was marked cancelled, // treat it as a cancellation. PyThreadState_SetAsyncExc injects // CancelationException which predict_worker sees as a generic // PyErr — we upgrade it to Cancelled here. match r { Err(_) if self.is_cancelled(slot) => { Err(coglet_core::PredictionError::Cancelled) } other => other, } } } }; tracing::trace!(%slot, %id, "Prediction completed"); self.finish_prediction(slot); // Exit prediction context drop(scope_guard); drop(log_guard); match result { Ok(r) => { let is_stream = r.output.is_stream(); PredictResult::success( output_to_json(r.output), start.elapsed().as_secs_f64(), is_stream, ) } Err(e) => { if matches!(e, coglet_core::PredictionError::Cancelled) { PredictResult::cancelled(start.elapsed().as_secs_f64()) } else { PredictResult::failed(e.to_string(), start.elapsed().as_secs_f64()) } } } } fn cancel(&self, slot: SlotId) { // Mark slot as cancelled and determine how to cancel based on state let mut slots = self.slots.lock().expect("slots mutex poisoned"); if let Some(state) = slots.get_mut(&slot) { state.mark_cancelled(); match state { SlotState::AsyncPrediction { .. } => { drop(slots); // Release lock before calling cancel_async_future // Async: cancel via future.cancel() if !self.cancel_async_future(slot) { tracing::trace!(%slot, "No async future to cancel (prediction may have completed)"); } } SlotState::SyncPrediction { py_thread_id, .. } => { let py_thread_id = *py_thread_id; drop(slots); // Release lock // Sync: inject CancelationException into the Python thread // via PyThreadState_SetAsyncExc (fires at next bytecode boundary) crate::cancel::cancel_sync_thread(py_thread_id); } SlotState::Idle => { // Already idle, nothing to cancel tracing::trace!(%slot, "Cancel called on idle slot"); } } } else { tracing::trace!(%slot, "Cancel called on unknown slot"); } } async fn healthcheck(&self) -> coglet_core::orchestrator::HealthcheckResult { // Get predictor let pred = { let guard = self.predictor.lock().expect("predictor mutex poisoned"); match guard.as_ref() { Some(p) => Arc::clone(p), None => { return coglet_core::orchestrator::HealthcheckResult::unhealthy( "Predictor not initialized", ); } } }; // Check if predictor has a healthcheck method let has_healthcheck = Python::attach(|py| pred.has_healthcheck(py)); if !has_healthcheck { // No healthcheck defined = healthy return coglet_core::orchestrator::HealthcheckResult::healthy(); } // Run healthcheck with timeout let is_async = Python::attach(|py| pred.is_healthcheck_async(py)); if is_async { // Async healthcheck - run in event loop with timeout let loop_obj = match self.get_async_loop() { Some(l) => l, None => { return coglet_core::orchestrator::HealthcheckResult::unhealthy( "Async event loop not initialized", ); } }; Python::attach(|py| pred.healthcheck_async(py, &loop_obj)) } else { // Sync healthcheck - run in thread pool with timeout Python::attach(|py| pred.healthcheck_sync(py)) } } } /// Shutdown the asyncio event loop and join the thread. impl Drop for PythonPredictHandler { fn drop(&mut self) { // Stop the event loop if let Some(loop_obj) = self .async_loop .lock() .expect("async_loop mutex poisoned") .take() { Python::attach(|py| { let loop_ref = loop_obj.bind(py); // Get the stop method and schedule it via call_soon_threadsafe match loop_ref.getattr("stop") { Ok(stop_method) => { if let Err(e) = loop_ref.call_method1("call_soon_threadsafe", (stop_method,)) { tracing::warn!(error = %e, "Failed to stop asyncio loop"); } } Err(e) => { tracing::warn!(error = %e, "Failed to get loop.stop method"); } } }); } // Join the thread if let Some(thread) = self .async_thread .lock() .expect("async_thread mutex poisoned") .take() && let Err(e) = thread.join() { tracing::warn!("Failed to join asyncio loop thread: {:?}", e); } } } /// Convert PredictionOutput to serde_json::Value fn output_to_json(output: coglet_core::PredictionOutput) -> serde_json::Value { match output { coglet_core::PredictionOutput::Single(v) => v, coglet_core::PredictionOutput::Stream(v) => serde_json::Value::Array(v), } } ================================================ FILE: crates/coglet-python/tests/test_coglet.py ================================================ """Tests for coglet Python bindings.""" import queue import re import socket import subprocess import sys import threading import time from pathlib import Path import coglet import pytest import requests # ============================================================================= # Module structure tests (no server needed) # ============================================================================= class TestModuleStructure: """Tests for coglet module public API and structure.""" def test_version_is_pep440(self) -> None: """__version__ must be a valid PEP 440 version string.""" # PEP 440: N.N.N, N.N.NaN, N.N.NbN, N.N.NrcN, N.N.N.devN, etc. assert re.match( r"^\d+\.\d+\.\d+(\.dev\d+|a\d+|b\d+|rc\d+)?(\+.+)?$", coglet.__version__, ), f"Not PEP 440: {coglet.__version__!r}" def test_version_is_str(self) -> None: assert isinstance(coglet.__version__, str) def test_build_info_exists(self) -> None: build = coglet.__build__ assert hasattr(build, "version") assert hasattr(build, "git_sha") assert hasattr(build, "build_time") assert hasattr(build, "rustc_version") def test_build_info_fields_are_strings(self) -> None: build = coglet.__build__ assert isinstance(build.version, str) assert isinstance(build.git_sha, str) assert isinstance(build.build_time, str) assert isinstance(build.rustc_version, str) def test_build_info_version_matches_module_version(self) -> None: assert coglet.__build__.version == coglet.__version__ def test_build_info_repr(self) -> None: r = repr(coglet.__build__) assert r.startswith("BuildInfo(") assert "version=" in r assert "git_sha=" in r def test_build_info_frozen(self) -> None: with pytest.raises(AttributeError): coglet.__build__.version = "hacked" # type: ignore[misc] def test_server_exists(self) -> None: assert hasattr(coglet, "server") def test_server_active_is_false(self) -> None: """Outside a worker subprocess, active should be False.""" assert coglet.server.active is False def test_server_active_is_property(self) -> None: """active should be a property (no parens needed), not callable.""" assert isinstance(coglet.server.active, bool) def test_server_frozen(self) -> None: with pytest.raises(AttributeError): coglet.server.foo = "bar" # type: ignore[attr-defined] def test_server_active_not_settable(self) -> None: with pytest.raises(AttributeError): coglet.server.active = True # type: ignore[misc] def test_server_repr(self) -> None: assert repr(coglet.server) == "coglet.server" def test_sdk_submodule_exists(self) -> None: assert hasattr(coglet, "_sdk") def test_sdk_has_slot_log_writer(self) -> None: assert hasattr(coglet._sdk, "_SlotLogWriter") def test_sdk_has_tee_writer(self) -> None: assert hasattr(coglet._sdk, "_TeeWriter") def test_all_excludes_internals(self) -> None: """__all__ should only list public API.""" assert "__version__" in coglet.__all__ assert "__build__" in coglet.__all__ assert "server" in coglet.__all__ # _sdk should not be in __all__ (underscore = private) assert "_sdk" not in coglet.__all__ assert "_impl" not in coglet.__all__ @pytest.fixture def sync_predictor(tmp_path: Path) -> Path: """Create a simple sync predictor.""" predictor = tmp_path / "predict.py" predictor.write_text(""" from cog import BasePredictor class Predictor(BasePredictor): def setup(self): self.prefix = "Hello, " def predict(self, name: str = "World") -> str: return self.prefix + name + "!" """) # Create cog.yaml cog_yaml = tmp_path / "cog.yaml" cog_yaml.write_text(""" predict: "predict.py:Predictor" """) return predictor @pytest.fixture def generator_predictor(tmp_path: Path) -> Path: """Create a generator predictor.""" predictor = tmp_path / "predict.py" predictor.write_text(""" from cog import BasePredictor from typing import Iterator class Predictor(BasePredictor): def setup(self): pass def predict(self, count: int = 3) -> Iterator[str]: for i in range(count): yield f"chunk {i}" """) # Create cog.yaml cog_yaml = tmp_path / "cog.yaml" cog_yaml.write_text(""" predict: "predict.py:Predictor" """) return predictor @pytest.fixture def async_predictor(tmp_path: Path) -> Path: """Create an async predictor.""" predictor = tmp_path / "predict.py" predictor.write_text(""" import asyncio from cog import BasePredictor class Predictor(BasePredictor): def setup(self): self.call_count = 0 async def predict(self, delay: float = 0.1, name: str = "test") -> str: self.call_count += 1 await asyncio.sleep(delay) return f"{name}: done after {delay}s (call #{self.call_count})" """) # Create cog.yaml cog_yaml = tmp_path / "cog.yaml" cog_yaml.write_text(""" predict: "predict.py:Predictor" """) return predictor @pytest.fixture def async_generator_predictor(tmp_path: Path) -> Path: """Create an async generator predictor.""" predictor = tmp_path / "predict.py" predictor.write_text(""" import asyncio from cog import BasePredictor from typing import AsyncIterator class Predictor(BasePredictor): def setup(self): pass async def predict(self, count: int = 3, delay: float = 0.05) -> AsyncIterator[str]: for i in range(count): await asyncio.sleep(delay) yield f"async chunk {i}" """) # Create cog.yaml cog_yaml = tmp_path / "cog.yaml" cog_yaml.write_text(""" predict: "predict.py:Predictor" """) return predictor class CogletServer: """Context manager for running coglet server.""" def __init__(self, predictor_path: Path, port: int = 0): self.predictor_path = predictor_path self.requested_port = port self.port = None self.process = None self.stderr_lines = [] self.stderr_queue = queue.Queue() self.stderr_thread = None def __enter__(self): cmd = [ sys.executable, "-c", f"import coglet; coglet.server.serve('{self.predictor_path}:Predictor', port={self.requested_port})", ] self.process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1, # Line buffered cwd=str( self.predictor_path.parent ), # Run from predictor directory to find cog.yaml ) # Start background thread to read stderr self.stderr_thread = threading.Thread(target=self._read_stderr, daemon=True) self.stderr_thread.start() # Discover actual port from logs self._discover_port() # Wait for server to become ready self._wait_for_ready() return self def __exit__(self, exc_type, exc_val, exc_tb): if self.process: self.process.terminate() self.process.wait(timeout=5) def _read_stderr(self): """Background thread that reads stderr and queues lines.""" try: for line in self.process.stderr: self.stderr_lines.append(line) self.stderr_queue.put(line) except Exception: pass # Process terminated def _discover_port(self, timeout: float = 5.0): """Read stderr until we find the port the server bound to.""" start = time.time() while time.time() - start < timeout: try: line = self.stderr_queue.get(timeout=0.1) except queue.Empty: if self.process.poll() is not None: # Process died raise RuntimeError( f"Server process died during startup\nSTDERR:\n{''.join(self.stderr_lines)}" ) continue # Look for: "Starting coglet server on 0.0.0.0:PORT" match = re.search(r"Starting coglet server on [\d.]+:(\d+)", line) if match: self.port = int(match.group(1)) return raise TimeoutError( f"Could not discover server port within {timeout}s\nSTDERR:\n{''.join(self.stderr_lines)}" ) def _wait_for_ready(self, timeout: float = 10.0): start = time.time() while time.time() - start < timeout: try: resp = requests.get( f"http://localhost:{self.port}/health-check", timeout=1 ) if resp.status_code == 200 and resp.json().get("status") == "READY": return except requests.exceptions.ConnectionError: pass time.sleep(0.1) # Terminate on failure if self.process and self.process.poll() is None: self.process.terminate() self.process.wait(timeout=2) raise TimeoutError( f"Server did not become ready within {timeout}s (port={self.port})\n" f"STDERR:\n{''.join(self.stderr_lines)}" ) @property def base_url(self) -> str: return f"http://localhost:{self.port}" def health_check(self) -> dict: resp = requests.get(f"{self.base_url}/health-check") resp.raise_for_status() return resp.json() def predict(self, input_data: dict) -> dict: resp = requests.post( f"{self.base_url}/predictions", json={"input": input_data}, ) return resp.json() class TestHealthCheck: """Tests for health check endpoint.""" def test_returns_ready_status(self, sync_predictor: Path): with CogletServer(sync_predictor) as server: health = server.health_check() assert health["status"] == "READY" def test_returns_version_info(self, sync_predictor: Path): with CogletServer(sync_predictor) as server: health = server.health_check() assert "version" in health assert "coglet" in health["version"] assert "python" in health["version"] assert "python_sdk" in health["version"] class TestSyncPredictor: """Tests for sync predictor.""" def test_basic_prediction(self, sync_predictor: Path): with CogletServer(sync_predictor) as server: result = server.predict({"name": "Claude"}) assert result["status"] == "succeeded" assert result["output"] == "Hello, Claude!" def test_default_input(self, sync_predictor: Path): with CogletServer(sync_predictor) as server: result = server.predict({}) assert result["status"] == "succeeded" assert result["output"] == "Hello, World!" def test_includes_predict_time(self, sync_predictor: Path): with CogletServer(sync_predictor) as server: result = server.predict({"name": "test"}) assert "metrics" in result assert "predict_time" in result["metrics"] assert result["metrics"]["predict_time"] >= 0 class TestGeneratorPredictor: """Tests for generator predictor.""" def test_returns_array_output(self, generator_predictor: Path): with CogletServer(generator_predictor) as server: result = server.predict({"count": 3}) assert result["status"] == "succeeded" assert result["output"] == ["chunk 0", "chunk 1", "chunk 2"] def test_custom_count(self, generator_predictor: Path): with CogletServer(generator_predictor) as server: result = server.predict({"count": 5}) assert len(result["output"]) == 5 class TestAsyncPredictor: """Tests for async predictor.""" def test_basic_prediction(self, async_predictor: Path): with CogletServer(async_predictor) as server: result = server.predict({"delay": 0.1, "name": "async"}) assert result["status"] == "succeeded" assert "async: done" in result["output"] def test_sequential_requests(self, async_predictor: Path): """Sequential requests both succeed (subprocess isolation means no concurrency).""" with CogletServer(async_predictor) as server: # Run two sequential requests result1 = server.predict({"delay": 0.1, "name": "req1"}) result2 = server.predict({"delay": 0.1, "name": "req2"}) assert result1["status"] == "succeeded" assert result2["status"] == "succeeded" assert "req1" in result1["output"] assert "req2" in result2["output"] class TestAsyncGeneratorPredictor: """Tests for async generator predictor.""" def test_returns_array_output(self, async_generator_predictor: Path): with CogletServer(async_generator_predictor) as server: result = server.predict({"count": 3, "delay": 0.01}) assert result["status"] == "succeeded" assert result["output"] == [ "async chunk 0", "async chunk 1", "async chunk 2", ] @pytest.fixture def slow_sync_predictor(tmp_path: Path) -> Path: """Create a sync predictor that busy-loops (cancellable at bytecode boundaries).""" predictor = tmp_path / "predict.py" predictor.write_text(""" import time from cog import BasePredictor class Predictor(BasePredictor): def setup(self): pass def predict(self, duration: float = 60.0) -> str: # Busy-loop in Python (hits bytecode boundaries, so PyThreadState_SetAsyncExc works) deadline = time.monotonic() + duration while time.monotonic() < deadline: pass return "completed" """) cog_yaml = tmp_path / "cog.yaml" cog_yaml.write_text(""" predict: "predict.py:Predictor" """) return predictor @pytest.fixture def blocking_sleep_predictor(tmp_path: Path) -> Path: """Create a sync predictor that blocks in time.sleep() (C-level nanosleep).""" predictor = tmp_path / "predict.py" predictor.write_text(""" import time from cog import BasePredictor class Predictor(BasePredictor): def setup(self): pass def predict(self, duration: float = 60.0) -> str: # C-level blocking sleep — PyThreadState_SetAsyncExc fires after sleep returns time.sleep(duration) return "completed" """) cog_yaml = tmp_path / "cog.yaml" cog_yaml.write_text(""" predict: "predict.py:Predictor" """) return predictor @pytest.fixture def slow_async_predictor(tmp_path: Path) -> Path: """Create an async predictor that sleeps for a long time (cancellable).""" predictor = tmp_path / "predict.py" predictor.write_text(""" import asyncio from cog import BasePredictor class Predictor(BasePredictor): def setup(self): pass async def predict(self, sleep_time: float = 60.0) -> str: await asyncio.sleep(sleep_time) return "completed" """) cog_yaml = tmp_path / "cog.yaml" cog_yaml.write_text(""" predict: "predict.py:Predictor" """) return predictor def _wait_for_health_status( server: "CogletServer", status: str, timeout: float = 5.0 ) -> None: """Poll health check until the expected status is reached, or fail.""" deadline = time.time() + timeout last_status = "" while time.time() < deadline: health = server.health_check() last_status = health["status"] if last_status == status: return time.sleep(0.1) stderr = "".join(server.stderr_lines) pytest.fail( f"Server did not reach status {status!r} within {timeout}s\n" f"Last status: {last_status!r}\n" f"STDERR:\n{stderr}" ) class TestCancellation: """Tests for prediction cancellation.""" def test_cancel_endpoint_returns_404_for_unknown_id(self, sync_predictor: Path): """Test that cancelling an unknown prediction returns 404.""" with CogletServer(sync_predictor) as server: resp = requests.post(f"{server.base_url}/predictions/unknown-id/cancel") assert resp.status_code == 404 result = resp.json() assert result == {} def test_prediction_response_includes_id(self, sync_predictor: Path): """Test that prediction responses include an ID.""" with CogletServer(sync_predictor) as server: result = server.predict({"name": "test"}) assert "id" in result assert result["id"].startswith("pred_") def test_cancel_running_sync_prediction(self, slow_sync_predictor: Path): """Test that cancelling a running sync prediction actually terminates it.""" with CogletServer(slow_sync_predictor) as server: # Start a long-running prediction asynchronously prediction_id = "cancel-sync-test" resp = requests.put( f"{server.base_url}/predictions/{prediction_id}", json={"input": {"duration": 60.0}}, headers={"Prefer": "respond-async"}, ) assert resp.status_code == 202 # Wait for the prediction to actually be processing (slot occupied) _wait_for_health_status(server, "BUSY", timeout=5.0) # Cancel the prediction cancel_resp = requests.post( f"{server.base_url}/predictions/{prediction_id}/cancel" ) assert cancel_resp.status_code == 200 # Wait for the server to return to READY (slot freed after cancel) _wait_for_health_status(server, "READY", timeout=10.0) def test_cancel_running_async_prediction(self, slow_async_predictor: Path): """Test that cancelling a running async prediction actually terminates it.""" with CogletServer(slow_async_predictor) as server: # Start a long-running async prediction prediction_id = "cancel-async-test" resp = requests.put( f"{server.base_url}/predictions/{prediction_id}", json={"input": {"sleep_time": 60.0}}, headers={"Prefer": "respond-async"}, ) assert resp.status_code == 202 # Wait for the prediction to actually be processing (slot occupied) _wait_for_health_status(server, "BUSY", timeout=5.0) # Cancel the prediction cancel_resp = requests.post( f"{server.base_url}/predictions/{prediction_id}/cancel" ) assert cancel_resp.status_code == 200 # Wait for the server to return to READY (slot freed after cancel) _wait_for_health_status(server, "READY", timeout=10.0) @pytest.mark.parametrize( ("predictor_fixture", "duration", "ready_timeout"), [ # Busy-loop: cancels immediately at the next bytecode boundary ("slow_sync_predictor", 60.0, 10.0), # time.sleep (nanosleep): blocks in C; cancel fires once sleep returns ("blocking_sleep_predictor", 5.0, 15.0), ], ids=["busy_loop", "nanosleep"], ) def test_repeated_cancel_is_idempotent( self, predictor_fixture: str, duration: float, ready_timeout: float, request: pytest.FixtureRequest, ): """Test that cancelling the same prediction multiple times doesn't panic or break. Covers both busy-loop (bytecode boundaries) and time.sleep (C-level nanosleep). For nanosleep the cancel is deferred until the sleep returns, so we use a short duration and a longer timeout for the server to recover. """ predictor_path: Path = request.getfixturevalue(predictor_fixture) with CogletServer(predictor_path) as server: prediction_id = "cancel-repeat-test" resp = requests.put( f"{server.base_url}/predictions/{prediction_id}", json={"input": {"duration": duration}}, headers={"Prefer": "respond-async"}, ) assert resp.status_code == 202 # Wait for the prediction to actually be processing _wait_for_health_status(server, "BUSY", timeout=5.0) # Cancel the same prediction multiple times in rapid succession for i in range(5): cancel_resp = requests.post( f"{server.base_url}/predictions/{prediction_id}/cancel" ) # First cancel returns 200 (found), subsequent may return 200 or # 404 depending on timing — but must never panic or 500. assert cancel_resp.status_code in (200, 404), ( f"Cancel attempt {i + 1} returned unexpected {cancel_resp.status_code}" ) # Server should recover to READY _wait_for_health_status(server, "READY", timeout=ready_timeout) def test_cancel_sync_prediction_connection_drop(self, slow_sync_predictor: Path): """Test that dropping a sync connection cancels the prediction.""" with CogletServer(slow_sync_predictor) as server: # Start a sync (non-async) prediction with a short timeout # The connection drop should trigger cancellation via SyncPredictionGuard sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(("localhost", server.port)) request_body = '{"input": {"duration": 60.0}}' http_request = ( f"POST /predictions HTTP/1.1\r\n" f"Host: localhost:{server.port}\r\n" f"Content-Type: application/json\r\n" f"Content-Length: {len(request_body)}\r\n" f"\r\n" f"{request_body}" ) sock.sendall(http_request.encode()) # Wait for the prediction to be processing (slot occupied) _wait_for_health_status(server, "BUSY", timeout=5.0) # Drop the connection abruptly sock.close() # Wait for the server to return to READY (slot freed after cancel) _wait_for_health_status(server, "READY", timeout=10.0) ================================================ FILE: crates/deny.toml ================================================ # cargo-deny configuration for coglet crates # See: https://embarkstudios.github.io/cargo-deny/ [graph] all-features = false no-default-features = false [output] feature-depth = 1 [advisories] ignore = [ # Unmaintained unic-* crates from rustpython-parser (transitive dep of pyo3-stub-gen) # No fix available - these are only used at build time for stub generation "RUSTSEC-2025-0075", # unic-char-range "RUSTSEC-2025-0080", # unic-common "RUSTSEC-2025-0081", # unic-char-property "RUSTSEC-2025-0090", # unic-emoji-char "RUSTSEC-2025-0098", # unic-ucd-version "RUSTSEC-2025-0100", # unic-ucd-ident ] [licenses] # Apache-2.0 compatible licenses only (no GPL/copyleft) allow = [ "MIT", "MIT-0", # MIT No Attribution (more permissive than MIT) "Apache-2.0", "Apache-2.0 WITH LLVM-exception", "BSD-2-Clause", "BSD-3-Clause", "ISC", "Zlib", "0BSD", "Unicode-3.0", "Unicode-DFS-2016", "CC0-1.0", "MPL-2.0", # Weak copyleft, Apache compatible for our use "CDLA-Permissive-2.0", # Community Data License, for Mozilla CA certificate data ] confidence-threshold = 0.8 exceptions = [] [licenses.private] ignore = false registries = [] [bans] multiple-versions = "warn" wildcards = "allow" highlight = "all" workspace-default-features = "allow" external-default-features = "allow" allow = [] allow-workspace = false deny = [] skip = [] skip-tree = [] [sources] unknown-registry = "warn" unknown-git = "warn" allow-registry = ["https://github.com/rust-lang/crates.io-index"] allow-git = [] [sources.allow-org] github = [] gitlab = [] bitbucket = [] ================================================ FILE: docs/CNAME ================================================ cog.run ================================================ FILE: docs/cli.md ================================================ # CLI reference ## `cog` Containers for machine learning. To get started, take a look at the documentation: https://github.com/replicate/cog **Examples** ``` To run a command inside a Docker environment defined with Cog: $ cog run echo hello world ``` **Options** ``` --debug Show debugging output -h, --help help for cog --no-color Disable colored output --version Show version of Cog ``` ## `cog build` Build a Docker image from the cog.yaml in the current directory. The generated image contains your model code, dependencies, and the Cog runtime. It can be run locally with 'cog predict' or pushed to a registry with 'cog push'. ``` cog build [flags] ``` **Examples** ``` # Build with default settings cog build # Build and tag the image cog build -t my-model:latest # Build without using the cache cog build --no-cache # Build with model weights in a separate layer cog build --separate-weights -t my-model:v1 ``` **Options** ``` -f, --file string The name of the config file. (default "cog.yaml") -h, --help help for build --no-cache Do not use cache when building the image --openapi-schema string Load OpenAPI schema from a file --progress string Set type of build progress output, 'auto' (default), 'tty', 'plain', or 'quiet' (default "auto") --secret stringArray Secrets to pass to the build environment in the form 'id=foo,src=/path/to/file' --separate-weights Separate model weights from code in image layers -t, --tag string A name for the built image in the form 'repository:tag' --use-cog-base-image Use pre-built Cog base image for faster cold boots (default true) --use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto") ``` ## `cog init` Create a cog.yaml and predict.py in the current directory. These files provide a starting template for defining your model's environment and prediction interface. Edit them to match your model's requirements. ``` cog init [flags] ``` **Examples** ``` # Set up a new Cog project in the current directory cog init ``` **Options** ``` -h, --help help for init ``` ## `cog login` Log in to a container registry. For Replicate's registry (r8.im), this command handles authentication through Replicate's token-based flow. For other registries, this command prompts for username and password, then stores credentials using Docker's credential system. ``` cog login [flags] ``` **Options** ``` -h, --help help for login --token-stdin Pass login token on stdin instead of opening a browser. You can find your Replicate login token at https://replicate.com/auth/token ``` ## `cog predict` Run a prediction. If 'image' is passed, it will run the prediction on that Docker image. It must be an image that has been built by Cog. Otherwise, it will build the model in the current directory and run the prediction on that. ``` cog predict [image] [flags] ``` **Examples** ``` # Run a prediction with named inputs cog predict -i prompt="a photo of a cat" # Pass a file as input cog predict -i image=@photo.jpg # Save output to a file cog predict -i image=@input.jpg -o output.png # Pass multiple inputs cog predict -i prompt="sunset" -i width=1024 -i height=768 # Run against a pre-built image cog predict r8.im/your-username/my-model -i prompt="hello" # Pass inputs as JSON echo '{"prompt": "a cat"}' | cog predict --json @- ``` **Options** ``` -e, --env stringArray Environment variables, in the form name=value -f, --file string The name of the config file. (default "cog.yaml") --gpus docker run --gpus GPU devices to add to the container, in the same format as docker run --gpus. -h, --help help for predict -i, --input stringArray Inputs, in the form name=value. if value is prefixed with @, then it is read from a file on disk. E.g. -i path=@image.jpg --json string Pass inputs as JSON object, read from file (@inputs.json) or via stdin (@-) -o, --output string Output path --progress string Set type of build progress output, 'auto' (default), 'tty', 'plain', or 'quiet' (default "auto") --setup-timeout uint32 The timeout for a container to setup (in seconds). (default 300) --use-cog-base-image Use pre-built Cog base image for faster cold boots (default true) --use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto") --use-replicate-token Pass REPLICATE_API_TOKEN from local environment into the model context ``` ## `cog push` Build a Docker image from cog.yaml and push it to a container registry. Cog can push to any OCI-compliant registry. When pushing to Replicate's registry (r8.im), run 'cog login' first to authenticate. ``` cog push [IMAGE] [flags] ``` **Examples** ``` # Push to Replicate cog push r8.im/your-username/my-model # Push to any OCI registry cog push registry.example.com/your-username/model-name # Push with model weights in a separate layer (Replicate only) cog push r8.im/your-username/my-model --separate-weights ``` **Options** ``` -f, --file string The name of the config file. (default "cog.yaml") -h, --help help for push --no-cache Do not use cache when building the image --openapi-schema string Load OpenAPI schema from a file --progress string Set type of build progress output, 'auto' (default), 'tty', 'plain', or 'quiet' (default "auto") --secret stringArray Secrets to pass to the build environment in the form 'id=foo,src=/path/to/file' --separate-weights Separate model weights from code in image layers --use-cog-base-image Use pre-built Cog base image for faster cold boots (default true) --use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto") ``` ## `cog run` Run a command inside a Docker environment defined by cog.yaml. Cog builds a temporary image from your cog.yaml configuration and runs the given command inside it. This is useful for debugging, running scripts, or exploring the environment your model will run in. ``` cog run [arg...] [flags] ``` **Examples** ``` # Open a Python interpreter inside the model environment cog run python # Run a script cog run python train.py # Run with environment variables cog run -e HUGGING_FACE_HUB_TOKEN=abc123 python download.py # Expose a port (e.g. for Jupyter) cog run -p 8888 jupyter notebook ``` **Options** ``` -e, --env stringArray Environment variables, in the form name=value -f, --file string The name of the config file. (default "cog.yaml") --gpus docker run --gpus GPU devices to add to the container, in the same format as docker run --gpus. -h, --help help for run --progress string Set type of build progress output, 'auto' (default), 'tty', 'plain', or 'quiet' (default "auto") -p, --publish stringArray Publish a container's port to the host, e.g. -p 8000 --use-cog-base-image Use pre-built Cog base image for faster cold boots (default true) --use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto") ``` ## `cog serve` Run a prediction HTTP server. Builds the model and starts an HTTP server that exposes the model's inputs and outputs as a REST API. Compatible with the Cog HTTP protocol. ``` cog serve [flags] ``` **Examples** ``` # Start the server on the default port (8393) cog serve # Start on a custom port cog serve -p 5000 # Test the server curl http://localhost:8393/predictions \ -X POST \ -H 'Content-Type: application/json' \ -d '{"input": {"prompt": "a cat"}}' ``` **Options** ``` -f, --file string The name of the config file. (default "cog.yaml") --gpus docker run --gpus GPU devices to add to the container, in the same format as docker run --gpus. -h, --help help for serve -p, --port int Port on which to listen (default 8393) --progress string Set type of build progress output, 'auto' (default), 'tty', 'plain', or 'quiet' (default "auto") --upload-url string Upload URL for file outputs (e.g. https://example.com/upload/) --use-cog-base-image Use pre-built Cog base image for faster cold boots (default true) --use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto") ``` ================================================ FILE: docs/deploy.md ================================================ # Deploy models with Cog Cog containers are Docker containers that serve an HTTP server for running predictions on your model. You can deploy them anywhere that Docker containers run. The server inside Cog containers is **coglet**, a Rust-based prediction server that handles HTTP requests, worker process management, and prediction execution. This guide assumes you have a model packaged with Cog. If you don't, [follow our getting started guide](getting-started-own-model.md), or use [an example model](https://github.com/replicate/cog-examples). ## Getting started First, build your model: ```console cog build -t my-model ``` You can serve predictions locally with `cog serve`: ```console cog serve # or, from a built image: cog serve my-model ``` Alternatively, start the Docker container directly: ```shell # If your model uses a CPU: docker run -d -p 5001:5000 my-model # If your model uses a GPU: docker run -d -p 5001:5000 --gpus all my-model ``` The server listens on port 5000 inside the container (mapped to 5001 above). To view the OpenAPI schema, open [localhost:5001/openapi.json](http://localhost:5001/openapi.json) in your browser or use cURL to make a request: ```console curl http://localhost:5001/openapi.json ``` To stop the server, run: ```console docker kill my-model ``` To run a prediction on the model, call the `/predictions` endpoint, passing input in the format expected by your model: ```console curl http://localhost:5001/predictions -X POST \ --header "Content-Type: application/json" \ --data '{"input": {"image": "https://.../input.jpg"}}' ``` For more details about the HTTP API, see the [HTTP API reference documentation](http.md). ## Health checks The server exposes a `GET /health-check` endpoint that returns the current status of the model container. Use this for readiness probes in orchestration systems like Kubernetes. ```console curl http://localhost:5001/health-check ``` The response includes a `status` field with values like `STARTING`, `READY`, `BUSY`, `SETUP_FAILED`, or `DEFUNCT`. See the [HTTP API reference](http.md#get-health-check) for full details. ## Concurrency By default, the server processes one prediction at a time. To enable concurrent predictions, set the `concurrency.max` option in `cog.yaml`: ```yaml concurrency: max: 4 ``` See the [`cog.yaml` reference](yaml.md#concurrency) for more details. ## Environment variables You can configure runtime behavior with environment variables: - `COG_SETUP_TIMEOUT`: Maximum time in seconds for the `setup()` method (default: no timeout). See the [environment variables reference](environment.md) for the full list. ================================================ FILE: docs/environment.md ================================================ # Environment variables This guide lists the environment variables that change how Cog functions. ## Build-time variables ### `COG_SDK_WHEEL` Controls which cog Python SDK wheel is installed in the Docker image during `cog build`. Takes precedence over `build.sdk_version` in `cog.yaml`. **Supported values:** | Value | Description | | -------------------- | ---------------------------------------------------- | | `pypi` | Install latest version from PyPI | | `pypi:0.12.0` | Install specific version from PyPI | | `dist` | Use wheel from `dist/` directory (requires git repo) | | `https://...` | Install from URL | | `/path/to/wheel.whl` | Install from local file path | **Default behavior:** - **Release builds**: Installs latest cog from PyPI - **Development builds**: Auto-detects wheel in `dist/` directory, falls back to latest PyPI **Examples:** ```console # Use specific PyPI version $ COG_SDK_WHEEL=pypi:0.11.0 cog build # Use local development wheel $ COG_SDK_WHEEL=dist cog build # Use wheel from URL $ COG_SDK_WHEEL=https://example.com/cog-0.12.0-py3-none-any.whl cog build ``` The `dist` option searches for wheels in: 1. `./dist/` (current directory) 2. `$REPO_ROOT/dist/` (if REPO_ROOT is set) 3. `/dist/` (via `git rev-parse`, useful when running from subdirectories) ### `COGLET_WHEEL` Controls which coglet wheel is installed in the Docker image. Coglet is the Rust-based prediction server. **Supported values:** Same as `COG_SDK_WHEEL` **Default behavior:** For development builds, auto-detects a wheel in `dist/`. For release builds, installs the latest version from PyPI. Can be overridden with an explicit value. **Examples:** ```console # Use local development wheel $ COGLET_WHEEL=dist cog build # Use specific version from PyPI $ COGLET_WHEEL=pypi:0.1.0 cog build ``` ## Runtime variables ### `COG_NO_UPDATE_CHECK` By default, Cog automatically checks for updates and notifies you if there is a new version available. To disable this behavior, set the `COG_NO_UPDATE_CHECK` environment variable to any value. ```console $ COG_NO_UPDATE_CHECK=1 cog build # runs without automatic update check ``` ### `COG_SETUP_TIMEOUT` Controls the maximum time (in seconds) allowed for the model's `setup()` method to complete. If setup exceeds this timeout, the server will report a setup failure. By default, there is no timeout — setup runs indefinitely. Set to `0` to disable the timeout (same as default). Invalid values are ignored with a warning. ```console $ COG_SETUP_TIMEOUT=300 docker run -p 5000:5000 my-model # 5-minute setup timeout ``` ### `COG_CA_CERT` Injects a custom CA certificate into the Docker image during `cog build`. This is useful when building behind a corporate proxy or VPN that uses custom certificate authorities (e.g. Cloudflare WARP). **Supported values:** | Value | Description | | -------------------------------- | ----------------------------------------------------------- | | `/path/to/cert.crt` | Path to a single PEM certificate file | | `/path/to/certs/` | Directory of `.crt` and `.pem` files (all are concatenated) | | `-----BEGIN CERTIFICATE-----...` | Inline PEM certificate | | `LS0tLS1CRUdJTi...` | Base64-encoded PEM certificate | The certificate is installed into the system CA store and the `SSL_CERT_FILE` and `REQUESTS_CA_BUNDLE` environment variables are set automatically in the built image. **Examples:** ```console # From a file $ COG_CA_CERT=/usr/local/share/ca-certificates/corporate-ca.crt cog build # From a directory of certs $ COG_CA_CERT=/etc/custom-certs/ cog build # Inline (e.g. from a CI secret) $ COG_CA_CERT="$(cat /path/to/cert.pem)" cog build ``` ================================================ FILE: docs/getting-started-own-model.md ================================================ # Getting started with your own model This guide will show you how to put your own machine learning model in a Docker image using Cog. If you haven't got a model to try out, you'll want to follow the [main getting started guide](getting-started.md). ## Prerequisites - **macOS or Linux**. Cog works on macOS and Linux, but does not currently support Windows. - **Docker**. Cog uses Docker to create a container for your model. You'll need to [install Docker](https://docs.docker.com/get-docker/) before you can run Cog. ## Initialization First, install Cog if you haven't already: **macOS (recommended):** ```sh brew install replicate/tap/cog ``` **Linux or macOS (manual):** ```sh sudo curl -o /usr/local/bin/cog -L https://github.com/replicate/cog/releases/latest/download/cog_`uname -s`_`uname -m` sudo chmod +x /usr/local/bin/cog ``` To configure your project for use with Cog, you'll need to add two files: - [`cog.yaml`](yaml.md) defines system requirements, Python package dependencies, etc - [`predict.py`](python.md) describes the prediction interface for your model Use the `cog init` command to generate these files in your project: ```sh $ cd path/to/your/model $ cog init ``` ## Define the Docker environment The `cog.yaml` file defines all the different things that need to be installed for your model to run. You can think of it as a simple way of defining a Docker image. For example: ```yaml build: python_version: "3.13" python_requirements: requirements.txt ``` With a `requirements.txt` containing your dependencies: ``` torch==2.6.0 ``` This will generate a Docker image with Python 3.13 and PyTorch 2 installed, for both CPU and GPU, with the correct version of CUDA, and various other sensible best-practices. To run a command inside this environment, prefix it with `cog run`: ``` $ cog run python ✓ Building Docker image from cog.yaml... Successfully built 8f54020c8981 Running 'python' in Docker with the current directory mounted as a volume... ──────────────────────────────────────────────────────────────────────────────────────── Python 3.13.x (main, ...) [GCC 12.2.0] on linux Type "help", "copyright", "credits" or "license" for more information. >>> ``` This is handy for ensuring a consistent environment for development or training. With `cog.yaml`, you can also install system packages and other things. [Take a look at the full reference to see what else you can do.](yaml.md) ## Define how to run predictions The next step is to update `predict.py` to define the interface for running predictions on your model. The `predict.py` generated by `cog init` looks something like this: ```python from cog import BasePredictor, Path, Input import torch class Predictor(BasePredictor): def setup(self): """Load the model into memory to make running multiple predictions efficient""" self.net = torch.load("weights.pth") def predict(self, image: Path = Input(description="Image to enlarge"), scale: float = Input(description="Factor to scale image by", default=1.5) ) -> Path: """Run a single prediction on the model""" # ... pre-processing ... output = self.net(input) # ... post-processing ... return output ``` Edit your `predict.py` file and fill in the functions with your own model's setup and prediction code. You might need to import parts of your model from another file. You also need to define the inputs to your model as arguments to the `predict()` function, as demonstrated above. For each argument, you need to annotate with a type. The supported types are: - `str`: a string - `int`: an integer - `float`: a floating point number - `bool`: a boolean - `cog.File`: a file-like object representing a file (deprecated — use `cog.Path` instead) - `cog.Path`: a path to a file on disk You can provide more information about the input with the `Input()` function, as shown above. It takes these basic arguments: - `description`: A description of what to pass to this input for users of the model - `default`: A default value to set the input to. If this argument is not passed, the input is required. If it is explicitly set to `None`, the input is optional. - `ge`: For `int` or `float` types, the value should be greater than or equal to this number. - `le`: For `int` or `float` types, the value should be less than or equal to this number. - `min_length`: For `str` types, the minimum length of the string. - `max_length`: For `str` types, the maximum length of the string. - `regex`: For `str` types, the string must match this regular expression. - `choices`: For `str` or `int` types, a list of possible values for this input. - `deprecated`: Mark this input as deprecated with a message explaining what to use instead. There are some more advanced options you can pass, too. For more details, [take a look at the prediction interface documentation](python.md). Next, add the line `predict: "predict.py:Predictor"` to your `cog.yaml`, so it looks something like this: ```yaml build: python_version: "3.13" python_requirements: requirements.txt predict: "predict.py:Predictor" ``` That's it! To test this works, try running a prediction on the model: ``` $ cog predict -i image=@input.jpg ✓ Building Docker image from cog.yaml... Successfully built 664ef88bc1f4 ✓ Model running in Docker image 664ef88bc1f4 Written output to output.png ``` To pass more inputs to the model, you can add more `-i` options: ``` $ cog predict -i image=@image.jpg -i scale=2.0 ``` In this case it is just a number, not a file, so you don't need the `@` prefix. ## Using GPUs To use GPUs with Cog, add the `gpu: true` option to the `build` section of your `cog.yaml`: ```yaml build: gpu: true ... ``` Cog will use the [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) base image and automatically figure out what versions of CUDA and cuDNN to use based on the version of Python, PyTorch, and Tensorflow that you are using. For more details, [see the `gpu` section of the `cog.yaml` reference](yaml.md#gpu). ## Next steps Next, you might want to take a look at: - [A guide explaining how to deploy a model.](deploy.md) - [The reference for `cog.yaml`](yaml.md) - [The reference for the Python library](python.md) ================================================ FILE: docs/getting-started.md ================================================ # Getting started This guide will walk you through what you can do with Cog by using an example model. > [!TIP] > Using a language model to help you write the code for your new Cog model? > > Feed it [https://cog.run/llms.txt](https://cog.run/llms.txt), which has all of Cog's documentation bundled into a single file. To learn more about this format, check out [llmstxt.org](https://llmstxt.org). ## Prerequisites - **macOS or Linux**. Cog works on macOS and Linux, but does not currently support Windows. - **Docker**. Cog uses Docker to create a container for your model. You'll need to [install Docker](https://docs.docker.com/get-docker/) before you can run Cog. ## Install Cog **macOS (recommended):** ```bash brew install replicate/tap/cog ``` **Linux or macOS (manual):** ```bash sudo curl -o /usr/local/bin/cog -L https://github.com/replicate/cog/releases/latest/download/cog_`uname -s`_`uname -m` sudo chmod +x /usr/local/bin/cog sudo xattr -d com.apple.quarantine /usr/local/bin/cog 2>/dev/null || true ``` > [!NOTE] > **macOS: "cannot be opened because the developer cannot be verified"** > > If you downloaded the binary manually (via `curl` or a browser) and see this Gatekeeper warning, run: > > ```bash > sudo xattr -d com.apple.quarantine /usr/local/bin/cog > ``` > > Installing via `brew install replicate/tap/cog` handles this automatically. ## Create a project Let's make a directory to work in: ```bash mkdir cog-quickstart cd cog-quickstart ``` ## Run commands The simplest thing you can do with Cog is run a command inside a Docker environment. The first thing you need to do is create a file called `cog.yaml`: ```yaml build: python_version: "3.13" ``` Then, you can run any command inside this environment. For example, enter ```bash cog run python ``` and you'll get an interactive Python shell: ```none ✓ Building Docker image from cog.yaml... Successfully built 8f54020c8981 Running 'python' in Docker with the current directory mounted as a volume... ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── Python 3.13.x (main, ...) [GCC 12.2.0] on linux Type "help", "copyright", "credits" or "license" for more information. >>> ``` (Hit Ctrl-D to exit the Python shell.) Inside this Docker environment you can do anything – run a Jupyter notebook, your training script, your evaluation script, and so on. ## Run predictions on a model Let's pretend we've trained a model. With Cog, we can define how to run predictions on it in a standard way, so other people can easily run predictions on it without having to hunt around for a prediction script. We need to write some code to describe how predictions are run on the model. Save this to `predict.py`: ```python import os os.environ["TORCH_HOME"] = "." import torch from cog import BasePredictor, Input, Path from PIL import Image from torchvision import models WEIGHTS = models.ResNet50_Weights.IMAGENET1K_V1 class Predictor(BasePredictor): def setup(self): """Load the model into memory to make running multiple predictions efficient""" self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = models.resnet50(weights=WEIGHTS).to(self.device) self.model.eval() def predict(self, image: Path = Input(description="Image to classify")) -> dict: """Run a single prediction on the model""" img = Image.open(image).convert("RGB") preds = self.model(WEIGHTS.transforms()(img).unsqueeze(0).to(self.device)) top3 = preds[0].softmax(0).topk(3) categories = WEIGHTS.meta["categories"] return {categories[i]: p.detach().item() for p, i in zip(*top3)} ``` We also need to point Cog at this, and tell it what Python dependencies to install. Save this to `requirements.txt`: ``` pillow==11.1.0 torch==2.6.0 torchvision==0.21.0 ``` Then update `cog.yaml` to look like this: ```yaml build: python_version: "3.13" python_requirements: requirements.txt predict: "predict.py:Predictor" ``` > [!TIP] > If you have a machine with an NVIDIA GPU attached, add `gpu: true` to the `build` section of your `cog.yaml` to enable GPU acceleration. Let's grab an image to test the model with: ```bash IMAGE_URL=https://gist.githubusercontent.com/bfirsh/3c2115692682ae260932a67d93fd94a8/raw/56b19f53f7643bb6c0b822c410c366c3a6244de2/mystery.jpg curl $IMAGE_URL > input.jpg ``` Now, let's run the model using Cog: ```bash cog predict -i image=@input.jpg ``` If you see the following output ```json { "tiger_cat": 0.4874822497367859, "tabby": 0.23169134557247162, "Egyptian_cat": 0.09728282690048218 } ``` then it worked! Note: The first time you run `cog predict`, the build process will be triggered to generate a Docker container that can run your model. The next time you run `cog predict` the pre-built container will be used. ## Build an image We can bake your model's code, the trained weights, and the Docker environment into a Docker image. This image serves predictions with an HTTP server, and can be deployed to anywhere that Docker runs to serve real-time predictions. ```bash cog build -t resnet # Building Docker image... # Built resnet:latest ``` You can run this image with `cog predict` by passing the filename as an argument: ```bash cog predict resnet -i image=@input.jpg ``` Or, you can run it with Docker directly, and it'll serve an HTTP server: ```bash docker run -d --rm -p 5000:5000 resnet ``` We can send inputs directly with `curl`: ```bash curl http://localhost:5000/predictions -X POST \ -H 'Content-Type: application/json' \ -d '{"input": {"image": "https://gist.githubusercontent.com/bfirsh/3c2115692682ae260932a67d93fd94a8/raw/56b19f53f7643bb6c0b822c410c366c3a6244de2/mystery.jpg"}}' ``` As a shorthand, you can add the Docker image's name as an extra line in `cog.yaml`: ```yaml image: "r8.im/replicate/resnet" ``` Once you've done this, you can use `cog push` to build and push the image to a Docker registry: ```bash cog push # Building r8.im/replicate/resnet... # Pushing r8.im/replicate/resnet... # Pushed! ``` The Docker image is now accessible to anyone or any system that has access to this Docker registry. ## Next steps Those are the basics! Next, you might want to take a look at: - [A guide to help you set up your own model on Cog.](getting-started-own-model.md) - [A guide explaining how to deploy a model.](deploy.md) - [Reference for `cog.yaml`](yaml.md) - [Reference for the Python library](python.md) ================================================ FILE: docs/http.md ================================================ # HTTP API > [!TIP] > For information about how to run the HTTP server, > see [our documentation on deploying models](deploy.md). When you run a Docker image built by Cog, it serves an HTTP API for making predictions. The server supports both synchronous and asynchronous prediction creation: - **Synchronous**: The server waits until the prediction is completed and responds with the result. - **Asynchronous**: The server immediately returns a response and processes the prediction in the background. The client can create a prediction asynchronously by setting the `Prefer: respond-async` header in their request. When provided, the server responds immediately after starting the prediction with `202 Accepted` status and a prediction object in status `processing`. > [!NOTE] > The only supported way to receive updates on the status of predictions > started asynchronously is using [webhooks](#webhooks). > Polling for prediction status is not currently supported. You can also use certain server endpoints to create predictions idempotently, such that if a client calls this endpoint more than once with the same ID (for example, due to a network interruption) while the prediction is still running, no new prediction is created. Instead, the client receives a `202 Accepted` response with the initial state of the prediction. --- Here's a summary of the prediction creation endpoints: | Endpoint | Header | Behavior | | ---------------------------------- | ----------------------- | ---------------------------- | | `POST /predictions` | - | Synchronous, non-idempotent | | `POST /predictions` | `Prefer: respond-async` | Asynchronous, non-idempotent | | `PUT /predictions/` | - | Synchronous, idempotent | | `PUT /predictions/` | `Prefer: respond-async` | Asynchronous, idempotent | Choose the endpoint that best fits your needs: - Use synchronous endpoints when you want to wait for the prediction result. - Use asynchronous endpoints when you want to start a prediction and receive updates via webhooks. - Use idempotent endpoints when you need to safely retry requests without creating duplicate predictions. ## Webhooks You can provide a `webhook` parameter in the client request body when creating a prediction. ```http POST /predictions HTTP/1.1 Content-Type: application/json; charset=utf-8 Prefer: respond-async { "input": {"prompt": "A picture of an onion with sunglasses"}, "webhook": "https://example.com/webhook/prediction" } ``` The server makes requests to the provided URL with the current state of the prediction object in the request body at the following times. - `start`: Once, when the prediction starts (`status` is `starting`). - `output`: Each time a predict function generates an output (either once using `return` or multiple times using `yield`) - `logs`: Each time the predict function writes to `stdout` - `completed`: Once, when the prediction reaches a terminal state (`status` is `succeeded`, `canceled`, or `failed`) Webhook requests for `start` and `completed` event types are sent immediately. Webhook requests for `output` and `logs` event types are sent at most once every 500ms. This interval is not configurable. By default, the server sends requests for all event types. Clients can specify which events trigger webhook requests with the `webhook_events_filter` parameter in the prediction request body. For example, the following request specifies that webhooks are sent by the server only at the start and end of the prediction: ```http POST /predictions HTTP/1.1 Content-Type: application/json; charset=utf-8 Prefer: respond-async { "input": {"prompt": "A picture of an onion with sunglasses"}, "webhook": "https://example.com/webhook/prediction", "webhook_events_filter": ["start", "completed"] } ``` ## Generating unique prediction IDs Endpoints for creating and canceling a prediction idempotently accept a `prediction_id` parameter in their path. By default, the server runs one prediction at a time, but this can be increased with the [`concurrency.max`](yaml.md#concurrency) setting. When all prediction slots are in use, the server returns `409 Conflict`. The client should ensure prediction slots are available before creating a new prediction with a different ID. Clients are responsible for providing unique prediction IDs. We recommend generating a UUIDv4 or [UUIDv7](https://uuid7.com), base32-encoding that value, and removing padding characters (`==`). This produces a random identifier that is 26 ASCII characters long. ```python >> from uuid import uuid4 >> from base64 import b32encode >> b32encode(uuid4().bytes).decode('utf-8').lower().rstrip('=') 'wjx3whax6rf4vphkegkhcvpv6a' ``` ## File uploads A model's `predict` function can produce file output by yielding or returning a `cog.Path` or `cog.File` value. By default, files are returned as a base64-encoded [data URL](https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/Data_URLs). ```http POST /predictions HTTP/1.1 Content-Type: application/json; charset=utf-8 { "input": {"prompt": "A picture of an onion with sunglasses"}, } ``` ```http HTTP/1.1 200 OK Content-Type: application/json { "status": "succeeded", "output": "data:image/png;base64,..." } ``` When creating a prediction synchronously, the client can configure a base URL to upload output files to instead by setting the `output_file_prefix` parameter in the request body: ```http POST /predictions HTTP/1.1 Content-Type: application/json; charset=utf-8 { "input": {"prompt": "A picture of an onion with sunglasses"}, "output_file_prefix": "https://example.com/upload", } ``` When the model produces a file output, the server sends the following request to upload the file to the configured URL: ```http PUT /upload HTTP/1.1 Host: example.com Content-Type: multipart/form-data --boundary Content-Disposition: form-data; name="file"; filename="image.png" Content-Type: image/png --boundary-- ``` If the upload succeeds, the server responds with output: ```http HTTP/1.1 200 OK Content-Type: application/json { "status": "succeeded", "output": "http://example.com/upload/image.png" } ``` If the upload fails, the server responds with an error. > [!IMPORTANT] > File uploads for predictions created asynchronously > require `--upload-url` to be specified when starting the HTTP server. ## Endpoints ### `GET /` Returns a discovery document listing available API endpoints, the OpenAPI schema URL, and version information. ```http GET / HTTP/1.1 ``` ```http HTTP/1.1 200 OK Content-Type: application/json { "cog_version": "0.17.0", "docs_url": "/openapi.json", "openapi_url": "/openapi.json", "predictions_url": "/predictions", "health_check_url": "/health-check" } ``` If training is configured, the response also includes a `trainings_url` field. ### `GET /health-check` Returns the current health status of the model container. This endpoint always responds with `200 OK` — check the `status` field in the response body to determine readiness. The response body is a JSON object with the following fields: - `status`: One of the following values: - `STARTING`: The model's `setup()` method is still running. - `READY`: The model is ready to accept predictions. - `BUSY`: The model is ready but all prediction slots are in use. - `SETUP_FAILED`: The model's `setup()` method raised an exception. - `DEFUNCT`: The model encountered an unrecoverable error. - `UNHEALTHY`: The model is ready but a user-defined `healthcheck()` method returned `False`. - `setup`: Setup phase details (included once setup has started): - `started_at`: ISO 8601 timestamp of when setup began. - `completed_at`: ISO 8601 timestamp of when setup finished (if complete). - `status`: One of `starting`, `succeeded`, or `failed`. - `logs`: Output captured during setup. - `version`: Runtime version information: - `coglet`: Coglet version. - `cog`: Cog Python SDK version (if available). - `python`: Python version (if available). - `user_healthcheck_error`: Error message from a user-defined `healthcheck()` method (if applicable). ```http GET /health-check HTTP/1.1 ``` ```http HTTP/1.1 200 OK Content-Type: application/json { "status": "READY", "setup": { "started_at": "2025-01-01T00:00:00.000000+00:00", "completed_at": "2025-01-01T00:00:05.000000+00:00", "status": "succeeded", "logs": "" }, "version": { "coglet": "0.17.0", "cog": "0.14.0", "python": "3.13.0" } } ``` ### `GET /openapi.json` The [OpenAPI](https://swagger.io/specification/) specification of the API, which is derived from the input and output types specified in your model's [Predictor](python.md) and [Training](training.md) objects. ### `POST /predictions` Makes a single prediction. The request body is a JSON object with the following fields: - `input`: A JSON object with the same keys as the [arguments to the `predict()` function](python.md). Any `File` or `Path` inputs are passed as URLs. The response body is a JSON object with the following fields: - `status`: Either `succeeded` or `failed`. - `output`: The return value of the `predict()` function. - `error`: If `status` is `failed`, the error message. - `metrics`: An object containing prediction metrics. Always includes `predict_time` (elapsed seconds). May also include custom metrics recorded by the model using [`self.record_metric()`](python.md#metrics). ```http POST /predictions HTTP/1.1 Content-Type: application/json; charset=utf-8 { "input": { "image": "https://example.com/image.jpg", "text": "Hello world!" } } ``` ```http HTTP/1.1 200 OK Content-Type: application/json { "status": "succeeded", "output": "data:image/png;base64,...", "metrics": { "predict_time": 4.52 } } ``` If the client sets the `Prefer: respond-async` header in their request, the server responds immediately after starting the prediction with `202 Accepted` status and a prediction object in status `processing`. ```http POST /predictions HTTP/1.1 Content-Type: application/json; charset=utf-8 Prefer: respond-async { "input": {"prompt": "A picture of an onion with sunglasses"} } ``` ```http HTTP/1.1 202 Accepted Content-Type: application/json { "status": "starting", } ``` ### `PUT /predictions/` Make a single prediction. This is the idempotent version of the `POST /predictions` endpoint. ```http PUT /predictions/wjx3whax6rf4vphkegkhcvpv6a HTTP/1.1 Content-Type: application/json; charset=utf-8 { "input": {"prompt": "A picture of an onion with sunglasses"} } ``` ```http HTTP/1.1 200 OK Content-Type: application/json { "status": "succeeded", "output": "data:image/png;base64,..." } ``` If the client sets the `Prefer: respond-async` header in their request, the server responds immediately after starting the prediction with `202 Accepted` status and a prediction object in status `processing`. ```http PUT /predictions/wjx3whax6rf4vphkegkhcvpv6a HTTP/1.1 Content-Type: application/json; charset=utf-8 Prefer: respond-async { "input": {"prompt": "A picture of an onion with sunglasses"} } ``` ```http HTTP/1.1 202 Accepted Content-Type: application/json { "id": "wjx3whax6rf4vphkegkhcvpv6a", "status": "starting" } ``` ### `POST /predictions//cancel` A client can cancel an asynchronous prediction by making a `POST /predictions//cancel` request using the prediction `id` provided when the prediction was created. For example, if the client creates a prediction by sending the request: ```http POST /predictions HTTP/1.1 Content-Type: application/json; charset=utf-8 Prefer: respond-async { "id": "abcd1234", "input": {"prompt": "A picture of an onion with sunglasses"}, } ``` The client can cancel the prediction by sending the request: ```http POST /predictions/abcd1234/cancel HTTP/1.1 ``` A prediction cannot be canceled if it's created synchronously, without the `Prefer: respond-async` header, or created without a provided `id`. If a prediction exists with the provided `id`, the server responds with status `200 OK`. Otherwise, the server responds with status `404 Not Found`. When a prediction is canceled, Cog raises [`CancelationException`](python.md#cancelationexception) in sync predictors (or `asyncio.CancelledError` in async predictors). This exception may be caught by the model to perform necessary cleanup. The cleanup should be brief, ideally completing within a few seconds. After cleanup, the exception must be re-raised using a bare `raise` statement. Failure to re-raise the exception may result in the termination of the container. ```python from cog import BasePredictor, CancelationException, Input, Path class Predictor(BasePredictor): def predict(self, image: Path = Input(description="Image to process")) -> Path: try: return self.process(image) except CancelationException: self.cleanup() raise # always re-raise ``` ================================================ FILE: docs/llms.txt ================================================ # Cog: Containers for machine learning Cog is an open-source tool that lets you package machine learning models in a standard, production-ready container. You can deploy your packaged model to your own infrastructure, or to [Replicate](https://replicate.com/). ## Highlights - 📦 **Docker containers without the pain.** Writing your own `Dockerfile` can be a bewildering process. With Cog, you define your environment with a [simple configuration file](#how-it-works) and it generates a Docker image with all the best practices: Nvidia base images, efficient caching of dependencies, installing specific Python versions, sensible environment variable defaults, and so on. - 🤬️ **No more CUDA hell.** Cog knows which CUDA/cuDNN/PyTorch/Tensorflow/Python combos are compatible and will set it all up correctly for you. - ✅ **Define the inputs and outputs for your model with standard Python.** Then, Cog generates an OpenAPI schema and validates the inputs and outputs. - 🎁 **Automatic HTTP prediction server**: Your model's types are used to dynamically generate a RESTful HTTP API using a high-performance Rust/Axum server. - 🚀 **Ready for production.** Deploy your model anywhere that Docker images run. Your own infrastructure, or [Replicate](https://replicate.com). ## How it works Define the Docker environment your model runs in with `cog.yaml`: ```yaml build: gpu: true system_packages: - "libgl1-mesa-glx" - "libglib2.0-0" python_version: "3.13" python_requirements: requirements.txt predict: "predict.py:Predictor" ``` Define how predictions are run on your model with `predict.py`: ```python from cog import BasePredictor, Input, Path import torch class Predictor(BasePredictor): def setup(self): """Load the model into memory to make running multiple predictions efficient""" self.model = torch.load("./weights.pth") # The arguments and types the model takes as input def predict(self, image: Path = Input(description="Grayscale input image") ) -> Path: """Run a single prediction on the model""" processed_image = preprocess(image) output = self.model(processed_image) return postprocess(output) ``` In the above we accept a path to the image as an input, and return a path to our transformed image after running it through our model. Now, you can run predictions on this model: ```console $ cog predict -i image=@input.jpg --> Building Docker image... --> Running Prediction... --> Output written to output.jpg ``` Or, build a Docker image for deployment: ```console $ cog build -t my-classification-model --> Building Docker image... --> Built my-classification-model:latest $ docker run -d -p 5000:5000 --gpus all my-classification-model $ curl http://localhost:5000/predictions -X POST \ -H 'Content-Type: application/json' \ -d '{"input": {"image": "https://.../input.jpg"}}' ``` Or, combine build and run via the `serve` command: ```console $ cog serve -p 8080 $ curl http://localhost:8080/predictions -X POST \ -H 'Content-Type: application/json' \ -d '{"input": {"image": "https://.../input.jpg"}}' ``` ## Why are we building this? It's really hard for researchers to ship machine learning models to production. Part of the solution is Docker, but it is so complex to get it to work: Dockerfiles, pre-/post-processing, Flask servers, CUDA versions. More often than not the researcher has to sit down with an engineer to get the damn thing deployed. [Andreas](https://github.com/andreasjansson) and [Ben](https://github.com/bfirsh) created Cog. Andreas used to work at Spotify, where he built tools for building and deploying ML models with Docker. Ben worked at Docker, where he created [Docker Compose](https://github.com/docker/compose). We realized that, in addition to Spotify, other companies were also using Docker to build and deploy machine learning models. [Uber](https://eng.uber.com/michelangelo-pyml/) and others have built similar systems. So, we're making an open source version so other people can do this too. Hit us up if you're interested in using it or want to collaborate with us. [We're on Discord](https://discord.gg/replicate) or email us at [team@replicate.com](mailto:team@replicate.com). ## Prerequisites - **macOS, Linux or Windows 11**. Cog works on macOS, Linux and Windows 11 with [WSL 2](docs/wsl2/wsl2.md) - **Docker**. Cog uses Docker to create a container for your model. You'll need to [install Docker](https://docs.docker.com/get-docker/) before you can run Cog. If you install Docker Engine instead of Docker Desktop, you will need to [install Buildx](https://docs.docker.com/build/architecture/#buildx) as well. ## Install If you're using macOS, you can install Cog using Homebrew: ```console brew install replicate/tap/cog ``` You can also download and install the latest release using our [install script](https://cog.run/install): ```sh # bash, zsh, and other shells sh <(curl -fsSL https://cog.run/install.sh) # fish shell sh (curl -fsSL https://cog.run/install.sh | psub) # download with wget and run in a separate command wget -qO- https://cog.run/install.sh sh ./install.sh ``` You can manually install the latest release of Cog directly from GitHub by running the following commands in a terminal: ```console sudo curl -o /usr/local/bin/cog -L "https://github.com/replicate/cog/releases/latest/download/cog_$(uname -s)_$(uname -m)" sudo chmod +x /usr/local/bin/cog ``` Or if you are on docker: ``` RUN sh -c "INSTALL_DIR=\"/usr/local/bin\" SUDO=\"\" $(curl -fsSL https://cog.run/install.sh)" ``` ## Upgrade If you're using macOS and you previously installed Cog with Homebrew, run the following: ```console brew upgrade replicate/tap/cog ``` Otherwise, you can upgrade to the latest version by running the same commands you used to install it. ## Development See [CONTRIBUTING.md](CONTRIBUTING.md) for how to set up a development environment and build from source. ## Next steps - [Get started with an example model](docs/getting-started.md) - [Get started with your own model](docs/getting-started-own-model.md) - [Using Cog with notebooks](docs/notebooks.md) - [Using Cog with Windows 11](docs/wsl2/wsl2.md) - [Take a look at some examples of using Cog](https://github.com/replicate/cog-examples) - [Deploy models with Cog](docs/deploy.md) - [`cog.yaml` reference](docs/yaml.md) to learn how to define your model's environment - [Prediction interface reference](docs/python.md) to learn how the `Predictor` interface works - [Training interface reference](docs/training.md) to learn how to add a fine-tuning API to your model - [HTTP API reference](docs/http.md) to learn how to use the HTTP API that models serve ## Need help? [Join us in #cog on Discord.](https://discord.gg/replicate) [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/replicate/cog) --- # CLI reference ## `cog` Containers for machine learning. To get started, take a look at the documentation: https://github.com/replicate/cog **Examples** ``` To run a command inside a Docker environment defined with Cog: $ cog run echo hello world ``` **Options** ``` --debug Show debugging output -h, --help help for cog --no-color Disable colored output --version Show version of Cog ``` ## `cog build` Build a Docker image from the cog.yaml in the current directory. The generated image contains your model code, dependencies, and the Cog runtime. It can be run locally with 'cog predict' or pushed to a registry with 'cog push'. ``` cog build [flags] ``` **Examples** ``` # Build with default settings cog build # Build and tag the image cog build -t my-model:latest # Build without using the cache cog build --no-cache # Build with model weights in a separate layer cog build --separate-weights -t my-model:v1 ``` **Options** ``` -f, --file string The name of the config file. (default "cog.yaml") -h, --help help for build --no-cache Do not use cache when building the image --openapi-schema string Load OpenAPI schema from a file --progress string Set type of build progress output, 'auto' (default), 'tty', 'plain', or 'quiet' (default "auto") --secret stringArray Secrets to pass to the build environment in the form 'id=foo,src=/path/to/file' --separate-weights Separate model weights from code in image layers -t, --tag string A name for the built image in the form 'repository:tag' --use-cog-base-image Use pre-built Cog base image for faster cold boots (default true) --use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto") ``` ## `cog init` Create a cog.yaml and predict.py in the current directory. These files provide a starting template for defining your model's environment and prediction interface. Edit them to match your model's requirements. ``` cog init [flags] ``` **Examples** ``` # Set up a new Cog project in the current directory cog init ``` **Options** ``` -h, --help help for init ``` ## `cog login` Log in to a container registry. For Replicate's registry (r8.im), this command handles authentication through Replicate's token-based flow. For other registries, this command prompts for username and password, then stores credentials using Docker's credential system. ``` cog login [flags] ``` **Options** ``` -h, --help help for login --token-stdin Pass login token on stdin instead of opening a browser. You can find your Replicate login token at https://replicate.com/auth/token ``` ## `cog predict` Run a prediction. If 'image' is passed, it will run the prediction on that Docker image. It must be an image that has been built by Cog. Otherwise, it will build the model in the current directory and run the prediction on that. ``` cog predict [image] [flags] ``` **Examples** ``` # Run a prediction with named inputs cog predict -i prompt="a photo of a cat" # Pass a file as input cog predict -i image=@photo.jpg # Save output to a file cog predict -i image=@input.jpg -o output.png # Pass multiple inputs cog predict -i prompt="sunset" -i width=1024 -i height=768 # Run against a pre-built image cog predict r8.im/your-username/my-model -i prompt="hello" # Pass inputs as JSON echo '{"prompt": "a cat"}' | cog predict --json @- ``` **Options** ``` -e, --env stringArray Environment variables, in the form name=value -f, --file string The name of the config file. (default "cog.yaml") --gpus docker run --gpus GPU devices to add to the container, in the same format as docker run --gpus. -h, --help help for predict -i, --input stringArray Inputs, in the form name=value. if value is prefixed with @, then it is read from a file on disk. E.g. -i path=@image.jpg --json string Pass inputs as JSON object, read from file (@inputs.json) or via stdin (@-) -o, --output string Output path --progress string Set type of build progress output, 'auto' (default), 'tty', 'plain', or 'quiet' (default "auto") --setup-timeout uint32 The timeout for a container to setup (in seconds). (default 300) --use-cog-base-image Use pre-built Cog base image for faster cold boots (default true) --use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto") --use-replicate-token Pass REPLICATE_API_TOKEN from local environment into the model context ``` ## `cog push` Build a Docker image from cog.yaml and push it to a container registry. Cog can push to any OCI-compliant registry. When pushing to Replicate's registry (r8.im), run 'cog login' first to authenticate. ``` cog push [IMAGE] [flags] ``` **Examples** ``` # Push to Replicate cog push r8.im/your-username/my-model # Push to any OCI registry cog push registry.example.com/your-username/model-name # Push with model weights in a separate layer (Replicate only) cog push r8.im/your-username/my-model --separate-weights ``` **Options** ``` -f, --file string The name of the config file. (default "cog.yaml") -h, --help help for push --no-cache Do not use cache when building the image --openapi-schema string Load OpenAPI schema from a file --progress string Set type of build progress output, 'auto' (default), 'tty', 'plain', or 'quiet' (default "auto") --secret stringArray Secrets to pass to the build environment in the form 'id=foo,src=/path/to/file' --separate-weights Separate model weights from code in image layers --use-cog-base-image Use pre-built Cog base image for faster cold boots (default true) --use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto") ``` ## `cog run` Run a command inside a Docker environment defined by cog.yaml. Cog builds a temporary image from your cog.yaml configuration and runs the given command inside it. This is useful for debugging, running scripts, or exploring the environment your model will run in. ``` cog run [arg...] [flags] ``` **Examples** ``` # Open a Python interpreter inside the model environment cog run python # Run a script cog run python train.py # Run with environment variables cog run -e HUGGING_FACE_HUB_TOKEN=abc123 python download.py # Expose a port (e.g. for Jupyter) cog run -p 8888 jupyter notebook ``` **Options** ``` -e, --env stringArray Environment variables, in the form name=value -f, --file string The name of the config file. (default "cog.yaml") --gpus docker run --gpus GPU devices to add to the container, in the same format as docker run --gpus. -h, --help help for run --progress string Set type of build progress output, 'auto' (default), 'tty', 'plain', or 'quiet' (default "auto") -p, --publish stringArray Publish a container's port to the host, e.g. -p 8000 --use-cog-base-image Use pre-built Cog base image for faster cold boots (default true) --use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto") ``` ## `cog serve` Run a prediction HTTP server. Builds the model and starts an HTTP server that exposes the model's inputs and outputs as a REST API. Compatible with the Cog HTTP protocol. ``` cog serve [flags] ``` **Examples** ``` # Start the server on the default port (8393) cog serve # Start on a custom port cog serve -p 5000 # Test the server curl http://localhost:8393/predictions \ -X POST \ -H 'Content-Type: application/json' \ -d '{"input": {"prompt": "a cat"}}' ``` **Options** ``` -f, --file string The name of the config file. (default "cog.yaml") --gpus docker run --gpus GPU devices to add to the container, in the same format as docker run --gpus. -h, --help help for serve -p, --port int Port on which to listen (default 8393) --progress string Set type of build progress output, 'auto' (default), 'tty', 'plain', or 'quiet' (default "auto") --upload-url string Upload URL for file outputs (e.g. https://example.com/upload/) --use-cog-base-image Use pre-built Cog base image for faster cold boots (default true) --use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto") ``` --- # Deploy models with Cog Cog containers are Docker containers that serve an HTTP server for running predictions on your model. You can deploy them anywhere that Docker containers run. The server inside Cog containers is **coglet**, a Rust-based prediction server that handles HTTP requests, worker process management, and prediction execution. This guide assumes you have a model packaged with Cog. If you don't, [follow our getting started guide](getting-started-own-model.md), or use [an example model](https://github.com/replicate/cog-examples). ## Getting started First, build your model: ```console cog build -t my-model ``` You can serve predictions locally with `cog serve`: ```console cog serve # or, from a built image: cog serve my-model ``` Alternatively, start the Docker container directly: ```shell # If your model uses a CPU: docker run -d -p 5001:5000 my-model # If your model uses a GPU: docker run -d -p 5001:5000 --gpus all my-model ``` The server listens on port 5000 inside the container (mapped to 5001 above). To view the OpenAPI schema, open [localhost:5001/openapi.json](http://localhost:5001/openapi.json) in your browser or use cURL to make a request: ```console curl http://localhost:5001/openapi.json ``` To stop the server, run: ```console docker kill my-model ``` To run a prediction on the model, call the `/predictions` endpoint, passing input in the format expected by your model: ```console curl http://localhost:5001/predictions -X POST \ --header "Content-Type: application/json" \ --data '{"input": {"image": "https://.../input.jpg"}}' ``` For more details about the HTTP API, see the [HTTP API reference documentation](http.md). ## Health checks The server exposes a `GET /health-check` endpoint that returns the current status of the model container. Use this for readiness probes in orchestration systems like Kubernetes. ```console curl http://localhost:5001/health-check ``` The response includes a `status` field with values like `STARTING`, `READY`, `BUSY`, `SETUP_FAILED`, or `DEFUNCT`. See the [HTTP API reference](http.md#get-health-check) for full details. ## Concurrency By default, the server processes one prediction at a time. To enable concurrent predictions, set the `concurrency.max` option in `cog.yaml`: ```yaml concurrency: max: 4 ``` See the [`cog.yaml` reference](yaml.md#concurrency) for more details. ## Environment variables You can configure runtime behavior with environment variables: - `COG_SETUP_TIMEOUT`: Maximum time in seconds for the `setup()` method (default: no timeout). See the [environment variables reference](environment.md) for the full list. --- # Environment variables This guide lists the environment variables that change how Cog functions. ## Build-time variables ### `COG_SDK_WHEEL` Controls which cog Python SDK wheel is installed in the Docker image during `cog build`. Takes precedence over `build.sdk_version` in `cog.yaml`. **Supported values:** | Value | Description | | -------------------- | ---------------------------------------------------- | | `pypi` | Install latest version from PyPI | | `pypi:0.12.0` | Install specific version from PyPI | | `dist` | Use wheel from `dist/` directory (requires git repo) | | `https://...` | Install from URL | | `/path/to/wheel.whl` | Install from local file path | **Default behavior:** - **Release builds**: Installs latest cog from PyPI - **Development builds**: Auto-detects wheel in `dist/` directory, falls back to latest PyPI **Examples:** ```console # Use specific PyPI version $ COG_SDK_WHEEL=pypi:0.11.0 cog build # Use local development wheel $ COG_SDK_WHEEL=dist cog build # Use wheel from URL $ COG_SDK_WHEEL=https://example.com/cog-0.12.0-py3-none-any.whl cog build ``` The `dist` option searches for wheels in: 1. `./dist/` (current directory) 2. `$REPO_ROOT/dist/` (if REPO_ROOT is set) 3. `/dist/` (via `git rev-parse`, useful when running from subdirectories) ### `COGLET_WHEEL` Controls which coglet wheel is installed in the Docker image. Coglet is the Rust-based prediction server. **Supported values:** Same as `COG_SDK_WHEEL` **Default behavior:** For development builds, auto-detects a wheel in `dist/`. For release builds, installs the latest version from PyPI. Can be overridden with an explicit value. **Examples:** ```console # Use local development wheel $ COGLET_WHEEL=dist cog build # Use specific version from PyPI $ COGLET_WHEEL=pypi:0.1.0 cog build ``` ## Runtime variables ### `COG_NO_UPDATE_CHECK` By default, Cog automatically checks for updates and notifies you if there is a new version available. To disable this behavior, set the `COG_NO_UPDATE_CHECK` environment variable to any value. ```console $ COG_NO_UPDATE_CHECK=1 cog build # runs without automatic update check ``` ### `COG_SETUP_TIMEOUT` Controls the maximum time (in seconds) allowed for the model's `setup()` method to complete. If setup exceeds this timeout, the server will report a setup failure. By default, there is no timeout — setup runs indefinitely. Set to `0` to disable the timeout (same as default). Invalid values are ignored with a warning. ```console $ COG_SETUP_TIMEOUT=300 docker run -p 5000:5000 my-model # 5-minute setup timeout ``` ### `COG_CA_CERT` Injects a custom CA certificate into the Docker image during `cog build`. This is useful when building behind a corporate proxy or VPN that uses custom certificate authorities (e.g. Cloudflare WARP). **Supported values:** | Value | Description | | -------------------------------- | ----------------------------------------------------------- | | `/path/to/cert.crt` | Path to a single PEM certificate file | | `/path/to/certs/` | Directory of `.crt` and `.pem` files (all are concatenated) | | `-----BEGIN CERTIFICATE-----...` | Inline PEM certificate | | `LS0tLS1CRUdJTi...` | Base64-encoded PEM certificate | The certificate is installed into the system CA store and the `SSL_CERT_FILE` and `REQUESTS_CA_BUNDLE` environment variables are set automatically in the built image. **Examples:** ```console # From a file $ COG_CA_CERT=/usr/local/share/ca-certificates/corporate-ca.crt cog build # From a directory of certs $ COG_CA_CERT=/etc/custom-certs/ cog build # Inline (e.g. from a CI secret) $ COG_CA_CERT="$(cat /path/to/cert.pem)" cog build ``` --- # Getting started with your own model This guide will show you how to put your own machine learning model in a Docker image using Cog. If you haven't got a model to try out, you'll want to follow the [main getting started guide](getting-started.md). ## Prerequisites - **macOS or Linux**. Cog works on macOS and Linux, but does not currently support Windows. - **Docker**. Cog uses Docker to create a container for your model. You'll need to [install Docker](https://docs.docker.com/get-docker/) before you can run Cog. ## Initialization First, install Cog if you haven't already: **macOS (recommended):** ```sh brew install replicate/tap/cog ``` **Linux or macOS (manual):** ```sh sudo curl -o /usr/local/bin/cog -L https://github.com/replicate/cog/releases/latest/download/cog_`uname -s`_`uname -m` sudo chmod +x /usr/local/bin/cog ``` To configure your project for use with Cog, you'll need to add two files: - [`cog.yaml`](yaml.md) defines system requirements, Python package dependencies, etc - [`predict.py`](python.md) describes the prediction interface for your model Use the `cog init` command to generate these files in your project: ```sh $ cd path/to/your/model $ cog init ``` ## Define the Docker environment The `cog.yaml` file defines all the different things that need to be installed for your model to run. You can think of it as a simple way of defining a Docker image. For example: ```yaml build: python_version: "3.13" python_requirements: requirements.txt ``` With a `requirements.txt` containing your dependencies: ``` torch==2.6.0 ``` This will generate a Docker image with Python 3.13 and PyTorch 2 installed, for both CPU and GPU, with the correct version of CUDA, and various other sensible best-practices. To run a command inside this environment, prefix it with `cog run`: ``` $ cog run python ✓ Building Docker image from cog.yaml... Successfully built 8f54020c8981 Running 'python' in Docker with the current directory mounted as a volume... ──────────────────────────────────────────────────────────────────────────────────────── Python 3.13.x (main, ...) [GCC 12.2.0] on linux Type "help", "copyright", "credits" or "license" for more information. >>> ``` This is handy for ensuring a consistent environment for development or training. With `cog.yaml`, you can also install system packages and other things. [Take a look at the full reference to see what else you can do.](yaml.md) ## Define how to run predictions The next step is to update `predict.py` to define the interface for running predictions on your model. The `predict.py` generated by `cog init` looks something like this: ```python from cog import BasePredictor, Path, Input import torch class Predictor(BasePredictor): def setup(self): """Load the model into memory to make running multiple predictions efficient""" self.net = torch.load("weights.pth") def predict(self, image: Path = Input(description="Image to enlarge"), scale: float = Input(description="Factor to scale image by", default=1.5) ) -> Path: """Run a single prediction on the model""" # ... pre-processing ... output = self.net(input) # ... post-processing ... return output ``` Edit your `predict.py` file and fill in the functions with your own model's setup and prediction code. You might need to import parts of your model from another file. You also need to define the inputs to your model as arguments to the `predict()` function, as demonstrated above. For each argument, you need to annotate with a type. The supported types are: - `str`: a string - `int`: an integer - `float`: a floating point number - `bool`: a boolean - `cog.File`: a file-like object representing a file (deprecated — use `cog.Path` instead) - `cog.Path`: a path to a file on disk You can provide more information about the input with the `Input()` function, as shown above. It takes these basic arguments: - `description`: A description of what to pass to this input for users of the model - `default`: A default value to set the input to. If this argument is not passed, the input is required. If it is explicitly set to `None`, the input is optional. - `ge`: For `int` or `float` types, the value should be greater than or equal to this number. - `le`: For `int` or `float` types, the value should be less than or equal to this number. - `min_length`: For `str` types, the minimum length of the string. - `max_length`: For `str` types, the maximum length of the string. - `regex`: For `str` types, the string must match this regular expression. - `choices`: For `str` or `int` types, a list of possible values for this input. - `deprecated`: Mark this input as deprecated with a message explaining what to use instead. There are some more advanced options you can pass, too. For more details, [take a look at the prediction interface documentation](python.md). Next, add the line `predict: "predict.py:Predictor"` to your `cog.yaml`, so it looks something like this: ```yaml build: python_version: "3.13" python_requirements: requirements.txt predict: "predict.py:Predictor" ``` That's it! To test this works, try running a prediction on the model: ``` $ cog predict -i image=@input.jpg ✓ Building Docker image from cog.yaml... Successfully built 664ef88bc1f4 ✓ Model running in Docker image 664ef88bc1f4 Written output to output.png ``` To pass more inputs to the model, you can add more `-i` options: ``` $ cog predict -i image=@image.jpg -i scale=2.0 ``` In this case it is just a number, not a file, so you don't need the `@` prefix. ## Using GPUs To use GPUs with Cog, add the `gpu: true` option to the `build` section of your `cog.yaml`: ```yaml build: gpu: true ... ``` Cog will use the [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) base image and automatically figure out what versions of CUDA and cuDNN to use based on the version of Python, PyTorch, and Tensorflow that you are using. For more details, [see the `gpu` section of the `cog.yaml` reference](yaml.md#gpu). ## Next steps Next, you might want to take a look at: - [A guide explaining how to deploy a model.](deploy.md) - [The reference for `cog.yaml`](yaml.md) - [The reference for the Python library](python.md) --- # Getting started This guide will walk you through what you can do with Cog by using an example model. > [!TIP] > Using a language model to help you write the code for your new Cog model? > > Feed it [https://cog.run/llms.txt](https://cog.run/llms.txt), which has all of Cog's documentation bundled into a single file. To learn more about this format, check out [llmstxt.org](https://llmstxt.org). ## Prerequisites - **macOS or Linux**. Cog works on macOS and Linux, but does not currently support Windows. - **Docker**. Cog uses Docker to create a container for your model. You'll need to [install Docker](https://docs.docker.com/get-docker/) before you can run Cog. ## Install Cog **macOS (recommended):** ```bash brew install replicate/tap/cog ``` **Linux or macOS (manual):** ```bash sudo curl -o /usr/local/bin/cog -L https://github.com/replicate/cog/releases/latest/download/cog_`uname -s`_`uname -m` sudo chmod +x /usr/local/bin/cog sudo xattr -d com.apple.quarantine /usr/local/bin/cog 2>/dev/null || true ``` > [!NOTE] > **macOS: "cannot be opened because the developer cannot be verified"** > > If you downloaded the binary manually (via `curl` or a browser) and see this Gatekeeper warning, run: > > ```bash > sudo xattr -d com.apple.quarantine /usr/local/bin/cog > ``` > > Installing via `brew install replicate/tap/cog` handles this automatically. ## Create a project Let's make a directory to work in: ```bash mkdir cog-quickstart cd cog-quickstart ``` ## Run commands The simplest thing you can do with Cog is run a command inside a Docker environment. The first thing you need to do is create a file called `cog.yaml`: ```yaml build: python_version: "3.13" ``` Then, you can run any command inside this environment. For example, enter ```bash cog run python ``` and you'll get an interactive Python shell: ```none ✓ Building Docker image from cog.yaml... Successfully built 8f54020c8981 Running 'python' in Docker with the current directory mounted as a volume... ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── Python 3.13.x (main, ...) [GCC 12.2.0] on linux Type "help", "copyright", "credits" or "license" for more information. >>> ``` (Hit Ctrl-D to exit the Python shell.) Inside this Docker environment you can do anything – run a Jupyter notebook, your training script, your evaluation script, and so on. ## Run predictions on a model Let's pretend we've trained a model. With Cog, we can define how to run predictions on it in a standard way, so other people can easily run predictions on it without having to hunt around for a prediction script. We need to write some code to describe how predictions are run on the model. Save this to `predict.py`: ```python import os os.environ["TORCH_HOME"] = "." import torch from cog import BasePredictor, Input, Path from PIL import Image from torchvision import models WEIGHTS = models.ResNet50_Weights.IMAGENET1K_V1 class Predictor(BasePredictor): def setup(self): """Load the model into memory to make running multiple predictions efficient""" self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = models.resnet50(weights=WEIGHTS).to(self.device) self.model.eval() def predict(self, image: Path = Input(description="Image to classify")) -> dict: """Run a single prediction on the model""" img = Image.open(image).convert("RGB") preds = self.model(WEIGHTS.transforms()(img).unsqueeze(0).to(self.device)) top3 = preds[0].softmax(0).topk(3) categories = WEIGHTS.meta["categories"] return {categories[i]: p.detach().item() for p, i in zip(*top3)} ``` We also need to point Cog at this, and tell it what Python dependencies to install. Save this to `requirements.txt`: ``` pillow==11.1.0 torch==2.6.0 torchvision==0.21.0 ``` Then update `cog.yaml` to look like this: ```yaml build: python_version: "3.13" python_requirements: requirements.txt predict: "predict.py:Predictor" ``` > [!TIP] > If you have a machine with an NVIDIA GPU attached, add `gpu: true` to the `build` section of your `cog.yaml` to enable GPU acceleration. Let's grab an image to test the model with: ```bash IMAGE_URL=https://gist.githubusercontent.com/bfirsh/3c2115692682ae260932a67d93fd94a8/raw/56b19f53f7643bb6c0b822c410c366c3a6244de2/mystery.jpg curl $IMAGE_URL > input.jpg ``` Now, let's run the model using Cog: ```bash cog predict -i image=@input.jpg ``` If you see the following output ```json { "tiger_cat": 0.4874822497367859, "tabby": 0.23169134557247162, "Egyptian_cat": 0.09728282690048218 } ``` then it worked! Note: The first time you run `cog predict`, the build process will be triggered to generate a Docker container that can run your model. The next time you run `cog predict` the pre-built container will be used. ## Build an image We can bake your model's code, the trained weights, and the Docker environment into a Docker image. This image serves predictions with an HTTP server, and can be deployed to anywhere that Docker runs to serve real-time predictions. ```bash cog build -t resnet # Building Docker image... # Built resnet:latest ``` You can run this image with `cog predict` by passing the filename as an argument: ```bash cog predict resnet -i image=@input.jpg ``` Or, you can run it with Docker directly, and it'll serve an HTTP server: ```bash docker run -d --rm -p 5000:5000 resnet ``` We can send inputs directly with `curl`: ```bash curl http://localhost:5000/predictions -X POST \ -H 'Content-Type: application/json' \ -d '{"input": {"image": "https://gist.githubusercontent.com/bfirsh/3c2115692682ae260932a67d93fd94a8/raw/56b19f53f7643bb6c0b822c410c366c3a6244de2/mystery.jpg"}}' ``` As a shorthand, you can add the Docker image's name as an extra line in `cog.yaml`: ```yaml image: "r8.im/replicate/resnet" ``` Once you've done this, you can use `cog push` to build and push the image to a Docker registry: ```bash cog push # Building r8.im/replicate/resnet... # Pushing r8.im/replicate/resnet... # Pushed! ``` The Docker image is now accessible to anyone or any system that has access to this Docker registry. ## Next steps Those are the basics! Next, you might want to take a look at: - [A guide to help you set up your own model on Cog.](getting-started-own-model.md) - [A guide explaining how to deploy a model.](deploy.md) - [Reference for `cog.yaml`](yaml.md) - [Reference for the Python library](python.md) --- # HTTP API > [!TIP] > For information about how to run the HTTP server, > see [our documentation on deploying models](deploy.md). When you run a Docker image built by Cog, it serves an HTTP API for making predictions. The server supports both synchronous and asynchronous prediction creation: - **Synchronous**: The server waits until the prediction is completed and responds with the result. - **Asynchronous**: The server immediately returns a response and processes the prediction in the background. The client can create a prediction asynchronously by setting the `Prefer: respond-async` header in their request. When provided, the server responds immediately after starting the prediction with `202 Accepted` status and a prediction object in status `processing`. > [!NOTE] > The only supported way to receive updates on the status of predictions > started asynchronously is using [webhooks](#webhooks). > Polling for prediction status is not currently supported. You can also use certain server endpoints to create predictions idempotently, such that if a client calls this endpoint more than once with the same ID (for example, due to a network interruption) while the prediction is still running, no new prediction is created. Instead, the client receives a `202 Accepted` response with the initial state of the prediction. --- Here's a summary of the prediction creation endpoints: | Endpoint | Header | Behavior | | ---------------------------------- | ----------------------- | ---------------------------- | | `POST /predictions` | - | Synchronous, non-idempotent | | `POST /predictions` | `Prefer: respond-async` | Asynchronous, non-idempotent | | `PUT /predictions/` | - | Synchronous, idempotent | | `PUT /predictions/` | `Prefer: respond-async` | Asynchronous, idempotent | Choose the endpoint that best fits your needs: - Use synchronous endpoints when you want to wait for the prediction result. - Use asynchronous endpoints when you want to start a prediction and receive updates via webhooks. - Use idempotent endpoints when you need to safely retry requests without creating duplicate predictions. ## Webhooks You can provide a `webhook` parameter in the client request body when creating a prediction. ```http POST /predictions HTTP/1.1 Content-Type: application/json; charset=utf-8 Prefer: respond-async { "input": {"prompt": "A picture of an onion with sunglasses"}, "webhook": "https://example.com/webhook/prediction" } ``` The server makes requests to the provided URL with the current state of the prediction object in the request body at the following times. - `start`: Once, when the prediction starts (`status` is `starting`). - `output`: Each time a predict function generates an output (either once using `return` or multiple times using `yield`) - `logs`: Each time the predict function writes to `stdout` - `completed`: Once, when the prediction reaches a terminal state (`status` is `succeeded`, `canceled`, or `failed`) Webhook requests for `start` and `completed` event types are sent immediately. Webhook requests for `output` and `logs` event types are sent at most once every 500ms. This interval is not configurable. By default, the server sends requests for all event types. Clients can specify which events trigger webhook requests with the `webhook_events_filter` parameter in the prediction request body. For example, the following request specifies that webhooks are sent by the server only at the start and end of the prediction: ```http POST /predictions HTTP/1.1 Content-Type: application/json; charset=utf-8 Prefer: respond-async { "input": {"prompt": "A picture of an onion with sunglasses"}, "webhook": "https://example.com/webhook/prediction", "webhook_events_filter": ["start", "completed"] } ``` ## Generating unique prediction IDs Endpoints for creating and canceling a prediction idempotently accept a `prediction_id` parameter in their path. By default, the server runs one prediction at a time, but this can be increased with the [`concurrency.max`](yaml.md#concurrency) setting. When all prediction slots are in use, the server returns `409 Conflict`. The client should ensure prediction slots are available before creating a new prediction with a different ID. Clients are responsible for providing unique prediction IDs. We recommend generating a UUIDv4 or [UUIDv7](https://uuid7.com), base32-encoding that value, and removing padding characters (`==`). This produces a random identifier that is 26 ASCII characters long. ```python >> from uuid import uuid4 >> from base64 import b32encode >> b32encode(uuid4().bytes).decode('utf-8').lower().rstrip('=') 'wjx3whax6rf4vphkegkhcvpv6a' ``` ## File uploads A model's `predict` function can produce file output by yielding or returning a `cog.Path` or `cog.File` value. By default, files are returned as a base64-encoded [data URL](https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/Data_URLs). ```http POST /predictions HTTP/1.1 Content-Type: application/json; charset=utf-8 { "input": {"prompt": "A picture of an onion with sunglasses"}, } ``` ```http HTTP/1.1 200 OK Content-Type: application/json { "status": "succeeded", "output": "data:image/png;base64,..." } ``` When creating a prediction synchronously, the client can configure a base URL to upload output files to instead by setting the `output_file_prefix` parameter in the request body: ```http POST /predictions HTTP/1.1 Content-Type: application/json; charset=utf-8 { "input": {"prompt": "A picture of an onion with sunglasses"}, "output_file_prefix": "https://example.com/upload", } ``` When the model produces a file output, the server sends the following request to upload the file to the configured URL: ```http PUT /upload HTTP/1.1 Host: example.com Content-Type: multipart/form-data --boundary Content-Disposition: form-data; name="file"; filename="image.png" Content-Type: image/png --boundary-- ``` If the upload succeeds, the server responds with output: ```http HTTP/1.1 200 OK Content-Type: application/json { "status": "succeeded", "output": "http://example.com/upload/image.png" } ``` If the upload fails, the server responds with an error. > [!IMPORTANT] > File uploads for predictions created asynchronously > require `--upload-url` to be specified when starting the HTTP server. ## Endpoints ### `GET /` Returns a discovery document listing available API endpoints, the OpenAPI schema URL, and version information. ```http GET / HTTP/1.1 ``` ```http HTTP/1.1 200 OK Content-Type: application/json { "cog_version": "0.17.0", "docs_url": "/openapi.json", "openapi_url": "/openapi.json", "predictions_url": "/predictions", "health_check_url": "/health-check" } ``` If training is configured, the response also includes a `trainings_url` field. ### `GET /health-check` Returns the current health status of the model container. This endpoint always responds with `200 OK` — check the `status` field in the response body to determine readiness. The response body is a JSON object with the following fields: - `status`: One of the following values: - `STARTING`: The model's `setup()` method is still running. - `READY`: The model is ready to accept predictions. - `BUSY`: The model is ready but all prediction slots are in use. - `SETUP_FAILED`: The model's `setup()` method raised an exception. - `DEFUNCT`: The model encountered an unrecoverable error. - `UNHEALTHY`: The model is ready but a user-defined `healthcheck()` method returned `False`. - `setup`: Setup phase details (included once setup has started): - `started_at`: ISO 8601 timestamp of when setup began. - `completed_at`: ISO 8601 timestamp of when setup finished (if complete). - `status`: One of `starting`, `succeeded`, or `failed`. - `logs`: Output captured during setup. - `version`: Runtime version information: - `coglet`: Coglet version. - `cog`: Cog Python SDK version (if available). - `python`: Python version (if available). - `user_healthcheck_error`: Error message from a user-defined `healthcheck()` method (if applicable). ```http GET /health-check HTTP/1.1 ``` ```http HTTP/1.1 200 OK Content-Type: application/json { "status": "READY", "setup": { "started_at": "2025-01-01T00:00:00.000000+00:00", "completed_at": "2025-01-01T00:00:05.000000+00:00", "status": "succeeded", "logs": "" }, "version": { "coglet": "0.17.0", "cog": "0.14.0", "python": "3.13.0" } } ``` ### `GET /openapi.json` The [OpenAPI](https://swagger.io/specification/) specification of the API, which is derived from the input and output types specified in your model's [Predictor](python.md) and [Training](training.md) objects. ### `POST /predictions` Makes a single prediction. The request body is a JSON object with the following fields: - `input`: A JSON object with the same keys as the [arguments to the `predict()` function](python.md). Any `File` or `Path` inputs are passed as URLs. The response body is a JSON object with the following fields: - `status`: Either `succeeded` or `failed`. - `output`: The return value of the `predict()` function. - `error`: If `status` is `failed`, the error message. - `metrics`: An object containing prediction metrics. Always includes `predict_time` (elapsed seconds). May also include custom metrics recorded by the model using [`self.record_metric()`](python.md#metrics). ```http POST /predictions HTTP/1.1 Content-Type: application/json; charset=utf-8 { "input": { "image": "https://example.com/image.jpg", "text": "Hello world!" } } ``` ```http HTTP/1.1 200 OK Content-Type: application/json { "status": "succeeded", "output": "data:image/png;base64,...", "metrics": { "predict_time": 4.52 } } ``` If the client sets the `Prefer: respond-async` header in their request, the server responds immediately after starting the prediction with `202 Accepted` status and a prediction object in status `processing`. ```http POST /predictions HTTP/1.1 Content-Type: application/json; charset=utf-8 Prefer: respond-async { "input": {"prompt": "A picture of an onion with sunglasses"} } ``` ```http HTTP/1.1 202 Accepted Content-Type: application/json { "status": "starting", } ``` ### `PUT /predictions/` Make a single prediction. This is the idempotent version of the `POST /predictions` endpoint. ```http PUT /predictions/wjx3whax6rf4vphkegkhcvpv6a HTTP/1.1 Content-Type: application/json; charset=utf-8 { "input": {"prompt": "A picture of an onion with sunglasses"} } ``` ```http HTTP/1.1 200 OK Content-Type: application/json { "status": "succeeded", "output": "data:image/png;base64,..." } ``` If the client sets the `Prefer: respond-async` header in their request, the server responds immediately after starting the prediction with `202 Accepted` status and a prediction object in status `processing`. ```http PUT /predictions/wjx3whax6rf4vphkegkhcvpv6a HTTP/1.1 Content-Type: application/json; charset=utf-8 Prefer: respond-async { "input": {"prompt": "A picture of an onion with sunglasses"} } ``` ```http HTTP/1.1 202 Accepted Content-Type: application/json { "id": "wjx3whax6rf4vphkegkhcvpv6a", "status": "starting" } ``` ### `POST /predictions//cancel` A client can cancel an asynchronous prediction by making a `POST /predictions//cancel` request using the prediction `id` provided when the prediction was created. For example, if the client creates a prediction by sending the request: ```http POST /predictions HTTP/1.1 Content-Type: application/json; charset=utf-8 Prefer: respond-async { "id": "abcd1234", "input": {"prompt": "A picture of an onion with sunglasses"}, } ``` The client can cancel the prediction by sending the request: ```http POST /predictions/abcd1234/cancel HTTP/1.1 ``` A prediction cannot be canceled if it's created synchronously, without the `Prefer: respond-async` header, or created without a provided `id`. If a prediction exists with the provided `id`, the server responds with status `200 OK`. Otherwise, the server responds with status `404 Not Found`. When a prediction is canceled, Cog raises [`CancelationException`](python.md#cancelationexception) in sync predictors (or `asyncio.CancelledError` in async predictors). This exception may be caught by the model to perform necessary cleanup. The cleanup should be brief, ideally completing within a few seconds. After cleanup, the exception must be re-raised using a bare `raise` statement. Failure to re-raise the exception may result in the termination of the container. ```python from cog import BasePredictor, CancelationException, Input, Path class Predictor(BasePredictor): def predict(self, image: Path = Input(description="Image to process")) -> Path: try: return self.process(image) except CancelationException: self.cleanup() raise # always re-raise ``` --- # Notebooks Cog plays nicely with Jupyter notebooks. ## Install the jupyterlab Python package First, add `jupyterlab` to your `requirements.txt` file and reference it in [`cog.yaml`](yaml.md): `requirements.txt`: ``` jupyterlab ``` `cog.yaml`: ```yaml build: python_requirements: requirements.txt ``` ## Run a notebook Cog can run notebooks in the environment you've defined in `cog.yaml` with the following command: ```sh cog run -p 8888 jupyter lab --allow-root --ip=0.0.0.0 ``` ## Use notebook code in your predictor You can also import a notebook into your Cog [Predictor](python.md) file. First, export your notebook to a Python file: ```sh jupyter nbconvert --to script my_notebook.ipynb # creates my_notebook.py ``` Then import the exported Python script into your `predict.py` file. Any functions or variables defined in your notebook will be available to your predictor: ```python from cog import BasePredictor, Input import my_notebook class Predictor(BasePredictor): def predict(self, prompt: str = Input(description="string prompt")) -> str: output = my_notebook.do_stuff(prompt) return output ``` --- # Private package registry This guide describes how to build a Docker image with Cog that fetches Python packages from a private registry during setup. ## `pip.conf` In a directory outside your Cog project, create a `pip.conf` file with an `index-url` set to the registry's URL with embedded credentials. ```conf [global] index-url = https://username:password@my-private-registry.com ``` > **Warning** > Be careful not to commit secrets in Git or include them in Docker images. If your Cog project contains any sensitive files, make sure they're listed in `.gitignore` and `.dockerignore`. ## `cog.yaml` In your project's [`cog.yaml`](yaml.md) file, add a setup command to run `pip install` with a secret configuration file mounted to `/etc/pip.conf`. ```yaml build: run: - command: pip install mounts: - type: secret id: pip target: /etc/pip.conf ``` ## Build When building or pushing your model with Cog, pass the `--secret` option with an `id` matching the one specified in `cog.yaml`, along with a path to your local `pip.conf` file. ```console $ cog build --secret id=pip,source=/path/to/pip.conf ``` Using a secret mount allows the private registry credentials to be securely passed to the `pip install` setup command, without baking them into the Docker image. > **Warning** > If you run `cog build` or `cog push` and then change the contents of a secret source file, the cached version of the file will be used on subsequent builds, ignoring any changes you made. To update the contents of the target secret file, either change the `id` value in `cog.yaml` and the `--secret` option, or pass the `--no-cache` option to bypass the cache entirely. --- # Prediction interface reference This document defines the API of the `cog` Python module, which is used to define the interface for running predictions on your model. > [!TIP] > Run [`cog init`](getting-started-own-model.md#initialization) to generate an annotated `predict.py` file that can be used as a starting point for setting up your model. > [!TIP] > Using a language model to help you write the code for your new Cog model? > > Feed it [https://cog.run/llms.txt](https://cog.run/llms.txt), which has all of Cog's documentation bundled into a single file. To learn more about this format, check out [llmstxt.org](https://llmstxt.org). ## Contents - [Contents](#contents) - [`BasePredictor`](#basepredictor) - [`Predictor.setup()`](#predictorsetup) - [`Predictor.predict(**kwargs)`](#predictorpredictkwargs) - [`async` predictors and concurrency](#async-predictors-and-concurrency) - [`Input(**kwargs)`](#inputkwargs) - [Deprecating inputs](#deprecating-inputs) - [Output](#output) - [Returning an object](#returning-an-object) - [Returning a list](#returning-a-list) - [Optional properties](#optional-properties) - [Streaming output](#streaming-output) - [Metrics](#metrics) - [Recording metrics](#recording-metrics) - [Accumulation modes](#accumulation-modes) - [Dot-path keys](#dot-path-keys) - [Type safety](#type-safety) - [Cancellation](#cancellation) - [`CancelationException`](#cancelationexception) - [Input and output types](#input-and-output-types) - [`File()`](#file) - [`Path()`](#path) - [`Secret`](#secret) - [`Optional`](#optional) - [`List`](#list) ## `BasePredictor` You define how Cog runs predictions on your model by defining a class that inherits from `BasePredictor`. It looks something like this: ```python from cog import BasePredictor, Path, Input import torch class Predictor(BasePredictor): def setup(self): """Load the model into memory to make running multiple predictions efficient""" self.model = torch.load("weights.pth") def predict(self, image: Path = Input(description="Image to enlarge"), scale: float = Input(description="Factor to scale image by", default=1.5) ) -> Path: """Run a single prediction on the model""" # ... pre-processing ... output = self.model(image) # ... post-processing ... return output ``` Your Predictor class should define two methods: `setup()` and `predict()`. ### `Predictor.setup()` Prepare the model so multiple predictions run efficiently. Use this _optional_ method to include expensive one-off operations like loading trained models, instantiating data transformations, etc. Many models use this method to download their weights (e.g. using [`pget`](https://github.com/replicate/pget)). This has some advantages: - Smaller image sizes - Faster build times - Faster pushes and inference on [Replicate](https://replicate.com) However, this may also significantly increase your `setup()` time. As an alternative, some choose to store their weights directly in the image. You can simply leave your weights in the directory alongside your `cog.yaml` and ensure they are not excluded in your `.dockerignore` file. While this will increase your image size and build time, it offers other advantages: - Faster `setup()` time - Ensures idempotency and reduces your model's reliance on external systems - Preserves reproducibility as your model will be self-contained in the image > When using this method, you should use the `--separate-weights` flag on `cog build` to store weights in a [separate layer](https://github.com/replicate/cog/blob/12ac02091d93beebebed037f38a0c99cd8749806/docs/getting-started.md?plain=1#L219). ### `Predictor.predict(**kwargs)` Run a single prediction. This _required_ method is where you call the model that was loaded during `setup()`, but you may also want to add pre- and post-processing code here. The `predict()` method takes an arbitrary list of named arguments, where each argument name must correspond to an [`Input()`](#inputkwargs) annotation. `predict()` can return strings, numbers, [`cog.Path`](#path) objects representing files on disk, or lists or dicts of those types. You can also define a custom [`Output()`](#outputbasemodel) for more complex return types. ## `async` predictors and concurrency > Added in cog 0.14.0. You may specify your `predict()` method as `async def predict(...)`. In addition, if you have an async `predict()` function you may also have an async `setup()` function: ```py class Predictor(BasePredictor): async def setup(self) -> None: print("async setup is also supported...") async def predict(self) -> str: print("async predict"); return "hello world"; ``` Models that have an async `predict()` function can run predictions concurrently, up to the limit specified by [`concurrency.max`](yaml.md#max) in cog.yaml. Attempting to exceed this limit will return a 409 Conflict response. ## `Input(**kwargs)` Use cog's `Input()` function to define each of the parameters in your `predict()` method: ```py class Predictor(BasePredictor): def predict(self, image: Path = Input(description="Image to enlarge"), scale: float = Input(description="Factor to scale image by", default=1.5, ge=1.0, le=10.0) ) -> Path: ``` The `Input()` function takes these keyword arguments: - `description`: A description of what to pass to this input for users of the model. - `default`: A default value to set the input to. If this argument is not passed, the input is required. If it is explicitly set to `None`, the input is optional. - `ge`: For `int` or `float` types, the value must be greater than or equal to this number. - `le`: For `int` or `float` types, the value must be less than or equal to this number. - `min_length`: For `str` types, the minimum length of the string. - `max_length`: For `str` types, the maximum length of the string. - `regex`: For `str` types, the string must match this regular expression. - `choices`: For `str` or `int` types, a list of possible values for this input. - `deprecated`: (optional) If set to `True`, marks this input as deprecated. Deprecated inputs will still be accepted, but tools and UIs may warn users that the input is deprecated and may be removed in the future. See [Deprecating inputs](#deprecating-inputs). Each parameter of the `predict()` method must be annotated with a type like `str`, `int`, `float`, `bool`, etc. See [Input and output types](#input-and-output-types) for the full list of supported types. Using the `Input` function provides better documentation and validation constraints to the users of your model, but it is not strictly required. You can also specify default values for your parameters using plain Python, or omit default assignment entirely: ```py class Predictor(BasePredictor): def predict(self, prompt: str = "default prompt", # this is valid iterations: int # also valid ) -> str: # ... ``` ## Deprecating inputs You can mark an input as deprecated by passing `deprecated=True` to the `Input()` function. Deprecated inputs will still be accepted, but tools and UIs may warn users that the input is deprecated and may be removed in the future. This is useful when you want to phase out an input without breaking existing clients immediately: ```py from cog import BasePredictor, Input class Predictor(BasePredictor): def predict(self, text: str = Input(description="Some deprecated text", deprecated=True), prompt: str = Input(description="Prompt for the model") ) -> str: # ... return prompt ``` ## Output Cog predictors can return a simple data type like a string, number, float, or boolean. Use Python's `-> ` syntax to annotate the return type. Here's an example of a predictor that returns a string: ```py from cog import BasePredictor class Predictor(BasePredictor): def predict(self) -> str: return "hello" ``` ### Returning an object To return a complex object with multiple values, define an `Output` object with multiple fields to return from your `predict()` method: ```py from cog import BasePredictor, BaseModel, File class Output(BaseModel): file: File text: str class Predictor(BasePredictor): def predict(self) -> Output: return Output(text="hello", file=io.StringIO("hello")) ``` Each of the output object's properties must be one of the supported output types. For the full list, see [Input and output types](#input-and-output-types). Also, make sure to name the output class as `Output` and nothing else. ### Returning a list The `predict()` method can return a list of any of the supported output types. Here's an example that outputs multiple files: ```py from cog import BasePredictor, Path class Predictor(BasePredictor): def predict(self) -> list[Path]: predictions = ["foo", "bar", "baz"] output = [] for i, prediction in enumerate(predictions): out_path = Path(f"/tmp/out-{i}.txt") with out_path.open("w") as f: f.write(prediction) output.append(out_path) return output ``` Files are named in the format `output..`, e.g. `output.0.txt`, `output.1.txt`, and `output.2.txt` from the example above. ### Optional properties To conditionally omit properties from the Output object, define them using `typing.Optional`: ```py from cog import BaseModel, BasePredictor, Path from typing import Optional class Output(BaseModel): score: Optional[float] file: Optional[Path] class Predictor(BasePredictor): def predict(self) -> Output: if condition: return Output(score=1.5) else: return Output(file=io.StringIO("hello")) ``` ### Streaming output Cog models can stream output as the `predict()` method is running. For example, a language model can output tokens as they're being generated and an image generation model can output images as they are being generated. To support streaming output in your Cog model, add `from typing import Iterator` to your predict.py file. The `typing` package is a part of Python's standard library so it doesn't need to be installed. Then add a return type annotation to the `predict()` method in the form `-> Iterator[]` where `` can be one of `str`, `int`, `float`, `bool`, or `cog.Path`. ```py from cog import BasePredictor, Path from typing import Iterator class Predictor(BasePredictor): def predict(self) -> Iterator[Path]: done = False while not done: output_path, done = do_stuff() yield Path(output_path) ``` If you have an [async `predict()` method](#async-predictors-and-concurrency), you must use `cog.AsyncIterator` instead: ```py from cog import AsyncIterator, BasePredictor, Path class Predictor(BasePredictor): async def predict(self) -> AsyncIterator[Path]: done = False while not done: output_path, done = do_stuff() yield Path(output_path) ``` If you're streaming text output, you can use `ConcatenateIterator` to hint that the output should be concatenated together into a single string. This is useful on Replicate to display the output as a string instead of a list of strings. ```py from cog import BasePredictor, Path, ConcatenateIterator class Predictor(BasePredictor): def predict(self) -> ConcatenateIterator[str]: tokens = ["The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"] for token in tokens: yield token + " " ``` Or for async `predict()` methods, use `AsyncConcatenateIterator`: ```py from cog import BasePredictor, Path, AsyncConcatenateIterator class Predictor(BasePredictor): async def predict(self) -> AsyncConcatenateIterator[str]: tokens = ["The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"] for token in tokens: yield token + " " ``` ## Metrics You can record custom metrics from your `predict()` function to track model-specific data like token counts, timing breakdowns, or confidence scores. Metrics are included in the prediction response alongside the output. ### Recording metrics Use `self.record_metric()` inside your `predict()` method: ```python from cog import BasePredictor class Predictor(BasePredictor): def predict(self, prompt: str) -> str: self.record_metric("temperature", 0.7) self.record_metric("token_count", 42) result = self.model.generate(prompt) return result ``` For advanced use (dict-style access, deleting metrics), use `self.scope`: ```python self.scope.metrics["token_count"] = 42 del self.scope.metrics["token_count"] ``` Metrics appear in the prediction response `metrics` field: ```json { "status": "succeeded", "output": "...", "metrics": { "temperature": 0.7, "token_count": 42, "predict_time": 1.23 } } ``` The `predict_time` metric is always added automatically by the runtime. If you set `predict_time` yourself, the runtime value takes precedence. Supported value types are `bool`, `int`, `float`, `str`, `list`, and `dict`. Setting a metric to `None` deletes it. ### Accumulation modes By default, recording a metric replaces any previous value for that key. You can use accumulation modes to build up values across multiple calls: ```python # Increment a counter (adds to the existing numeric value) self.record_metric("token_count", 1, mode="incr") self.record_metric("token_count", 1, mode="incr") # Result: {"token_count": 2} # Append to an array self.record_metric("steps", "preprocessing", mode="append") self.record_metric("steps", "inference", mode="append") # Result: {"steps": ["preprocessing", "inference"]} # Replace (default behavior) self.record_metric("status", "running", mode="replace") self.record_metric("status", "done", mode="replace") # Result: {"status": "done"} ``` The `mode` parameter accepts `"replace"` (default), `"incr"`, or `"append"`. ### Dot-path keys Use dot-separated keys to create nested objects in the metrics output: ```python self.record_metric("timing.preprocess", 0.12) self.record_metric("timing.inference", 0.85) ``` This produces nested JSON: ```json { "metrics": { "timing": { "preprocess": 0.12, "inference": 0.85 }, "predict_time": 1.23 } } ``` ### Type safety Once a metric key has been assigned a value of a certain type, it cannot be changed to a different type without deleting it first. This prevents accidental type mismatches when using accumulation modes: ```python self.record_metric("count", 1) # This would raise an error — "count" is an int, not a string: # self.record_metric("count", "oops") # Delete first, then set with new type: del self.scope.metrics["count"] self.record_metric("count", "now a string") ``` Outside an active prediction, `self.record_metric()` and `self.scope` are silent no-ops — no need for `None` checks. ## Cancellation When a prediction is canceled (via the [cancel HTTP endpoint](http.md#post-predictionsprediction_idcancel) or a dropped connection), the Cog runtime interrupts the running `predict()` function. The exception raised depends on whether the predictor is sync or async: | Predictor type | Exception raised | | --------------------------- | ------------------------ | | Sync (`def predict`) | `CancelationException` | | Async (`async def predict`) | `asyncio.CancelledError` | ### `CancelationException` ```python from cog import CancelationException ``` `CancelationException` is raised in **sync** predictors when a prediction is cancelled. It is a `BaseException` subclass — **not** an `Exception` subclass. This means bare `except Exception` blocks in your predict code will not accidentally catch it, matching the behavior of `KeyboardInterrupt` and `asyncio.CancelledError`. You do **not** need to handle this exception in normal predictor code — the runtime manages cancellation automatically. However, if you need to run cleanup logic when a prediction is cancelled, you can catch it explicitly: ```python from cog import BasePredictor, CancelationException, Path class Predictor(BasePredictor): def predict(self, image: Path) -> Path: try: return self.process(image) except CancelationException: self.cleanup() raise # always re-raise ``` > [!WARNING] > You **must** re-raise `CancelationException` after cleanup. Swallowing it will prevent the runtime from marking the prediction as canceled, and may result in the termination of the container. `CancelationException` is available as: - `cog.CancelationException` (recommended) - `cog.exceptions.CancelationException` For **async** predictors, cancellation follows standard Python async conventions and raises `asyncio.CancelledError` instead. ## Input and output types Each parameter of the `predict()` method must be annotated with a type. The method's return type must also be annotated. The supported types are: - `str`: a string - `int`: an integer - `float`: a floating point number - `bool`: a boolean - [`cog.File`](#file): a file-like object representing a file - [`cog.Path`](#path): a path to a file on disk - [`cog.Secret`](#secret): a string containing sensitive information ## `File()` > [!WARNING] > `cog.File` is deprecated and will be removed in a future version of Cog. Use [`cog.Path`](#path) instead. The `cog.File` object is used to get files in and out of models. It represents a _file handle_. For models that return a `cog.File` object, the prediction output returned by Cog's built-in HTTP server will be a URL. ```python from cog import BasePredictor, File, Input, Path from PIL import Image class Predictor(BasePredictor): def predict(self, source_image: File = Input(description="Image to enlarge")) -> File: pillow_img = Image.open(source_image) upscaled_image = do_some_processing(pillow_img) return File(upscaled_image) ``` ## `Path()` The `cog.Path` object is used to get files in and out of models. It represents a _path to a file on disk_. `cog.Path` is a subclass of Python's [`pathlib.Path`](https://docs.python.org/3/library/pathlib.html#basic-use) and can be used as a drop-in replacement. For models that return a `cog.Path` object, the prediction output returned by Cog's built-in HTTP server will be a URL. This example takes an input file, resizes it, and returns the resized image: ```python import tempfile from cog import BasePredictor, Input, Path class Predictor(BasePredictor): def predict(self, image: Path = Input(description="Image to enlarge")) -> Path: upscaled_image = do_some_processing(image) # To output `cog.Path` objects the file needs to exist, so create a temporary file first. # This file will automatically be deleted by Cog after it has been returned. output_path = Path(tempfile.mkdtemp()) / "upscaled.png" upscaled_image.save(output_path) return Path(output_path) ``` ## `Secret` The `cog.Secret` type is used to signify that an input holds sensitive information, like a password or API token. `cog.Secret` is a type that redacts its contents in string representations to prevent accidental disclosure. You can access its contents with the `get_secret_value()` method. ```python from cog import BasePredictor, Secret class Predictor(BasePredictor): def predict(self, api_token: Secret) -> None: # Prints '**********' print(api_token) # Use get_secret_value method to see the secret's content. print(api_token.get_secret_value()) ``` A predictor's `Secret` inputs are represented in OpenAPI with the following schema: ```json { "type": "string", "format": "password", "x-cog-secret": true } ``` Models uploaded to Replicate treat secret inputs differently throughout its system. When you create a prediction on Replicate, any value passed to a `Secret` input is redacted after being sent to the model. > [!WARNING] > Passing secret values to untrusted models can result in > unintended disclosure, exfiltration, or misuse of sensitive data. ## `Optional` Optional inputs should be explicitly defined as `Optional[T]` so that type checker can warn us about error-prone `None` values. For example, the following code might fail if `prompt` is not specified in the inputs: ```python class Predictor(BasePredictor): def predict(self, prompt: str=Input(description="prompt", default=None)) -> str: return "hello" + prompt # TypeError: can only concatenate str (not "NoneType") to str ``` We can improve it by making `prompt` an `Optional[str]`. Note that `default=None` is now redundant as `Optional` implies it. ```python class Predictor(BasePredictor): def predict(self, prompt: Optional[str]=Input(description="prompt")) -> str: if prompt is None: # type check can warn us if we forget this return "hello" else: return "hello" + prompt ``` Note that the error prone usage of `prompt: str=Input(default=None)` might throw an error in a future release of Cog. ## `List` The List type is also supported in inputs. It can hold any supported type. Example for **List[Path]**: ```py class Predictor(BasePredictor): def predict(self, paths: list[Path]) -> str: output_parts = [] # Use a list to collect file contents for path in paths: with open(path) as f: output_parts.append(f.read()) return "".join(output_parts) ``` The corresponding cog command: ```bash $ echo test1 > 1.txt $ echo test2 > 2.txt $ cog predict -i paths=@1.txt -i paths=@2.txt Running prediction... test1 test2 ``` - Note the repeated inputs with the same name "paths" which constitute the list --- # Training interface reference > [!WARNING] > The `cog train` command is deprecated and will be removed in the next version of Cog. The training API described below may still be used with the HTTP API's `/trainings` endpoint, but the CLI command is no longer recommended for new projects. Cog's training API allows you to define a fine-tuning interface for an existing Cog model, so users of the model can bring their own training data to create derivative fine-tuned models. Real-world examples of this API in use include [fine-tuning SDXL with images](https://replicate.com/blog/fine-tune-sdxl) or [fine-tuning Llama 2 with structured text](https://replicate.com/blog/fine-tune-llama-2). ## How it works If you've used Cog before, you've probably seen the [Predictor](./python.md) class, which defines the interface for creating predictions against your model. Cog's training API works similarly: You define a Python function that describes the inputs and outputs of the training process. The inputs are things like training data, epochs, batch size, seed, etc. The output is typically a file with the fine-tuned weights. `cog.yaml`: ```yaml build: python_version: "3.13" train: "train.py:train" ``` `train.py`: ```python from cog import BasePredictor, File import io def train(param: str) -> File: return io.StringIO("hello " + param) ``` Then you can run it like this: ``` $ cog train -i param=train ... $ cat weights hello train ``` You can also use classes if you want to run many model trainings and save on setup time. This works the same way as the [Predictor](./python.md) class with the only difference being the `train` method. `cog.yaml`: ```yaml build: python_version: "3.13" train: "train.py:Trainer" ``` `train.py`: ```python from cog import BasePredictor, File import io class Trainer: def setup(self) -> None: self.base_model = ... # Load a big base model def train(self, param: str) -> File: return self.base_model.train(param) # Train on top of a base model ``` ## `Input(**kwargs)` Use Cog's `Input()` function to define each of the parameters in your `train()` function: ```py from cog import Input, Path def train( train_data: Path = Input(description="HTTPS URL of a file containing training data"), learning_rate: float = Input(description="learning rate, for learning!", default=1e-4, ge=0), seed: int = Input(description="random seed to use for training", default=None) ) -> str: return "hello, weights" ``` The `Input()` function takes these keyword arguments: - `description`: A description of what to pass to this input for users of the model. - `default`: A default value to set the input to. If this argument is not passed, the input is required. If it is explicitly set to `None`, the input is optional. - `ge`: For `int` or `float` types, the value must be greater than or equal to this number. - `le`: For `int` or `float` types, the value must be less than or equal to this number. - `min_length`: For `str` types, the minimum length of the string. - `max_length`: For `str` types, the maximum length of the string. - `regex`: For `str` types, the string must match this regular expression. - `choices`: For `str` or `int` types, a list of possible values for this input. Each parameter of the `train()` function must be annotated with a type like `str`, `int`, `float`, `bool`, etc. See [Input and output types](./python.md#input-and-output-types) for the full list of supported types. Using the `Input` function provides better documentation and validation constraints to the users of your model, but it is not strictly required. You can also specify default values for your parameters using plain Python, or omit default assignment entirely: ```py def train(self, training_data: str = "foo bar", # this is valid iterations: int # also valid ) -> str: # ... ``` ## Training Output Training output is typically a binary weights file. To return a custom output object or a complex object with multiple values, define a `TrainingOutput` object with multiple fields to return from your `train()` function, and specify it as the return type for the train function using Python's `->` return type annotation: ```python from cog import BaseModel, Input, Path class TrainingOutput(BaseModel): weights: Path def train( train_data: Path = Input(description="HTTPS URL of a file containing training data"), learning_rate: float = Input(description="learning rate, for learning!", default=1e-4, ge=0), seed: int = Input(description="random seed to use for training", default=42) ) -> TrainingOutput: weights_file = generate_weights("...") return TrainingOutput(weights=Path(weights_file)) ``` ## Testing If you are doing development of a Cog model like Llama or SDXL, you can test that the fine-tuned code path works before pushing by specifying a `COG_WEIGHTS` environment variable when running `predict`: ```console cog predict -e COG_WEIGHTS=https://replicate.delivery/pbxt/xyz/weights.tar -i prompt="a photo of TOK" ``` --- # Using `cog` on Windows 11 with WSL 2 - [0. Prerequisites](#0-prerequisites) - [1. Install the GPU driver](#1-install-the-gpu-driver) - [2. Unlocking features](#2-unlocking-features) - [2.1. Unlock WSL2](#21-unlock-wsl2) - [2.2. Unlock virtualization](#22-unlock-virtualization) - [2.3. Reboot](#23-reboot) - [3. Update MS Linux kernel](#3-update-ms-linux-kernel) - [4. Configure WSL 2](#4-configure-wsl-2) - [5. Configure CUDA WSL-Ubuntu Toolkit](#5-configure-cuda-wsl-ubuntu-toolkit) - [6. Install Docker](#6-install-docker) - [7. Install `cog` and pull an image](#7-install-cog-and-pull-an-image) - [8. Run a model in WSL 2](#8-run-a-model-in-wsl-2) - [9. References](#9-references) Running cog on Windows is now possible thanks to WSL 2. Follow this guide to enable WSL 2 and GPU passthrough on Windows 11. **Windows 10 is not officially supported, as you need to be on an insider build in order to use GPU passthrough.** ## 0. Prerequisites Before beginning installation, make sure you have: - Windows 11. - NVIDIA GPU. - RTX 2000/3000 series - Kesler/Tesla/Volta/Ampere series - Other configurations are not guaranteed to work. ## 1. Install the GPU driver Per NVIDIA, the first order of business is to install the latest Game Ready drivers for your NVIDIA GPU. I have an NVIDIA RTX 2070 Super, so filled out the form as such: ![a form showing the correct model number selected for an RTX 2070 Super](images/nvidia_driver_select.png) Click "search", and follow the dialogue to download and install the driver. Restart your computer once the driver has finished installation. ## 2. Unlocking features Open Windows Terminal as an administrator. - Use start to search for "Terminal" - Right click -> Run as administrator... Run the following powershell command to enable the Windows Subsystem for Linux and Virtual Machine Platform capabilities. ### 2.1. Unlock WSL2 ```powershell dism.exe /online /enable-feature /featurename:Microsoft-Windows-Subsystem-Linux /all /norestart ``` If you see an error about permissions, make sure the terminal you are using is run as an administrator and that you have an account with administrator-level privileges. ### 2.2. Unlock virtualization ```powershell dism.exe /online /enable-feature /featurename:VirtualMachinePlatform /all /norestart ``` If this command fails, make sure to [enable virtualization capabilities](https://docs.microsoft.com/en-us/windows/wsl/troubleshooting#error-0x80370102-the-virtual-machine-could-not-be-started-because-a-required-feature-is-not-installed) in your computer's BIOS/UEFI. A successful output will print `The operation completed successfully.` ![Output from running the above commands successfully. Should read "The operation completed successfully".](images/enable_feature_success.png) ### 2.3. Reboot Before moving forward, make sure you reboot your computer so that Windows 11 will have WSL2 and virtualization available to it. ## 3. Update MS Linux kernel Download and run the [WSL2 Linux kernel update package for x64 machines](https://wslstorestorage.blob.core.windows.net/wslblob/wsl_update_x64.msi) msi installer. When prompted for elevated permissions, click 'yes' to approve the installation. To ensure you are using the correct WSL kernel, `open Windows Terminal as an administrator` and enter: ```powershell wsl cat /proc/version ``` This will return a complicated string such as: ```sh Linux version 5.10.102.1-microsoft-standard-WSL2 (oe-user@oe-host) (x86_64-msft-linux-gcc (GCC) 9.3.0, GNU ld (GNU Binutils) 2.34.0.20200220) ``` The version we are interested in is `Linux version 5.10.102.1`. At this point, you should have updated your kernel to be at least `Linux version 5.10.43.3`. If you can't get the correct kernel version to show: Open `Settings` → `Windows Update` → `Advanced options` and ensure `Receive updates for other Microsoft products` is enabled. Then go to `Windows Update` again and click `Check for updates`. ## 4. Configure WSL 2 First, configure Windows to use the virtualization-based version of WSL (version 2) by default. In a Windows Terminal with administrator privileges, type the following: ```powershell wsl --set-default-version 2 ``` Now, you will need to go to the Microsoft Store and [Download Ubuntu 18.04](https://www.microsoft.com/store/apps/9N9TNGVNDL3Q) ![Screenshot showing the "Ubuntu" store page](https://docs.microsoft.com/en-us/windows/wsl/media/ubuntustore.png) Launch the "Ubuntu" app available in your Start Menu. Linux will require its own user account and password, which you will need to enter now: ![a terminal showing input for user account info on WSL 2](https://docs.microsoft.com/en-us/windows/wsl/media/ubuntuinstall.png) ## 5. Configure CUDA WSL-Ubuntu Toolkit By default, a shimmed version of the CUDA tooling is provided by your Windows GPU drivers. Important: you should _never_ use instructions for installing CUDA-toolkit in a generic linux fashion. in WSL 2, you _always_ want to use the provided `CUDA Toolkit using WSL-Ubuntu Package`. First, open PowerShell or Windows Command Prompt in administrator mode by right-clicking and selecting "Run as administrator". Then enter the following command: ```powershell wsl.exe ``` This should drop you into your running linux VM. Now you can run the following bash commands to install the correct version of cuda-toolkit for WSL-Ubuntu. Note that the version of CUDA used below may not be the version of CUDA your GPU supports. ```sh sudo apt-key del 7fa2af80 # if this line fails, you may remove it. wget https://developer.download.nvidia.com/compute/cuda/repos/wsl-ubuntu/x86_64/cuda-wsl-ubuntu.pin sudo mv cuda-wsl-ubuntu.pin /etc/apt/preferences.d/cuda-repository-pin-600 wget https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda-repo-wsl-ubuntu-11-7-local_11.7.0-1_amd64.deb sudo dpkg -i cuda-repo-wsl-ubuntu-11-7-local_11.7.0-1_amd64.deb sudo cp /var/cuda-repo-wsl-ubuntu-11-7-local/cuda-B81839D3-keyring.gpg /usr/share/keyrings/ sudo apt-get update sudo apt-get -y install cuda-toolkit-11-7 ``` ## 6. Install Docker Download and install [Docker Desktop for Windows](https://desktop.docker.com/win/main/amd64/Docker%20Desktop%20Installer.exe). It has WSL 2 support built in by default. Once installed, run `Docker Desktop`, you can ignore the first-run tutorial. Go to **Settings → General** and ensure **Use the WSL 2 based engine** has a checkmark next to it. Click **Apply & Restart**. !["Use the WSL 2 based engine" is checked in this interface](images/wsl2-enable.png) Reboot your computer one more time. ## 7. Install `cog` and pull an image Open Windows Terminal and enter your WSL 2 VM: ```powershell wsl.exe ``` Download and install `cog` inside the VM: ```bash sudo curl -o /usr/local/bin/cog -L https://github.com/replicate/cog/releases/latest/download/cog_`uname -s`_`uname -m` sudo chmod +x /usr/local/bin/cog ``` Make sure it's available by typing: ```bash which cog # should output /usr/local/bin/cog cog --version # should output the cog version number. ``` ## 8. Run a model in WSL 2 Finally, make sure it works. Let's try running `afiaka87/glid-3-xl` locally: ```bash cog predict 'r8.im/afiaka87/glid-3-xl' -i prompt="a fresh avocado floating in the water" -o prediction.json ``` ![Output from a running cog prediction in Windows Terminal](images/cog_model_output.png) While your prediction is running, you can use `Task Manager` to keep an eye on GPU memory consumption: ![Windows task manager will show the shared host/guest GPU memory](images/memory-usage.png) This model just barely manages to fit under 8 GB of VRAM. Notice that output is returned as JSON for this model as it has a complex return type. You will want to convert the base64 string in the json array to an image. `jq` can help with this: ```sh sudo apt install jq ``` The following bash uses `jq` to grab the first element in our prediction array and converts it from a base64 string to a `png` file. ```bash jq -cs '.[0][0][0]' prediction.json | cut --delimiter "," --field 2 | base64 --ignore-garbage --decode > prediction.png ``` When using WSL 2, you can access Windows binaries with the `.exe` extension. This lets you open photos easily within linux. ```bash explorer.exe prediction.png ``` ![a square image of an avocado, generated by the model](images/glide_out.png) ## 9. References - - - - - --- # `cog.yaml` reference `cog.yaml` defines how to build a Docker image and how to run predictions on your model inside that image. It has three keys: [`build`](#build), [`image`](#image), and [`predict`](#predict). It looks a bit like this: ```yaml build: python_version: "3.13" python_requirements: requirements.txt system_packages: - "ffmpeg" - "git" predict: "predict.py:Predictor" ``` Tip: Run [`cog init`](getting-started-own-model.md#initialization) to generate an annotated `cog.yaml` file that can be used as a starting point for setting up your model. ## `build` This stanza describes how to build the Docker image your model runs in. It contains various options within it: ### `cuda` Cog automatically picks the correct version of CUDA to install, but this lets you override it for whatever reason by specifying the minor (`11.8`) or patch (`11.8.0`) version of CUDA to use. For example: ```yaml build: cuda: "11.8" ``` ### `gpu` Enable GPUs for this model. When enabled, the [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) base image will be used, and Cog will automatically figure out what versions of CUDA and cuDNN to use based on the version of Python, PyTorch, and Tensorflow that you are using. For example: ```yaml build: gpu: true ``` When you use `cog run` or `cog predict`, Cog will automatically pass the `--gpus=all` flag to Docker. When you run a Docker image built with Cog, you'll need to pass this option to `docker run`. ### `python_requirements` A pip requirements file specifying the Python packages to install. For example: ```yaml build: python_requirements: requirements.txt ``` Your `cog.yaml` file can set either `python_packages` or `python_requirements`, but not both. Use `python_requirements` when you need to configure options like `--extra-index-url` or `--trusted-host` to fetch Python package dependencies. This follows the standard [requirements.txt](https://pip.pypa.io/en/stable/reference/requirements-file-format/) format. To install Git-hosted Python packages, add `git` to the `system_packages` list, then use the `git+https://` syntax to specify the package name. For example: `cog.yaml`: ```yaml build: system_packages: - "git" python_requirements: requirements.txt ``` `requirements.txt`: ``` git+https://github.com/huggingface/transformers ``` You can also pin Python package installations to a specific git commit: `cog.yaml`: ```yaml build: system_packages: - "git" python_requirements: requirements.txt ``` `requirements.txt`: ``` git+https://github.com/huggingface/transformers@2d1602a ``` Note that you can use a shortened prefix of the 40-character git commit SHA, but you must use at least six characters, like `2d1602a` above. ### `python_packages` **DEPRECATED**: This will be removed in future versions, please use [python_requirements](#python_requirements) instead. A list of Python packages to install from the PyPi package index, in the format `package==version`. For example: ```yaml build: python_packages: - pillow==8.3.1 - tensorflow==2.5.0 ``` Your `cog.yaml` file can set either `python_packages` or `python_requirements`, but not both. ### `python_version` The minor (`3.13`) or patch (`3.13.1`) version of Python to use. For example: ```yaml build: python_version: "3.13.1" ``` Cog supports Python 3.10, 3.11, 3.12, and 3.13. If you don't define a version, Cog will use the latest version of Python 3.13 or a version of Python that is compatible with the versions of PyTorch or TensorFlow you specify. Note that these are the versions supported **in the Docker container**, not your host machine. You can run any version(s) of Python you wish on your host machine. ### `run` A list of setup commands to run in the environment after your system packages and Python packages have been installed. If you're familiar with Docker, it's like a `RUN` instruction in your `Dockerfile`. For example: ```yaml build: run: - curl -L https://github.com/cowsay-org/cowsay/archive/refs/tags/v3.7.0.tar.gz | tar -xzf - - cd cowsay-3.7.0 && make install ``` Your code is _not_ available to commands in `run`. This is so we can build your image efficiently when running locally. Each command in `run` can be either a string or a dictionary in the following format: ```yaml build: run: - command: pip install mounts: - type: secret id: pip target: /etc/pip.conf ``` You can use secret mounts to securely pass credentials to setup commands, without baking them into the image. For more information, see [Dockerfile reference](https://docs.docker.com/engine/reference/builder/#run---mounttypesecret). ### `sdk_version` Pin the version of the cog Python SDK installed in the container. Accepts a [PEP 440](https://peps.python.org/pep-0440/) version string. When omitted, the latest release is installed. ```yaml build: python_version: "3.13" sdk_version: "0.18.0" ``` Pre-release versions are also supported: ```yaml build: sdk_version: "0.18.0a1" ``` When a pre-release `sdk_version` is set, `--pre` is automatically passed to the pip install commands for both `cog` and `coglet`, so pip will resolve matching pre-release packages. The minimum supported version is `0.16.0`. Specifying an older version will cause `cog build` to fail with an error. The `COG_SDK_WHEEL` environment variable takes precedence over `sdk_version`. See [Environment variables](./environment.md) for details. ### `system_packages` A list of Ubuntu APT packages to install. For example: ```yaml build: system_packages: - "ffmpeg" - "libavcodec-dev" ``` ## `concurrency` > Added in cog 0.14.0. This stanza describes the concurrency capabilities of the model. It has one option: ### `max` The maximum number of concurrent predictions the model can process. If this is set, the model must specify an [async `predict()` method](python.md#async-predictors-and-concurrency). For example: ```yaml concurrency: max: 10 ``` ## `image` The name given to built Docker images. If you want to push to a registry, this should also include the registry name. For example: ```yaml image: "r8.im/your-username/your-model" ``` r8.im is Replicate's registry, but this can be any Docker registry. If you don't set this, then a name will be generated from the directory name. If you set this, then you can run `cog push` without specifying the model name. If you specify an image name argument when pushing (like `cog push your-username/custom-model-name`), the argument will be used and the value of `image` in cog.yaml will be ignored. ## `predict` The pointer to the `Predictor` object in your code, which defines how predictions are run on your model. For example: ```yaml predict: "predict.py:Predictor" ``` See [the Python API documentation for more information](python.md). ================================================ FILE: docs/notebooks.md ================================================ # Notebooks Cog plays nicely with Jupyter notebooks. ## Install the jupyterlab Python package First, add `jupyterlab` to your `requirements.txt` file and reference it in [`cog.yaml`](yaml.md): `requirements.txt`: ``` jupyterlab ``` `cog.yaml`: ```yaml build: python_requirements: requirements.txt ``` ## Run a notebook Cog can run notebooks in the environment you've defined in `cog.yaml` with the following command: ```sh cog run -p 8888 jupyter lab --allow-root --ip=0.0.0.0 ``` ## Use notebook code in your predictor You can also import a notebook into your Cog [Predictor](python.md) file. First, export your notebook to a Python file: ```sh jupyter nbconvert --to script my_notebook.ipynb # creates my_notebook.py ``` Then import the exported Python script into your `predict.py` file. Any functions or variables defined in your notebook will be available to your predictor: ```python from cog import BasePredictor, Input import my_notebook class Predictor(BasePredictor): def predict(self, prompt: str = Input(description="string prompt")) -> str: output = my_notebook.do_stuff(prompt) return output ``` ================================================ FILE: docs/private-package-registry.md ================================================ # Private package registry This guide describes how to build a Docker image with Cog that fetches Python packages from a private registry during setup. ## `pip.conf` In a directory outside your Cog project, create a `pip.conf` file with an `index-url` set to the registry's URL with embedded credentials. ```conf [global] index-url = https://username:password@my-private-registry.com ``` > **Warning** > Be careful not to commit secrets in Git or include them in Docker images. If your Cog project contains any sensitive files, make sure they're listed in `.gitignore` and `.dockerignore`. ## `cog.yaml` In your project's [`cog.yaml`](yaml.md) file, add a setup command to run `pip install` with a secret configuration file mounted to `/etc/pip.conf`. ```yaml build: run: - command: pip install mounts: - type: secret id: pip target: /etc/pip.conf ``` ## Build When building or pushing your model with Cog, pass the `--secret` option with an `id` matching the one specified in `cog.yaml`, along with a path to your local `pip.conf` file. ```console $ cog build --secret id=pip,source=/path/to/pip.conf ``` Using a secret mount allows the private registry credentials to be securely passed to the `pip install` setup command, without baking them into the Docker image. > **Warning** > If you run `cog build` or `cog push` and then change the contents of a secret source file, the cached version of the file will be used on subsequent builds, ignoring any changes you made. To update the contents of the target secret file, either change the `id` value in `cog.yaml` and the `--secret` option, or pass the `--no-cache` option to bypass the cache entirely. ================================================ FILE: docs/python.md ================================================ # Prediction interface reference This document defines the API of the `cog` Python module, which is used to define the interface for running predictions on your model. > [!TIP] > Run [`cog init`](getting-started-own-model.md#initialization) to generate an annotated `predict.py` file that can be used as a starting point for setting up your model. > [!TIP] > Using a language model to help you write the code for your new Cog model? > > Feed it [https://cog.run/llms.txt](https://cog.run/llms.txt), which has all of Cog's documentation bundled into a single file. To learn more about this format, check out [llmstxt.org](https://llmstxt.org). ## Contents - [Contents](#contents) - [`BasePredictor`](#basepredictor) - [`Predictor.setup()`](#predictorsetup) - [`Predictor.predict(**kwargs)`](#predictorpredictkwargs) - [`async` predictors and concurrency](#async-predictors-and-concurrency) - [`Input(**kwargs)`](#inputkwargs) - [Deprecating inputs](#deprecating-inputs) - [Output](#output) - [Returning an object](#returning-an-object) - [Returning a list](#returning-a-list) - [Optional properties](#optional-properties) - [Streaming output](#streaming-output) - [Metrics](#metrics) - [Recording metrics](#recording-metrics) - [Accumulation modes](#accumulation-modes) - [Dot-path keys](#dot-path-keys) - [Type safety](#type-safety) - [Cancellation](#cancellation) - [`CancelationException`](#cancelationexception) - [Input and output types](#input-and-output-types) - [`File()`](#file) - [`Path()`](#path) - [`Secret`](#secret) - [`Optional`](#optional) - [`List`](#list) ## `BasePredictor` You define how Cog runs predictions on your model by defining a class that inherits from `BasePredictor`. It looks something like this: ```python from cog import BasePredictor, Path, Input import torch class Predictor(BasePredictor): def setup(self): """Load the model into memory to make running multiple predictions efficient""" self.model = torch.load("weights.pth") def predict(self, image: Path = Input(description="Image to enlarge"), scale: float = Input(description="Factor to scale image by", default=1.5) ) -> Path: """Run a single prediction on the model""" # ... pre-processing ... output = self.model(image) # ... post-processing ... return output ``` Your Predictor class should define two methods: `setup()` and `predict()`. ### `Predictor.setup()` Prepare the model so multiple predictions run efficiently. Use this _optional_ method to include expensive one-off operations like loading trained models, instantiating data transformations, etc. Many models use this method to download their weights (e.g. using [`pget`](https://github.com/replicate/pget)). This has some advantages: - Smaller image sizes - Faster build times - Faster pushes and inference on [Replicate](https://replicate.com) However, this may also significantly increase your `setup()` time. As an alternative, some choose to store their weights directly in the image. You can simply leave your weights in the directory alongside your `cog.yaml` and ensure they are not excluded in your `.dockerignore` file. While this will increase your image size and build time, it offers other advantages: - Faster `setup()` time - Ensures idempotency and reduces your model's reliance on external systems - Preserves reproducibility as your model will be self-contained in the image > When using this method, you should use the `--separate-weights` flag on `cog build` to store weights in a [separate layer](https://github.com/replicate/cog/blob/12ac02091d93beebebed037f38a0c99cd8749806/docs/getting-started.md?plain=1#L219). ### `Predictor.predict(**kwargs)` Run a single prediction. This _required_ method is where you call the model that was loaded during `setup()`, but you may also want to add pre- and post-processing code here. The `predict()` method takes an arbitrary list of named arguments, where each argument name must correspond to an [`Input()`](#inputkwargs) annotation. `predict()` can return strings, numbers, [`cog.Path`](#path) objects representing files on disk, or lists or dicts of those types. You can also define a custom [`Output()`](#outputbasemodel) for more complex return types. ## `async` predictors and concurrency > Added in cog 0.14.0. You may specify your `predict()` method as `async def predict(...)`. In addition, if you have an async `predict()` function you may also have an async `setup()` function: ```py class Predictor(BasePredictor): async def setup(self) -> None: print("async setup is also supported...") async def predict(self) -> str: print("async predict"); return "hello world"; ``` Models that have an async `predict()` function can run predictions concurrently, up to the limit specified by [`concurrency.max`](yaml.md#max) in cog.yaml. Attempting to exceed this limit will return a 409 Conflict response. ## `Input(**kwargs)` Use cog's `Input()` function to define each of the parameters in your `predict()` method: ```py class Predictor(BasePredictor): def predict(self, image: Path = Input(description="Image to enlarge"), scale: float = Input(description="Factor to scale image by", default=1.5, ge=1.0, le=10.0) ) -> Path: ``` The `Input()` function takes these keyword arguments: - `description`: A description of what to pass to this input for users of the model. - `default`: A default value to set the input to. If this argument is not passed, the input is required. If it is explicitly set to `None`, the input is optional. - `ge`: For `int` or `float` types, the value must be greater than or equal to this number. - `le`: For `int` or `float` types, the value must be less than or equal to this number. - `min_length`: For `str` types, the minimum length of the string. - `max_length`: For `str` types, the maximum length of the string. - `regex`: For `str` types, the string must match this regular expression. - `choices`: For `str` or `int` types, a list of possible values for this input. - `deprecated`: (optional) If set to `True`, marks this input as deprecated. Deprecated inputs will still be accepted, but tools and UIs may warn users that the input is deprecated and may be removed in the future. See [Deprecating inputs](#deprecating-inputs). Each parameter of the `predict()` method must be annotated with a type like `str`, `int`, `float`, `bool`, etc. See [Input and output types](#input-and-output-types) for the full list of supported types. Using the `Input` function provides better documentation and validation constraints to the users of your model, but it is not strictly required. You can also specify default values for your parameters using plain Python, or omit default assignment entirely: ```py class Predictor(BasePredictor): def predict(self, prompt: str = "default prompt", # this is valid iterations: int # also valid ) -> str: # ... ``` ## Deprecating inputs You can mark an input as deprecated by passing `deprecated=True` to the `Input()` function. Deprecated inputs will still be accepted, but tools and UIs may warn users that the input is deprecated and may be removed in the future. This is useful when you want to phase out an input without breaking existing clients immediately: ```py from cog import BasePredictor, Input class Predictor(BasePredictor): def predict(self, text: str = Input(description="Some deprecated text", deprecated=True), prompt: str = Input(description="Prompt for the model") ) -> str: # ... return prompt ``` ## Output Cog predictors can return a simple data type like a string, number, float, or boolean. Use Python's `-> ` syntax to annotate the return type. Here's an example of a predictor that returns a string: ```py from cog import BasePredictor class Predictor(BasePredictor): def predict(self) -> str: return "hello" ``` ### Returning an object To return a complex object with multiple values, define an `Output` object with multiple fields to return from your `predict()` method: ```py from cog import BasePredictor, BaseModel, File class Output(BaseModel): file: File text: str class Predictor(BasePredictor): def predict(self) -> Output: return Output(text="hello", file=io.StringIO("hello")) ``` Each of the output object's properties must be one of the supported output types. For the full list, see [Input and output types](#input-and-output-types). Also, make sure to name the output class as `Output` and nothing else. ### Returning a list The `predict()` method can return a list of any of the supported output types. Here's an example that outputs multiple files: ```py from cog import BasePredictor, Path class Predictor(BasePredictor): def predict(self) -> list[Path]: predictions = ["foo", "bar", "baz"] output = [] for i, prediction in enumerate(predictions): out_path = Path(f"/tmp/out-{i}.txt") with out_path.open("w") as f: f.write(prediction) output.append(out_path) return output ``` Files are named in the format `output..`, e.g. `output.0.txt`, `output.1.txt`, and `output.2.txt` from the example above. ### Optional properties To conditionally omit properties from the Output object, define them using `typing.Optional`: ```py from cog import BaseModel, BasePredictor, Path from typing import Optional class Output(BaseModel): score: Optional[float] file: Optional[Path] class Predictor(BasePredictor): def predict(self) -> Output: if condition: return Output(score=1.5) else: return Output(file=io.StringIO("hello")) ``` ### Streaming output Cog models can stream output as the `predict()` method is running. For example, a language model can output tokens as they're being generated and an image generation model can output images as they are being generated. To support streaming output in your Cog model, add `from typing import Iterator` to your predict.py file. The `typing` package is a part of Python's standard library so it doesn't need to be installed. Then add a return type annotation to the `predict()` method in the form `-> Iterator[]` where `` can be one of `str`, `int`, `float`, `bool`, or `cog.Path`. ```py from cog import BasePredictor, Path from typing import Iterator class Predictor(BasePredictor): def predict(self) -> Iterator[Path]: done = False while not done: output_path, done = do_stuff() yield Path(output_path) ``` If you have an [async `predict()` method](#async-predictors-and-concurrency), you must use `cog.AsyncIterator` instead: ```py from cog import AsyncIterator, BasePredictor, Path class Predictor(BasePredictor): async def predict(self) -> AsyncIterator[Path]: done = False while not done: output_path, done = do_stuff() yield Path(output_path) ``` If you're streaming text output, you can use `ConcatenateIterator` to hint that the output should be concatenated together into a single string. This is useful on Replicate to display the output as a string instead of a list of strings. ```py from cog import BasePredictor, Path, ConcatenateIterator class Predictor(BasePredictor): def predict(self) -> ConcatenateIterator[str]: tokens = ["The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"] for token in tokens: yield token + " " ``` Or for async `predict()` methods, use `AsyncConcatenateIterator`: ```py from cog import BasePredictor, Path, AsyncConcatenateIterator class Predictor(BasePredictor): async def predict(self) -> AsyncConcatenateIterator[str]: tokens = ["The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"] for token in tokens: yield token + " " ``` ## Metrics You can record custom metrics from your `predict()` function to track model-specific data like token counts, timing breakdowns, or confidence scores. Metrics are included in the prediction response alongside the output. ### Recording metrics Use `self.record_metric()` inside your `predict()` method: ```python from cog import BasePredictor class Predictor(BasePredictor): def predict(self, prompt: str) -> str: self.record_metric("temperature", 0.7) self.record_metric("token_count", 42) result = self.model.generate(prompt) return result ``` For advanced use (dict-style access, deleting metrics), use `self.scope`: ```python self.scope.metrics["token_count"] = 42 del self.scope.metrics["token_count"] ``` Metrics appear in the prediction response `metrics` field: ```json { "status": "succeeded", "output": "...", "metrics": { "temperature": 0.7, "token_count": 42, "predict_time": 1.23 } } ``` The `predict_time` metric is always added automatically by the runtime. If you set `predict_time` yourself, the runtime value takes precedence. Supported value types are `bool`, `int`, `float`, `str`, `list`, and `dict`. Setting a metric to `None` deletes it. ### Accumulation modes By default, recording a metric replaces any previous value for that key. You can use accumulation modes to build up values across multiple calls: ```python # Increment a counter (adds to the existing numeric value) self.record_metric("token_count", 1, mode="incr") self.record_metric("token_count", 1, mode="incr") # Result: {"token_count": 2} # Append to an array self.record_metric("steps", "preprocessing", mode="append") self.record_metric("steps", "inference", mode="append") # Result: {"steps": ["preprocessing", "inference"]} # Replace (default behavior) self.record_metric("status", "running", mode="replace") self.record_metric("status", "done", mode="replace") # Result: {"status": "done"} ``` The `mode` parameter accepts `"replace"` (default), `"incr"`, or `"append"`. ### Dot-path keys Use dot-separated keys to create nested objects in the metrics output: ```python self.record_metric("timing.preprocess", 0.12) self.record_metric("timing.inference", 0.85) ``` This produces nested JSON: ```json { "metrics": { "timing": { "preprocess": 0.12, "inference": 0.85 }, "predict_time": 1.23 } } ``` ### Type safety Once a metric key has been assigned a value of a certain type, it cannot be changed to a different type without deleting it first. This prevents accidental type mismatches when using accumulation modes: ```python self.record_metric("count", 1) # This would raise an error — "count" is an int, not a string: # self.record_metric("count", "oops") # Delete first, then set with new type: del self.scope.metrics["count"] self.record_metric("count", "now a string") ``` Outside an active prediction, `self.record_metric()` and `self.scope` are silent no-ops — no need for `None` checks. ## Cancellation When a prediction is canceled (via the [cancel HTTP endpoint](http.md#post-predictionsprediction_idcancel) or a dropped connection), the Cog runtime interrupts the running `predict()` function. The exception raised depends on whether the predictor is sync or async: | Predictor type | Exception raised | | --------------------------- | ------------------------ | | Sync (`def predict`) | `CancelationException` | | Async (`async def predict`) | `asyncio.CancelledError` | ### `CancelationException` ```python from cog import CancelationException ``` `CancelationException` is raised in **sync** predictors when a prediction is cancelled. It is a `BaseException` subclass — **not** an `Exception` subclass. This means bare `except Exception` blocks in your predict code will not accidentally catch it, matching the behavior of `KeyboardInterrupt` and `asyncio.CancelledError`. You do **not** need to handle this exception in normal predictor code — the runtime manages cancellation automatically. However, if you need to run cleanup logic when a prediction is cancelled, you can catch it explicitly: ```python from cog import BasePredictor, CancelationException, Path class Predictor(BasePredictor): def predict(self, image: Path) -> Path: try: return self.process(image) except CancelationException: self.cleanup() raise # always re-raise ``` > [!WARNING] > You **must** re-raise `CancelationException` after cleanup. Swallowing it will prevent the runtime from marking the prediction as canceled, and may result in the termination of the container. `CancelationException` is available as: - `cog.CancelationException` (recommended) - `cog.exceptions.CancelationException` For **async** predictors, cancellation follows standard Python async conventions and raises `asyncio.CancelledError` instead. ## Input and output types Each parameter of the `predict()` method must be annotated with a type. The method's return type must also be annotated. The supported types are: - `str`: a string - `int`: an integer - `float`: a floating point number - `bool`: a boolean - [`cog.File`](#file): a file-like object representing a file - [`cog.Path`](#path): a path to a file on disk - [`cog.Secret`](#secret): a string containing sensitive information ## `File()` > [!WARNING] > `cog.File` is deprecated and will be removed in a future version of Cog. Use [`cog.Path`](#path) instead. The `cog.File` object is used to get files in and out of models. It represents a _file handle_. For models that return a `cog.File` object, the prediction output returned by Cog's built-in HTTP server will be a URL. ```python from cog import BasePredictor, File, Input, Path from PIL import Image class Predictor(BasePredictor): def predict(self, source_image: File = Input(description="Image to enlarge")) -> File: pillow_img = Image.open(source_image) upscaled_image = do_some_processing(pillow_img) return File(upscaled_image) ``` ## `Path()` The `cog.Path` object is used to get files in and out of models. It represents a _path to a file on disk_. `cog.Path` is a subclass of Python's [`pathlib.Path`](https://docs.python.org/3/library/pathlib.html#basic-use) and can be used as a drop-in replacement. For models that return a `cog.Path` object, the prediction output returned by Cog's built-in HTTP server will be a URL. This example takes an input file, resizes it, and returns the resized image: ```python import tempfile from cog import BasePredictor, Input, Path class Predictor(BasePredictor): def predict(self, image: Path = Input(description="Image to enlarge")) -> Path: upscaled_image = do_some_processing(image) # To output `cog.Path` objects the file needs to exist, so create a temporary file first. # This file will automatically be deleted by Cog after it has been returned. output_path = Path(tempfile.mkdtemp()) / "upscaled.png" upscaled_image.save(output_path) return Path(output_path) ``` ## `Secret` The `cog.Secret` type is used to signify that an input holds sensitive information, like a password or API token. `cog.Secret` is a type that redacts its contents in string representations to prevent accidental disclosure. You can access its contents with the `get_secret_value()` method. ```python from cog import BasePredictor, Secret class Predictor(BasePredictor): def predict(self, api_token: Secret) -> None: # Prints '**********' print(api_token) # Use get_secret_value method to see the secret's content. print(api_token.get_secret_value()) ``` A predictor's `Secret` inputs are represented in OpenAPI with the following schema: ```json { "type": "string", "format": "password", "x-cog-secret": true } ``` Models uploaded to Replicate treat secret inputs differently throughout its system. When you create a prediction on Replicate, any value passed to a `Secret` input is redacted after being sent to the model. > [!WARNING] > Passing secret values to untrusted models can result in > unintended disclosure, exfiltration, or misuse of sensitive data. ## `Optional` Optional inputs should be explicitly defined as `Optional[T]` so that type checker can warn us about error-prone `None` values. For example, the following code might fail if `prompt` is not specified in the inputs: ```python class Predictor(BasePredictor): def predict(self, prompt: str=Input(description="prompt", default=None)) -> str: return "hello" + prompt # TypeError: can only concatenate str (not "NoneType") to str ``` We can improve it by making `prompt` an `Optional[str]`. Note that `default=None` is now redundant as `Optional` implies it. ```python class Predictor(BasePredictor): def predict(self, prompt: Optional[str]=Input(description="prompt")) -> str: if prompt is None: # type check can warn us if we forget this return "hello" else: return "hello" + prompt ``` Note that the error prone usage of `prompt: str=Input(default=None)` might throw an error in a future release of Cog. ## `List` The List type is also supported in inputs. It can hold any supported type. Example for **List[Path]**: ```py class Predictor(BasePredictor): def predict(self, paths: list[Path]) -> str: output_parts = [] # Use a list to collect file contents for path in paths: with open(path) as f: output_parts.append(f.read()) return "".join(output_parts) ``` The corresponding cog command: ```bash $ echo test1 > 1.txt $ echo test2 > 2.txt $ cog predict -i paths=@1.txt -i paths=@2.txt Running prediction... test1 test2 ``` - Note the repeated inputs with the same name "paths" which constitute the list ================================================ FILE: docs/stylesheets/extra.css ================================================ .md-typeset h1, .md-typeset h2, .md-typeset h3 { font-weight: 600; } /* move the "Cog" header to the left */ [dir="ltr"] .md-header__title { margin-left: 0; } /* Remove the superfluous "Cog" label above the TOC */ .md-nav__title { display: none; } ================================================ FILE: docs/training.md ================================================ # Training interface reference > [!WARNING] > The `cog train` command is deprecated and will be removed in the next version of Cog. The training API described below may still be used with the HTTP API's `/trainings` endpoint, but the CLI command is no longer recommended for new projects. Cog's training API allows you to define a fine-tuning interface for an existing Cog model, so users of the model can bring their own training data to create derivative fine-tuned models. Real-world examples of this API in use include [fine-tuning SDXL with images](https://replicate.com/blog/fine-tune-sdxl) or [fine-tuning Llama 2 with structured text](https://replicate.com/blog/fine-tune-llama-2). ## How it works If you've used Cog before, you've probably seen the [Predictor](./python.md) class, which defines the interface for creating predictions against your model. Cog's training API works similarly: You define a Python function that describes the inputs and outputs of the training process. The inputs are things like training data, epochs, batch size, seed, etc. The output is typically a file with the fine-tuned weights. `cog.yaml`: ```yaml build: python_version: "3.13" train: "train.py:train" ``` `train.py`: ```python from cog import BasePredictor, File import io def train(param: str) -> File: return io.StringIO("hello " + param) ``` Then you can run it like this: ``` $ cog train -i param=train ... $ cat weights hello train ``` You can also use classes if you want to run many model trainings and save on setup time. This works the same way as the [Predictor](./python.md) class with the only difference being the `train` method. `cog.yaml`: ```yaml build: python_version: "3.13" train: "train.py:Trainer" ``` `train.py`: ```python from cog import BasePredictor, File import io class Trainer: def setup(self) -> None: self.base_model = ... # Load a big base model def train(self, param: str) -> File: return self.base_model.train(param) # Train on top of a base model ``` ## `Input(**kwargs)` Use Cog's `Input()` function to define each of the parameters in your `train()` function: ```py from cog import Input, Path def train( train_data: Path = Input(description="HTTPS URL of a file containing training data"), learning_rate: float = Input(description="learning rate, for learning!", default=1e-4, ge=0), seed: int = Input(description="random seed to use for training", default=None) ) -> str: return "hello, weights" ``` The `Input()` function takes these keyword arguments: - `description`: A description of what to pass to this input for users of the model. - `default`: A default value to set the input to. If this argument is not passed, the input is required. If it is explicitly set to `None`, the input is optional. - `ge`: For `int` or `float` types, the value must be greater than or equal to this number. - `le`: For `int` or `float` types, the value must be less than or equal to this number. - `min_length`: For `str` types, the minimum length of the string. - `max_length`: For `str` types, the maximum length of the string. - `regex`: For `str` types, the string must match this regular expression. - `choices`: For `str` or `int` types, a list of possible values for this input. Each parameter of the `train()` function must be annotated with a type like `str`, `int`, `float`, `bool`, etc. See [Input and output types](./python.md#input-and-output-types) for the full list of supported types. Using the `Input` function provides better documentation and validation constraints to the users of your model, but it is not strictly required. You can also specify default values for your parameters using plain Python, or omit default assignment entirely: ```py def train(self, training_data: str = "foo bar", # this is valid iterations: int # also valid ) -> str: # ... ``` ## Training Output Training output is typically a binary weights file. To return a custom output object or a complex object with multiple values, define a `TrainingOutput` object with multiple fields to return from your `train()` function, and specify it as the return type for the train function using Python's `->` return type annotation: ```python from cog import BaseModel, Input, Path class TrainingOutput(BaseModel): weights: Path def train( train_data: Path = Input(description="HTTPS URL of a file containing training data"), learning_rate: float = Input(description="learning rate, for learning!", default=1e-4, ge=0), seed: int = Input(description="random seed to use for training", default=42) ) -> TrainingOutput: weights_file = generate_weights("...") return TrainingOutput(weights=Path(weights_file)) ``` ## Testing If you are doing development of a Cog model like Llama or SDXL, you can test that the fine-tuned code path works before pushing by specifying a `COG_WEIGHTS` environment variable when running `predict`: ```console cog predict -e COG_WEIGHTS=https://replicate.delivery/pbxt/xyz/weights.tar -i prompt="a photo of TOK" ``` ================================================ FILE: docs/wsl2/wsl2.md ================================================ # Using `cog` on Windows 11 with WSL 2 - [0. Prerequisites](#0-prerequisites) - [1. Install the GPU driver](#1-install-the-gpu-driver) - [2. Unlocking features](#2-unlocking-features) - [2.1. Unlock WSL2](#21-unlock-wsl2) - [2.2. Unlock virtualization](#22-unlock-virtualization) - [2.3. Reboot](#23-reboot) - [3. Update MS Linux kernel](#3-update-ms-linux-kernel) - [4. Configure WSL 2](#4-configure-wsl-2) - [5. Configure CUDA WSL-Ubuntu Toolkit](#5-configure-cuda-wsl-ubuntu-toolkit) - [6. Install Docker](#6-install-docker) - [7. Install `cog` and pull an image](#7-install-cog-and-pull-an-image) - [8. Run a model in WSL 2](#8-run-a-model-in-wsl-2) - [9. References](#9-references) Running cog on Windows is now possible thanks to WSL 2. Follow this guide to enable WSL 2 and GPU passthrough on Windows 11. **Windows 10 is not officially supported, as you need to be on an insider build in order to use GPU passthrough.** ## 0. Prerequisites Before beginning installation, make sure you have: - Windows 11. - NVIDIA GPU. - RTX 2000/3000 series - Kesler/Tesla/Volta/Ampere series - Other configurations are not guaranteed to work. ## 1. Install the GPU driver Per NVIDIA, the first order of business is to install the latest Game Ready drivers for your NVIDIA GPU. I have an NVIDIA RTX 2070 Super, so filled out the form as such: ![a form showing the correct model number selected for an RTX 2070 Super](images/nvidia_driver_select.png) Click "search", and follow the dialogue to download and install the driver. Restart your computer once the driver has finished installation. ## 2. Unlocking features Open Windows Terminal as an administrator. - Use start to search for "Terminal" - Right click -> Run as administrator... Run the following powershell command to enable the Windows Subsystem for Linux and Virtual Machine Platform capabilities. ### 2.1. Unlock WSL2 ```powershell dism.exe /online /enable-feature /featurename:Microsoft-Windows-Subsystem-Linux /all /norestart ``` If you see an error about permissions, make sure the terminal you are using is run as an administrator and that you have an account with administrator-level privileges. ### 2.2. Unlock virtualization ```powershell dism.exe /online /enable-feature /featurename:VirtualMachinePlatform /all /norestart ``` If this command fails, make sure to [enable virtualization capabilities](https://docs.microsoft.com/en-us/windows/wsl/troubleshooting#error-0x80370102-the-virtual-machine-could-not-be-started-because-a-required-feature-is-not-installed) in your computer's BIOS/UEFI. A successful output will print `The operation completed successfully.` ![Output from running the above commands successfully. Should read "The operation completed successfully".](images/enable_feature_success.png) ### 2.3. Reboot Before moving forward, make sure you reboot your computer so that Windows 11 will have WSL2 and virtualization available to it. ## 3. Update MS Linux kernel Download and run the [WSL2 Linux kernel update package for x64 machines](https://wslstorestorage.blob.core.windows.net/wslblob/wsl_update_x64.msi) msi installer. When prompted for elevated permissions, click 'yes' to approve the installation. To ensure you are using the correct WSL kernel, `open Windows Terminal as an administrator` and enter: ```powershell wsl cat /proc/version ``` This will return a complicated string such as: ```sh Linux version 5.10.102.1-microsoft-standard-WSL2 (oe-user@oe-host) (x86_64-msft-linux-gcc (GCC) 9.3.0, GNU ld (GNU Binutils) 2.34.0.20200220) ``` The version we are interested in is `Linux version 5.10.102.1`. At this point, you should have updated your kernel to be at least `Linux version 5.10.43.3`. If you can't get the correct kernel version to show: Open `Settings` → `Windows Update` → `Advanced options` and ensure `Receive updates for other Microsoft products` is enabled. Then go to `Windows Update` again and click `Check for updates`. ## 4. Configure WSL 2 First, configure Windows to use the virtualization-based version of WSL (version 2) by default. In a Windows Terminal with administrator privileges, type the following: ```powershell wsl --set-default-version 2 ``` Now, you will need to go to the Microsoft Store and [Download Ubuntu 18.04](https://www.microsoft.com/store/apps/9N9TNGVNDL3Q) ![Screenshot showing the "Ubuntu" store page](https://docs.microsoft.com/en-us/windows/wsl/media/ubuntustore.png) Launch the "Ubuntu" app available in your Start Menu. Linux will require its own user account and password, which you will need to enter now: ![a terminal showing input for user account info on WSL 2](https://docs.microsoft.com/en-us/windows/wsl/media/ubuntuinstall.png) ## 5. Configure CUDA WSL-Ubuntu Toolkit By default, a shimmed version of the CUDA tooling is provided by your Windows GPU drivers. Important: you should _never_ use instructions for installing CUDA-toolkit in a generic linux fashion. in WSL 2, you _always_ want to use the provided `CUDA Toolkit using WSL-Ubuntu Package`. First, open PowerShell or Windows Command Prompt in administrator mode by right-clicking and selecting "Run as administrator". Then enter the following command: ```powershell wsl.exe ``` This should drop you into your running linux VM. Now you can run the following bash commands to install the correct version of cuda-toolkit for WSL-Ubuntu. Note that the version of CUDA used below may not be the version of CUDA your GPU supports. ```sh sudo apt-key del 7fa2af80 # if this line fails, you may remove it. wget https://developer.download.nvidia.com/compute/cuda/repos/wsl-ubuntu/x86_64/cuda-wsl-ubuntu.pin sudo mv cuda-wsl-ubuntu.pin /etc/apt/preferences.d/cuda-repository-pin-600 wget https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda-repo-wsl-ubuntu-11-7-local_11.7.0-1_amd64.deb sudo dpkg -i cuda-repo-wsl-ubuntu-11-7-local_11.7.0-1_amd64.deb sudo cp /var/cuda-repo-wsl-ubuntu-11-7-local/cuda-B81839D3-keyring.gpg /usr/share/keyrings/ sudo apt-get update sudo apt-get -y install cuda-toolkit-11-7 ``` ## 6. Install Docker Download and install [Docker Desktop for Windows](https://desktop.docker.com/win/main/amd64/Docker%20Desktop%20Installer.exe). It has WSL 2 support built in by default. Once installed, run `Docker Desktop`, you can ignore the first-run tutorial. Go to **Settings → General** and ensure **Use the WSL 2 based engine** has a checkmark next to it. Click **Apply & Restart**. !["Use the WSL 2 based engine" is checked in this interface](images/wsl2-enable.png) Reboot your computer one more time. ## 7. Install `cog` and pull an image Open Windows Terminal and enter your WSL 2 VM: ```powershell wsl.exe ``` Download and install `cog` inside the VM: ```bash sudo curl -o /usr/local/bin/cog -L https://github.com/replicate/cog/releases/latest/download/cog_`uname -s`_`uname -m` sudo chmod +x /usr/local/bin/cog ``` Make sure it's available by typing: ```bash which cog # should output /usr/local/bin/cog cog --version # should output the cog version number. ``` ## 8. Run a model in WSL 2 Finally, make sure it works. Let's try running `afiaka87/glid-3-xl` locally: ```bash cog predict 'r8.im/afiaka87/glid-3-xl' -i prompt="a fresh avocado floating in the water" -o prediction.json ``` ![Output from a running cog prediction in Windows Terminal](images/cog_model_output.png) While your prediction is running, you can use `Task Manager` to keep an eye on GPU memory consumption: ![Windows task manager will show the shared host/guest GPU memory](images/memory-usage.png) This model just barely manages to fit under 8 GB of VRAM. Notice that output is returned as JSON for this model as it has a complex return type. You will want to convert the base64 string in the json array to an image. `jq` can help with this: ```sh sudo apt install jq ``` The following bash uses `jq` to grab the first element in our prediction array and converts it from a base64 string to a `png` file. ```bash jq -cs '.[0][0][0]' prediction.json | cut --delimiter "," --field 2 | base64 --ignore-garbage --decode > prediction.png ``` When using WSL 2, you can access Windows binaries with the `.exe` extension. This lets you open photos easily within linux. ```bash explorer.exe prediction.png ``` ![a square image of an avocado, generated by the model](images/glide_out.png) ## 9. References - - - - - ================================================ FILE: docs/yaml.md ================================================ # `cog.yaml` reference `cog.yaml` defines how to build a Docker image and how to run predictions on your model inside that image. It has three keys: [`build`](#build), [`image`](#image), and [`predict`](#predict). It looks a bit like this: ```yaml build: python_version: "3.13" python_requirements: requirements.txt system_packages: - "ffmpeg" - "git" predict: "predict.py:Predictor" ``` Tip: Run [`cog init`](getting-started-own-model.md#initialization) to generate an annotated `cog.yaml` file that can be used as a starting point for setting up your model. ## `build` This stanza describes how to build the Docker image your model runs in. It contains various options within it: ### `cuda` Cog automatically picks the correct version of CUDA to install, but this lets you override it for whatever reason by specifying the minor (`11.8`) or patch (`11.8.0`) version of CUDA to use. For example: ```yaml build: cuda: "11.8" ``` ### `gpu` Enable GPUs for this model. When enabled, the [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) base image will be used, and Cog will automatically figure out what versions of CUDA and cuDNN to use based on the version of Python, PyTorch, and Tensorflow that you are using. For example: ```yaml build: gpu: true ``` When you use `cog run` or `cog predict`, Cog will automatically pass the `--gpus=all` flag to Docker. When you run a Docker image built with Cog, you'll need to pass this option to `docker run`. ### `python_requirements` A pip requirements file specifying the Python packages to install. For example: ```yaml build: python_requirements: requirements.txt ``` Your `cog.yaml` file can set either `python_packages` or `python_requirements`, but not both. Use `python_requirements` when you need to configure options like `--extra-index-url` or `--trusted-host` to fetch Python package dependencies. This follows the standard [requirements.txt](https://pip.pypa.io/en/stable/reference/requirements-file-format/) format. To install Git-hosted Python packages, add `git` to the `system_packages` list, then use the `git+https://` syntax to specify the package name. For example: `cog.yaml`: ```yaml build: system_packages: - "git" python_requirements: requirements.txt ``` `requirements.txt`: ``` git+https://github.com/huggingface/transformers ``` You can also pin Python package installations to a specific git commit: `cog.yaml`: ```yaml build: system_packages: - "git" python_requirements: requirements.txt ``` `requirements.txt`: ``` git+https://github.com/huggingface/transformers@2d1602a ``` Note that you can use a shortened prefix of the 40-character git commit SHA, but you must use at least six characters, like `2d1602a` above. ### `python_packages` **DEPRECATED**: This will be removed in future versions, please use [python_requirements](#python_requirements) instead. A list of Python packages to install from the PyPi package index, in the format `package==version`. For example: ```yaml build: python_packages: - pillow==8.3.1 - tensorflow==2.5.0 ``` Your `cog.yaml` file can set either `python_packages` or `python_requirements`, but not both. ### `python_version` The minor (`3.13`) or patch (`3.13.1`) version of Python to use. For example: ```yaml build: python_version: "3.13.1" ``` Cog supports Python 3.10, 3.11, 3.12, and 3.13. If you don't define a version, Cog will use the latest version of Python 3.13 or a version of Python that is compatible with the versions of PyTorch or TensorFlow you specify. Note that these are the versions supported **in the Docker container**, not your host machine. You can run any version(s) of Python you wish on your host machine. ### `run` A list of setup commands to run in the environment after your system packages and Python packages have been installed. If you're familiar with Docker, it's like a `RUN` instruction in your `Dockerfile`. For example: ```yaml build: run: - curl -L https://github.com/cowsay-org/cowsay/archive/refs/tags/v3.7.0.tar.gz | tar -xzf - - cd cowsay-3.7.0 && make install ``` Your code is _not_ available to commands in `run`. This is so we can build your image efficiently when running locally. Each command in `run` can be either a string or a dictionary in the following format: ```yaml build: run: - command: pip install mounts: - type: secret id: pip target: /etc/pip.conf ``` You can use secret mounts to securely pass credentials to setup commands, without baking them into the image. For more information, see [Dockerfile reference](https://docs.docker.com/engine/reference/builder/#run---mounttypesecret). ### `sdk_version` Pin the version of the cog Python SDK installed in the container. Accepts a [PEP 440](https://peps.python.org/pep-0440/) version string. When omitted, the latest release is installed. ```yaml build: python_version: "3.13" sdk_version: "0.18.0" ``` Pre-release versions are also supported: ```yaml build: sdk_version: "0.18.0a1" ``` When a pre-release `sdk_version` is set, `--pre` is automatically passed to the pip install commands for both `cog` and `coglet`, so pip will resolve matching pre-release packages. The minimum supported version is `0.16.0`. Specifying an older version will cause `cog build` to fail with an error. The `COG_SDK_WHEEL` environment variable takes precedence over `sdk_version`. See [Environment variables](./environment.md) for details. ### `system_packages` A list of Ubuntu APT packages to install. For example: ```yaml build: system_packages: - "ffmpeg" - "libavcodec-dev" ``` ## `concurrency` > Added in cog 0.14.0. This stanza describes the concurrency capabilities of the model. It has one option: ### `max` The maximum number of concurrent predictions the model can process. If this is set, the model must specify an [async `predict()` method](python.md#async-predictors-and-concurrency). For example: ```yaml concurrency: max: 10 ``` ## `image` The name given to built Docker images. If you want to push to a registry, this should also include the registry name. For example: ```yaml image: "r8.im/your-username/your-model" ``` r8.im is Replicate's registry, but this can be any Docker registry. If you don't set this, then a name will be generated from the directory name. If you set this, then you can run `cog push` without specifying the model name. If you specify an image name argument when pushing (like `cog push your-username/custom-model-name`), the argument will be used and the value of `image` in cog.yaml will be ignored. ## `predict` The pointer to the `Predictor` object in your code, which defines how predictions are run on your model. For example: ```yaml predict: "predict.py:Predictor" ``` See [the Python API documentation for more information](python.md). ================================================ FILE: go.mod ================================================ module github.com/replicate/cog go 1.26 require ( github.com/anaskhan96/soup v1.2.5 github.com/containerd/errdefs v1.0.0 github.com/creack/pty v1.1.24 github.com/docker/cli v29.2.1+incompatible github.com/docker/docker v28.5.2+incompatible github.com/docker/go-connections v0.6.0 github.com/getkin/kin-openapi v0.133.0 github.com/google/go-containerregistry v0.21.1 github.com/hashicorp/go-version v1.7.0 github.com/logrusorgru/aurora v2.0.3+incompatible github.com/mattn/go-isatty v0.0.20 github.com/mitchellh/go-homedir v1.1.0 github.com/moby/buildkit v0.28.0 github.com/moby/docker-image-spec v1.3.1 github.com/moby/term v0.5.2 github.com/opencontainers/image-spec v1.1.1 github.com/pkg/errors v0.9.1 github.com/replicate/go v0.0.0-20250205165008-b772d7cd506b github.com/rogpeppe/go-internal v1.14.1 github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06 github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82 github.com/spf13/cobra v1.10.2 github.com/spf13/pflag v1.0.10 github.com/stretchr/testify v1.11.1 github.com/testcontainers/testcontainers-go v0.40.0 github.com/testcontainers/testcontainers-go/modules/registry v0.40.0 github.com/tonistiigi/go-csvvalue v0.0.0-20240814133006-030d3b2625d0 github.com/vincent-petithory/dataurl v1.0.0 github.com/xeipuuv/gojsonschema v1.2.0 github.com/xeonx/timeago v1.0.0-rc5 go.yaml.in/yaml/v4 v4.0.0-rc.4 golang.org/x/crypto v0.48.0 golang.org/x/exp v0.0.0-20250911091902-df9299821621 golang.org/x/sync v0.19.0 golang.org/x/sys v0.42.0 golang.org/x/term v0.40.0 google.golang.org/grpc v1.79.1 ) require ( dario.cat/mergo v1.0.2 // indirect github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/bitfield/gotestdox v0.2.2 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/containerd/console v1.0.5 // indirect github.com/containerd/containerd/api v1.10.0 // indirect github.com/containerd/containerd/v2 v2.2.1 // indirect github.com/containerd/continuity v0.4.5 // indirect github.com/containerd/errdefs/pkg v0.3.0 // indirect github.com/containerd/log v0.1.0 // indirect github.com/containerd/platforms v1.0.0-rc.2 // indirect github.com/containerd/stargz-snapshotter/estargz v0.18.2 // indirect github.com/containerd/ttrpc v1.2.7 // indirect github.com/containerd/typeurl/v2 v2.2.3 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/distribution/reference v0.6.0 // indirect github.com/dnephin/pflag v1.0.7 // indirect github.com/docker/distribution v2.8.3+incompatible // indirect github.com/docker/docker-credential-helpers v0.9.5 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/ebitengine/purego v0.8.4 // indirect github.com/fatih/color v1.18.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/fvbommel/sortorder v1.1.0 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.2.6 // indirect github.com/go-openapi/jsonpointer v0.22.4 // indirect github.com/go-openapi/swag/jsonname v0.25.4 // indirect github.com/go-test/deep v1.1.1 // indirect github.com/gofrs/flock v0.13.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect github.com/google/uuid v1.6.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect github.com/in-toto/attestation v1.1.2 // indirect github.com/in-toto/in-toto-golang v0.10.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/klauspost/compress v1.18.4 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/magiconair/properties v1.8.10 // indirect github.com/mailru/easyjson v0.9.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/moby/go-archive v0.2.0 // indirect github.com/moby/locker v1.0.1 // indirect github.com/moby/moby/api v1.53.0 // indirect github.com/moby/moby/client v0.2.2 // indirect github.com/moby/patternmatcher v0.6.0 // indirect github.com/moby/sys/atomicwriter v0.1.0 // indirect github.com/moby/sys/sequential v0.6.0 // indirect github.com/moby/sys/signal v0.7.1 // indirect github.com/moby/sys/user v0.4.0 // indirect github.com/moby/sys/userns v0.1.0 // indirect github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect github.com/morikuni/aec v1.1.0 // indirect github.com/oasdiff/yaml v0.0.0-20250309154309-f31be36b4037 // indirect github.com/oasdiff/yaml3 v0.0.0-20250309153720-d2182401db90 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/perimeterx/marshmallow v1.1.5 // indirect github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/secure-systems-lab/go-securesystemslib v0.10.0 // indirect github.com/shibumi/go-pathspec v1.3.0 // indirect github.com/shirou/gopsutil/v4 v4.25.6 // indirect github.com/sirupsen/logrus v1.9.4 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect github.com/tonistiigi/fsutil v0.0.0-20251211185533-a2aa163d723f // indirect github.com/tonistiigi/units v0.0.0-20180711220420-6950e57a87ea // indirect github.com/tonistiigi/vt100 v0.0.0-20240514184818-90bafcd6abab // indirect github.com/vbatts/tar-split v0.12.2 // indirect github.com/woodsbury/decimal128 v1.3.0 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/httptrace/otelhttptrace v0.63.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 // indirect go.opentelemetry.io/otel v1.40.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 // indirect go.opentelemetry.io/otel/metric v1.40.0 // indirect go.opentelemetry.io/otel/sdk v1.40.0 // indirect go.opentelemetry.io/otel/trace v1.40.0 // indirect go.opentelemetry.io/proto/otlp v1.7.1 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/mod v0.33.0 // indirect golang.org/x/net v0.51.0 // indirect golang.org/x/telemetry v0.0.0-20260209163413-e7419c687ee4 // indirect golang.org/x/text v0.34.0 // indirect golang.org/x/time v0.14.0 // indirect golang.org/x/tools v0.42.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b // indirect google.golang.org/protobuf v1.36.11 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect gotest.tools/gotestsum v1.12.2 // indirect ) replace ( github.com/mholt/archiver/v3 => github.com/bfirsh/archiver/v3 v3.5.1-0.20210316180101-755470a1a69b gopkg.in/fsnotify.v1 => github.com/kolaente/fsnotify v1.4.10-0.20200411160148-1bc3c8ff4048 ) tool ( golang.org/x/tools/cmd/goimports gotest.tools/gotestsum ) ================================================ FILE: go.sum ================================================ cyphar.com/go-pathrs v0.2.1 h1:9nx1vOgwVvX1mNBWDu93+vaceedpbsDqo+XuBGL40b8= cyphar.com/go-pathrs v0.2.1/go.mod h1:y8f1EMG7r+hCuFf/rXsKqMJrJAUoADZGNh5/vZPKcGc= dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk= github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8= github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEKWjV8V+WSxDXJ4NFATAsZjh8iIbsQIg= github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/Microsoft/hcsshim v0.14.0-rc.1 h1:qAPXKwGOkVn8LlqgBN8GS0bxZ83hOJpcjxzmlQKxKsQ= github.com/Microsoft/hcsshim v0.14.0-rc.1/go.mod h1:hTKFGbnDtQb1wHiOWv4v0eN+7boSWAHyK/tNAaYZL0c= github.com/ProtonMail/go-crypto v1.3.0 h1:ILq8+Sf5If5DCpHQp4PbZdS1J7HDFRXz/+xKBiRGFrw= github.com/ProtonMail/go-crypto v1.3.0/go.mod h1:9whxjD8Rbs29b4XWbB8irEcE8KHMqaR2e7GWU1R+/PE= github.com/anaskhan96/soup v1.2.5 h1:V/FHiusdTrPrdF4iA1YkVxsOpdNcgvqT1hG+YtcZ5hM= github.com/anaskhan96/soup v1.2.5/go.mod h1:6YnEp9A2yywlYdM4EgDz9NEHclocMepEtku7wg6Cq3s= github.com/anchore/go-struct-converter v0.1.0 h1:2rDRssAl6mgKBSLNiVCMADgZRhoqtw9dedlWa0OhD30= github.com/anchore/go-struct-converter v0.1.0/go.mod h1:rYqSE9HbjzpHTI74vwPvae4ZVYZd1lue2ta6xHPdblA= github.com/bitfield/gotestdox v0.2.2 h1:x6RcPAbBbErKLnapz1QeAlf3ospg8efBsedU93CDsnE= github.com/bitfield/gotestdox v0.2.2/go.mod h1:D+gwtS0urjBrzguAkTM2wodsTQYFHdpx8eqRJ3N+9pY= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cloudflare/circl v1.6.3 h1:9GPOhQGF9MCYUeXyMYlqTR6a5gTrgR/fBLXvUgtVcg8= github.com/cloudflare/circl v1.6.3/go.mod h1:2eXP6Qfat4O/Yhh8BznvKnJ+uzEoTQ6jVKJRn81BiS4= github.com/codahale/rfc6979 v0.0.0-20141003034818-6a90f24967eb h1:EDmT6Q9Zs+SbUoc7Ik9EfrFqcylYqgPZ9ANSbTAntnE= github.com/codahale/rfc6979 v0.0.0-20141003034818-6a90f24967eb/go.mod h1:ZjrT6AXHbDs86ZSdt/osfBi5qfexBrKUdONk989Wnk4= github.com/containerd/cgroups/v3 v3.1.2 h1:OSosXMtkhI6Qove637tg1XgK4q+DhR0mX8Wi8EhrHa4= github.com/containerd/cgroups/v3 v3.1.2/go.mod h1:PKZ2AcWmSBsY/tJUVhtS/rluX0b1uq1GmPO1ElCmbOw= github.com/containerd/console v1.0.5 h1:R0ymNeydRqH2DmakFNdmjR2k0t7UPuiOV/N/27/qqsc= github.com/containerd/console v1.0.5/go.mod h1:YynlIjWYF8myEu6sdkwKIvGQq+cOckRm6So2avqoYAk= github.com/containerd/containerd/api v1.10.0 h1:5n0oHYVBwN4VhoX9fFykCV9dF1/BvAXeg2F8W6UYq1o= github.com/containerd/containerd/api v1.10.0/go.mod h1:NBm1OAk8ZL+LG8R0ceObGxT5hbUYj7CzTmR3xh0DlMM= github.com/containerd/containerd/v2 v2.2.1 h1:TpyxcY4AL5A+07dxETevunVS5zxqzuq7ZqJXknM11yk= github.com/containerd/containerd/v2 v2.2.1/go.mod h1:NR70yW1iDxe84F2iFWbR9xfAN0N2F0NcjTi1OVth4nU= github.com/containerd/continuity v0.4.5 h1:ZRoN1sXq9u7V6QoHMcVWGhOwDFqZ4B9i5H6un1Wh0x4= github.com/containerd/continuity v0.4.5/go.mod h1:/lNJvtJKUQStBzpVQ1+rasXO1LAWtUQssk28EZvJ3nE= github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI= github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M= github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE= github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk= github.com/containerd/fifo v1.1.0 h1:4I2mbh5stb1u6ycIABlBw9zgtlK8viPI9QkQNRQEEmY= github.com/containerd/fifo v1.1.0/go.mod h1:bmC4NWMbXlt2EZ0Hc7Fx7QzTFxgPID13eH0Qu+MAb2o= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= github.com/containerd/nydus-snapshotter v0.15.11 h1:YTdF4rsjFRsfyaIhnWVUSLz8FqJwOyRZ5FhvFjHh7Uc= github.com/containerd/nydus-snapshotter v0.15.11/go.mod h1:EWRd/QJ0b6UKHAqYgiV5gHlqLC2qq5cQiSlXEdVovrA= github.com/containerd/platforms v1.0.0-rc.2 h1:0SPgaNZPVWGEi4grZdV8VRYQn78y+nm6acgLGv/QzE4= github.com/containerd/platforms v1.0.0-rc.2/go.mod h1:J71L7B+aiM5SdIEqmd9wp6THLVRzJGXfNuWCZCllLA4= github.com/containerd/plugin v1.0.0 h1:c8Kf1TNl6+e2TtMHZt+39yAPDbouRH9WAToRjex483Y= github.com/containerd/plugin v1.0.0/go.mod h1:hQfJe5nmWfImiqT1q8Si3jLv3ynMUIBB47bQ+KexvO8= github.com/containerd/stargz-snapshotter/estargz v0.18.2 h1:yXkZFYIzz3eoLwlTUZKz2iQ4MrckBxJjkmD16ynUTrw= github.com/containerd/stargz-snapshotter/estargz v0.18.2/go.mod h1:XyVU5tcJ3PRpkA9XS2T5us6Eg35yM0214Y+wvrZTBrY= github.com/containerd/ttrpc v1.2.7 h1:qIrroQvuOL9HQ1X6KHe2ohc7p+HP/0VE6XPU7elJRqQ= github.com/containerd/ttrpc v1.2.7/go.mod h1:YCXHsb32f+Sq5/72xHubdiJRQY9inL4a4ZQrAbN1q9o= github.com/containerd/typeurl/v2 v2.2.3 h1:yNA/94zxWdvYACdYO8zofhrTVuQY73fFU1y++dYSw40= github.com/containerd/typeurl/v2 v2.2.3/go.mod h1:95ljDnPfD3bAbDJRugOiShd/DlAAsxGtUBhJxIn7SCk= github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA= github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/cpuguy83/go-md2man/v2 v2.0.7 h1:zbFlGlXEAKlwXpmvle3d8Oe3YnkKIK4xSRTd3sHPnBo= github.com/cpuguy83/go-md2man/v2 v2.0.7/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= github.com/cyphar/filepath-securejoin v0.6.0 h1:BtGB77njd6SVO6VztOHfPxKitJvd/VPT+OFBFMOi1Is= github.com/cyphar/filepath-securejoin v0.6.0/go.mod h1:A8hd4EnAeyujCJRrICiOWqjS1AX0a9kM5XL+NwKoYSc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= github.com/dnephin/pflag v1.0.7 h1:oxONGlWxhmUct0YzKTgrpQv9AUA1wtPBn7zuSjJqptk= github.com/dnephin/pflag v1.0.7/go.mod h1:uxE91IoWURlOiTUIA8Mq5ZZkAv3dPUfZNaT80Zm7OQE= github.com/docker/cli v29.2.1+incompatible h1:n3Jt0QVCN65eiVBoUTZQM9mcQICCJt3akW4pKAbKdJg= github.com/docker/cli v29.2.1+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= github.com/docker/distribution v2.8.3+incompatible h1:AtKxIZ36LoNK51+Z6RpzLpddBirtxJnzDrHLEKxTAYk= github.com/docker/distribution v2.8.3+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= github.com/docker/docker v28.5.2+incompatible h1:DBX0Y0zAjZbSrm1uzOkdr1onVghKaftjlSWt4AFexzM= github.com/docker/docker v28.5.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/docker-credential-helpers v0.9.5 h1:EFNN8DHvaiK8zVqFA2DT6BjXE0GzfLOZ38ggPTKePkY= github.com/docker/docker-credential-helpers v0.9.5/go.mod h1:v1S+hepowrQXITkEfw6o4+BMbGot02wiKpzWhGUZK6c= github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94= github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw= github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/fvbommel/sortorder v1.1.0 h1:fUmoe+HLsBTctBDoaBwpQo5N+nrCp8g/BjKb/6ZQmYw= github.com/fvbommel/sortorder v1.1.0/go.mod h1:uk88iVf1ovNn1iLfgUVU2F9o5eO30ui720w+kxuqRs0= github.com/getkin/kin-openapi v0.133.0 h1:pJdmNohVIJ97r4AUFtEXRXwESr8b0bD721u/Tz6k8PQ= github.com/getkin/kin-openapi v0.133.0/go.mod h1:boAciF6cXk5FhPqe/NQeBTeenbjqU4LhWBf09ILVvWE= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-openapi/jsonpointer v0.22.4 h1:dZtK82WlNpVLDW2jlA1YCiVJFVqkED1MegOUy9kR5T4= github.com/go-openapi/jsonpointer v0.22.4/go.mod h1:elX9+UgznpFhgBuaMQ7iu4lvvX1nvNsesQ3oxmYTw80= github.com/go-openapi/swag/jsonname v0.25.4 h1:bZH0+MsS03MbnwBXYhuTttMOqk+5KcQ9869Vye1bNHI= github.com/go-openapi/swag/jsonname v0.25.4/go.mod h1:GPVEk9CWVhNvWhZgrnvRA6utbAltopbKwDu8mXNUMag= github.com/go-openapi/testify/v2 v2.0.2 h1:X999g3jeLcoY8qctY/c/Z8iBHTbwLz7R2WXd6Ub6wls= github.com/go-openapi/testify/v2 v2.0.2/go.mod h1:HCPmvFFnheKK2BuwSA0TbbdxJ3I16pjwMkYkP4Ywn54= github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U= github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/gofrs/flock v0.13.0 h1:95JolYOvGMqeH31+FC7D2+uULf6mG61mEZ/A8dRYMzw= github.com/gofrs/flock v0.13.0/go.mod h1:jxeyy9R1auM5S6JYDBhDt+E2TCo7DkratH4Pgi8P+Z0= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ= github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUvekVysuuOpQKo3191zZyTpiI6se1N1ULghS0sw= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-containerregistry v0.21.1 h1:sOt/o9BS2b87FnR7wxXPvRKU1XVJn2QCwOS5g8zQXlc= github.com/google/go-containerregistry v0.21.1/go.mod h1:ctO5aCaewH4AK1AumSF5DPW+0+R+d2FmylMJdp5G7p0= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4= github.com/hashicorp/go-version v1.7.0 h1:5tqGy27NaOTB8yJKUZELlFAS/LTKJkrmONwQKeRZfjY= github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/in-toto/attestation v1.1.2 h1:MBFn6lsMq6dptQZJBhalXTcWMb/aJy3V+GX3VYj/V1E= github.com/in-toto/attestation v1.1.2/go.mod h1:gYFddHMZj3DiQ0b62ltNi1Vj5rC879bTmBbrv9CRHpM= github.com/in-toto/in-toto-golang v0.10.0 h1:+s2eZQSK3WmWfYV85qXVSBfqgawi/5L02MaqA4o/tpM= github.com/in-toto/in-toto-golang v0.10.0/go.mod h1:wjT4RiyFlLWCmLUJjwB8oZcjaq7HA390aMJcD3xXgmg= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c= github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/logrusorgru/aurora v2.0.3+incompatible h1:tOpm7WcpBTn4fjmVfgpQq0EfczGlG91VSDkswnjF5A8= github.com/logrusorgru/aurora v2.0.3+incompatible/go.mod h1:7rIyQOR62GCctdiQpZ/zOJlFyk6y+94wXzv6RNZgaR4= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE= github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/moby/buildkit v0.28.0 h1:rKulfRRSduHJPNpLTk481fHElqN9tps0VUx8YV/5zsA= github.com/moby/buildkit v0.28.0/go.mod h1:RCuOcj/bVsCriBG8NeFzRxjiCFQKnKP7KOVlNTS18t4= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= github.com/moby/go-archive v0.2.0 h1:zg5QDUM2mi0JIM9fdQZWC7U8+2ZfixfTYoHL7rWUcP8= github.com/moby/go-archive v0.2.0/go.mod h1:mNeivT14o8xU+5q1YnNrkQVpK+dnNe/K6fHqnTg4qPU= github.com/moby/locker v1.0.1 h1:fOXqR41zeveg4fFODix+1Ch4mj/gT0NE1XJbp/epuBg= github.com/moby/locker v1.0.1/go.mod h1:S7SDdo5zpBK84bzzVlKr2V0hz+7x9hWbYC/kq7oQppc= github.com/moby/moby/api v1.53.0 h1:PihqG1ncw4W+8mZs69jlwGXdaYBeb5brF6BL7mPIS/w= github.com/moby/moby/api v1.53.0/go.mod h1:8mb+ReTlisw4pS6BRzCMts5M49W5M7bKt1cJy/YbAqc= github.com/moby/moby/client v0.2.2 h1:Pt4hRMCAIlyjL3cr8M5TrXCwKzguebPAc2do2ur7dEM= github.com/moby/moby/client v0.2.2/go.mod h1:2EkIPVNCqR05CMIzL1mfA07t0HvVUUOl85pasRz/GmQ= github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk= github.com/moby/patternmatcher v0.6.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc= github.com/moby/policy-helpers v0.0.0-20260211190020-824747bfdd3c h1:hRUo0Ir9PEaa0PQCgg8WvGku0sgmTo/NgnCzMb83iII= github.com/moby/policy-helpers v0.0.0-20260211190020-824747bfdd3c/go.mod h1:2P1OGoTVIrybI4M7yhpkDpqiwOnI3yR+HnNhEyo8ovs= github.com/moby/sys/atomicwriter v0.1.0 h1:kw5D/EqkBwsBFi0ss9v1VG3wIkVhzGvLklJ+w3A14Sw= github.com/moby/sys/atomicwriter v0.1.0/go.mod h1:Ul8oqv2ZMNHOceF643P6FKPXeCmYtlQMvpizfsSoaWs= github.com/moby/sys/mountinfo v0.7.2 h1:1shs6aH5s4o5H2zQLn796ADW1wMrIwHsyJ2v9KouLrg= github.com/moby/sys/mountinfo v0.7.2/go.mod h1:1YOa8w8Ih7uW0wALDUgT1dTTSBrZ+HiBLGws92L2RU4= github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU= github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko= github.com/moby/sys/signal v0.7.1 h1:PrQxdvxcGijdo6UXXo/lU/TvHUWyPhj7UOpSo8tuvk0= github.com/moby/sys/signal v0.7.1/go.mod h1:Se1VGehYokAkrSQwL4tDzHvETwUZlnY7S5XtQ50mQp8= github.com/moby/sys/user v0.4.0 h1:jhcMKit7SA80hivmFJcbB1vqmw//wU61Zdui2eQXuMs= github.com/moby/sys/user v0.4.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs= github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g= github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28= github.com/moby/term v0.5.2 h1:6qk3FJAFDs6i/q3W/pQ97SX192qKfZgGjCQqfCJkgzQ= github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFLc= github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= github.com/morikuni/aec v1.1.0 h1:vBBl0pUnvi/Je71dsRrhMBtreIqNMYErSAbEeb8jrXQ= github.com/morikuni/aec v1.1.0/go.mod h1:xDRgiq/iw5l+zkao76YTKzKttOp2cwPEne25HDkJnBw= github.com/oasdiff/yaml v0.0.0-20250309154309-f31be36b4037 h1:G7ERwszslrBzRxj//JalHPu/3yz+De2J+4aLtSRlHiY= github.com/oasdiff/yaml v0.0.0-20250309154309-f31be36b4037/go.mod h1:2bpvgLBZEtENV5scfDFEtB/5+1M4hkQhDQrccEJ/qGw= github.com/oasdiff/yaml3 v0.0.0-20250309153720-d2182401db90 h1:bQx3WeLcUWy+RletIKwUIt4x3t8n2SxavmoclizMb8c= github.com/oasdiff/yaml3 v0.0.0-20250309153720-d2182401db90/go.mod h1:y5+oSEHCPT/DGrS++Wc/479ERge0zTFxaF8PbGKcg2o= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M= github.com/opencontainers/runtime-spec v1.3.0 h1:YZupQUdctfhpZy3TM39nN9Ika5CBWT5diQ8ibYCRkxg= github.com/opencontainers/runtime-spec v1.3.0/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= github.com/opencontainers/selinux v1.13.1 h1:A8nNeceYngH9Ow++M+VVEwJVpdFmrlxsN22F+ISDCJE= github.com/opencontainers/selinux v1.13.1/go.mod h1:S10WXZ/osk2kWOYKy1x2f/eXF5ZHJoUs8UU/2caNRbg= github.com/package-url/packageurl-go v0.1.1 h1:KTRE0bK3sKbFKAk3yy63DpeskU7Cvs/x/Da5l+RtzyU= github.com/package-url/packageurl-go v0.1.1/go.mod h1:uQd4a7Rh3ZsVg5j0lNyAfyxIeGde9yrlhjF78GzeW0c= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s= github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/prometheus/procfs v0.17.0 h1:FuLQ+05u4ZI+SS/w9+BWEM2TXiHKsUQ9TADiRH7DuK0= github.com/prometheus/procfs v0.17.0/go.mod h1:oPQLaDAMRbA+u8H5Pbfq+dl3VDAvHxMUOVhe0wYB2zw= github.com/replicate/go v0.0.0-20250205165008-b772d7cd506b h1:GIkpkQ+xwWJ6IRUFmwCLcg+zkZVoKmVXnPjhMncZc4I= github.com/replicate/go v0.0.0-20250205165008-b772d7cd506b/go.mod h1:kUMwEnHJEvWXdu6Py/9fjp7969tsPRYN2a4+Z8BiVEE= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06 h1:OkMGxebDjyw0ULyrTYWeN0UNCCkmCWfjPnIA2W6oviI= github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06/go.mod h1:+ePHsJ1keEjQtpvf9HHw0f4ZeJ0TLRsxhunSI2hYJSs= github.com/secure-systems-lab/go-securesystemslib v0.10.0 h1:l+H5ErcW0PAehBNrBxoGv1jjNpGYdZ9RcheFkB2WI14= github.com/secure-systems-lab/go-securesystemslib v0.10.0/go.mod h1:MRKONWmRoFzPNQ9USRF9i1mc7MvAVvF1LlW8X5VWDvk= github.com/shibumi/go-pathspec v1.3.0 h1:QUyMZhFo0Md5B8zV8x2tesohbb5kfbpTi9rBnKh5dkI= github.com/shibumi/go-pathspec v1.3.0/go.mod h1:Xutfslp817l2I1cZvgcfeMQJG5QnU2lh5tVaaMCl3jE= github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs= github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c= github.com/sigstore/sigstore v1.10.4 h1:ytOmxMgLdcUed3w1SbbZOgcxqwMG61lh1TmZLN+WeZE= github.com/sigstore/sigstore v1.10.4/go.mod h1:tDiyrdOref3q6qJxm2G+JHghqfmvifB7hw+EReAfnbI= github.com/sigstore/sigstore-go v1.1.4 h1:wTTsgCHOfqiEzVyBYA6mDczGtBkN7cM8mPpjJj5QvMg= github.com/sigstore/sigstore-go v1.1.4/go.mod h1:2U/mQOT9cjjxrtIUeKDVhL+sHBKsnWddn8URlswdBsg= github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w= github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g= github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82 h1:6C8qej6f1bStuePVkLSFxoU22XBS165D3klxlzRg8F4= github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82/go.mod h1:xe4pgH49k4SsmkQq5OT8abwhWmnzkhpgnXeekbx2efw= github.com/spdx/tools-golang v0.5.7 h1:+sWcKGnhwp3vLdMqPcLdA6QK679vd86cK9hQWH3AwCg= github.com/spdx/tools-golang v0.5.7/go.mod h1:jg7w0LOpoNAw6OxKEzCoqPC2GCTj45LyTlVmXubDsYw= github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU= github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3APNyLkkap35zNLxukw9oBi/MY= github.com/testcontainers/testcontainers-go/modules/registry v0.40.0 h1:z+CymIuT9quh8plBbM+lpncY6diV//q0LbRk+mxMpow= github.com/testcontainers/testcontainers-go/modules/registry v0.40.0/go.mod h1:TWdy7+y7w14Ii5UCSfr7qvxPYI3GE7lc7NEP0ofxlLQ= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= github.com/tonistiigi/fsutil v0.0.0-20251211185533-a2aa163d723f h1:Z4NEQ86qFl1mHuCu9gwcE+EYCwDKfXAYXZbdIXyxmEA= github.com/tonistiigi/fsutil v0.0.0-20251211185533-a2aa163d723f/go.mod h1:BKdcez7BiVtBvIcef90ZPc6ebqIWr4JWD7+EvLm6J98= github.com/tonistiigi/go-csvvalue v0.0.0-20240814133006-030d3b2625d0 h1:2f304B10LaZdB8kkVEaoXvAMVan2tl9AiK4G0odjQtE= github.com/tonistiigi/go-csvvalue v0.0.0-20240814133006-030d3b2625d0/go.mod h1:278M4p8WsNh3n4a1eqiFcV2FGk7wE5fwUpUom9mK9lE= github.com/tonistiigi/units v0.0.0-20180711220420-6950e57a87ea h1:SXhTLE6pb6eld/v/cCndK0AMpt1wiVFb/YYmqB3/QG0= github.com/tonistiigi/units v0.0.0-20180711220420-6950e57a87ea/go.mod h1:WPnis/6cRcDZSUvVmezrxJPkiO87ThFYsoUiMwWNDJk= github.com/tonistiigi/vt100 v0.0.0-20240514184818-90bafcd6abab h1:H6aJ0yKQ0gF49Qb2z5hI1UHxSQt4JMyxebFR15KnApw= github.com/tonistiigi/vt100 v0.0.0-20240514184818-90bafcd6abab/go.mod h1:ulncasL3N9uLrVann0m+CDlJKWsIAP34MPcOJF6VRvc= github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= github.com/vbatts/tar-split v0.12.2 h1:w/Y6tjxpeiFMR47yzZPlPj/FcPLpXbTUi/9H7d3CPa4= github.com/vbatts/tar-split v0.12.2/go.mod h1:eF6B6i6ftWQcDqEn3/iGFRFRo8cBIMSJVOpnNdfTMFA= github.com/vincent-petithory/dataurl v1.0.0 h1:cXw+kPto8NLuJtlMsI152irrVw9fRDX8AbShPRpg2CI= github.com/vincent-petithory/dataurl v1.0.0/go.mod h1:FHafX5vmDzyP+1CQATJn7WFKc9CvnvxyvZy6I1MrG/U= github.com/woodsbury/decimal128 v1.3.0 h1:8pffMNWIlC0O5vbyHWFZAt5yWvWcrHA+3ovIIjVWss0= github.com/woodsbury/decimal128 v1.3.0/go.mod h1:C5UTmyTjW3JftjUFzOVhC20BEQa2a4ZKOB5I6Zjb+ds= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo= github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= github.com/xeonx/timeago v1.0.0-rc5 h1:pwcQGpaH3eLfPtXeyPA4DmHWjoQt0Ea7/++FwpxqLxg= github.com/xeonx/timeago v1.0.0-rc5/go.mod h1:qDLrYEFynLO7y5Ho7w3GwgtYgpy5UfhcXIIQvMKVDkA= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 h1:YH4g8lQroajqUwWbq/tr2QX1JFmEXaDLgG+ew9bLMWo= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0/go.mod h1:fvPi2qXDqFs8M4B4fmJhE92TyQs9Ydjlg3RvfUp+NbQ= go.opentelemetry.io/contrib/instrumentation/net/http/httptrace/otelhttptrace v0.63.0 h1:2pn7OzMewmYRiNtv1doZnLo3gONcnMHlFnmOR8Vgt+8= go.opentelemetry.io/contrib/instrumentation/net/http/httptrace/otelhttptrace v0.63.0/go.mod h1:rjbQTDEPQymPE0YnRQp9/NuPwwtL0sesz/fnqRW/v84= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 h1:RbKq8BG0FI8OiXhBfcRtqqHcZcka+gU3cskNuf05R18= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0/go.mod h1:h06DGIukJOevXaj/xrNjhi/2098RZzcLTbc0jDAUbsg= go.opentelemetry.io/otel v1.40.0 h1:oA5YeOcpRTXq6NN7frwmwFR0Cn3RhTVZvXsP4duvCms= go.opentelemetry.io/otel v1.40.0/go.mod h1:IMb+uXZUKkMXdPddhwAHm6UfOwJyh4ct1ybIlV14J0g= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 h1:GqRJVj7UmLjCVyVJ3ZFLdPRmhDUp2zFmQe3RHIOsw24= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0/go.mod h1:ri3aaHSmCTVYu2AWv44YMauwAQc0aqI9gHKIcSbI1pU= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.38.0 h1:aTL7F04bJHUlztTsNGJ2l+6he8c+y/b//eR0jjjemT4= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.38.0/go.mod h1:kldtb7jDTeol0l3ewcmd8SDvx3EmIE7lyvqbasU3QC4= go.opentelemetry.io/otel/metric v1.40.0 h1:rcZe317KPftE2rstWIBitCdVp89A2HqjkxR3c11+p9g= go.opentelemetry.io/otel/metric v1.40.0/go.mod h1:ib/crwQH7N3r5kfiBZQbwrTge743UDc7DTFVZrrXnqc= go.opentelemetry.io/otel/sdk v1.40.0 h1:KHW/jUzgo6wsPh9At46+h4upjtccTmuZCFAc9OJ71f8= go.opentelemetry.io/otel/sdk v1.40.0/go.mod h1:Ph7EFdYvxq72Y8Li9q8KebuYUr2KoeyHx0DRMKrYBUE= go.opentelemetry.io/otel/sdk/metric v1.40.0 h1:mtmdVqgQkeRxHgRv4qhyJduP3fYJRMX4AtAlbuWdCYw= go.opentelemetry.io/otel/sdk/metric v1.40.0/go.mod h1:4Z2bGMf0KSK3uRjlczMOeMhKU2rhUqdWNoKcYrtcBPg= go.opentelemetry.io/otel/trace v1.40.0 h1:WA4etStDttCSYuhwvEa8OP8I5EWu24lkOzp+ZYblVjw= go.opentelemetry.io/otel/trace v1.40.0/go.mod h1:zeAhriXecNGP/s2SEG3+Y8X9ujcJOTqQ5RgdEJcawiA= go.opentelemetry.io/proto/otlp v1.7.1 h1:gTOMpGDb0WTBOP8JaO72iL3auEZhVmAQg4ipjOVAtj4= go.opentelemetry.io/proto/otlp v1.7.1/go.mod h1:b2rVh6rfI/s2pHWNlB7ILJcRALpcNDzKhACevjI+ZnE= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= go.yaml.in/yaml/v4 v4.0.0-rc.4 h1:UP4+v6fFrBIb1l934bDl//mmnoIZEDK0idg1+AIvX5U= go.yaml.in/yaml/v4 v4.0.0-rc.4/go.mod h1:aZqd9kCMsGL7AuUv/m/PvWLdg5sjJsZ4oHDEnfPPfY0= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= golang.org/x/exp v0.0.0-20250911091902-df9299821621 h1:2id6c1/gto0kaHYyrixvknJ8tUK/Qs5IsmBtrc+FtgU= golang.org/x/exp v0.0.0-20250911091902-df9299821621/go.mod h1:TwQYMMnGpvZyc+JpB/UAuTNIsVJifOlSkrZkhcvpVUk= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/telemetry v0.0.0-20260209163413-e7419c687ee4 h1:bTLqdHv7xrGlFbvf5/TXNxy/iUwwdkjhqQTJDjW7aj0= golang.org/x/telemetry v0.0.0-20260209163413-e7419c687ee4/go.mod h1:g5NllXBEermZrmR51cJDQxmJUHUOfRAaNyWBM+R+548= golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 h1:fCvbg86sFXwdrl5LgVcTEvNC+2txB5mgROGmRL5mrls= google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto= google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b h1:Mv8VFug0MP9e5vUxfBcE3vUkV6CImK3cMNMIDFjmzxU= google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ= google.golang.org/grpc v1.79.1 h1:zGhSi45ODB9/p3VAawt9a+O/MULLl9dpizzNNpq7flY= google.golang.org/grpc v1.79.1/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/gotestsum v1.12.2 h1:eli4tu9Q2D/ogDsEGSr8XfQfl7mT0JsGOG6DFtUiZ/Q= gotest.tools/gotestsum v1.12.2/go.mod h1:kjRtCglPZVsSU0hFHX3M5VWBM6Y63emHuB14ER1/sow= gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA= pgregory.net/rapid v1.2.0 h1:keKAYRcjm+e1F0oAuU5F5+YPAWcyxNNRK2wud503Gnk= pgregory.net/rapid v1.2.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04= ================================================ FILE: integration-tests/.gitignore ================================================ .bin/ ================================================ FILE: integration-tests/README.md ================================================ # Cog Integration Tests This directory contains Go-based integration tests for the Cog CLI using the [testscript](https://pkg.go.dev/github.com/rogpeppe/go-internal/testscript) framework. ## Test Formats Most integration tests use the txtar format (`.txtar` files in `tests/`), which provides a simple declarative way to define test scripts and fixtures. However, some tests require capabilities that don't fit txtar's sequential execution model and are written as standard Go test functions instead: | Test | Location | Why Go instead of txtar | |------|----------|-------------------------| | `TestConcurrentPredictions` | `concurrent/` | Requires parallel HTTP requests with precise timing coordination | | `TestLogin*` | `login/` | Login requires interactive PTY input and mock HTTP servers | Note: PTY/TTY tests now use the `pty-run` command in txtar format (see Custom Commands below). ## Quick Start ```bash # Run all tests make test-integration # Run fast tests only (skip slow GPU/framework tests) cd integration-tests && go test -short -v # Run a specific test cd integration-tests && go test -v -run TestIntegration/string_predictor # Run with a custom cog binary COG_BINARY=/path/to/cog make test-integration ``` ## Directory Structure ``` integration-tests/ ├── README.md # This file ├── suite_test.go # Main test runner (txtar tests) ├── harness/ │ ├── harness.go # Test harness core │ ├── command.go # Command interface │ └── cmd_pty.go # PTY command implementation ├── tests/ │ └── *.txtar # Test files (one per test case) ├── concurrent/ │ └── concurrent_test.go # Concurrent request tests ├── login/ │ └── login_test.go # Login tests with PTY └── .bin/ └── cog # Cached cog binary (auto-generated) ``` ## Writing Tests Tests are `.txtar` files in the `tests/` directory. Each file is a self-contained test with embedded fixtures. **NOTE: if a test has the suffix `_serial` (e.g. `tests/integration_test_name_serial.txtar`) it will be run in isolation of all other tests. By default we run integration tests in Parallel using `t.Parallel()`.** ### Editor Support For syntax highlighting of `.txtar` files: **VS Code:** - [testscript](https://marketplace.visualstudio.com/items?itemName=twpayne.vscode-testscript) by twpayne - Syntax highlighting with embedded file support - [txtar](https://github.com/brody715/vscode-txtar) by brody715 - Alternative txtar extension Install via VS Code: ``` ext install twpayne.vscode-testscript ``` **Zed:** - [zed-txtar](https://github.com/FollowTheProcess/zed-txtar) - Syntax highlighting for txtar files Install via Zed extensions panel or add to your extensions. **Vim/Neovim:** - Use [tree-sitter-go-template](https://github.com/ngalaiko/tree-sitter-go-template) for basic support - Or set filetype manually: `:set ft=conf` for basic highlighting ### Basic Test Structure ```txtar # Comments describe what the test does # This is a test for basic string prediction # Build the Docker image cog build -t $TEST_IMAGE # Run a prediction cog predict $TEST_IMAGE -i s=world stdout 'hello world' # Test that invalid input fails ! cog predict $TEST_IMAGE -i wrong=value stderr 'Field required' -- cog.yaml -- build: python_version: "3.12" predict: "predict.py:Predictor" -- predict.py -- from cog import BasePredictor class Predictor(BasePredictor): def predict(self, s: str) -> str: return "hello " + s ``` ### Test File Format - Lines starting with `#` are comments - Lines starting with `-- filename --` begin embedded files - Commands are executed in order - Use `!` prefix for commands expected to fail - Use `stdout` and `stderr` to assert on command output ## Environment Variables The harness automatically sets these environment variables: | Variable | Description | |----------|-------------| | `$TEST_IMAGE` | Unique Docker image name for test isolation | | `$WORK` | Test's temporary working directory | | `$SERVER_URL` | URL of running cog server (after `cog serve`) | | `$HOME` | Real home directory (for Docker credentials) | You can also use: | Variable | Description | |----------|-------------| | `COG_BINARY` | Path to cog binary (defaults to auto-build) | | `TEST_PARALLEL` | Number of parallel tests (default: 4) | Use `go test -short` to skip slow tests. ## Custom Commands ### `cog` - Run cog CLI commands ```txtar cog build -t $TEST_IMAGE cog predict $TEST_IMAGE -i name=value ! cog predict $TEST_IMAGE -i invalid=input # Expected to fail ``` Special handling for `cog serve`: - Runs in background automatically - Allocates a random port - Waits for health check before continuing - Sets `$SERVER_URL` for subsequent commands - Cleans up on test completion ### `curl` - Make HTTP requests to cog server ```txtar cog serve curl GET /health-check curl POST /predictions '{"input":{"s":"hello"}}' stdout '"output":"hello"' ``` Usage: `curl METHOD PATH [BODY]` The `curl` command includes built-in retry logic (10 attempts, 500ms delay) for resilience against timing issues in integration tests. ### `wait-for` - Wait for conditions **Note**: This command waits for conditions on the **host machine**, not inside Docker containers. For Docker-based tests, use `curl` instead (which has built-in retry logic). ```txtar # Wait for a file to exist (host filesystem only) wait-for file output.txt 30s # Wait for HTTP endpoint (host-accessible URLs only) wait-for http http://localhost:8080/health 200 60s # Wait for file with content wait-for not-empty results.json 30s ``` Usage: `wait-for CONDITION TARGET [ARGS] [TIMEOUT]` ### `pty-run` - Run commands with PTY Run a command with a pseudo-terminal (PTY), sending input from a file and capturing output. ```txtar # Run bash interactively with commands from input file pty-run input.txt cog run $TEST_IMAGE /bin/bash stdout 'expected output' # Run a simple command (no input needed) pty-run /dev/null cog run $TEST_IMAGE echo "hello" stdout 'hello' ``` Usage: `pty-run [args...]` - The input file contents are written to the PTY as terminal input - Use `/dev/null` if no input is needed - Output is captured and can be matched with `stdout` command - Uses `github.com/creack/pty` which works on both Linux and macOS ## Conditions Use conditions to control when tests run based on environment. Conditions are evaluated by the test runner and can be used with `skip` to conditionally skip tests. ### Available Conditions | Condition | Evaluates to True When | Negated | Example Use Case | |-----------|------------------------|---------|------------------| | `[short]` | `go test -short` is used | `[!short]` | Use `[short] skip` to skip GPU tests, long builds, or slow framework installs when running in short mode | | `[linux]` | Running on Linux | `[!linux]` | Tests requiring Linux-specific features | | `[amd64]` | Running on amd64/x86_64 architecture | `[!amd64]` | Tests requiring specific CPU architecture | | `[linux_amd64]` | Running on Linux AND amd64 | `[!linux_amd64]` | Tests requiring both Linux and amd64 (e.g., `--use-cog-base-image` builds) | ### Usage Examples **Skip slow tests:** ```txtar [short] skip 'requires GPU or long build time' cog build -t $TEST_IMAGE # ... rest of test ``` Skip slow tests with: `go test -short ./...` **Platform-specific tests:** ```txtar [!linux] skip 'requires Linux' # Linux-specific test cog build -t $TEST_IMAGE ``` **Architecture-specific tests:** ```txtar [!amd64] skip 'requires amd64 architecture' # amd64-specific test cog build -t $TEST_IMAGE ``` **Combined platform and architecture:** ```txtar [!linux_amd64] skip 'requires Linux on amd64' # Test that requires both (e.g., --use-cog-base-image builds) cog build -t $TEST_IMAGE --use-cog-base-image ``` ### Condition Logic Conditions can be negated with `!`: - `[short]` - True when `go test -short` is used - Use `[short] skip` to skip a slow test when running in short mode - `[!short]` - True when NOT running with `-short` flag - Use `[!short] skip` to only run a test in short mode (rare) - `[!linux]` - True when NOT on Linux - Use `[!linux] skip` to skip non-Linux tests - `[linux_amd64]` - True when on Linux AND amd64 - Use `[!linux_amd64] skip` to skip tests that need this specific platform Multiple conditions can be used on separate lines: ```txtar [short] skip 'requires long build time' [!linux] skip 'requires Linux' # Only runs on Linux when not using -short flag cog build -t $TEST_IMAGE ``` ## Built-in Commands These are provided by testscript itself: | Command | Description | |---------|-------------| | `exec` | Run an arbitrary command | | `stdout PATTERN` | Assert stdout matches regex | | `stderr PATTERN` | Assert stderr matches regex | | `exists FILE` | Assert file exists | | `! exists FILE` | Assert file does not exist | | `cp SRC DST` | Copy file | | `rm FILE` | Remove file | | `mkdir DIR` | Create directory | | `cd DIR` | Change directory | | `env KEY=VALUE` | Set environment variable | | `skip MESSAGE` | Skip the test | | `stop MESSAGE` | Stop test early (success) | See [testscript documentation](https://pkg.go.dev/github.com/rogpeppe/go-internal/testscript) for the full list. ## Test Patterns ### Testing predictions ```txtar cog build -t $TEST_IMAGE cog predict $TEST_IMAGE -i s=hello stdout 'hello' ``` ### Testing server endpoints ```txtar cog build -t $TEST_IMAGE cog serve curl POST /predictions '{"input":{"s":"test"}}' stdout '"output":' ``` ### Testing expected failures ```txtar # Build should fail without predictor ! cog build -t $TEST_IMAGE stderr 'predict' -- cog.yaml -- build: python_version: "3.12" # Note: no predict field ``` ### Testing with subprocess initialization ```txtar cog build -t $TEST_IMAGE cog serve # curl has built-in retry logic for timing resilience curl POST /predictions '{"input":{"s":"test"}}' stdout '"output":"hello test"' -- predict.py -- import subprocess from cog import BasePredictor class Predictor(BasePredictor): def setup(self): self.bg = subprocess.Popen(["./background.sh"]) def predict(self, s: str) -> str: return "hello " + s ``` ### Slow tests (GPU/frameworks) ```txtar [fast] skip 'requires long build time' cog build -t $TEST_IMAGE cog predict $TEST_IMAGE stdout 'torch' -- cog.yaml -- build: python_version: "3.12" gpu: true python_packages: - torch==2.7.0 ``` ## How It Works 1. **Test Discovery**: The test runner finds all `.txtar` files in `tests/` 2. **Setup**: For each test, the harness: - Creates a fresh temporary directory - Extracts embedded files from the txtar - Sets environment variables (`$TEST_IMAGE`, etc.) - Registers cleanup (Docker image removal, server shutdown) 3. **Execution**: Commands run sequentially in the temp directory 4. **Assertions**: `stdout`/`stderr` commands verify output 5. **Cleanup**: Docker images are removed, servers are stopped ## Debugging Failed Tests ### View verbose output ```bash go test -v -run TestIntegration/test_name ``` ### Keep work directory ```bash # Add to test or set in harness env TESTWORK=1 ``` ### Run single test interactively ```bash cd integration-tests go test -v -run TestIntegration/string_predictor -timeout 10m ``` ### Check Docker images ```bash # List test images (should be cleaned up) docker images | grep cog-test ``` ## Adding New Tests 1. Create a new `.txtar` file in `tests/` 2. Name it descriptively (e.g., `async_predictor.txtar`) 3. Add comments explaining what's being tested 4. Include all necessary fixture files inline 5. Run your test: `go test -v -run TestIntegration/your_test_name` ## Common Issues ### Test times out waiting for server The server health check has a 30-second timeout. If your model takes longer to load: - Consider if it should be a `[slow]` test - Check for errors in the predictor's `setup()` method ### "SERVER_URL not set" error Make sure `cog serve` is called before `curl`. ### Docker build output cluttering logs Build output is suppressed by default (`BUILDKIT_PROGRESS=quiet`). Errors are still shown. ### Files created in container not visible The `wait-for file` command checks the **host** filesystem, not inside Docker containers. Use `curl` for Docker-based synchronization (it has built-in retry logic). ### Test works locally but fails in CI - CI environments may be slower - increase retry attempts - Check for hardcoded paths or assumptions about the environment - Make sure the test is properly isolated (no shared state) ================================================ FILE: integration-tests/concurrent/concurrent_test.go ================================================ //go:build integration package concurrent_test import ( "encoding/json" "errors" "fmt" "net" "net/http" "os" "os/exec" "path/filepath" "strings" "sync" "syscall" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/replicate/cog/integration-tests/harness" ) // TestConcurrentPredictions tests that concurrent async predictions complete properly // with server shutdown. // // This test verifies: // 1. Multiple predictions can run concurrently // 2. Server shutdown waits for running predictions to complete // 3. All predictions return correct results // // This test is written in Go (not txtar) because it requires parallel HTTP requests // with precise timing coordination that doesn't fit txtar's sequential execution model. func TestConcurrentPredictions(t *testing.T) { if testing.Short() { t.Skip("skipping slow test in short mode") } // Create a temp directory for our test project tmpDir, err := os.MkdirTemp("", "cog-concurrent-test-*") require.NoError(t, err, "failed to create temp dir") defer os.RemoveAll(tmpDir) // Write the async-sleep predictor fixture err = os.WriteFile(filepath.Join(tmpDir, "cog.yaml"), []byte(cogYAML), 0o644) require.NoError(t, err, "failed to write cog.yaml") err = os.WriteFile(filepath.Join(tmpDir, "predict.py"), []byte(predictPy), 0o644) require.NoError(t, err, "failed to write predict.py") // Get the cog binary cogBinary, err := harness.ResolveCogBinary() require.NoError(t, err, "failed to resolve cog binary") // Generate unique image name imageName := fmt.Sprintf("cog-concurrent-test-%d", time.Now().UnixNano()) defer func() { exec.Command("docker", "rmi", "-f", imageName).Run() }() // Build the image t.Log("Building image...") buildCmd := exec.Command(cogBinary, "build", "-t", imageName) buildCmd.Dir = tmpDir buildCmd.Env = append(os.Environ(), "COG_NO_UPDATE_CHECK=1") output, err := buildCmd.CombinedOutput() require.NoError(t, err, "failed to build image\n%s", output) // Start the server t.Log("Starting server...") port, err := allocatePort() require.NoError(t, err, "failed to allocate port") serveCmd := exec.Command(cogBinary, "serve", "-p", fmt.Sprintf("%d", port)) serveCmd.Dir = tmpDir serveCmd.Env = append(os.Environ(), "COG_NO_UPDATE_CHECK=1") err = serveCmd.Start() require.NoError(t, err, "failed to start server") defer func() { serveCmd.Process.Kill() serveCmd.Wait() }() serverURL := fmt.Sprintf("http://127.0.0.1:%d", port) // Wait for server to be ready t.Log("Waiting for server to be ready...") require.True(t, waitForServerReady(serverURL, 60*time.Second), "server did not become ready within timeout") // Fire 5 concurrent predictions t.Log("Starting concurrent predictions...") const numPredictions = 5 var wg sync.WaitGroup results := make([]predictionResult, numPredictions) start := time.Now() for i := range numPredictions { wg.Add(1) go func(idx int) { defer wg.Done() results[idx] = makePrediction(serverURL, idx) }(i) } // Wait a bit for all predictions to be accepted but not completed time.Sleep(200 * time.Millisecond) // Shutdown the server while predictions are in-flight t.Log("Sending shutdown request...") shutdownResp, err := http.Post(serverURL+"/shutdown", "application/json", nil) if err != nil { t.Logf("shutdown request error (may be expected): %v", err) } else { shutdownResp.Body.Close() } // Wait for all predictions to complete wg.Wait() elapsed := time.Since(start) t.Logf("All predictions completed in %v", elapsed) // Verify timing - should be < 3s if running concurrently (each sleeps 1s) assert.Less(t, elapsed, 3*time.Second, "predictions took too long (%v), expected < 3s for concurrent execution", elapsed) // Verify all predictions succeeded with correct output for i, result := range results { if !assert.NoError(t, result.err, "prediction %d failed", i) { continue } if !assert.Equal(t, http.StatusOK, result.statusCode, "prediction %d returned unexpected status", i) { continue } expectedOutput := fmt.Sprintf("wake up sleepyhead%d", i) assert.Equal(t, expectedOutput, result.output, "prediction %d output mismatch", i) } } type predictionResult struct { statusCode int output string err error } func makePrediction(serverURL string, idx int) predictionResult { reqBody := fmt.Sprintf(`{"id":"id-%d","input":{"s":"sleepyhead%d","sleep":1.0}}`, idx, idx) resp, err := http.Post( serverURL+"/predictions", "application/json", strings.NewReader(reqBody), ) if err != nil { return predictionResult{err: err} } defer resp.Body.Close() var response struct { Output string `json:"output"` Status string `json:"status"` } if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { return predictionResult{statusCode: resp.StatusCode, err: err} } return predictionResult{ statusCode: resp.StatusCode, output: response.Output, } } func waitForServerReady(serverURL string, timeout time.Duration) bool { client := &http.Client{Timeout: 2 * time.Second} deadline := time.Now().Add(timeout) for time.Now().Before(deadline) { resp, err := client.Get(serverURL + "/health-check") if err != nil { time.Sleep(200 * time.Millisecond) continue } var health struct { Status string `json:"status"` } if err := json.NewDecoder(resp.Body).Decode(&health); err != nil { resp.Body.Close() time.Sleep(200 * time.Millisecond) continue } resp.Body.Close() if health.Status == "READY" { return true } if health.Status == "SETUP_FAILED" || health.Status == "DEFUNCT" { return false } time.Sleep(200 * time.Millisecond) } return false } // waitForServerStatus polls /health-check until the server reports the given status. // Unlike waitForServerReady which waits for READY, this can wait for intermediate // states like STARTING (useful for testing signals during setup). func waitForServerStatus(serverURL string, targetStatus string, timeout time.Duration) bool { client := &http.Client{Timeout: 2 * time.Second} deadline := time.Now().Add(timeout) for time.Now().Before(deadline) { resp, err := client.Get(serverURL + "/health-check") if err != nil { time.Sleep(200 * time.Millisecond) continue } var health struct { Status string `json:"status"` } if err := json.NewDecoder(resp.Body).Decode(&health); err != nil { resp.Body.Close() time.Sleep(200 * time.Millisecond) continue } resp.Body.Close() if health.Status == targetStatus { return true } if health.Status == "SETUP_FAILED" || health.Status == "DEFUNCT" { return false } time.Sleep(200 * time.Millisecond) } return false } // allocatePort finds an available TCP port by letting the OS assign one. func allocatePort() (int, error) { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { return 0, err } defer listener.Close() return listener.Addr().(*net.TCPAddr).Port, nil } // Embedded fixture files const cogYAML = `build: python_version: "3.11" predict: "predict.py:Predictor" concurrency: max: 5 ` const predictPy = `import asyncio from cog import BasePredictor class Predictor(BasePredictor): async def predict(self, s: str, sleep: float) -> str: await asyncio.sleep(sleep) return f"wake up {s}" ` // TestConcurrentAboveLimit tests that sending more predictions than max_concurrency // returns a 409 Conflict for the excess prediction. func TestConcurrentAboveLimit(t *testing.T) { if testing.Short() { t.Skip("skipping slow test in short mode") } tmpDir, err := os.MkdirTemp("", "cog-above-limit-test-*") require.NoError(t, err, "failed to create temp dir") defer os.RemoveAll(tmpDir) err = os.WriteFile(filepath.Join(tmpDir, "cog.yaml"), []byte(aboveLimitCogYAML), 0o644) require.NoError(t, err, "failed to write cog.yaml") err = os.WriteFile(filepath.Join(tmpDir, "predict.py"), []byte(predictPy), 0o644) require.NoError(t, err, "failed to write predict.py") cogBinary, err := harness.ResolveCogBinary() require.NoError(t, err, "failed to resolve cog binary") imageName := fmt.Sprintf("cog-above-limit-test-%d", time.Now().UnixNano()) defer func() { exec.Command("docker", "rmi", "-f", imageName).Run() }() t.Log("Building image...") buildCmd := exec.Command(cogBinary, "build", "-t", imageName) buildCmd.Dir = tmpDir buildCmd.Env = append(os.Environ(), "COG_NO_UPDATE_CHECK=1") output, err := buildCmd.CombinedOutput() require.NoError(t, err, "failed to build image\n%s", output) t.Log("Starting server...") port, err := allocatePort() require.NoError(t, err, "failed to allocate port") serveCmd := exec.Command(cogBinary, "serve", "-p", fmt.Sprintf("%d", port)) serveCmd.Dir = tmpDir serveCmd.Env = append(os.Environ(), "COG_NO_UPDATE_CHECK=1") err = serveCmd.Start() require.NoError(t, err, "failed to start server") defer func() { serveCmd.Process.Kill() serveCmd.Wait() }() serverURL := fmt.Sprintf("http://127.0.0.1:%d", port) t.Log("Waiting for server to be ready...") require.True(t, waitForServerReady(serverURL, 60*time.Second), "server did not become ready within timeout") // Fill all 2 slots with long-running predictions (each sleeps 1s) const maxConcurrency = 2 var wg sync.WaitGroup for i := range maxConcurrency { wg.Add(1) go func(idx int) { defer wg.Done() makePrediction(serverURL, idx) }(i) } // Poll with an overflow request until we get a 409, meaning both slots // are occupied. This avoids a fixed sleep that can flake on slow CI. t.Log("Polling for 409 (all slots occupied)...") deadline := time.Now().Add(10 * time.Second) var resp *http.Response for time.Now().Before(deadline) { extraBody := `{"id":"extra","input":{"s":"overflow","sleep":1.0}}` resp, err = http.Post( serverURL+"/predictions", "application/json", strings.NewReader(extraBody), ) require.NoError(t, err, "failed to send extra prediction") if resp.StatusCode == http.StatusConflict { break } // Got 200 — slots weren't full yet, close and retry resp.Body.Close() time.Sleep(100 * time.Millisecond) } defer resp.Body.Close() require.Equal(t, http.StatusConflict, resp.StatusCode, "extra prediction status = %d, want %d (409 Conflict); slots never filled within timeout", resp.StatusCode, http.StatusConflict) var errResp struct { Error string `json:"error"` Status string `json:"status"` } err = json.NewDecoder(resp.Body).Decode(&errResp) require.NoError(t, err, "failed to decode error response") assert.Equal(t, "failed", errResp.Status, "error response status mismatch") assert.Contains(t, strings.ToLower(errResp.Error), "capacity", "error response error = %q, want string containing \"capacity\"", errResp.Error) wg.Wait() } const aboveLimitCogYAML = `build: python_version: "3.11" predict: "predict.py:Predictor" concurrency: max: 2 ` // TestSIGTERMDuringSetup tests that SIGTERM during setup() causes clean shutdown. func TestSIGTERMDuringSetup(t *testing.T) { if testing.Short() { t.Skip("skipping slow test in short mode") } tmpDir, err := os.MkdirTemp("", "cog-sigterm-setup-test-*") require.NoError(t, err, "failed to create temp dir") defer os.RemoveAll(tmpDir) slowSetupCogYAML := `build: python_version: "3.12" predict: "predict.py:Predictor" ` slowSetupPredictPy := `import time from cog import BasePredictor class Predictor(BasePredictor): def setup(self) -> None: time.sleep(30) def predict(self, s: str) -> str: return "hello " + s ` err = os.WriteFile(filepath.Join(tmpDir, "cog.yaml"), []byte(slowSetupCogYAML), 0o644) require.NoError(t, err, "failed to write cog.yaml") err = os.WriteFile(filepath.Join(tmpDir, "predict.py"), []byte(slowSetupPredictPy), 0o644) require.NoError(t, err, "failed to write predict.py") cogBinary, err := harness.ResolveCogBinary() require.NoError(t, err, "failed to resolve cog binary") t.Log("Building image...") imageName := fmt.Sprintf("cog-sigterm-setup-test-%d", time.Now().UnixNano()) defer func() { exec.Command("docker", "rmi", "-f", imageName).Run() }() buildCmd := exec.Command(cogBinary, "build", "-t", imageName) buildCmd.Dir = tmpDir buildCmd.Env = append(os.Environ(), "COG_NO_UPDATE_CHECK=1") output, err := buildCmd.CombinedOutput() require.NoError(t, err, "failed to build image\n%s", output) t.Log("Starting server...") port, err := allocatePort() require.NoError(t, err, "failed to allocate port") serveCmd := exec.Command(cogBinary, "serve", "-p", fmt.Sprintf("%d", port)) serveCmd.Dir = tmpDir serveCmd.Env = append(os.Environ(), "COG_NO_UPDATE_CHECK=1") err = serveCmd.Start() require.NoError(t, err, "failed to start server") // Poll health-check until setup has begun (status STARTING), // rather than a fixed sleep that can be too short on cold Docker pulls. t.Log("Waiting for setup to begin (STARTING status)...") if !waitForServerStatus(fmt.Sprintf("http://127.0.0.1:%d", port), "STARTING", 60*time.Second) { serveCmd.Process.Kill() serveCmd.Wait() t.Fatal("server did not reach STARTING status within timeout") } // Send SIGTERM t.Log("Sending SIGTERM during setup...") err = serveCmd.Process.Signal(syscall.SIGTERM) require.NoError(t, err, "failed to send signal") // Wait for process to exit with a timeout done := make(chan error, 1) go func() { done <- serveCmd.Wait() }() select { case err := <-done: if err == nil { t.Fatal("server exited cleanly after SIGTERM; expected termination by signal") } var exitErr *exec.ExitError if !errors.As(err, &exitErr) { t.Fatalf("server exited with unexpected error type after SIGTERM: %T (%v)", err, err) } ws, ok := exitErr.Sys().(syscall.WaitStatus) if !ok { t.Fatalf("server exited after SIGTERM but wait status was unavailable: %v", err) } if !ws.Signaled() || ws.Signal() != syscall.SIGTERM { t.Fatalf("server exit = %v, want signal %v", ws, syscall.SIGTERM) } case <-time.After(15 * time.Second): serveCmd.Process.Kill() t.Fatal("server did not exit within 15s after SIGTERM") } } ================================================ FILE: integration-tests/harness/cmd_pty.go ================================================ package harness import ( "bytes" "io" "os" "os/exec" "strings" "sync" "time" "github.com/creack/pty" "github.com/rogpeppe/go-internal/testscript" ) // PtyRunCommand implements the 'pty-run' command for testscript. type PtyRunCommand struct { harness *Harness } func (c *PtyRunCommand) Name() string { return "pty-run" } // Run executes a command with a PTY, sending input from a file and capturing output. // // Usage: pty-run [args...] // // The input file contents are written to the PTY as terminal input. // Use /dev/null or an empty file if no input is needed. // The command's output is written to stdout for matching with 'stdout' command. // // This uses github.com/creack/pty which works on both Linux and macOS, // unlike testscript's native ttyin/ttyout which hangs on macOS due to // Go bug https://github.com/golang/go/issues/61779. // // TODO: Remove this implementation and use testscript's native ttyin/ttyout // once the Go bug is fixed (check Go 1.26+). func (c *PtyRunCommand) Run(ts *testscript.TestScript, neg bool, args []string) { if len(args) < 2 { ts.Fatalf("pty-run: usage: pty-run [args...]") } inputFile := args[0] cmdName := args[1] cmdArgs := args[2:] // Read input file var input string if inputFile != "/dev/null" { input = ts.ReadFile(inputFile) } // Expand environment variables in command and args cmdName = os.Expand(cmdName, ts.Getenv) // Handle "cog" command specially - use the resolved binary if cmdName == "cog" { cmdName = c.harness.CogBinary } expandedArgs := make([]string, len(cmdArgs)) for i, arg := range cmdArgs { expandedArgs[i] = os.Expand(arg, ts.Getenv) } // Create the command cmd := exec.Command(cmdName, expandedArgs...) cmd.Dir = ts.Getenv("WORK") // Build environment cmd.Env = []string{ "HOME=" + ts.Getenv("HOME"), "PATH=" + ts.Getenv("PATH"), "COG_NO_UPDATE_CHECK=1", } if v := ts.Getenv("REPO_ROOT"); v != "" { cmd.Env = append(cmd.Env, "REPO_ROOT="+v) } if v := ts.Getenv("COG_SDK_WHEEL"); v != "" { cmd.Env = append(cmd.Env, "COG_SDK_WHEEL="+v) } if v := ts.Getenv("COGLET_WHEEL"); v != "" { cmd.Env = append(cmd.Env, "COGLET_WHEEL="+v) } // Start command with PTY ptmx, err := pty.Start(cmd) if err != nil { ts.Fatalf("pty-run: failed to start command with PTY: %v", err) } defer ptmx.Close() // Set terminal size if err := pty.Setsize(ptmx, &pty.Winsize{Rows: 24, Cols: 80}); err != nil { ts.Logf("pty-run: failed to set terminal size: %v", err) } // Use shared buffer pattern for reading (avoids race conditions) var buf bytes.Buffer var mu sync.Mutex done := make(chan struct{}) go func() { tmp := make([]byte, 1024) for { select { case <-done: return default: n, err := ptmx.Read(tmp) if n > 0 { mu.Lock() buf.Write(tmp[:n]) mu.Unlock() } if err != nil { if err != io.EOF { // Log non-EOF errors but don't fail - PTY may close unexpectedly ts.Logf("pty-run: read error (may be normal): %v", err) } return } } } }() // Write input to PTY with small delays between lines for reliability if input != "" { for line := range strings.SplitSeq(input, "\n") { if line == "" { continue } _, err := ptmx.Write([]byte(line + "\n")) if err != nil { ts.Logf("pty-run: failed to write input (may be normal if command exited): %v", err) break } // Small delay to let the shell process the line time.Sleep(50 * time.Millisecond) } } // Wait for command to finish with timeout cmdDone := make(chan error, 1) go func() { cmdDone <- cmd.Wait() }() timeout := 60 * time.Second var cmdErr error select { case cmdErr = <-cmdDone: // Command finished case <-time.After(timeout): // Timeout - kill the process _ = cmd.Process.Kill() mu.Lock() output := buf.String() mu.Unlock() ts.Logf("pty-run: timeout after %v, partial output: %q", timeout, output) ts.Fatalf("pty-run: command timed out after %v", timeout) return } // Give a moment for final output to be captured time.Sleep(100 * time.Millisecond) close(done) // Get final output mu.Lock() output := buf.String() mu.Unlock() // Handle negation if neg { if cmdErr == nil { ts.Fatalf("pty-run: command succeeded unexpectedly") } // Command failed as expected - write output for potential pattern matching _, _ = ts.Stdout().Write([]byte(output)) return } if cmdErr != nil { ts.Logf("pty-run: command output: %q", output) ts.Fatalf("pty-run: command failed: %v", cmdErr) } // Write output to stdout for pattern matching _, _ = ts.Stdout().Write([]byte(output)) } ================================================ FILE: integration-tests/harness/command.go ================================================ package harness import "github.com/rogpeppe/go-internal/testscript" // Command defines the interface for testscript commands. type Command interface { // Name returns the command name as used in txtar scripts. Name() string // Run executes the command. // neg is true if the command was prefixed with '!' (expecting failure). // args are the command arguments. Run(ts *testscript.TestScript, neg bool, args []string) } // CommandFunc adapts a function to the Command interface. type CommandFunc struct { name string fn func(ts *testscript.TestScript, neg bool, args []string) } func (c CommandFunc) Name() string { return c.name } func (c CommandFunc) Run(ts *testscript.TestScript, neg bool, args []string) { c.fn(ts, neg, args) } // NewCommand creates a Command from a name and function. func NewCommand(name string, fn func(ts *testscript.TestScript, neg bool, args []string)) Command { return CommandFunc{name: name, fn: fn} } ================================================ FILE: integration-tests/harness/harness.go ================================================ // Package harness provides utilities for running cog integration tests. package harness import ( "context" cryptorand "crypto/rand" "crypto/sha256" "encoding/hex" "encoding/json" "fmt" "io" mathrand "math/rand/v2" "net" "net/http" "os" "os/exec" "path/filepath" "strconv" "strings" "sync" "time" "github.com/google/go-containerregistry/pkg/crane" "github.com/rogpeppe/go-internal/testscript" "github.com/replicate/cog/pkg/registry" "github.com/replicate/cog/pkg/registry_testhelpers" ) // propagatedEnvVars lists host environment variables that should be propagated // into testscript environments (Setup) and background processes (cmdCogServe). // Keep this list in sync: if you add a new env var to propagate, add it here. var propagatedEnvVars = []string{ "COG_SDK_WHEEL", // SDK wheel override "COGLET_WHEEL", // coglet wheel override "RUST_LOG", // Rust logging control "COG_CA_CERT", // custom CA certificates (e.g. Cloudflare WARP) "BUILDKIT_PROGRESS", // Docker build output format } // Harness provides utilities for running cog integration tests. // serverInfo tracks a running cog serve process and its port type serverInfo struct { cmd *exec.Cmd port int } // registryInfo tracks a running test registry container type registryInfo struct { container *registry_testhelpers.RegistryContainer cleanup func() host string // e.g., "localhost:5432" } // mockUploadRecord records a single upload received by the mock upload server. type mockUploadRecord struct { Path string ContentType string Size int } // mockUploadServer is a lightweight HTTP server that accepts PUT requests // and records what was uploaded. type mockUploadServer struct { server *http.Server port int mu sync.Mutex uploads []mockUploadRecord } // webhookResult is the summary written to stdout by webhook-server-wait. type webhookResult struct { Status string `json:"status"` OutputSize int `json:"output_size"` HasError bool `json:"has_error"` ErrorMessage string `json:"error_message,omitempty"` Metrics json.RawMessage `json:"metrics,omitempty"` } // webhookServer accepts prediction webhook callbacks from coglet. // It parses the JSON payload to extract status and output size, without // ever exposing the (potentially huge) output to testscript's log buffer. type webhookServer struct { server *http.Server port int mu sync.Mutex result *webhookResult done chan struct{} // closed on first terminal webhook } type Harness struct { CogBinary string // realHome is captured at creation time before testscript overrides HOME realHome string // repoRoot is the path to the cog repository root repoRoot string // serverProcs tracks background cog serve processes for cleanup, keyed by work directory serverProcs map[string]*serverInfo serverProcsMu sync.Mutex // registries tracks test registry containers for cleanup, keyed by work directory registries map[string]*registryInfo registriesMu sync.Mutex // uploadServers tracks mock upload servers for cleanup, keyed by work directory uploadServers map[string]*mockUploadServer uploadServersMu sync.Mutex // webhookServers tracks webhook receiver servers for cleanup, keyed by work directory webhookServers map[string]*webhookServer webhookServersMu sync.Mutex } // New creates a new Harness, resolving the cog binary location. func New() (*Harness, error) { cogBinary, err := ResolveCogBinary() if err != nil { return nil, err } repoRoot, err := findRepoRoot() if err != nil { return nil, err } return &Harness{ CogBinary: cogBinary, realHome: os.Getenv("HOME"), repoRoot: repoRoot, serverProcs: make(map[string]*serverInfo), registries: make(map[string]*registryInfo), uploadServers: make(map[string]*mockUploadServer), webhookServers: make(map[string]*webhookServer), }, nil } // ResolveCogBinary finds the cog binary to use for tests. // It checks (in order): // 1. COG_BINARY environment variable // 2. Build from source (if in cog repository) func ResolveCogBinary() (string, error) { if cogBinary := os.Getenv("COG_BINARY"); cogBinary != "" { if !filepath.IsAbs(cogBinary) { // Resolve relative paths from repo root, not the test package directory. repoRoot, err := findRepoRoot() if err != nil { return "", err } cogBinary = filepath.Join(repoRoot, cogBinary) } return cogBinary, nil } // Build from source return buildCogBinary() } // buildCogBinary builds the cog binary from source. // It finds the repository root, builds wheels if needed, and compiles the binary. // If the binary already exists, it returns the cached path. func buildCogBinary() (string, error) { // Find repository root (where go.mod with module github.com/replicate/cog exists) repoRoot, err := findRepoRoot() if err != nil { return "", fmt.Errorf("failed to find cog repository root: %w", err) } // Check if binary already exists binPath := filepath.Join(repoRoot, "integration-tests", ".bin", "cog") if _, err := os.Stat(binPath); err == nil { fmt.Printf("Using cached cog binary: %s\n", binPath) return binPath, nil } // Check if wheels exist, build if not var ( wheelsDir = filepath.Join(repoRoot, "pkg", "wheels") cogWheelExists, _ = filepath.Glob(filepath.Join(wheelsDir, "cog-*.whl")) cogletWheelExists, _ = filepath.Glob(filepath.Join(wheelsDir, "coglet-*.whl")) ) if len(cogWheelExists) == 0 || len(cogletWheelExists) == 0 { fmt.Println("Building Python wheels...") if err := runCommand(repoRoot, "mise", "run", "build:wheels"); err != nil { return "", fmt.Errorf("failed to build wheels: %w", err) } fmt.Println("Generating wheel embeds...") if err := runCommand(repoRoot, "go", "generate", "./pkg/wheels"); err != nil { return "", fmt.Errorf("failed to generate wheel embeds: %w", err) } } // Build the cog binary if err := os.MkdirAll(filepath.Dir(binPath), 0o755); err != nil { return "", fmt.Errorf("failed to create bin directory: %w", err) } fmt.Println("Building cog binary...") if err := runCommand(repoRoot, "go", "build", "-o", binPath, "./cmd/cog"); err != nil { return "", fmt.Errorf("failed to build cog: %w", err) } return binPath, nil } // findRepoRoot finds the cog repository root by looking for go.mod with the main module func findRepoRoot() (string, error) { // Start from current working directory dir, err := os.Getwd() if err != nil { return "", err } for { goMod := filepath.Join(dir, "go.mod") if _, err := os.Stat(goMod); err == nil { // Verify it's the cog repo root (matches the expected module path) content, err := os.ReadFile(goMod) if err == nil && strings.Contains(string(content), "module github.com/replicate/cog\n") { return dir, nil } } parent := filepath.Dir(dir) if parent == dir { break } dir = parent } return "", fmt.Errorf("could not find cog repository root") } // runCommand runs a command in the specified directory func runCommand(dir string, name string, args ...string) error { cmd := exec.Command(name, args...) cmd.Dir = dir cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr return cmd.Run() } // Commands returns the custom testscript commands provided by this harness. func (h *Harness) Commands() map[string]func(ts *testscript.TestScript, neg bool, args []string) { // Register all commands commands := []Command{ // Built-in commands (defined in this file) NewCommand("cog", h.cmdCog), NewCommand("curl", h.cmdCurl), NewCommand("wait-for", h.cmdWaitFor), NewCommand("docker-run", h.cmdDockerRun), // Registry and OCI bundle testing commands NewCommand("registry-start", h.cmdRegistryStart), NewCommand("registry-inspect", h.cmdRegistryInspect), NewCommand("registry-seed", h.cmdRegistrySeed), NewCommand("docker-push", h.cmdDockerPush), NewCommand("mock-weights", h.cmdMockWeights), // Mock upload server commands NewCommand("upload-server-start", h.cmdUploadServerStart), NewCommand("upload-server-count", h.cmdUploadServerCount), // Webhook receiver commands NewCommand("webhook-server-start", h.cmdWebhookServerStart), NewCommand("webhook-server-wait", h.cmdWebhookServerWait), // PTY command (defined in cmd_pty.go) &PtyRunCommand{harness: h}, } // Build the command map result := make(map[string]func(ts *testscript.TestScript, neg bool, args []string)) for _, cmd := range commands { result[cmd.Name()] = cmd.Run } return result } // cmdCog implements the 'cog' command for testscript. // It handles all cog subcommands, with special handling for certain commands. func (h *Harness) cmdCog(ts *testscript.TestScript, neg bool, args []string) { // Check for subcommands that need special handling if len(args) > 0 && args[0] == "serve" { // Special handling for 'cog serve' - run in background h.cmdCogServe(ts, neg, args[1:]) return } // Default: run cog command normally expandedArgs := make([]string, len(args)) for i, arg := range args { expandedArgs[i] = os.Expand(arg, ts.Getenv) } err := ts.Exec(h.CogBinary, expandedArgs...) if neg { if err == nil { ts.Fatalf("cog command succeeded unexpectedly") } return } if err != nil { ts.Fatalf("cog command failed: %v", err) } } // Setup returns a testscript Setup function that configures the test environment. // Fixtures are embedded in the txtar files themselves, so no file copying is needed. func (h *Harness) Setup(env *testscript.Env) error { // Restore real HOME for Docker credential helpers. // Docker credential helpers (e.g., docker-credential-desktop) need the real HOME // to access the macOS keychain. env.Setenv("HOME", h.realHome) // Export repo root for tests that need to reference files outside the work directory env.Setenv("REPO_ROOT", h.repoRoot) // Disable update checks during tests env.Setenv("COG_NO_UPDATE_CHECK", "1") // Propagate host env vars listed in propagatedEnvVars for _, key := range propagatedEnvVars { if val := os.Getenv(key); val != "" { env.Setenv(key, val) } } // Auto-detect wheels from dist/ if not explicitly set via env vars. // CI sets these env vars; locally we need to find them ourselves. distDir := filepath.Join(h.repoRoot, "dist") if os.Getenv("COGLET_WHEEL") == "" { if matches, _ := filepath.Glob(filepath.Join(distDir, "coglet-*.whl")); len(matches) > 0 { env.Setenv("COGLET_WHEEL", distDir) } } if os.Getenv("COG_SDK_WHEEL") == "" { if matches, _ := filepath.Glob(filepath.Join(distDir, "cog-*.whl")); len(matches) > 0 { env.Setenv("COG_SDK_WHEEL", distDir) } } // Generate unique image name for this test run imageName := generateUniqueImageName() env.Setenv("TEST_IMAGE", imageName) // Capture the work directory for this test (used as key for server tracking) workDir := env.WorkDir // Register cleanup to remove the Docker image, stop servers, and cleanup registries env.Defer(func() { // Stop the server for this specific test (if any) h.stopServerByWorkDir(workDir) // Stop the registry for this specific test (if any) h.stopRegistryByWorkDir(workDir) // Stop the upload server for this specific test (if any) h.stopUploadServerByWorkDir(workDir) // Stop the webhook server for this specific test (if any) h.stopWebhookServerByWorkDir(workDir) removeDockerImage(imageName) }) return nil } // generateUniqueImageName creates a unique Docker image name for test isolation. func generateUniqueImageName() string { b := make([]byte, 5) if _, err := cryptorand.Read(b); err != nil { // Fall back to a less random but still unique name return fmt.Sprintf("cog-test-%d", os.Getpid()) } return fmt.Sprintf("cog-test-%s", hex.EncodeToString(b)) } // removeDockerImage attempts to remove a Docker image by name. // It silently ignores errors (image may not exist if test failed early). func removeDockerImage(imageName string) { // Remove all images that match the prefix (base and final images) cmd := exec.Command("docker", "images", "--format", "{{.Repository}}:{{.Tag}}", "--filter", fmt.Sprintf("reference=%s*", imageName)) output, err := cmd.Output() if err != nil { return } for img := range strings.SplitSeq(strings.TrimSpace(string(output)), "\n") { if img == "" { continue } exec.Command("docker", "rmi", "-f", img).Run() //nolint:errcheck,gosec } } // cmdCogServe implements background 'cog serve' for testscript. // It starts a cog serve process in the background and waits for it to be healthy. // Usage: cog serve [flags] // Exports $SERVER_URL environment variable with the server address. func (h *Harness) cmdCogServe(ts *testscript.TestScript, neg bool, args []string) { workDir := ts.Getenv("WORK") // Check if server is already running h.serverProcsMu.Lock() if _, exists := h.serverProcs[workDir]; exists { h.serverProcsMu.Unlock() ts.Fatalf("server already running") } h.serverProcsMu.Unlock() // Allocate a random available port port, err := allocatePort() if err != nil { ts.Fatalf("failed to allocate port: %v", err) } // Build command arguments cmdArgs := []string{"serve", "-p", strconv.Itoa(port)} cmdArgs = append(cmdArgs, args...) // Expand environment variables in arguments expandedArgs := make([]string, len(cmdArgs)) for i, arg := range cmdArgs { expandedArgs[i] = os.Expand(arg, ts.Getenv) } // Start the server process cmd := exec.Command(h.CogBinary, expandedArgs...) cmd.Dir = workDir // Build environment from testscript. // Always include core vars, plus everything from propagatedEnvVars. var env []string for _, key := range []string{"HOME", "PATH", "REPO_ROOT", "COG_NO_UPDATE_CHECK", "TEST_IMAGE"} { if val := ts.Getenv(key); val != "" { env = append(env, key+"="+val) } } for _, key := range propagatedEnvVars { if val := ts.Getenv(key); val != "" { env = append(env, key+"="+val) } } cmd.Env = env // Capture server output for debugging cmd.Stdout = ts.Stdout() cmd.Stderr = ts.Stderr() if err := cmd.Start(); err != nil { ts.Fatalf("failed to start server: %v", err) } // Store the process for cleanup (keyed by work directory) h.serverProcsMu.Lock() h.serverProcs[workDir] = &serverInfo{cmd: cmd, port: port} h.serverProcsMu.Unlock() // Wait for server to be healthy serverURL := fmt.Sprintf("http://127.0.0.1:%d", port) ts.Setenv("SERVER_URL", serverURL) if !waitForServer(serverURL, 60*time.Second) { if neg { // Test expected the server to fail setup — keep it running // so the test can inspect the health-check status. return } // Try to get server output for debugging _ = cmd.Process.Kill() ts.Fatalf("server did not become healthy within timeout") } if neg { ts.Fatalf("server became healthy, but expected setup failure") } } // cmdCurl implements the 'curl' command for testscript. // It makes HTTP requests to the server started with 'serve'. // Includes built-in retry logic (10 attempts, 500ms delay) for resilience. // Usage: curl [-H key:value]... [method] [path] [body] // Examples: // // curl GET /health-check // curl POST /predictions '{"input":{"s":"hello"}}' // curl -H Prefer:respond-async POST /predictions '{"input":{}}' func (h *Harness) cmdCurl(ts *testscript.TestScript, neg bool, args []string) { if len(args) < 2 { ts.Fatalf("curl: usage: curl [-H key:value]... [method] [path] [body | @file]") } // Parse -H flags for extra headers var extraHeaders [][2]string for len(args) >= 2 && args[0] == "-H" { kv := args[1] parts := strings.SplitN(kv, ":", 2) if len(parts) != 2 { ts.Fatalf("curl: invalid header %q (expected key:value)", kv) } extraHeaders = append(extraHeaders, [2]string{ strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]), }) args = args[2:] } if len(args) < 2 { ts.Fatalf("curl: usage: curl [-H key:value]... [method] [path] [body | @file]") } serverURL := ts.Getenv("SERVER_URL") if serverURL == "" { ts.Fatalf("curl: SERVER_URL not set (did you call 'cog serve' first?)") } method := args[0] path := args[1] var body string if len(args) > 2 { body = os.Expand(args[2], ts.Getenv) if strings.HasPrefix(body, "@") { filename := body[1:] data, err := os.ReadFile(ts.MkAbs(filename)) if err != nil { ts.Fatalf("curl: failed to read body file %q: %v", filename, err) } body = string(data) } } // Retry settings maxAttempts := 10 retryDelay := 500 * time.Millisecond client := &http.Client{Timeout: 10 * time.Second} var ( lastErr error lastStatus int lastBody string ) for attempt := 1; attempt <= maxAttempts; attempt++ { req, err := http.NewRequest(method, serverURL+path, strings.NewReader(body)) if err != nil { lastErr = err time.Sleep(retryDelay) continue } if body != "" { req.Header.Set("Content-Type", "application/json") } for _, h := range extraHeaders { req.Header.Set(h[0], h[1]) } resp, err := client.Do(req) //nolint:gosec // G704: URL from test harness, not user input if err != nil { lastErr = err time.Sleep(retryDelay) continue } // Read response body respBodyBytes, readErr := io.ReadAll(resp.Body) if readErr != nil { ts.Fatalf("curl: failed to read response: %v", readErr) } respBody := string(respBodyBytes) _ = resp.Body.Close() lastStatus = resp.StatusCode lastBody = respBody lastErr = nil // Check if this is a successful response statusOK := resp.StatusCode >= 200 && resp.StatusCode < 300 if neg { if !statusOK { // Expected to fail — write body to stderr so tests can assert _, _ = ts.Stderr().Write([]byte(respBody)) return } } else { if statusOK { // Success - write body to stdout _, _ = ts.Stdout().Write([]byte(respBody)) return } } // If this isn't the last attempt, wait before retrying if attempt < maxAttempts { time.Sleep(retryDelay) } } // All attempts failed if neg { ts.Fatalf("curl: expected failure but got status %d after %d attempts", lastStatus, maxAttempts) return } if lastErr != nil { ts.Fatalf("curl: all %d attempts failed with error: %v", maxAttempts, lastErr) return } errorMsg := lastBody if len(errorMsg) > 500 { errorMsg = errorMsg[:500] + "..." } ts.Logf("curl: full response body: %s", lastBody) ts.Fatalf("curl: all %d attempts failed with status %d: %s", maxAttempts, lastStatus, errorMsg) } // StopServer stops the background server process for a test script. func (h *Harness) StopServer(ts *testscript.TestScript) { workDir := ts.Getenv("WORK") h.stopServerByWorkDir(workDir) } // stopServerByWorkDir stops the server process associated with a work directory. func (h *Harness) stopServerByWorkDir(workDir string) { h.serverProcsMu.Lock() info, exists := h.serverProcs[workDir] if !exists { h.serverProcsMu.Unlock() return } delete(h.serverProcs, workDir) h.serverProcsMu.Unlock() // Try graceful shutdown first via /shutdown endpoint serverURL := fmt.Sprintf("http://127.0.0.1:%d", info.port) shutdownURL := serverURL + "/shutdown" resp, err := http.Post(shutdownURL, "application/json", nil) //nolint:gosec,noctx if err == nil { _ = resp.Body.Close() } // Force kill the cog process if still running if info.cmd.Process != nil { _ = info.cmd.Process.Kill() } _ = info.cmd.Wait() // Also kill any Docker container that may still be running on this port // Find container by port and kill it output, err := exec.Command("docker", "ps", "-q", "--filter", fmt.Sprintf("publish=%d", info.port)).Output() if err == nil && len(output) > 0 { containerID := strings.TrimSpace(string(output)) if containerID != "" { exec.Command("docker", "kill", containerID).Run() //nolint:errcheck,gosec } } } // allocatePort finds an available TCP port. func allocatePort() (int, error) { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { return 0, err } defer listener.Close() return listener.Addr().(*net.TCPAddr).Port, nil } // healthCheckResponse represents the JSON response from /health-check type healthCheckResponse struct { Status string `json:"status"` } // waitForServer polls the server's health-check endpoint until it returns READY status. // The server may return HTTP 200 while still in STARTING state (during setup), // so we must check the actual status field in the response. func waitForServer(serverURL string, timeout time.Duration) bool { client := &http.Client{Timeout: 5 * time.Second} deadline := time.Now().Add(timeout) for time.Now().Before(deadline) { resp, err := client.Get(serverURL + "/health-check") if err != nil { time.Sleep(200 * time.Millisecond) continue } if resp.StatusCode == http.StatusOK { body, err := io.ReadAll(resp.Body) _ = resp.Body.Close() if err != nil { time.Sleep(200 * time.Millisecond) continue } var health healthCheckResponse if err := json.Unmarshal(body, &health); err != nil { time.Sleep(200 * time.Millisecond) continue } // Return success when the server has completed setup // READY = setup completed, healthcheck passed (or no healthcheck) // UNHEALTHY = setup completed, but user healthcheck failed // BUSY = setup completed, prediction in progress if health.Status == "READY" || health.Status == "UNHEALTHY" || health.Status == "BUSY" { return true } // If setup failed, no point waiting if health.Status == "SETUP_FAILED" || health.Status == "DEFUNCT" { return false } } else { _ = resp.Body.Close() } time.Sleep(200 * time.Millisecond) } return false } // cmdWaitFor implements the 'wait-for' command for testscript. // It waits for a specific condition to become true with retries. // Usage: // // wait-for file [timeout] - Wait for file to exist // wait-for http [status] [timeout] - Wait for HTTP endpoint // wait-for not-empty [timeout] - Wait for file with content func (h *Harness) cmdWaitFor(ts *testscript.TestScript, neg bool, args []string) { if len(args) < 2 { ts.Fatalf("wait-for: usage: wait-for [file|http|not-empty] [timeout]") } var ( condition = args[0] target = args[1] // Default timeout of 30 seconds, can be overridden timeout = 30 * time.Second ) if len(args) > 2 { if duration, err := time.ParseDuration(args[len(args)-1]); err == nil { timeout = duration } } deadline := time.Now().Add(timeout) for time.Now().Before(deadline) { var conditionMet bool switch condition { case "file": // Wait for file to exist targetPath := filepath.Join(ts.Getenv("WORK"), target) _, err := os.Stat(targetPath) conditionMet = err == nil case "not-empty": // Wait for file to exist with non-empty content targetPath := filepath.Join(ts.Getenv("WORK"), target) data, err := os.ReadFile(targetPath) conditionMet = err == nil && len(data) > 0 case "http": // Wait for HTTP endpoint to return expected status expectedStatus := http.StatusOK if len(args) > 2 { if status, err := strconv.Atoi(args[2]); err == nil { expectedStatus = status } } client := &http.Client{Timeout: 2 * time.Second} resp, err := client.Get(target) if err == nil { conditionMet = resp.StatusCode == expectedStatus _ = resp.Body.Close() } default: ts.Fatalf("wait-for: unknown condition: %s", condition) } if neg { // For negation, we want the condition to remain false if !conditionMet { return } } else { // Normal case: condition should become true if conditionMet { return } } time.Sleep(200 * time.Millisecond) } if neg { ts.Fatalf("wait-for: condition became true (expected to remain false)") return } ts.Fatalf("wait-for: timeout waiting for condition: %s %s", condition, target) } // cmdDockerRun implements the 'docker-run' command for testscript. // It runs a command inside a Docker container. // Usage: // // docker-run [args...] // // The container is run with: // - --rm (auto-remove after exit) // - --add-host=host.docker.internal:host-gateway (for Linux compatibility) // - Working directory mounted if needed func (h *Harness) cmdDockerRun(ts *testscript.TestScript, neg bool, args []string) { if len(args) < 2 { ts.Fatalf("docker-run: usage: docker-run [args...]") } var ( image = os.Expand(args[0], ts.Getenv) containerArgs = make([]string, len(args)-1) ) for i, arg := range args[1:] { containerArgs[i] = os.Expand(arg, ts.Getenv) } // Build docker run command dockerArgs := []string{ "run", "--rm", "--add-host=host.docker.internal:host-gateway", image, } dockerArgs = append(dockerArgs, containerArgs...) cmd := exec.Command("docker", dockerArgs...) cmd.Stdout = ts.Stdout() cmd.Stderr = ts.Stderr() err := cmd.Run() if neg { if err == nil { ts.Fatalf("docker-run: command succeeded unexpectedly") } return } if err != nil { ts.Fatalf("docker-run: command failed: %v", err) } } // ============================================================================= // Registry commands // ============================================================================= // cmdRegistryStart starts a test registry container. // The registry is automatically cleaned up when the test ends. // Usage: registry-start // Exports $TEST_REGISTRY environment variable with the registry address. func (h *Harness) cmdRegistryStart(ts *testscript.TestScript, neg bool, args []string) { if neg { ts.Fatalf("registry-start: does not support negation") } workDir := ts.Getenv("WORK") // Check if registry is already running (idempotent) h.registriesMu.Lock() if info, exists := h.registries[workDir]; exists { h.registriesMu.Unlock() // Already started, just ensure env is set ts.Setenv("TEST_REGISTRY", info.host) return } h.registriesMu.Unlock() // Start new registry container, cleanup, err := registry_testhelpers.StartTestRegistryWithCleanup(context.Background()) if err != nil { ts.Fatalf("registry-start: failed to start registry: %v", err) } host := container.RegistryHost() // Store for cleanup h.registriesMu.Lock() h.registries[workDir] = ®istryInfo{ container: container, cleanup: cleanup, host: host, } h.registriesMu.Unlock() ts.Setenv("TEST_REGISTRY", host) ts.Logf("registry-start: started registry at %s", host) } // stopRegistryByWorkDir stops the registry container associated with a work directory. func (h *Harness) stopRegistryByWorkDir(workDir string) { h.registriesMu.Lock() info, exists := h.registries[workDir] if !exists { h.registriesMu.Unlock() return } delete(h.registries, workDir) h.registriesMu.Unlock() if info.cleanup != nil { info.cleanup() } } // cmdRegistrySeed copies an image into the test registry under a new repository:tag. // The source can be a local reference (relative to $TEST_REGISTRY) or an absolute // reference to an external registry (e.g., docker.io/library/python:3.12-slim). // The destination is always relative to $TEST_REGISTRY. // // Usage: registry-seed // Examples: // // registry-seed alpine:latest cog-base:cuda11.8-python3.10-torch2.0.1 // registry-seed docker.io/library/python:3.12-slim cog-base:python3.12 func (h *Harness) cmdRegistrySeed(ts *testscript.TestScript, neg bool, args []string) { if neg { ts.Fatalf("registry-seed: does not support negation") } if len(args) < 2 { ts.Fatalf("registry-seed: usage: registry-seed ") } src := os.Expand(args[0], ts.Getenv) dst := os.Expand(args[1], ts.Getenv) testRegistry := ts.Getenv("TEST_REGISTRY") if testRegistry == "" { ts.Fatalf("registry-seed: TEST_REGISTRY not set (call registry-start first)") } // If the source looks like an absolute reference (contains a registry host // with a dot, e.g. "docker.io/library/python:3.12-slim"), use it as-is. // Otherwise treat it as relative to the test registry. srcRef := src if !isAbsoluteImageRef(src) { srcRef = testRegistry + "/" + src } dstRef := testRegistry + "/" + dst if err := crane.Copy(srcRef, dstRef, crane.Insecure); err != nil { ts.Fatalf("registry-seed: failed to copy %s to %s: %v", srcRef, dstRef, err) } ts.Logf("registry-seed: copied %s to %s", srcRef, dstRef) } // isAbsoluteImageRef returns true if ref looks like it contains an explicit // registry host (e.g. "docker.io/library/python:3.12-slim" or // "ghcr.io/foo/bar:latest"). It checks whether the part before the first // slash contains a dot or a colon (port), which distinguishes a registry // host from a simple repository name like "alpine:latest". func isAbsoluteImageRef(ref string) bool { host, _, ok := strings.Cut(ref, "/") if !ok { return false } return strings.Contains(host, ".") || strings.Contains(host, ":") } // cmdRegistryInspect inspects a registry manifest and outputs JSON. // Usage: registry-inspect // Outputs the manifest result as JSON to stdout. func (h *Harness) cmdRegistryInspect(ts *testscript.TestScript, neg bool, args []string) { if len(args) < 1 { ts.Fatalf("registry-inspect: usage: registry-inspect ") } imageRef := os.Expand(args[0], ts.Getenv) client := registry.NewRegistryClient() result, err := client.Inspect(context.Background(), imageRef, nil) if neg { if err == nil { ts.Fatalf("registry-inspect: expected failure but succeeded") } return } if err != nil { ts.Fatalf("registry-inspect: failed to inspect %s: %v", imageRef, err) } // Output as JSON output, err := json.MarshalIndent(result, "", " ") if err != nil { ts.Fatalf("registry-inspect: failed to marshal result: %v", err) } _, _ = ts.Stdout().Write(output) _, _ = ts.Stdout().Write([]byte("\n")) } // cmdDockerPush tags and pushes a local image to the test registry. // Usage: docker-push // Example: docker-push $TEST_IMAGE test/mymodel:v1 // The image is pushed to $TEST_REGISTRY/ func (h *Harness) cmdDockerPush(ts *testscript.TestScript, neg bool, args []string) { if len(args) < 2 { ts.Fatalf("docker-push: usage: docker-push ") } localImage := os.Expand(args[0], ts.Getenv) repoTag := os.Expand(args[1], ts.Getenv) testRegistry := ts.Getenv("TEST_REGISTRY") if testRegistry == "" { ts.Fatalf("docker-push: TEST_REGISTRY not set (call registry-start first)") } remoteRef := testRegistry + "/" + repoTag // Tag the image tagCmd := exec.Command("docker", "tag", localImage, remoteRef) tagCmd.Stdout = ts.Stdout() tagCmd.Stderr = ts.Stderr() if err := tagCmd.Run(); err != nil { if neg { return } ts.Fatalf("docker-push: failed to tag image: %v", err) } // Push the image pushCmd := exec.Command("docker", "push", remoteRef) pushCmd.Stdout = ts.Stdout() pushCmd.Stderr = ts.Stderr() err := pushCmd.Run() if neg { if err == nil { ts.Fatalf("docker-push: expected failure but succeeded") } return } if err != nil { ts.Fatalf("docker-push: failed to push image: %v", err) } ts.Logf("docker-push: pushed %s to %s", localImage, remoteRef) } // ============================================================================= // Mock weights command // ============================================================================= // mockWeightsLock mirrors the structure from pkg/model/weights_lock.go // SYNC: If pkg/model/WeightsLock changes, update this copy. // We duplicate it here to avoid importing pkg/model which transitively imports pkg/wheels. type mockWeightsLock struct { Version string `json:"version"` Created time.Time `json:"created"` Files []mockWeightFile `json:"files"` } // mockWeightFile mirrors WeightFile from pkg/model/weights.go // SYNC: If pkg/model/WeightFile changes, update this copy. type mockWeightFile struct { Name string `json:"name"` Dest string `json:"dest"` DigestOriginal string `json:"digestOriginal"` Digest string `json:"digest"` Size int64 `json:"size"` SizeUncompressed int64 `json:"sizeUncompressed"` MediaType string `json:"mediaType"` ContentType string `json:"contentType,omitempty"` } // cmdMockWeights generates mock weight files and a weights.lock file. // Usage: mock-weights [--count N] [--min-size S] [--max-size S] // Defaults: // - count: 2 // - min-size: 1kb // - max-size: 10kb // // Creates files in $WORK/weights/ and writes $WORK/weights.lock func (h *Harness) cmdMockWeights(ts *testscript.TestScript, neg bool, args []string) { if neg { ts.Fatalf("mock-weights: does not support negation") } // Parse arguments count := 2 minSize := int64(1024) // 1KB maxSize := int64(10 * 1024) // 10KB for i := 0; i < len(args); i++ { switch args[i] { case "--count", "-n": if i+1 < len(args) { if n, err := strconv.Atoi(args[i+1]); err == nil { count = n } i++ } case "--min-size": if i+1 < len(args) { if size, err := parseSize(args[i+1]); err == nil { minSize = size } i++ } case "--max-size": if i+1 < len(args) { if size, err := parseSize(args[i+1]); err == nil { maxSize = size } i++ } } } workDir := ts.Getenv("WORK") weightsDir := filepath.Join(workDir, "weights") lockPath := filepath.Join(workDir, "weights.lock") // Create weights directory if err := os.MkdirAll(weightsDir, 0o755); err != nil { ts.Fatalf("mock-weights: failed to create weights dir: %v", err) } var files []mockWeightFile for i := 1; i <= count; i++ { // Random size between min and max size := minSize if maxSize > minSize { size = minSize + mathrand.Int64N(maxSize-minSize+1) //nolint:gosec // test data, not security-sensitive } // Generate identifier (e.g., "weights-001") weightName := fmt.Sprintf("weights-%03d", i) filename := weightName + ".bin" filePath := filepath.Join(weightsDir, filename) // Generate random data data := make([]byte, size) if _, err := cryptorand.Read(data); err != nil { ts.Fatalf("mock-weights: failed to generate random data: %v", err) } // Write file if err := os.WriteFile(filePath, data, 0o644); err != nil { ts.Fatalf("mock-weights: failed to write %s: %v", filename, err) } // Compute digest (uncompressed, since we're not actually compressing for tests) hash := sha256.Sum256(data) digest := "sha256:" + hex.EncodeToString(hash[:]) files = append(files, mockWeightFile{ Name: weightName, Dest: "/cache/" + filename, DigestOriginal: digest, Digest: digest, // Same as original since we're not compressing Size: size, SizeUncompressed: size, // MediaType matches production WeightBuilder output (uncompressed). MediaType: "application/vnd.cog.weight.layer.v1", ContentType: "application/octet-stream", }) } // Create weights.lock lock := mockWeightsLock{ Version: "1.0", Created: time.Now().UTC(), Files: files, } lockData, err := json.MarshalIndent(lock, "", " ") if err != nil { ts.Fatalf("mock-weights: failed to marshal weights.lock: %v", err) } if err := os.WriteFile(lockPath, lockData, 0o644); err != nil { ts.Fatalf("mock-weights: failed to write weights.lock: %v", err) } ts.Logf("mock-weights: created %d files in %s", count, weightsDir) } // parseSize parses size strings like "1kb", "10KB", "1mb" into bytes. func parseSize(s string) (int64, error) { s = strings.TrimSpace(strings.ToLower(s)) if s == "" { return 0, fmt.Errorf("empty size string") } var multiplier int64 = 1 var numStr string switch { case strings.HasSuffix(s, "gb"): multiplier = 1024 * 1024 * 1024 numStr = strings.TrimSuffix(s, "gb") case strings.HasSuffix(s, "mb"): multiplier = 1024 * 1024 numStr = strings.TrimSuffix(s, "mb") case strings.HasSuffix(s, "kb"): multiplier = 1024 numStr = strings.TrimSuffix(s, "kb") case strings.HasSuffix(s, "b"): numStr = strings.TrimSuffix(s, "b") default: numStr = s } num, err := strconv.ParseFloat(strings.TrimSpace(numStr), 64) if err != nil { return 0, fmt.Errorf("invalid number: %s", numStr) } if num < 0 { return 0, fmt.Errorf("size cannot be negative") } return int64(num * float64(multiplier)), nil } // ============================================================================= // Mock upload server commands // ============================================================================= // cmdUploadServerStart starts a mock HTTP upload server on the host. // It accepts PUT requests, records them, and responds with a Location header. // Usage: upload-server-start // Exports $UPLOAD_SERVER_URL with the server's base URL. func (h *Harness) cmdUploadServerStart(ts *testscript.TestScript, neg bool, args []string) { if neg { ts.Fatalf("upload-server-start: does not support negation") } workDir := ts.Getenv("WORK") h.uploadServersMu.Lock() if _, exists := h.uploadServers[workDir]; exists { h.uploadServersMu.Unlock() ts.Fatalf("upload-server-start: server already running for this test") } h.uploadServersMu.Unlock() mus := &mockUploadServer{} mux := http.NewServeMux() mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPut { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } body, _ := io.ReadAll(r.Body) record := mockUploadRecord{ Path: r.URL.Path, ContentType: r.Header.Get("Content-Type"), Size: len(body), } mus.mu.Lock() mus.uploads = append(mus.uploads, record) mus.mu.Unlock() // Return a clean URL without query params (simulates a signed URL redirect) location := fmt.Sprintf("http://host.docker.internal:%d%s", mus.port, r.URL.Path) w.Header().Set("Location", location) w.WriteHeader(http.StatusOK) }) mus.server = &http.Server{Handler: mux, ReadHeaderTimeout: 10 * time.Second} //nolint:gosec // test harness, not production // Bind to all interfaces so the container can reach us via host.docker.internal ln, err := net.Listen("tcp", "0.0.0.0:0") //nolint:gosec // must be reachable from Docker container if err != nil { ts.Fatalf("upload-server-start: failed to listen: %v", err) } mus.port = ln.Addr().(*net.TCPAddr).Port go func() { _ = mus.server.Serve(ln) }() h.uploadServersMu.Lock() h.uploadServers[workDir] = mus h.uploadServersMu.Unlock() // Advertise host.docker.internal so the container can reach the host server. // On Linux, cog serve adds --add-host=host.docker.internal:host-gateway. // On Mac, Docker Desktop resolves host.docker.internal automatically. url := fmt.Sprintf("http://host.docker.internal:%d/", mus.port) ts.Setenv("UPLOAD_SERVER_URL", url) ts.Logf("upload-server-start: listening on 0.0.0.0:%d, container URL: %s", mus.port, url) } // cmdUploadServerCount verifies exactly N uploads were received. // Usage: upload-server-count N func (h *Harness) cmdUploadServerCount(ts *testscript.TestScript, neg bool, args []string) { if len(args) != 1 { ts.Fatalf("upload-server-count: usage: upload-server-count N") } expected, err := strconv.Atoi(args[0]) if err != nil { ts.Fatalf("upload-server-count: invalid count %q: %v", args[0], err) } workDir := ts.Getenv("WORK") h.uploadServersMu.Lock() mus, exists := h.uploadServers[workDir] h.uploadServersMu.Unlock() if !exists { ts.Fatalf("upload-server-count: no upload server running (call upload-server-start first)") } mus.mu.Lock() got := len(mus.uploads) mus.mu.Unlock() if neg { if got == expected { ts.Fatalf("upload-server-count: expected NOT %d uploads but got %d", expected, got) } return } if got != expected { ts.Fatalf("upload-server-count: expected %d uploads but got %d", expected, got) } } // stopUploadServerByWorkDir shuts down the upload server for a work directory. func (h *Harness) stopUploadServerByWorkDir(workDir string) { h.uploadServersMu.Lock() mus, exists := h.uploadServers[workDir] if !exists { h.uploadServersMu.Unlock() return } delete(h.uploadServers, workDir) h.uploadServersMu.Unlock() if mus.server != nil { _ = mus.server.Close() } } // ============================================================================= // Webhook receiver commands // ============================================================================= // cmdWebhookServerStart starts a webhook receiver that accepts prediction callbacks. // It parses the JSON payload to extract status and measure the output size, without // ever exposing the (potentially huge) output to testscript's log buffer. // Usage: webhook-server-start // Exports $WEBHOOK_URL with the server's callback URL. func (h *Harness) cmdWebhookServerStart(ts *testscript.TestScript, neg bool, args []string) { if neg { ts.Fatalf("webhook-server-start: does not support negation") } workDir := ts.Getenv("WORK") h.webhookServersMu.Lock() if _, exists := h.webhookServers[workDir]; exists { h.webhookServersMu.Unlock() ts.Fatalf("webhook-server-start: server already running for this test") } h.webhookServersMu.Unlock() ws := &webhookServer{ done: make(chan struct{}), } mux := http.NewServeMux() mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } // Stream-parse the JSON to extract status, measure output size, and // capture metrics without holding the entire output string in memory. // Output is json.RawMessage because it can be a string (single output) // or an array (iterator/streaming output). var payload struct { Status string `json:"status"` Output json.RawMessage `json:"output"` Error string `json:"error"` Metrics json.RawMessage `json:"metrics"` } if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { http.Error(w, "bad json", http.StatusBadRequest) return } w.WriteHeader(http.StatusOK) // Only record terminal statuses switch payload.Status { case "succeeded", "failed", "canceled": default: return } ws.mu.Lock() defer ws.mu.Unlock() // Only record the first terminal callback if ws.result != nil { return } // Compute output size: for strings, use the unquoted length; // for arrays or other types, use the raw JSON byte length. outputSize := len(payload.Output) var outputStr string if json.Unmarshal(payload.Output, &outputStr) == nil { outputSize = len(outputStr) } ws.result = &webhookResult{ Status: payload.Status, OutputSize: outputSize, HasError: payload.Error != "", ErrorMessage: payload.Error, Metrics: payload.Metrics, } close(ws.done) }) ws.server = &http.Server{Handler: mux, ReadHeaderTimeout: 10 * time.Second} //nolint:gosec // Bind to all interfaces so the container can reach us via host.docker.internal ln, err := net.Listen("tcp", "0.0.0.0:0") //nolint:gosec if err != nil { ts.Fatalf("webhook-server-start: failed to listen: %v", err) } ws.port = ln.Addr().(*net.TCPAddr).Port go func() { _ = ws.server.Serve(ln) }() h.webhookServersMu.Lock() h.webhookServers[workDir] = ws h.webhookServersMu.Unlock() url := fmt.Sprintf("http://host.docker.internal:%d/webhook", ws.port) ts.Setenv("WEBHOOK_URL", url) ts.Logf("webhook-server-start: listening on 0.0.0.0:%d, container URL: %s", ws.port, url) } // cmdWebhookServerWait blocks until the webhook server receives a terminal prediction callback, // then writes a compact JSON summary to stdout for assertion with stdout/stderr matchers. // Usage: webhook-server-wait [timeout] // Default timeout: 120s func (h *Harness) cmdWebhookServerWait(ts *testscript.TestScript, neg bool, args []string) { if neg { ts.Fatalf("webhook-server-wait: does not support negation") } timeout := 120 * time.Second if len(args) > 0 { if d, err := time.ParseDuration(args[0]); err == nil { timeout = d } } workDir := ts.Getenv("WORK") h.webhookServersMu.Lock() ws, exists := h.webhookServers[workDir] h.webhookServersMu.Unlock() if !exists { ts.Fatalf("webhook-server-wait: no webhook server running (call webhook-server-start first)") } select { case <-ws.done: case <-time.After(timeout): ts.Fatalf("webhook-server-wait: timed out after %s waiting for terminal webhook", timeout) } ws.mu.Lock() result := ws.result ws.mu.Unlock() out, _ := json.Marshal(result) _, _ = ts.Stdout().Write(out) _, _ = ts.Stdout().Write([]byte("\n")) } // stopWebhookServerByWorkDir shuts down the webhook server for a work directory. func (h *Harness) stopWebhookServerByWorkDir(workDir string) { h.webhookServersMu.Lock() ws, exists := h.webhookServers[workDir] if !exists { h.webhookServersMu.Unlock() return } delete(h.webhookServers, workDir) h.webhookServersMu.Unlock() if ws.server != nil { _ = ws.server.Close() } } ================================================ FILE: integration-tests/login/login_test.go ================================================ //go:build integration // Package login provides integration tests for the cog login command. // // These tests verify: // - Generic registry login with username/password (PTY-based) // - Provider routing based on --registry flag // - Help text and CLI flags // // This test file is written in Go (not txtar) because: // - Login requires interactive input (PTY for generic provider) // - We need fine-grained control over stdin and stdout // // Note: Replicate provider token verification is tested in unit tests // (pkg/provider/replicate/replicate_test.go) since mocking the r8.im // hostname requires DNS-level changes not suitable for integration tests. package login_test import ( "bytes" "os" "os/exec" "runtime" "strings" "sync" "testing" "time" "github.com/creack/pty" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/replicate/cog/integration-tests/harness" ) // TestLoginGenericRegistryPTY tests interactive login to a generic registry. // This test uses PTY to simulate interactive terminal input. func TestLoginGenericRegistryPTY(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } // PTY tests only work reliably on Unix-like systems if runtime.GOOS == "windows" { t.Skip("PTY tests not supported on Windows") } // Get cog binary cogBinary, err := harness.ResolveCogBinary() require.NoError(t, err, "failed to resolve cog binary") // Test login to a fake generic registry // Note: This will fail at the Docker credential save step, but we can verify // the interactive prompts work correctly up to that point t.Run("prompts for username and password", func(t *testing.T) { cmd := exec.Command(cogBinary, "login", "--registry", "fake-registry.example.com") cmd.Env = append(os.Environ(), "COG_NO_UPDATE_CHECK=1") // Start with a PTY ptmx, err := pty.Start(cmd) require.NoError(t, err, "failed to start PTY") defer func() { ptmx.Close() cmd.Process.Kill() cmd.Wait() }() // Set terminal size if err := pty.Setsize(ptmx, &pty.Winsize{Rows: 24, Cols: 80}); err != nil { t.Logf("failed to set terminal size: %v", err) } // Use a single mutex-protected buffer for thread safety // This avoids the race condition of multiple goroutines reading from the PTY var bufMu bytes.Buffer var mu sync.Mutex // Start a single reader goroutine done := make(chan struct{}) go func() { tmp := make([]byte, 1024) for { select { case <-done: return default: n, err := ptmx.Read(tmp) if n > 0 { mu.Lock() bufMu.Write(tmp[:n]) mu.Unlock() } if err != nil { return } } } }() defer close(done) // Helper to get current buffer contents getOutput := func() string { mu.Lock() defer mu.Unlock() return bufMu.String() } // Helper to wait for a pattern in output with timeout waitForPattern := func(pattern string, timeout time.Duration) (string, bool) { deadline := time.Now().Add(timeout) for time.Now().Before(deadline) { output := getOutput() if strings.Contains(strings.ToLower(output), strings.ToLower(pattern)) { return output, true } time.Sleep(100 * time.Millisecond) } return getOutput(), false } // Wait for and verify username prompt output, found := waitForPattern("username", 5*time.Second) t.Logf("Output after start: %q", output) assert.Contains(t, output, "fake-registry.example.com", "expected output to mention registry host") assert.True(t, found, "expected username prompt, got: %q", output) // Send username _, err = ptmx.Write([]byte("testuser\n")) require.NoError(t, err, "failed to write username") // Wait for password prompt output, found = waitForPattern("password", 3*time.Second) t.Logf("Output after username: %q", output) assert.True(t, found, "expected password prompt, got: %q", output) // Send password (will fail at Docker credential save, but we've verified the flow) _, err = ptmx.Write([]byte("testpass\n")) require.NoError(t, err, "failed to write password") // Read final output briefly (expect failure since we can't actually save credentials) time.Sleep(2 * time.Second) output = getOutput() t.Logf("Final output: %q", output) }) t.Run("rejects empty username", func(t *testing.T) { cmd := exec.Command(cogBinary, "login", "--registry", "fake-registry.example.com") cmd.Env = append(os.Environ(), "COG_NO_UPDATE_CHECK=1") ptmx, err := pty.Start(cmd) require.NoError(t, err, "failed to start PTY") defer func() { ptmx.Close() cmd.Process.Kill() cmd.Wait() }() // Use a mutex-protected buffer for thread safety var bufMu bytes.Buffer var mu sync.Mutex // Start reader goroutine done := make(chan struct{}) go func() { tmp := make([]byte, 1024) for { select { case <-done: return default: n, err := ptmx.Read(tmp) if n > 0 { mu.Lock() bufMu.Write(tmp[:n]) mu.Unlock() } if err != nil { return } } } }() defer close(done) // Helper to check buffer contents getOutput := func() string { mu.Lock() defer mu.Unlock() return bufMu.String() } // Wait for username prompt deadline := time.Now().Add(5 * time.Second) for time.Now().Before(deadline) { if strings.Contains(strings.ToLower(getOutput()), "username") { break } time.Sleep(100 * time.Millisecond) } output := getOutput() require.Contains(t, strings.ToLower(output), "username", "did not get username prompt: %q", output) // Send empty username _, err = ptmx.Write([]byte("\n")) require.NoError(t, err, "failed to write empty username") // Wait for error about empty username deadline = time.Now().Add(5 * time.Second) for time.Now().Before(deadline) { output = getOutput() if strings.Contains(strings.ToLower(output), "empty") || strings.Contains(strings.ToLower(output), "cannot") { break } time.Sleep(100 * time.Millisecond) } output = getOutput() t.Logf("Output: %q", output) // Verify we got an error about empty username lowerOutput := strings.ToLower(output) assert.True(t, strings.Contains(lowerOutput, "empty") || strings.Contains(lowerOutput, "cannot"), "expected error about empty username, got: %q", output) }) } // TestLoginProviderRouting tests that the --registry flag correctly routes to the appropriate provider. // This test verifies the routing behavior by checking error messages and prompts. func TestLoginProviderRouting(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } // PTY tests only work reliably on Unix-like systems if runtime.GOOS == "windows" { t.Skip("PTY tests not supported on Windows") } cogBinary, err := harness.ResolveCogBinary() require.NoError(t, err, "failed to resolve cog binary") tests := []struct { name string registry string expectReplicate bool // True if we expect Replicate provider behavior }{ { name: "default registry uses Replicate", registry: "", expectReplicate: true, }, { name: "r8.im uses Replicate", registry: "r8.im", expectReplicate: true, }, { name: "custom registry uses generic", registry: "ghcr.io", expectReplicate: false, }, { name: "dockerhub uses generic", registry: "docker.io", expectReplicate: false, }, { name: "localhost uses generic", registry: "localhost:5000", expectReplicate: false, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { args := []string{"login"} if tc.registry != "" { args = append(args, "--registry", tc.registry) } cmd := exec.Command(cogBinary, args...) cmd.Env = append(os.Environ(), "COG_NO_UPDATE_CHECK=1") // Start with a PTY to handle interactive prompts ptmx, err := pty.Start(cmd) require.NoError(t, err, "failed to start PTY") defer func() { ptmx.Close() cmd.Process.Kill() cmd.Wait() }() // Read initial output var buf bytes.Buffer deadline := time.Now().Add(5 * time.Second) tmp := make([]byte, 1024) for time.Now().Before(deadline) { ptmx.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) n, _ := ptmx.Read(tmp) if n > 0 { buf.Write(tmp[:n]) // Check if we have enough output to determine provider output := buf.String() if tc.expectReplicate { // Replicate provider shows "Hit enter to get started" message if strings.Contains(output, "Hit enter") || strings.Contains(output, "browser") { t.Logf("Confirmed Replicate provider: %q", output) return } } else { // Generic provider shows "Username:" prompt directly if strings.Contains(output, "Username") { t.Logf("Confirmed Generic provider: %q", output) return } } } } output := buf.String() if tc.expectReplicate { assert.NotContains(t, output, "Username:", "expected Replicate provider, but got Generic provider with Username prompt") } else { assert.True(t, strings.Contains(output, "Username") || strings.Contains(strings.ToLower(output), "logging in"), "expected Generic provider prompts, got: %q", output) } }) } } // TestLoginEnvironmentVariable tests that COG_REGISTRY_HOST environment variable works. func TestLoginEnvironmentVariable(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } // PTY tests only work reliably on Unix-like systems if runtime.GOOS == "windows" { t.Skip("PTY tests not supported on Windows") } cogBinary, err := harness.ResolveCogBinary() require.NoError(t, err, "failed to resolve cog binary") t.Run("COG_REGISTRY_HOST sets default registry", func(t *testing.T) { cmd := exec.Command(cogBinary, "login") cmd.Env = append(os.Environ(), "COG_NO_UPDATE_CHECK=1", "COG_REGISTRY_HOST=custom-registry.example.com", ) // Start with a PTY ptmx, err := pty.Start(cmd) require.NoError(t, err, "failed to start PTY") defer func() { ptmx.Close() cmd.Process.Kill() cmd.Wait() }() // Read output var buf bytes.Buffer deadline := time.Now().Add(5 * time.Second) tmp := make([]byte, 1024) for time.Now().Before(deadline) { ptmx.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) n, _ := ptmx.Read(tmp) if n > 0 { buf.Write(tmp[:n]) // Stop early if we see the expected registry if strings.Contains(buf.String(), "custom-registry.example.com") { break } } } output := buf.String() t.Logf("Output: %s", output) // Verify the custom registry is mentioned in output assert.Contains(t, output, "custom-registry.example.com", "expected custom registry in output") // Since custom-registry.example.com is not r8.im, it should use generic provider if !strings.Contains(output, "Username") { t.Logf("Note: Generic provider should prompt for Username") } }) t.Run("--registry flag overrides COG_REGISTRY_HOST", func(t *testing.T) { cmd := exec.Command(cogBinary, "login", "--registry", "override-registry.example.com") cmd.Env = append(os.Environ(), "COG_NO_UPDATE_CHECK=1", "COG_REGISTRY_HOST=ignored-registry.example.com", ) // Start with a PTY ptmx, err := pty.Start(cmd) require.NoError(t, err, "failed to start PTY") defer func() { ptmx.Close() cmd.Process.Kill() cmd.Wait() }() // Read output var buf bytes.Buffer deadline := time.Now().Add(5 * time.Second) tmp := make([]byte, 1024) for time.Now().Before(deadline) { ptmx.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) n, _ := ptmx.Read(tmp) if n > 0 { buf.Write(tmp[:n]) // Stop early if we see the expected registry if strings.Contains(buf.String(), "override-registry.example.com") { break } } } output := buf.String() t.Logf("Output: %s", output) // Verify the override registry is used, not the env var one assert.Contains(t, output, "override-registry.example.com", "expected override registry in output") assert.NotContains(t, output, "ignored-registry.example.com", "env var registry should have been overridden, but it appeared in output") }) } // TestLoginHelp tests that the login command shows appropriate help text. func TestLoginHelp(t *testing.T) { cogBinary, err := harness.ResolveCogBinary() require.NoError(t, err, "failed to resolve cog binary") cmd := exec.Command(cogBinary, "login", "--help") cmd.Env = append(os.Environ(), "COG_NO_UPDATE_CHECK=1") output, err := cmd.CombinedOutput() require.NoError(t, err, "help command failed") helpText := string(output) t.Logf("Help text:\n%s", helpText) // Verify help contains expected information expectedStrings := []string{ "login", "registry", "--token-stdin", "container registry", // Updated description mentions "container registry" } for _, expected := range expectedStrings { assert.True(t, strings.Contains(strings.ToLower(helpText), strings.ToLower(expected)), "expected help to contain %q", expected) } // Verify help mentions both Replicate and generic registry support assert.Contains(t, helpText, "Replicate", "expected help to mention Replicate") assert.True(t, strings.Contains(helpText, "other registries") && strings.Contains(helpText, "username and password"), "expected help to mention generic registry login with username/password") } // TestLoginSuggestFor tests that similar commands are suggested. func TestLoginSuggestFor(t *testing.T) { cogBinary, err := harness.ResolveCogBinary() require.NoError(t, err, "failed to resolve cog binary") // Test that "cog auth" suggests "cog login" cmd := exec.Command(cogBinary, "auth") cmd.Env = append(os.Environ(), "COG_NO_UPDATE_CHECK=1") output, err := cmd.CombinedOutput() // We expect an error since "auth" is not a valid command if err == nil { t.Logf("Unexpected success, output: %s", output) } outputStr := string(output) t.Logf("Output for 'cog auth': %s", outputStr) // Check if login is suggested if strings.Contains(outputStr, "login") { t.Logf("'login' suggested for 'auth' command (good)") } } ================================================ FILE: integration-tests/suite_test.go ================================================ //go:build integration package integration_test import ( "fmt" "os" "os/signal" "path/filepath" "runtime" "sort" "strings" "syscall" "testing" "github.com/rogpeppe/go-internal/testscript" "github.com/stretchr/testify/require" "github.com/replicate/cog/integration-tests/harness" ) // TestMain sets up signal handling to force exit on cancellation. // Without this, go test ignores SIGTERM and keeps running when CI cancels. func TestMain(m *testing.M) { sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) go func() { sig := <-sigCh fmt.Fprintf(os.Stderr, "\nReceived %v, forcing exit...\n", sig) os.Exit(1) }() os.Exit(m.Run()) } func TestIntegration(t *testing.T) { dir := "tests" h, err := harness.New() require.NoError(t, err, "failed to create harness") files, err := filepath.Glob(filepath.Join(dir, "*.txtar")) require.NoError(t, err) sort.Strings(files) for _, f := range files { name := strings.TrimSuffix(filepath.Base(f), filepath.Ext(f)) t.Run(name, func(t *testing.T) { if !strings.HasSuffix(name, "_serial") { t.Parallel() } testscript.Run(t, testscript.Params{ Files: []string{f}, Setup: h.Setup, Cmds: h.Commands(), Condition: condition, }) }) } } // condition provides custom conditions for testscript. // Supported conditions: // - linux/linux_amd64/amd64: platform guards for specialized tests. // // Note: testscript has built-in support for [short] which checks testing.Short(). func condition(cond string) (bool, error) { negated := false for strings.HasPrefix(cond, "!") { negated = !negated cond = cond[1:] } var value bool switch cond { case "linux": value = runtime.GOOS == "linux" case "amd64": value = runtime.GOARCH == "amd64" case "linux_amd64": value = runtime.GOOS == "linux" && runtime.GOARCH == "amd64" default: return false, fmt.Errorf("unknown condition: %s", cond) } if negated { value = !value } return value, nil } ================================================ FILE: integration-tests/tests/apt_packages.txtar ================================================ # Skip for cog-dataclass and coglet (Rust) which require Python 3.10+ # Test that system packages are installed correctly via system_packages in cog.yaml # Build the image (the run command verifies git is installed) cog build -t $TEST_IMAGE # Verify the predictor works cog predict $TEST_IMAGE -i s=world stdout 'hello world' -- cog.yaml -- build: gpu: true python_version: "3.10" system_packages: - "git" run: - command: git --version predict: "predict.py:Predictor" -- predict.py -- from cog import BasePredictor class Predictor(BasePredictor): def predict(self, s: str) -> str: return "hello " + s ================================================ FILE: integration-tests/tests/async_generator_precollect.txtar ================================================ # Test that async generator output is pre-collected before response. # # Coglet collects all async generator yields into a list before sending # the response. This test verifies all items arrive in the response and # that predict_time reflects the full generation duration. cog serve curl POST /predictions '{"input":{}}' stdout '"status":"succeeded"' # All 5 items should be present in the output stdout '"output":\["chunk-0","chunk-1","chunk-2","chunk-3","chunk-4"\]' # predict_time should reflect full generation time (at least ~0.5s) stdout '"predict_time":(0\.[5-9][0-9]*|[1-9][0-9]*(\.[0-9]+)?)' -- cog.yaml -- build: python_version: "3.12" predict: "predict.py:Predictor" concurrency: max: 1 -- predict.py -- import asyncio from typing import AsyncIterator from cog import BasePredictor class Predictor(BasePredictor): async def predict(self) -> AsyncIterator[str]: for i in range(5): await asyncio.sleep(0.1) yield f"chunk-{i}" ================================================ FILE: integration-tests/tests/async_predictor.txtar ================================================ # Build the image cog build -t $TEST_IMAGE # Async prediction works cog predict $TEST_IMAGE -i s=world stdout 'hello world' -- cog.yaml -- build: python_version: "3.12" predict: "predict.py:Predictor" -- predict.py -- from cog import BasePredictor class Predictor(BasePredictor): async def predict(self, s: str) -> str: return "hello " + s ================================================ FILE: integration-tests/tests/async_sleep.txtar ================================================ # Test async predictor with sleep # Build the image cog build -t $TEST_IMAGE # Async prediction with sleep works cog predict $TEST_IMAGE -i s=sleepyhead -i sleep=0.1 stdout 'wake up sleepyhead' -- cog.yaml -- build: python_version: "3.11" predict: "predict.py:Predictor" concurrency: max: 5 -- predict.py -- import asyncio from cog import BasePredictor class Predictor(BasePredictor): async def predict(self, s: str, sleep: float) -> str: await asyncio.sleep(sleep) return f"wake up {s}" ================================================ FILE: integration-tests/tests/bad_dockerignore.txtar ================================================ # Skip for cog-dataclass and coglet (Rust) which require Python 3.10+ # Test that build fails with proper error when .cog is in .dockerignore ! cog build -t $TEST_IMAGE stderr 'The .cog tmp path cannot be ignored by docker in .dockerignore' -- cog.yaml -- build: gpu: true python_version: "3.10" predict: "predict.py:Predictor" -- .dockerignore -- .cog -- predict.py -- from cog import BasePredictor class Predictor(BasePredictor): def predict(self, s: str) -> str: return "hello " + s ================================================ FILE: integration-tests/tests/bool_input_output.txtar ================================================ # Test bool as a direct predict input and output type # Build the image cog build -t $TEST_IMAGE # Bool input and output works via JSON (true -> false) cog predict $TEST_IMAGE --json '{"flag": true}' stdout '"output": false' # Bool input and output works via JSON (false -> true) cog predict $TEST_IMAGE --json '{"flag": false}' stdout '"output": true' -- cog.yaml -- build: python_version: "3.12" predict: "predict.py:Predictor" -- predict.py -- from cog import BasePredictor class Predictor(BasePredictor): def predict(self, flag: bool) -> bool: return not flag ================================================ FILE: integration-tests/tests/build_base_image_sha.txtar ================================================ # Test that base image SHA is recorded in labels with --use-cog-base-image # Source: test_build.py::test_build_base_image_sha # # Uses a local test registry to avoid depending on the live r8.im registry. # Start local registry and seed a cog-base image from Docker Hub's python:3.12-slim registry-start registry-seed docker.io/library/python:3.12-slim cog-base:python3.12 # Build with --use-cog-base-image flag against the local registry env COG_REGISTRY_HOST=$TEST_REGISTRY cog build -t $TEST_IMAGE --use-cog-base-image # Verify the base image layer label matches one of the actual layers exec python3 -c 'import json,os,subprocess; image=os.environ["TEST_IMAGE"]; base_layer=subprocess.check_output(["docker","inspect",image,"--format={{index .Config.Labels \"run.cog.cog-base-image-last-layer-sha\"}}"], text=True).strip(); assert base_layer.startswith("sha256:"), f"Base layer label missing sha256 digest: {base_layer}"; layers=json.loads(subprocess.check_output(["docker","inspect",image,"--format={{json .RootFS.Layers}}"], text=True)); assert base_layer in layers, f"Base layer {base_layer} not found in RootFS layers"; print(base_layer)' stdout 'sha256:' -- cog.yaml -- build: python_version: "3.12" predict: predict.py:Predictor -- predict.py -- import tempfile from cog import BasePredictor, Path class Predictor(BasePredictor): def setup(self): self.foo = "foo" def predict(self, text: str, path: Path) -> Path: with open(path) as f: output = self.foo + text + f.read() tmpdir = Path(tempfile.mkdtemp()) with open(tmpdir / "output.txt", "w") as fh: fh.write(output) return tmpdir / "output.txt" ================================================ FILE: integration-tests/tests/build_cog_init.txtar ================================================ # Test that cog init creates a buildable project # Source: test_build.py::test_build_with_cog_init_templates # Initialize a new cog project cog init # Build the initialized project cog build -t $TEST_IMAGE stderr 'Image built as' # Verify the expected files were created exists cog.yaml exists predict.py ================================================ FILE: integration-tests/tests/build_cog_version_match.txtar ================================================ # Note: This test can be flaky if the wheel build time and CI run time # cross day boundaries (version date mismatch). If that happens, the test # will fail with a date component mismatch. # Test that cog version in base image contains a version number # Source: test_build.py::test_cog_install_base_image # # This test verifies that when building with --use-cog-base-image, # the installed Python cog package has a valid version number. # # Uses a local test registry to avoid depending on the live r8.im registry. # Start local registry and seed a cog-base image from Docker Hub's python:3.12-slim registry-start registry-seed docker.io/library/python:3.12-slim cog-base:python3.12 # Build using --use-cog-base-image against local registry env COG_REGISTRY_HOST=$TEST_REGISTRY cog build -t $TEST_IMAGE --use-cog-base-image=true # Compare the embedded version label with the Python package version inside the image exec python3 check_versions.py stdout '[0-9]+\.[0-9]+\.[0-9]+' -- check_versions.py -- import os import re import subprocess # SemVer pattern from https://semver.org/ SEMVER_PATTERN = r"^(?P0|[1-9]\d*)\.(?P0|[1-9]\d*)\.(?P0|[1-9]\d*)(?:-(?P(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+(?P[0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$" # PEP 440 pattern from packaging.version.VERSION_PATTERN PEP440_PATTERN = r""" v? (?: (?:(?P[0-9]+)!)? # epoch (?P[0-9]+(?:\.[0-9]+)*) # release segment (?P
                                           # pre-release
            [-_\.]?
            (?P(a|alpha|b|beta|c|rc))
            [-_\.]?
            (?P[0-9]+)?
        )?
        (?P                                          # post release
            (?:-(?P[0-9]+))
            |
            (?:
                [-_\.]?
                (?Ppost|rev|r)
                [-_\.]?
                (?P[0-9]+)?
            )
        )?
        (?P                                           # dev release
            [-_\.]?
            (?Pdev)
            [-_\.]?
            (?P[0-9]+)?
        )?
    )
    (?:\+(?P[a-z0-9]+(?:[-_\.][a-z0-9]+)*))?       # local version
"""

def assert_versions_match(semver_version, pep440_version):
    semver_re = re.compile(SEMVER_PATTERN)
    pep440_re = re.compile(PEP440_PATTERN, re.VERBOSE | re.IGNORECASE)

    semver_match = semver_re.match(semver_version)
    pep440_match = pep440_re.match(pep440_version)

    assert semver_match, f"Invalid semver version: {semver_version}"
    assert pep440_match, f"Invalid PEP 440 version: {pep440_version}"

    semver_groups = semver_match.groupdict()
    pep440_groups = pep440_match.groupdict()

    semver_release = f"{semver_groups['major']}.{semver_groups['minor']}.{semver_groups['patch']}"

    # Check base release version
    assert semver_release == pep440_groups["release"], (
        f"Release versions do not match: {semver_release} != {pep440_groups['release']}"
    )

    # Check prerelease status
    semver_pre = semver_groups["prerelease"]
    pep440_pre = pep440_groups["pre"] or pep440_groups["dev"]

    assert bool(semver_pre) == bool(pep440_pre), "Pre-release status does not match"

    if semver_pre:
        if semver_pre.startswith("alpha"):
            assert pep440_groups["pre_l"] == "a", "Alpha pre-release status does not match"
            assert not pep440_groups["dev"], "Semver pre-release cannot also be a PEP440 dev build"

        if semver_pre.startswith("beta"):
            assert pep440_groups["pre_l"] == "b", "Beta pre-release status does not match"
            assert not pep440_groups["dev"], "Semver pre-release cannot also be a PEP440 dev build"

        if semver_pre.startswith("rc"):
            assert pep440_groups["pre_l"] == "rc", "Release candidate pre-release status does not match"
            assert not pep440_groups["dev"], "Semver pre-release cannot also be a PEP440 dev build"

        if semver_pre.startswith("dev"):
            assert pep440_groups["dev_l"] == "dev", "Dev build status does not match"

    if pep440_groups["local"] is not None and semver_groups["buildmetadata"] is not None:
        # Both build metadata formats are: g.d
        # The git short hash length can vary (typically 7-9 chars) depending on
        # git settings and repo size, so we need to compare flexibly.
        # Split by '.' and compare the git hash and date parts separately.
        semver_parts = semver_groups["buildmetadata"].split(".")
        pep440_parts = pep440_groups["local"].split(".")

        # Compare git commit hash - one should be a prefix of the other
        # (e.g., "g5606e933" and "g5606e9331" both refer to the same commit)
        semver_hash = semver_parts[0] if semver_parts else ""
        pep440_hash = pep440_parts[0] if pep440_parts else ""
        hash_match = semver_hash.startswith(pep440_hash) or pep440_hash.startswith(semver_hash)
        assert hash_match, (
            f"Git commit hash does not match: {semver_hash} vs {pep440_hash}"
        )

        # Compare date parts if present (should be exact match)
        if len(semver_parts) > 1 and len(pep440_parts) > 1:
            assert semver_parts[1] == pep440_parts[1], (
                f"Date component does not match: {semver_parts[1]} != {pep440_parts[1]}"
            )

image = os.environ["TEST_IMAGE"]
label = subprocess.check_output(
    ["docker", "inspect", image, "--format={{index .Config.Labels \"run.cog.version\"}}"],
    text=True
).strip()
package = subprocess.check_output(
    ["docker", "run", "--rm", "-t", image, "python", "-c", "import cog; print(cog.__version__)"],
    text=True
).strip()

# Validate package version format
pattern = re.compile(r"^\d+\.\d+\.\d+")
assert pattern.search(package), f"Invalid package version: {package}"

# If label is "dev", skip version matching (special dev builds)
if label != "dev":
    assert pattern.search(label), f"Invalid label version: {label}"
    assert_versions_match(label, package)

print(package)

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, s: str) -> str:
        return "hello " + s


================================================
FILE: integration-tests/tests/build_gpu_labels.txtar
================================================
# Test GPU build labels
# Source: test_build.py::test_build_gpu_model_on_cpu
# Requires git init/tag for version labels

[short] skip 'requires long GPU build time'

# Setup git repo for version labels
exec git init
exec git config user.email noreply@replicate.com
exec git config user.name 'Replicate Test Bot'
exec git config commit.gpgsign false
exec git commit --allow-empty -m initial
exec git tag 0.0.1

# Build the GPU image
cog build -t $TEST_IMAGE

# Check core labels exist
exec docker inspect $TEST_IMAGE --format '{{index .Config.Labels "run.cog.version"}}'
stdout '.+'

exec docker inspect $TEST_IMAGE --format '{{index .Config.Labels "run.cog.config"}}'
stdout '"gpu":true'
stdout '"cuda":'
stdout '"cudnn":'
stdout '"python_version":"3.12"'

exec docker inspect $TEST_IMAGE --format '{{index .Config.Labels "run.cog.openapi_schema"}}'
stdout 'openapi'

# Check OCI labels for version info
exec docker inspect $TEST_IMAGE --format '{{index .Config.Labels "org.opencontainers.image.version"}}'
stdout '.+'

exec docker inspect $TEST_IMAGE --format '{{index .Config.Labels "org.opencontainers.image.revision"}}'
stdout '.+'

-- cog.yaml --
build:
  python_version: "3.12"
  gpu: true
predict: predict.py:Predictor

-- predict.py --
from cog import BasePredictor

class Predictor(BasePredictor):
    def predict(self, text: str) -> str:
        return text


================================================
FILE: integration-tests/tests/build_image_option.txtar
================================================
# Test that image: option in cog.yaml names the built image
# Source: test_build.py::test_build_names_uses_image_option_in_cog_yaml

# Build without explicit -t flag, should use image from cog.yaml
cog build

# Verify image exists with the name from cog.yaml
exec docker images --format '{{.Repository}}'
stdout 'cog-test-image-option'

# Cleanup the custom image
exec docker rmi cog-test-image-option

-- cog.yaml --
image: cog-test-image-option
build:
  python_version: "3.12"
predict: predict.py:Predictor

-- predict.py --
from cog import BasePredictor

class Predictor(BasePredictor):
    def predict(self, text: str) -> str:
        return text


================================================
FILE: integration-tests/tests/build_openapi_schema.txtar
================================================
# Test that OpenAPI schema is embedded in image labels
# Source: test_build.py::test_build_with_model

# Build the image
cog build -t $TEST_IMAGE

# Check the openapi_schema label exists and contains expected schema structure
exec docker inspect $TEST_IMAGE --format '{{index .Config.Labels "run.cog.openapi_schema"}}'
stdout '"title":"Input"'
stdout '"required":\["text","path"\]'
stdout '"text":'
stdout '"path":'
stdout '"type":"string"'
stdout '"format":"uri"'

-- cog.yaml --
build:
  python_version: "3.12"
predict: predict.py:Predictor

-- predict.py --
import tempfile

from cog import BasePredictor, Path


class Predictor(BasePredictor):
    def setup(self):
        self.foo = "foo"

    def predict(self, text: str, path: Path) -> Path:
        with open(path) as f:
            output = self.foo + text + f.read()
        tmpdir = Path(tempfile.mkdtemp())
        with open(tmpdir / "output.txt", "w") as fh:
            fh.write(output)
        return tmpdir / "output.txt"


================================================
FILE: integration-tests/tests/build_openapi_schema_complex.txtar
================================================
# Test that the OpenAPI schema for complex input/output types is correct.
#
# Verifies schema generation for:
# - Multiple input types with constraints (ge, le, choices)
# - Optional fields with defaults
# - Secret type
# - Structured BaseModel output with nested types

cog build -t $TEST_IMAGE

# Extract the schema from the image label
exec docker inspect $TEST_IMAGE --format '{{index .Config.Labels "run.cog.openapi_schema"}}'

# Input schema checks
stdout '"title":"Input"'
stdout '"temperature"'
stdout '"prompt"'
stdout '"style"'
stdout '"api_key"'
stdout '"image"'

# Constraints on temperature
stdout '"minimum":0'
stdout '"maximum":2'

# Choices for style (enum)
stdout '"enum":\["fast","balanced","quality"\]'

# Secret type renders as string with format
stdout '"x-cog-secret":true'

# Optional field has default
stdout '"default":"hello"'

# Output schema has structured type
stdout '"Output"'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from typing import Optional

from cog import BaseModel, BasePredictor, Input, Path, Secret


class Output(BaseModel):
    text: str
    score: float


class Predictor(BasePredictor):
    def predict(
        self,
        prompt: str = Input(description="The prompt", default="hello"),
        temperature: float = Input(description="Sampling temp", ge=0, le=2, default=0.7),
        style: str = Input(description="Style", choices=["fast", "balanced", "quality"], default="balanced"),
        api_key: Secret = Input(description="API key"),
        image: Optional[Path] = Input(description="Optional image", default=None),
    ) -> Output:
        return Output(text=f"generated from {prompt}", score=temperature)


================================================
FILE: integration-tests/tests/build_pip_freeze.txtar
================================================

# Test that pip freeze label is embedded in image
# Source: test_build.py::test_pip_freeze

# Build the image
cog build -t $TEST_IMAGE

# Check the pip_freeze label exists and contains expected packages
exec docker inspect $TEST_IMAGE --format '{{index .Config.Labels "run.cog.pip_freeze"}}'
stdout 'structlog'
stdout 'coglet'

-- cog.yaml --
build:
  python_version: "3.12"
predict: predict.py:Predictor

-- predict.py --
import tempfile

from cog import BasePredictor, Path


class Predictor(BasePredictor):
    def setup(self):
        self.foo = "foo"

    def predict(self, text: str, path: Path) -> Path:
        with open(path) as f:
            output = self.foo + text + f.read()
        tmpdir = Path(tempfile.mkdtemp())
        with open(tmpdir / "output.txt", "w") as fh:
            fh.write(output)
        return tmpdir / "output.txt"


================================================
FILE: integration-tests/tests/build_python313_base_image.txtar
================================================
# Test Python 3.13 works with --use-cog-base-image
# Source: test_build.py::test_python_313_base_images
#
# Uses a local test registry to avoid depending on the live r8.im registry.

# Start local registry and seed a cog-base image from Docker Hub's python:3.13-slim
registry-start
registry-seed docker.io/library/python:3.13-slim cog-base:python3.13

# Build using Python 3.13 with --use-cog-base-image against local registry
env COG_REGISTRY_HOST=$TEST_REGISTRY
cog build -t $TEST_IMAGE --use-cog-base-image

# Verify build succeeded by running a prediction
cog predict $TEST_IMAGE -i num=7
stdout '14'

-- cog.yaml --
build:
  gpu: false
  python_version: "3.13"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, num: int) -> int:
        return num * 2


================================================
FILE: integration-tests/tests/build_torch_version_required.txtar
================================================
# Test that build fails with proper error when torch is specified without version and GPU is enabled
# This validates the error message from cudasFromTorch when torch version is empty

! cog build -t $TEST_IMAGE
stderr 'torch version must be specified when using CUDA'

-- cog.yaml --
build:
  gpu: true
  python_version: "3.12"
  python_requirements: requirements.txt
predict: predict.py:Predictor

-- requirements.txt --
torch

-- predict.py --
from cog import BasePredictor

class Predictor(BasePredictor):
    def predict(self, text: str) -> str:
        return text


================================================
FILE: integration-tests/tests/ca_cert.txtar
================================================
# Test CA certificate injection via COG_CA_CERT
#
# This test verifies that custom CA certificates can be injected into
# cog-built images, allowing HTTPS connections to servers using those CAs.

# Build the HTTPS test server helper
exec go build -C $REPO_ROOT/test-helpers/https-server -o $WORK/https-server .

# Start the HTTPS test server in background using the embedded certs
exec $WORK/https-server --cert=$WORK/certs/server.crt --key=$WORK/certs/server.key --addr=:8443 &

# Wait for the server to be ready
exec sh -c 'for i in 1 2 3 4 5 6 7 8 9 10; do curl -ksf https://localhost:8443/ && exit 0; sleep 1; done; exit 1'

# Build a minimal cog image
cog build -t $TEST_IMAGE

# ============================================
# Test 1: Without CA cert, HTTPS should FAIL
# ============================================
# The container tries to curl the HTTPS server, which uses a self-signed cert.
# Without the CA cert installed, this should fail with a certificate error.
! docker-run $TEST_IMAGE curl --fail --max-time 5 https://host.docker.internal:8443/

# ============================================
# Test 2: With COG_CA_CERT, HTTPS should WORK
# ============================================
# Set the CA cert environment variable and rebuild
env COG_CA_CERT=$WORK/certs/ca.crt

cog build -t $TEST_IMAGE-with-ca

# Now the HTTPS request should succeed because the CA cert is installed
docker-run $TEST_IMAGE-with-ca curl --fail --max-time 5 https://host.docker.internal:8443/
stdout 'OK'

# Verify the environment variables are set correctly in the image
docker-run $TEST_IMAGE-with-ca printenv SSL_CERT_FILE
stdout '/etc/ssl/certs/ca-certificates.crt'

docker-run $TEST_IMAGE-with-ca printenv REQUESTS_CA_BUNDLE
stdout '/etc/ssl/certs/ca-certificates.crt'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor

class Predictor(BasePredictor):
    def predict(self, text: str) -> str:
        return text

-- certs/ca.crt --
-----BEGIN CERTIFICATE-----
MIIDDTCCAfWgAwIBAgIUayiqAjIWvavCGAlFwU4CtOlDyjgwDQYJKoZIhvcNAQEL
BQAwFjEUMBIGA1UEAwwLQ29nIFRlc3QgQ0EwHhcNMjYwMTE0MjAxMjE0WhcNMzYw
MTEyMjAxMjE0WjAWMRQwEgYDVQQDDAtDb2cgVGVzdCBDQTCCASIwDQYJKoZIhvcN
AQEBBQADggEPADCCAQoCggEBAOanC1VJPLjZrfW+hLL6FmDsnNokMU1rI8KCWjrE
G2BLWOODSOECd1TWuJ84leiwOKuqj9FXlHzf0wr/D6MnQq39R4yDHKdbHYVuwRBu
uP3M3M3LWkqs7FDcXRz2htEoSoFAfoNo85Paj8rpFYwzLsuS/DtxX2yM5ja1UAZk
SNjrWF7DY97cT9njLF2QYFLj1unWAlVKoR90cYZZ72S4QIWsTQXBNN3GR/GC80AJ
vaxC83n4fCN94vJgO4reMAlojFNlXSgqQkEf8z+SMcuzHNcV/FkNOArTXHaYLeaB
yKChtDIlHV9W0+Ifsr+qYkWCN2Aznw5Yz5bhXrFLd+BcHUcCAwEAAaNTMFEwHQYD
VR0OBBYEFDAHM6Rgg47dX2D0h7gBw365IZ6kMB8GA1UdIwQYMBaAFDAHM6Rgg47d
X2D0h7gBw365IZ6kMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEB
AMd8xNa6BO7nZBfpwrPaijLMawIkv37ngM2dkNFKPMV/Vl9urAAnCDtjFq/pCXnC
lXvja0vy6nFmRITmetabLdhBHeDe8lcV1OHgyksy6AnRT6zLu/Jw3cx5/U3zLwM7
zLZkSylhYaNVpqaTRyzdFbP5V8h/QjpL+ffYTQgNMtRT0PphV4AvAqpfJJ18Jtcc
K4jIXYMUKU4ZkAT9JUSTXa2aefudzjMMr9GwvZGn/6ZpU3Y8H/DmmizoHNQRuC97
uGnxwufInGQ9W20UnUam9May0tsea654Ebtjw7QDzvMTFIsMFVOjBVOXMuB8PAfL
ATgc+2ToYm+V3Z1f46mm/4U=
-----END CERTIFICATE-----

-- certs/server.crt --
-----BEGIN CERTIFICATE-----
MIIDRzCCAi+gAwIBAgIUa+aDQvQpf7Sq+7foxZ7MNjJKcnIwDQYJKoZIhvcNAQEL
BQAwFjEUMBIGA1UEAwwLQ29nIFRlc3QgQ0EwHhcNMjYwMTE0MjAxMjE5WhcNMzYw
MTEyMjAxMjE5WjAfMR0wGwYDVQQDDBRob3N0LmRvY2tlci5pbnRlcm5hbDCCASIw
DQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAJuuQyShcfUyXWhVzYGZ2UTo2Do1
TfHFrcHLsXB2UvuMnH67x3h1ylLM2gmMnMipVFVTL6OAhZQi9BHiRRoCepMyFJKU
RKxkXpIF6knfsmB3FsAcTKale3PUYrzTj63BCGZmnS768Drv9e48A2ZvsPTBkyuN
kjTgC8tIDt5b8xKMRDFG5ZhBxC2+nNBoLBB4ujBF32dWcxGWkFUWsIyN3oH3MsQp
ydO2FnOlc+bQI/hlanl+4CnL/TczOe5O906TL/oC8Wq5nYooGgfG2ZDfSpw7/Ver
lE0nc4Sy8+d8XL2fNWgM9ISL2sCJ1DT8dZVufmda5MIo0UMCRzCmy0Q6aqcCAwEA
AaOBgzCBgDA+BgNVHREENzA1ghRob3N0LmRvY2tlci5pbnRlcm5hbIIJbG9jYWxo
b3N0ggxodHRwcy1zZXJ2ZXKHBH8AAAEwHQYDVR0OBBYEFK/uLPc1TpgjXzMTPCMJ
Qjfx52+6MB8GA1UdIwQYMBaAFDAHM6Rgg47dX2D0h7gBw365IZ6kMA0GCSqGSIb3
DQEBCwUAA4IBAQBYvP1G1bJM6tOVixEqpxWkGd6Ghr3J/R4hzJDKtMfO8O4FdlZJ
FG9OaAXJqWmlXszXF2cD2IcdOeayR1oTDTyaId464Y5Fi5WDhaIOAuwlepvAwQod
p1xXbbI6k2n60sSvaCOT9KWwM/zMda94awc2oBYWAaqAJWtRqR+sHnpIe5PmVQN/
kMfugBZCt21v6tvOEyc94xU6XgqYbbAyZlrFL33KbDpOhnEQzq4F0fKcWit3STR1
oyzlcWaFNysaNOBjxBflBMIKMnpg3DvcK8SUHqCDYjGZzodtC/7f7/I/1zkBvouN
AdvFtZREbY0xv0tnHl0Afx46Z+hR9rPtjzWh
-----END CERTIFICATE-----

-- certs/server.key --
-----BEGIN PRIVATE KEY-----
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCbrkMkoXH1Ml1o
Vc2BmdlE6Ng6NU3xxa3By7FwdlL7jJx+u8d4dcpSzNoJjJzIqVRVUy+jgIWUIvQR
4kUaAnqTMhSSlESsZF6SBepJ37JgdxbAHEympXtz1GK804+twQhmZp0u+vA67/Xu
PANmb7D0wZMrjZI04AvLSA7eW/MSjEQxRuWYQcQtvpzQaCwQeLowRd9nVnMRlpBV
FrCMjd6B9zLEKcnTthZzpXPm0CP4ZWp5fuApy/03MznuTvdOky/6AvFquZ2KKBoH
xtmQ30qcO/1Xq5RNJ3OEsvPnfFy9nzVoDPSEi9rAidQ0/HWVbn5nWuTCKNFDAkcw
pstEOmqnAgMBAAECggEAC+dIJP3fK8NdFwQwgW9VCIrRNaoruofF4GKFv7acY7V9
pccP2msPPEODjGVe+4zO8PM6WkMSc6A0j0WAyRtVafnTTt3dXl0SShH/twROrEeO
ysOfLMLMbK/ZmNyISN3QmZvQ+u2e/rKoWD3oeKWjnyNJ8HOTsU1MOY/Z6zCWpl1K
pAiVVvDVvUzrenMLVHlbyHXPYOS+oktctVd57bCNnipG4b/i4pqw08LAkP26jSCm
yIOg0r+RocN4GhhbmzP4Plmv0JCXhQposIDhC7KfbXEWaa3nFF2nfR5QPVUF5tht
xLPVrBMac8oT2owQoerhrgCQLi3b4Lmdloz7TwEcQQKBgQDVrwajIW1z5W+HdUm6
VcJf6+I9v4vL0EXPMSK4+XwLzTkC/RCz0wPfrRGlhdrMZ1wU9zCqFw1LCp+ixsMN
tC3MWGncFY5TOVf+jqom3PTESC5O9AySRAqI1jE9BwnmDkLpQ5zkYZaF/MAg07px
CW5r43EKGL5AW4Umv0OEuAEExwKBgQC6grMIz4IuWnZhj03DPRqSx4AymyqrzMjp
kNzCb1Wz6iKFMeIIPC8Mz0iju9QBVofh8Lxj1Bdqbh3kN4NyfT8nII3i/6Cvb9Ol
agMgTq1q5BgcSl4jliDbn1gCMq9MCJf8y+xXDQCLIKNaJidXQjQZn1XK/bV1IkPx
lkuG6znLIQKBgEVcV+Ix4o5hJi+pEbKLTdnG/pwehek1hMN5ZpT2Xp6SEfR3YqmM
UFCVpAm/hkMdNdWUW1aKvwThwOmcbQoQt2ECPfJziMxY68g0VOTiig0AhQ+Zxk7g
CS9bn4X4t+zWKj//c3jqeGqrnU3KjFVOw2n/3NxzJaZMTs9B/E+jTqlXAoGBALKc
QanZVvjfBulM3BJxrMYNqZZNBGM8HNeYM+E7z54ZRW+6opRyVjh1NUIfuNqDLGPS
MAeF79qrk5KfGxGEIfttcJOHbDE17UBGsrG4xthLkU9eZKK9vb+07ApG0ZsFy896
1l1TBUc3PVgym5AzxUMYVIetyZ1f8CMmZDPThighAoGBAIaoc0LvNVW9O9wT74dr
y1ThnmpgeIfg+4yyfY3Eel5wnfdOJ+J9g0vfV+PH/2B45Jdkr82L+XjPipzP3tqP
fQnfvBCfWXqiCV5yTK4OcVCicKNRj25Oq0vUAgSua8tE6hFaIp5hnmbLYiLIyOAO
ng5irDkEC4tQl7D4WoX1Viye
-----END PRIVATE KEY-----


================================================
FILE: integration-tests/tests/cancel_async_prediction.txtar
================================================
# Test cancellation of a running async prediction.
#
# Submits an async prediction (PUT with Prefer: respond-async) that sleeps
# for 60s via asyncio.sleep, waits for it to start processing, cancels via
# the /cancel endpoint, and verifies the webhook receives "canceled" status.

[short] skip 'requires Docker build'

cog build -t $TEST_IMAGE

webhook-server-start
cog serve --upload-url http://unused/

# Submit async prediction that would run for 60s
curl -H Prefer:respond-async PUT /predictions/cancel-async-test '{"id":"cancel-async-test","input":{"duration":60},"webhook":"$WEBHOOK_URL","webhook_events_filter":["completed"]}'

# Give the prediction time to start processing
exec sleep 2

# Cancel it
curl POST /predictions/cancel-async-test/cancel '{}'

# Wait for webhook callback
webhook-server-wait 30s

# Prediction should be canceled
stdout '"status":"canceled"'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
import asyncio
from cog import BasePredictor


class Predictor(BasePredictor):
    def setup(self):
        pass

    async def predict(self, duration: float = 60.0) -> str:
        await asyncio.sleep(duration)
        return "completed"


================================================
FILE: integration-tests/tests/cancel_repeated.txtar
================================================
# Test that cancelling the same prediction multiple times doesn't panic or break.
#
# Uses time.sleep() (C-level nanosleep) — the harder cancellation case since
# the thread is blocked in native code.  Fires 3 cancel requests in quick
# succession and verifies the webhook still receives "canceled" status.

[short] skip 'requires Docker build'

cog build -t $TEST_IMAGE

webhook-server-start
cog serve --upload-url http://unused/

# Submit async prediction that sleeps for 5s (nanosleep blocks in C;
# cancel fires once sleep returns, so keep it short for CI)
curl -H Prefer:respond-async PUT /predictions/cancel-repeat-test '{"id":"cancel-repeat-test","input":{"duration":5},"webhook":"$WEBHOOK_URL","webhook_events_filter":["completed"]}'

# Give the prediction time to start processing
exec sleep 2

# Cancel it 3 times in rapid succession — none should 500 or panic
curl POST /predictions/cancel-repeat-test/cancel '{}'
curl POST /predictions/cancel-repeat-test/cancel '{}'
curl POST /predictions/cancel-repeat-test/cancel '{}'

# Wait for webhook callback
webhook-server-wait 30s

# Prediction should be canceled
stdout '"status":"canceled"'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
import time
from cog import BasePredictor


class Predictor(BasePredictor):
    def setup(self):
        pass

    def predict(self, duration: float = 60.0) -> str:
        # C-level blocking sleep (nanosleep) — harder to cancel than a busy-loop
        time.sleep(duration)
        return "completed"


================================================
FILE: integration-tests/tests/cancel_sync_prediction.txtar
================================================
# Test cancellation of a running sync prediction.
#
# Submits an async prediction (PUT with Prefer: respond-async) that busy-loops
# for 60s, waits for it to start processing, cancels via the /cancel endpoint,
# and verifies the webhook receives "canceled" status.

[short] skip 'requires Docker build'

cog build -t $TEST_IMAGE

webhook-server-start
cog serve --upload-url http://unused/

# Submit async prediction that would run for 60s
curl -H Prefer:respond-async PUT /predictions/cancel-sync-test '{"id":"cancel-sync-test","input":{"duration":60},"webhook":"$WEBHOOK_URL","webhook_events_filter":["completed"]}'

# Give the prediction time to start processing
exec sleep 2

# Cancel it
curl POST /predictions/cancel-sync-test/cancel '{}'

# Wait for webhook callback
webhook-server-wait 30s

# Prediction should be canceled
stdout '"status":"canceled"'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
import time
from cog import BasePredictor


class Predictor(BasePredictor):
    def setup(self):
        pass

    def predict(self, duration: float = 60.0) -> str:
        # Busy-loop (hits bytecode boundaries, cancellable via PyThreadState_SetAsyncExc)
        deadline = time.monotonic() + duration
        while time.monotonic() < deadline:
            pass
        return "completed"


================================================
FILE: integration-tests/tests/coglet_iterator_path_output.txtar
================================================
# Test iterator prediction with Path outputs (no upload URL — files written to disk)

cog build -t $TEST_IMAGE
cog predict $TEST_IMAGE

# cog predict writes file outputs to disk, not as base64 to stdout
stderr 'Written output to: output.0.png'
stderr 'Written output to: output.1.png'
stderr 'Written output to: output.2.png'
! stderr 'failed'

-- cog.yaml --
build:
  python_version: "3.12"
  python_packages:
    - "pillow==10.4.0"
predict: "predict.py:Predictor"

-- predict.py --
import os
import tempfile
from typing import Iterator

from cog import BasePredictor, Path
from PIL import Image


class Predictor(BasePredictor):
    def predict(self) -> Iterator[Path]:
        for color in ["red", "blue", "green"]:
            d = tempfile.mkdtemp()
            p = os.path.join(d, f"{color}.png")
            Image.new("RGB", (10, 10), color).save(p)
            yield Path(p)


================================================
FILE: integration-tests/tests/coglet_iterator_upload_url.txtar
================================================
# Test that iterator Path outputs are uploaded per-yield to --upload-url.

cog build -t $TEST_IMAGE

# Start mock upload server on the host, sets $UPLOAD_SERVER_URL
upload-server-start

cog serve --upload-url $UPLOAD_SERVER_URL

# Run a prediction — three Path outputs should be uploaded, not base64-encoded
curl POST /predictions '{"input":{}}'
stdout '"status":"succeeded"'
# Outputs should be URLs pointing at the mock server, not data URIs
stdout '"output":\["http://host.docker.internal'
! stdout 'data:image/'

# Verify the mock server received exactly 3 PUT uploads
upload-server-count 3

-- cog.yaml --
build:
  python_version: "3.12"
  python_packages:
    - "pillow==10.4.0"
predict: "predict.py:Predictor"

-- predict.py --
import os
import tempfile
from typing import Iterator

from cog import BasePredictor, Path
from PIL import Image


class Predictor(BasePredictor):
    def predict(self) -> Iterator[Path]:
        for color in ["red", "blue", "green"]:
            d = tempfile.mkdtemp()
            p = os.path.join(d, f"{color}.png")
            Image.new("RGB", (10, 10), color).save(p)
            yield Path(p)


================================================
FILE: integration-tests/tests/coglet_large_file_upload_serial.txtar
================================================
[short] skip 'large file test - slow'

# Test that a large binary file (50 MiB) is successfully uploaded via --upload-url.
#
# Verifies the upload pipeline handles large binary payloads — not just large
# strings through the IPC spill path (which coglet_large_output.txtar covers).

cog build -t $TEST_IMAGE

upload-server-start
cog serve --upload-url $UPLOAD_SERVER_URL

# Prediction generates a 50 MiB binary file
curl POST /predictions '{"input":{}}'
stdout '"status":"succeeded"'
# Output should be a URL, not inline data
stdout '"output":"http://host.docker.internal'
! stdout 'data:'

# Verify upload server received exactly 1 file
upload-server-count 1

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
import os
import tempfile

from cog import BasePredictor, Path


class Predictor(BasePredictor):
    def predict(self) -> Path:
        d = tempfile.mkdtemp()
        p = os.path.join(d, "large_output.bin")
        # Write 50 MiB of binary data
        with open(p, "wb") as f:
            chunk = b"\xde\xad\xbe\xef" * 256  # 1 KiB chunk
            for _ in range(50 * 1024):  # 50 MiB
                f.write(chunk)
        return Path(p)


================================================
FILE: integration-tests/tests/coglet_large_input.txtar
================================================
# Test that inputs larger than the 6 MiB IPC inline threshold spill to disk
# and are rehydrated correctly by the worker.
# Without input spilling this would exceed the 8 MiB LengthDelimitedCodec
# frame limit and break the bridge.
#
# Strategy: generate a JSON file with a 7 MiB padding string on the host,
# then POST it via the harness curl @file syntax. The predictor echoes
# back len(padding) to prove the full input survived the spill-rehydrate
# round-trip.
#
# Uses async prediction + webhook so the output goes directly to our
# Go webhook receiver — never through testscript's log buffer.

webhook-server-start
cog serve --upload-url http://unused/

# Generate a ~7 MiB JSON request body.
# dd produces 7340032 bytes of 'A', wrapped in a JSON prediction request.
exec sh -c 'printf "{\"id\":\"large-input-test\",\"input\":{\"padding\":\"" > large_input.json && dd if=/dev/zero bs=1024 count=7168 2>/dev/null | tr "\0" "A" >> large_input.json && printf "\"},\"webhook\":\"$WEBHOOK_URL\",\"webhook_events_filter\":[\"completed\"]}" >> large_input.json'

# POST the large payload via @file
curl -H Prefer:respond-async POST /predictions @large_input.json

# Wait for the webhook callback (up to 120s)
webhook-server-wait

# Prediction succeeded — input was spilled and rehydrated correctly
stdout '"status":"succeeded"'

# Output is the string "7340032" (len of 7 MiB padding), which is 7 bytes
stdout '"output_size":7'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, padding: str = "") -> str:
        return str(len(padding))


================================================
FILE: integration-tests/tests/coglet_large_output.txtar
================================================
# Test that outputs larger than the 8MiB IPC frame limit spill to disk
# and are reconstructed correctly by the orchestrator.
# Without spilling this would panic the bridge and poison the slot.
#
# Uses async prediction + webhook so the 9MiB output goes directly to our
# Go webhook receiver — never through testscript's log buffer.
#
# --upload-url is set to a dummy value so cog serve adds
# --add-host=host.docker.internal:host-gateway (needed on Linux for the
# webhook callback to reach the host). Nothing is actually uploaded because
# the output is a plain string, not a Path.

webhook-server-start
cog serve --upload-url http://unused/

# Async prediction — server returns 202 immediately, delivers result to webhook
curl -H Prefer:respond-async POST /predictions '{"id":"large-output-test","webhook":"$WEBHOOK_URL","webhook_events_filter":["completed"]}'

# Wait for the webhook callback (up to 120s)
webhook-server-wait

# 1. Prediction succeeded
stdout '"status":"succeeded"'

# 2. Output is correct — 9 * 1024 * 1024 = 9437184 bytes
stdout '"output_size":9437184'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self) -> str:
        # 9MiB string — exceeds the 8MiB IPC frame limit
        return "x" * (9 * 1024 * 1024)


================================================
FILE: integration-tests/tests/coglet_list_path_single_element.txtar
================================================
# Test that List[Path] with a single element returns an array, not a scalar.
# Regression test: the orchestrator must not collapse [url] → url for
# single-element list outputs. The schema declares Output as "type": "array",
# so the response must always be an array regardless of item count.

cog build -t $TEST_IMAGE

# Start mock upload server on the host, sets $UPLOAD_SERVER_URL
upload-server-start

cog serve --upload-url $UPLOAD_SERVER_URL

# Single element: output MUST be ["url"], not "url"
curl POST /predictions '{"input":{"count":1}}'
stdout '"status":"succeeded"'
# Match array with exactly one URL element (not a bare string)
stdout '"output":\["http://host.docker.internal[^"]*"\]'
! stdout 'data:image/'

upload-server-count 1

# Multiple elements: output must be ["url1","url2"]
curl POST /predictions '{"input":{"count":2}}'
stdout '"status":"succeeded"'
stdout '"output":\["http://host.docker.internal[^"]*","http://host.docker.internal[^"]*"\]'
! stdout 'data:image/'

# 1 from first prediction + 2 from second = 3 total uploads
upload-server-count 3

-- cog.yaml --
build:
  python_version: "3.12"
  python_packages:
    - "pillow==10.4.0"
predict: "predict.py:Predictor"

-- predict.py --
import os
import tempfile
from typing import List

from cog import BasePredictor, Input, Path
from PIL import Image


class Predictor(BasePredictor):
    def predict(self, count: int = Input(description="Number of images", default=1)) -> List[Path]:
        outputs = []
        colors = ["red", "blue", "green", "yellow"]
        for i in range(count):
            d = tempfile.mkdtemp()
            p = os.path.join(d, f"{colors[i % len(colors)]}.png")
            Image.new("RGB", (10, 10), colors[i % len(colors)]).save(p)
            outputs.append(Path(p))
        return outputs


================================================
FILE: integration-tests/tests/coglet_list_path_upload_url.txtar
================================================
# Test that List[Path] outputs are uploaded to --upload-url.
# This verifies that list returns (not just iterators) go through the
# FileOutput IPC path for upload instead of being base64-encoded inline.

cog build -t $TEST_IMAGE

# Start mock upload server on the host, sets $UPLOAD_SERVER_URL
upload-server-start

cog serve --upload-url $UPLOAD_SERVER_URL

# Run a prediction — three Path outputs should be uploaded, not base64-encoded
curl POST /predictions '{"input":{}}'
stdout '"status":"succeeded"'
# Outputs should be URLs pointing at the mock server, not data URIs
stdout '"output":\["http://host.docker.internal'
! stdout 'data:image/'

# Verify the mock server received exactly 3 PUT uploads
upload-server-count 3

-- cog.yaml --
build:
  python_version: "3.12"
  python_packages:
    - "pillow==10.4.0"
predict: "predict.py:Predictor"

-- predict.py --
import os
import tempfile
from typing import List

from cog import BasePredictor, Path
from PIL import Image


class Predictor(BasePredictor):
    def predict(self) -> List[Path]:
        outputs = []
        for color in ["red", "blue", "green"]:
            d = tempfile.mkdtemp()
            p = os.path.join(d, f"{color}.png")
            Image.new("RGB", (10, 10), color).save(p)
            outputs.append(Path(p))
        return outputs


================================================
FILE: integration-tests/tests/coglet_metrics.txtar
================================================
# Test that user-emitted metrics appear in sync prediction responses.
#
# Verifies:
# 1. record_metric() with default mode (replace) appears in response
# 2. Increment and append accumulation modes work
# 3. predict_time is always present (system metric)
# 4. predict_time overrides any user-set predict_time
# 5. Dict-style metrics access (scope.metrics["key"] = value)
# 6. Dot-path keys create nested objects

cog serve

# Prediction that records metrics via current_scope()
curl POST /predictions '{"input":{}}'
stdout '"status":"succeeded"'

# User metrics present
stdout '"temperature":0.7'
stdout '"token_count":3'
stdout '"tags":\["fast","cached"\]'
stdout '"timing":\{.*"preprocess":0.1'
stdout '"dict_metric":"hello"'

# System predict_time overrides user value — should be a float, not 999
stdout '"predict_time":'
! stdout '"predict_time":999'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor, current_scope


class Predictor(BasePredictor):
    def predict(self) -> str:
        scope = current_scope()

        # Replace mode (default)
        scope.record_metric("temperature", 0.7)

        # Increment mode: 1 + 2 = 3
        scope.record_metric("token_count", 1, mode="incr")
        scope.record_metric("token_count", 2, mode="incr")

        # Append mode: builds array
        scope.record_metric("tags", "fast", mode="append")
        scope.record_metric("tags", "cached", mode="append")

        # Dot-path key creates nested object
        scope.record_metric("timing.preprocess", 0.1)

        # Dict-style access
        scope.metrics["dict_metric"] = "hello"

        # User-set predict_time should be overridden by system
        scope.record_metric("predict_time", 999)

        return "ok"


================================================
FILE: integration-tests/tests/coglet_metrics_webhook.txtar
================================================
# Test that user-emitted metrics appear in webhook payloads.
#
# Uses async prediction with webhook to verify metrics flow through
# the supervisor and webhook sender. The webhook server captures the
# terminal payload including metrics for assertion.
#
# --upload-url is set to a dummy value so cog serve adds
# --add-host=host.docker.internal:host-gateway (needed on Linux for the
# webhook callback to reach the host).

webhook-server-start
cog serve --upload-url http://unused/

# Async prediction — server returns 202, delivers result to webhook
curl -H Prefer:respond-async POST /predictions '{"id":"metrics-webhook-test","webhook":"$WEBHOOK_URL","webhook_events_filter":["completed"]}'

# Wait for the webhook callback
webhook-server-wait

# Prediction succeeded
stdout '"status":"succeeded"'

# User metrics present in webhook payload
stdout '"model_version":"v2.1"'
stdout '"confidence":0.95'

# System predict_time always present
stdout '"predict_time":'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor, current_scope


class Predictor(BasePredictor):
    def predict(self) -> str:
        scope = current_scope()
        scope.record_metric("model_version", "v2.1")
        scope.record_metric("confidence", 0.95)
        return "ok"


================================================
FILE: integration-tests/tests/coglet_single_path_output.txtar
================================================
# Test that a single Path return (not List[Path]) returns a scalar string, not an array.
# This complements coglet_list_path_single_element.txtar — verifying that
# the schema-driven output wrapping correctly distinguishes:
#   -> Path       →  "url"    (scalar)
#   -> List[Path] →  ["url"]  (array, even with one element)

cog build -t $TEST_IMAGE

# Start mock upload server on the host, sets $UPLOAD_SERVER_URL
upload-server-start

cog serve --upload-url $UPLOAD_SERVER_URL

# Single Path output: must be "url" (scalar string), NOT ["url"]
curl POST /predictions '{"input":{}}'
stdout '"status":"succeeded"'
# Output should be a bare string URL, not wrapped in an array
stdout '"output":"http://host.docker.internal'
! stdout '"output":\['
! stdout 'data:image/'

upload-server-count 1

-- cog.yaml --
build:
  python_version: "3.12"
  python_packages:
    - "pillow==10.4.0"
predict: "predict.py:Predictor"

-- predict.py --
import os
import tempfile

from cog import BasePredictor, Path
from PIL import Image


class Predictor(BasePredictor):
    def predict(self) -> Path:
        d = tempfile.mkdtemp()
        p = os.path.join(d, "output.png")
        Image.new("RGB", (10, 10), "red").save(p)
        return Path(p)


================================================
FILE: integration-tests/tests/complex_output.txtar
================================================
# Test complex/structured output type using cog.BaseModel (dataclass)

# Build the image
cog build -t $TEST_IMAGE

# Predict returns structured output
cog predict $TEST_IMAGE -i msg='test error'
stdout '"success": false'
stdout '"error": "test error"'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from typing import Optional

from cog import BaseModel, BasePredictor, Path


class ModelOutput(BaseModel):
    success: bool
    error: Optional[str]
    segmented_image: Optional[Path]


class Predictor(BasePredictor):
    def predict(self, msg: str) -> ModelOutput:
        return ModelOutput(success=False, error=msg, segmented_image=None)


================================================
FILE: integration-tests/tests/concatenate_iterator_output.txtar
================================================
# Test ConcatenateIterator[str] as predict output type
#
# ConcatenateIterator is the primary streaming text output type for LLMs.
# cog predict renders each yielded token as an array element.

# Build the image
cog build -t $TEST_IMAGE

# Streaming output yields individual tokens
cog predict $TEST_IMAGE -i prompt=hello
stdout '"hello"'
stdout '" world"'
stdout '" !"'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor, ConcatenateIterator


class Predictor(BasePredictor):
    def predict(self, prompt: str) -> ConcatenateIterator[str]:
        for token in [prompt, " world", " !"]:
            yield token


================================================
FILE: integration-tests/tests/config_subdirectory.txtar
================================================
# Test that cog.yaml is discovered from subdirectories
# Source: test_config.py::test_config

# Create a subdirectory structure
mkdir some
mkdir some/sub
mkdir some/sub/dir

# Run cog from a subdirectory - it should find cog.yaml in parent
cd some/sub/dir
cog run echo hello world
stdout 'hello world'

-- cog.yaml --
build:
  python_version: "3.12"


================================================
FILE: integration-tests/tests/debug_secrets.txtar
================================================
# Test that cog debug shows secret mount syntax in Dockerfile output
# Source: test_run.py::test_run_with_secret

# Run cog debug to see the generated Dockerfile
cog debug
stdout 'RUN echo hello world'
stdout 'RUN --mount=type=secret,id=foo,target=secret.txt echo shh'

-- cog.yaml --
build:
  python_version: "3.12"
  run:
    - echo hello world
    - command: >-
        echo shh
      mounts:
        - type: secret
          id: foo
          target: secret.txt
predict: "predict.py:Predictor"

-- secret.txt --
secret content here

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, s: str) -> str:
        return "hello " + s


================================================
FILE: integration-tests/tests/dict_output.txtar
================================================
# Test bare dict return type works for predict output

# Build the image
cog build -t $TEST_IMAGE

# Predict returns a dict
cog predict $TEST_IMAGE -i name=alice
stdout '"greeting": "hello alice"'
stdout '"length": 5'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, name: str) -> dict:
        return {"greeting": "hello " + name, "length": len(name)}


================================================
FILE: integration-tests/tests/emit_metric_deprecated.txtar
================================================
# Test that a predictor using the deprecated emit_metric() still builds,
# runs, and records metrics correctly.
#
# emit_metric() was dropped without a compat shim in 0.17.0, causing a hard
# ImportError on startup for models that use it. This test covers the two
# call styles used across replicate/* models:
#
#   from cog import emit_metric; emit_metric("key", value)
#   cog.emit_metric("key", value)

cog serve

# Prediction works and the metric appears in the response
curl POST /predictions '{"input":{}}'
stdout '"status":"succeeded"'
stdout '"output_tokens":42'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
import cog
from cog import BasePredictor, emit_metric


class Predictor(BasePredictor):
    def predict(self) -> str:
        # Both call styles should work
        emit_metric("output_tokens", 42)
        cog.emit_metric("input_tokens", 10)
        return "ok"


================================================
FILE: integration-tests/tests/env_vars.txtar
================================================
# Build the image
cog build -t $TEST_IMAGE

# Environment variables are set
cog predict $TEST_IMAGE -i name=TEST_VAR
stdout 'test_value'

cog predict $TEST_IMAGE -i name=NAME
stdout 'michael'

-- cog.yaml --
predict: "predict.py:Predictor"
build:
  python_version: "3.12"
environment:
  - NAME=michael
  - TEST_VAR=test_value

-- predict.py --
from cog import BasePredictor
import os

class Predictor(BasePredictor):
    def predict(self, name: str) -> str:
        return f"ENV[{name}]={os.getenv(name)}"


================================================
FILE: integration-tests/tests/experimental_feature_warning.txtar
================================================
# Test that a predictor importing the deprecated ExperimentalFeatureWarning
# still builds and runs successfully.

# Build the image
cog build -t $TEST_IMAGE

# Prediction works despite the deprecated import
cog predict $TEST_IMAGE -i s=world
stdout 'hello world'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
import warnings
from cog import BasePredictor, ExperimentalFeatureWarning

warnings.filterwarnings("ignore", category=ExperimentalFeatureWarning)

class Predictor(BasePredictor):
    def predict(self, s: str) -> str:
        return "hello " + s


================================================
FILE: integration-tests/tests/ffmpeg_package.txtar
================================================
# Test that ffmpeg system package is installed (common ML dependency)

# Build the image (the run command verifies ffmpeg is installed)
cog build -t $TEST_IMAGE

# Verify the predictor works
cog predict $TEST_IMAGE -i s=world
stdout 'hello world'

-- cog.yaml --
build:
  gpu: true
  python_version: "3.12"
  python_packages:
    - "torch==2.5.1"
  cuda: "12.4"
  run:
    - command: ffmpeg --help
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, s: str) -> str:
        return "hello " + s


================================================
FILE: integration-tests/tests/file_input.txtar
================================================
# Build the image
cog build -t $TEST_IMAGE

# File input works
cog predict $TEST_IMAGE -i file=@input.txt
stdout 'hello from file'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor, Path

class Predictor(BasePredictor):
    def predict(self, file: Path) -> str:
        return file.read_text()

-- input.txt --
hello from file


================================================
FILE: integration-tests/tests/file_list_input.txtar
================================================

# Test list[File] input type (multiple file inputs using File type)

# Build the image
cog build -t $TEST_IMAGE

# Predict with multiple file inputs
cog predict $TEST_IMAGE -i files=@file1.txt -i files=@file2.txt
stdout 'content one'
stdout 'content two'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor, File


class Predictor(BasePredictor):
    def predict(self, files: list[File]) -> str:
        output_parts = []
        for f in files:
            content = f.read()
            if isinstance(content, bytes):
                content = content.decode('utf-8')
            output_parts.append(content)
        return "\n\n".join(output_parts)

-- file1.txt --
content one
-- file2.txt --
content two


================================================
FILE: integration-tests/tests/float_input_output.txtar
================================================
# Test float input and output types work correctly
cog build -t $TEST_IMAGE

# Float input and output works
cog predict $TEST_IMAGE -i num=10
stdout '20'

# Negative numbers work
cog predict $TEST_IMAGE -i num=-10
stdout '-20'

-- cog.yaml --
build:
  python_version: "3.11"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor, Input

class Predictor(BasePredictor):
    def predict(
        self, num: float = Input(description="Number of things")
    ) -> float:
        return num * 2.0


================================================
FILE: integration-tests/tests/function_predictor.txtar
================================================
# Test function-based predictor (no class, just a function)

# Build the image
cog build -t $TEST_IMAGE

# Predict using function-based predictor
cog predict $TEST_IMAGE -i prompt=world
stdout 'HELLO WORLD'

-- cog.yaml --
build:
  python_version: "3.13"
predict: "predict.py:run"

-- predict.py --
from cog import Input


def run(
    prompt: str = Input(),
) -> str:
    return f"HELLO {prompt.upper()}"


================================================
FILE: integration-tests/tests/future_annotations.txtar
================================================

# Test from __future__ import annotations support
# This tests that future annotations work correctly with Path types

# Build the image
cog build -t $TEST_IMAGE

# Predict - creates and returns an image
cog predict $TEST_IMAGE

# Verify output file was created
exec test -f output.png

-- cog.yaml --
build:
  python_version: "3.12"
  python_packages:
    - pillow==10.0.0
predict: "predict.py:Predictor"

-- predict.py --
from __future__ import annotations

from cog import BasePredictor, Input, Path
from PIL import Image


class Predictor(BasePredictor):
    def predict(self) -> Path:
        """Create and return a simple test image."""
        # Create a simple red image
        img = Image.new("RGB", (100, 100), color="red")
        output_path = Path("/tmp/output.png")
        img.save(output_path)
        return output_path


================================================
FILE: integration-tests/tests/glb_project.txtar
================================================
# Test GLB (3D model) file output

# Create a minimal GLB placeholder file (GLB files start with "glTF" magic bytes)
exec sh -c 'printf "glTF" > mesh.glb'

# Build and predict - verifies GLB file output works
cog build -t $TEST_IMAGE
cog predict $TEST_IMAGE

-- cog.yaml --
build:
  python_version: "3.13"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor, Path


class Predictor(BasePredictor):
    def setup(self) -> None:
        if not Path("mesh.glb").exists():
            raise ValueError("Example file mesh.glb does not exist")

    def predict(self) -> Path:
        return Path("mesh.glb")


================================================
FILE: integration-tests/tests/granite_project.txtar
================================================
# Test that Pydantic 2 is not clobbered to a <2 version

# Build the image
cog build -t $TEST_IMAGE

# Predict and verify pydantic version is preserved
cog predict $TEST_IMAGE
stdout '2.11.9'

-- cog.yaml --
build:
  python_version: "3.11"
  python_requirements: requirements.txt
predict: "predict.py:Predictor"

-- requirements.txt --
pydantic==2.11.9

-- predict.py --
from cog import BasePredictor
import pydantic


class Predictor(BasePredictor):
    def predict(self) -> str:
        return pydantic.__version__


================================================
FILE: integration-tests/tests/healthcheck.txtar
================================================
# Test custom healthcheck functionality
# This tests the user-defined healthcheck() method in predictors

# Build the image
cog build -t $TEST_IMAGE

# Start the server
cog serve

# Test 1: Healthy healthcheck returns READY status
curl GET /health-check
stdout '"status":"READY"'
! stdout 'user_healthcheck_error'

# Test 2: Make a prediction to ensure predictor works
curl POST /predictions '{"input":{"text":"world"}}'
stdout '"output":"hello world"'

# Test 3: Health check still works after prediction
curl GET /health-check
stdout '"status":"READY"'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor

class Predictor(BasePredictor):
    def predict(self, text: str) -> str:
        return f"hello {text}"
    
    def healthcheck(self) -> bool:
        """Custom healthcheck that always returns healthy."""
        return True


================================================
FILE: integration-tests/tests/healthcheck_async.txtar
================================================
# Test async healthcheck function
# Ensures async healthchecks work correctly

# Build the image
cog build -t $TEST_IMAGE

# Start the server
cog serve

# Test async healthcheck returns healthy
curl GET /health-check
stdout '"status":"READY"'
! stdout 'user_healthcheck_error'

# Make a prediction
curl POST /predictions '{"input":{"text":"world"}}'
stdout '"output":"hello world"'

# Healthcheck should still work
curl GET /health-check
stdout '"status":"READY"'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
import asyncio
from cog import BasePredictor

class Predictor(BasePredictor):
    def predict(self, text: str) -> str:
        return f"hello {text}"

    async def healthcheck(self) -> bool:
        """Async healthcheck function."""
        # Simulate async operation
        await asyncio.sleep(0.1)
        return True


================================================
FILE: integration-tests/tests/healthcheck_async_exception.txtar
================================================
# Test async healthcheck that raises an exception
# Tests that async def healthcheck() exceptions are handled

# Build the image
cog build -t $TEST_IMAGE

# Start the server
cog serve

# Test: Async healthcheck raising exception gives UNHEALTHY with error
curl GET /health-check
stdout '"status":"UNHEALTHY"'
stdout 'user_healthcheck_error'
stdout 'Async healthcheck error'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor
import asyncio

class Predictor(BasePredictor):
    async def predict(self, text: str) -> str:
        return f"hello {text}"
    
    async def healthcheck(self) -> bool:
        """Async healthcheck that raises an exception."""
        await asyncio.sleep(0.01)
        raise RuntimeError("Async healthcheck error")


================================================
FILE: integration-tests/tests/healthcheck_async_timeout.txtar
================================================
# Test async healthcheck timeout behavior
# Tests that async def healthcheck() timeouts are handled (5 second limit)

# Build the image
cog build -t $TEST_IMAGE

# Start the server
cog serve

# Server should be healthy initially
curl GET /health-check
stdout '"status":"READY"'

# Trigger slow healthcheck mode via prediction
curl POST /predictions '{"input":{"text":"trigger_slow"}}'
stdout '"status":"succeeded"'

# Now healthcheck should timeout and return UNHEALTHY
curl GET /health-check
stdout '"status":"UNHEALTHY"'
stdout 'user_healthcheck_error'
stdout 'timed out after 5.0 seconds'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor
import asyncio

class Predictor(BasePredictor):
    def setup(self) -> None:
        self._slow_mode = False

    async def predict(self, text: str) -> str:
        if text == "trigger_slow":
            self._slow_mode = True
        return f"hello {text}"
    
    async def healthcheck(self) -> bool:
        """Async healthcheck that times out when triggered."""
        if self._slow_mode:
            await asyncio.sleep(10)  # Sleep longer than 5s timeout
        return True


================================================
FILE: integration-tests/tests/healthcheck_async_unhealthy.txtar
================================================
# Test async unhealthy healthcheck behavior
# Tests that async def healthcheck() returning False works correctly

# Build the image
cog build -t $TEST_IMAGE

# Start the server
cog serve

# Test: Async healthcheck returning False gives UNHEALTHY
curl GET /health-check
stdout '"status":"UNHEALTHY"'
stdout 'user_healthcheck_error'
stdout 'returned False'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor
import asyncio

class Predictor(BasePredictor):
    async def predict(self, text: str) -> str:
        return f"hello {text}"
    
    async def healthcheck(self) -> bool:
        """Async healthcheck that returns unhealthy."""
        await asyncio.sleep(0.01)
        return False


================================================
FILE: integration-tests/tests/healthcheck_during_prediction.txtar
================================================
# Test healthcheck called during active prediction
# Ensures healthcheck pipe doesn't interfere with prediction pipe

# Build the image
cog build -t $TEST_IMAGE

# Start the server with concurrency enabled
cog serve

# Start a long-running prediction in background (5 seconds)
exec bash -c 'curl -s -X POST $SERVER_URL/predictions -H "Content-Type: application/json" -d "{\"input\":{\"sleep_time\":5}}" > /tmp/prediction.json &'

# Wait for prediction to start (500ms)
exec sleep 0.5

# Call healthcheck while prediction is running
curl GET /health-check
stdout '"status":"READY"'

# Call healthcheck again (multiple times during prediction)
exec sleep 1
curl GET /health-check
stdout '"status":"READY"'

exec sleep 1
curl GET /health-check
stdout '"status":"READY"'

# Wait for prediction to complete
exec sleep 3

# Verify prediction succeeded
exec bash -c 'cat /tmp/prediction.json | grep -q "\"output\":\"slept for 5 seconds\""'

# Healthcheck should still work after prediction completes
curl GET /health-check
stdout '"status":"READY"'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
import time
from cog import BasePredictor

class Predictor(BasePredictor):
    def predict(self, sleep_time: int) -> str:
        """Sleep for specified seconds."""
        time.sleep(sleep_time)
        return f"slept for {sleep_time} seconds"

    def healthcheck(self) -> bool:
        """Healthcheck should work during predictions."""
        return True


================================================
FILE: integration-tests/tests/healthcheck_exception.txtar
================================================
# Test healthcheck that raises an exception
# This tests error handling when healthcheck throws

# Build the image
cog build -t $TEST_IMAGE

# Start the server
cog serve

# Exception in healthcheck should return UNHEALTHY with error message
curl GET /health-check
stdout '"status":"UNHEALTHY"'
stdout 'user_healthcheck_error'
stdout 'Critical system error'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor

class Predictor(BasePredictor):
    def setup(self) -> None:
        self._healthcheck_calls = 0

    def predict(self, text: str) -> str:
        return f"hello {text}"
    
    def healthcheck(self) -> bool:
        """Healthcheck that raises an exception after startup."""
        self._healthcheck_calls += 1
        if self._healthcheck_calls == 1:
            return True
        raise RuntimeError("Critical system error")


================================================
FILE: integration-tests/tests/healthcheck_immediately_after_prediction.txtar
================================================
# Test healthcheck immediately after prediction completes
# Ensures healthcheck works correctly in quick succession with predictions

# Build the image
cog build -t $TEST_IMAGE

# Start the server
cog serve

# Make a quick prediction
curl POST /predictions '{"input":{"text":"world"}}'
stdout '"output":"hello world"'

# Immediately call healthcheck after prediction (no delay)
curl GET /health-check
stdout '"status":"READY"'

# Do it again - rapid fire prediction + healthcheck
curl POST /predictions '{"input":{"text":"again"}}'
stdout '"output":"hello again"'

curl GET /health-check
stdout '"status":"READY"'

# One more time to ensure pattern holds
curl POST /predictions '{"input":{"text":"final"}}'
stdout '"output":"hello final"'

curl GET /health-check
stdout '"status":"READY"'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor

class Predictor(BasePredictor):
    def predict(self, text: str) -> str:
        return f"hello {text}"

    def healthcheck(self) -> bool:
        """Healthcheck should work immediately after predictions."""
        return True


================================================
FILE: integration-tests/tests/healthcheck_repeated_calls.txtar
================================================
# Test repeated healthcheck calls (simulates supervisor container polling pattern)
# This ensures no resource leaks over many healthcheck calls

# Build the image
cog build -t $TEST_IMAGE

# Start the server
cog serve

# Call healthcheck 50 times sequentially (simulates supervisor polling)
# Each call should succeed and return READY status
exec bash -c 'for i in {1..50}; do curl -s $SERVER_URL/health-check | grep -q "\"status\":\"READY\"" || exit 1; done'

# Make a prediction to ensure system still works after many healthchecks
curl POST /predictions '{"input":{"text":"world"}}'
stdout '"output":"hello world"'

# Healthcheck should still work after prediction
curl GET /health-check
stdout '"status":"READY"'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor

class Predictor(BasePredictor):
    def __init__(self):
        self.call_count = 0

    def predict(self, text: str) -> str:
        return f"hello {text}"

    def healthcheck(self) -> bool:
        """Custom healthcheck that tracks call count."""
        self.call_count += 1
        # Always return healthy
        return True


================================================
FILE: integration-tests/tests/healthcheck_timeout.txtar
================================================
# Test healthcheck timeout behavior
# This tests when sync healthcheck takes too long (>5 seconds)

# Build the image
cog build -t $TEST_IMAGE

# Start the server
cog serve

# Server should be healthy initially
curl GET /health-check
stdout '"status":"READY"'

# Trigger slow healthcheck mode via prediction
curl POST /predictions '{"input":{"text":"trigger_slow"}}'
stdout '"status":"succeeded"'

# Now healthcheck should timeout and return UNHEALTHY
curl GET /health-check
stdout '"status":"UNHEALTHY"'
stdout 'user_healthcheck_error'
stdout 'timed out after 5.0 seconds'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
import time
from cog import BasePredictor

class Predictor(BasePredictor):
    def setup(self) -> None:
        self._slow_mode = False

    def predict(self, text: str) -> str:
        if text == "trigger_slow":
            self._slow_mode = True
        return f"hello {text}"
    
    def healthcheck(self) -> bool:
        """Sync healthcheck that times out when triggered."""
        if self._slow_mode:
            time.sleep(10)  # Sleep longer than 5s timeout
        return True


================================================
FILE: integration-tests/tests/healthcheck_unhealthy.txtar
================================================
# Test unhealthy healthcheck behavior
# This tests when healthcheck returns False

# Build the image
cog build -t $TEST_IMAGE

# Start the server
cog serve

# Unhealthy healthcheck should return UNHEALTHY status
curl GET /health-check
stdout '"status":"UNHEALTHY"'
stdout 'user_healthcheck_error'
stdout 'user-defined healthcheck returned False'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor

class Predictor(BasePredictor):
    def setup(self) -> None:
        self._healthcheck_calls = 0

    def predict(self, text: str) -> str:
        return f"hello {text}"
    
    def healthcheck(self) -> bool:
        """Unhealthy healthcheck after startup."""
        self._healthcheck_calls += 1
        if self._healthcheck_calls == 1:
            return True
        return False


================================================
FILE: integration-tests/tests/int_input_output.txtar
================================================
# Test integer input and output types work correctly
cog build -t $TEST_IMAGE

# Integer input and output works
cog predict $TEST_IMAGE -i num=10
stdout '20'

# Negative numbers work
cog predict $TEST_IMAGE -i num=-10
stdout '-20'

-- cog.yaml --
build:
  python_version: "3.11"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor, Input

class Predictor(BasePredictor):
    def predict(
        self, num: int = Input(description="Number of things")
    ) -> int:
        return num * 2


================================================
FILE: integration-tests/tests/int_none_output.txtar
================================================

# Test int return type that returns None
# This tests the handling of None values for typed outputs

# Build the image
cog build -t $TEST_IMAGE

# Predict returns None despite int type annotation
# When None is returned, cog shows "No output generated" on stderr
cog predict $TEST_IMAGE
stderr 'No output generated'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self) -> int:
        return None


================================================
FILE: integration-tests/tests/int_predictor.txtar
================================================
# Build the image
cog build -t $TEST_IMAGE

# Integer input and output works
cog predict $TEST_IMAGE -i num=5
stdout '10'

# Negative numbers work
cog predict $TEST_IMAGE -i num=-3
stdout '-6'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor

class Predictor(BasePredictor):
    def predict(self, num: int) -> int:
        return num * 2


================================================
FILE: integration-tests/tests/invalid_int_validation.txtar
================================================

# Test input validation with ge/le constraints
# Build should fail because default=1 violates ge=2 constraint

# Build should fail at schema validation
! cog build -t $TEST_IMAGE
stderr 'invalid'
stderr 'minimum'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor, Input


class Predictor(BasePredictor):
    def predict(
        self, num: int = Input(description="Number of things", default=1, ge=2, le=10)
    ) -> int:
        return num * 2


================================================
FILE: integration-tests/tests/iterator_error_midstream.txtar
================================================
# Test that a generator which yields items then raises an exception
# correctly reports failure. This is a common real-world failure mode
# (model produces partial output then hits an error).

webhook-server-start
cog serve --upload-url http://unused/

# Async prediction — generator yields 3 items then raises
curl -H Prefer:respond-async POST /predictions '{"id":"iter-error-test","webhook":"$WEBHOOK_URL","webhook_events_filter":["completed"]}'

webhook-server-wait

# Prediction should fail
stdout '"status":"failed"'
stdout '"has_error":true'
stdout '"error_message":".*generator exploded"'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from typing import Iterator

from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self) -> Iterator[str]:
        yield "chunk-1"
        yield "chunk-2"
        yield "chunk-3"
        raise RuntimeError("generator exploded")


================================================
FILE: integration-tests/tests/iterator_string_output.txtar
================================================
# Test Iterator[str] as predict output type
#
# Iterator[str] yields individual string items as an array.

# Build the image
cog build -t $TEST_IMAGE

# Iterator output returns items
cog predict $TEST_IMAGE -i count=3
stdout 'item-0'
stdout 'item-1'
stdout 'item-2'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from typing import Iterator

from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, count: int) -> Iterator[str]:
        for i in range(count):
            yield f"item-{i}"


================================================
FILE: integration-tests/tests/legacy_sdk_schema.txtar
================================================
# Test that building with a legacy SDK (< 0.17.0) falls back to runtime
# schema generation instead of using the static Go tree-sitter parser.
#
# SDK 0.16.12 uses pydantic-based runtime introspection for schema generation
# and predates coglet. Coglet is not installed because:
#   1. No explicit COGLET_WHEEL is set
#   2. The SDK dependency handles it — and 0.16.12 doesn't depend on coglet

# Override SDK wheel to use PyPI 0.16.12 (legacy, pre-coglet)
env COG_SDK_WHEEL=pypi:0.16.12

# Build should succeed — without COG_STATIC_SCHEMA=1 the build
# automatically uses the legacy runtime schema generation path.
cog build -t $TEST_IMAGE

# Predict should work with the legacy SDK's built-in Python HTTP server
cog predict $TEST_IMAGE -i s=world
stdout 'hello world'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor

class Predictor(BasePredictor):
    def predict(self, s: str) -> str:
        return "hello " + s


================================================
FILE: integration-tests/tests/list_int_input_output.txtar
================================================
# Test list[int] as predict input and output types

# Build the image
cog build -t $TEST_IMAGE

# List of ints works as input and output
cog predict $TEST_IMAGE --json '{"numbers": [1, 2, 3]}'
stdout '"status": "succeeded"'
stdout '"output":'
stdout '2'
stdout '4'
stdout '6'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, numbers: list[int]) -> list[int]:
        return [n * 2 for n in numbers]


================================================
FILE: integration-tests/tests/list_string_output.txtar
================================================
# Test list[str] as predict output type

# Build the image
cog build -t $TEST_IMAGE

# List output returns items
cog predict $TEST_IMAGE -i text='hello world foo'
stdout 'hello'
stdout 'world'
stdout 'foo'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, text: str) -> list[str]:
        return text.split()


================================================
FILE: integration-tests/tests/many_inputs.txtar
================================================

# Test predictor with many different input types

# Build the image
cog build -t $TEST_IMAGE

# Predict with various input types
cog predict $TEST_IMAGE -i no_default=hello -i path=@path.txt -i image=@image.jpg -i choices=foo -i int_choices=3
stdout 'hello default 20 world jpg foo 6'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor, Input, Path


class Predictor(BasePredictor):
    def predict(
        self,
        no_default: str,
        default_without_input: str = "default",
        input_with_default: int = Input(default=10),
        path: Path = Input(description="Some path"),
        image: Path = Input(description="Some path"),
        choices: str = Input(choices=["foo", "bar"]),
        int_choices: int = Input(description="hello", choices=[3, 4, 5]),
    ) -> str:
        with path.open() as f:
            path_contents = f.read().strip()
        image_extension = str(image).split(".")[-1]
        return (
            no_default
            + " "
            + default_without_input
            + " "
            + str(input_with_default * 2)
            + " "
            + path_contents
            + " "
            + image_extension
            + " "
            + choices
            + " "
            + str(int_choices * 2)
        )

-- path.txt --
world
-- image.jpg --
fake image content


================================================
FILE: integration-tests/tests/multi_file_schema.txtar
================================================
# Test that schema generation works when the output type is defined in a
# separate Python module. This exercises the cross-file model resolution:
# the predictor imports Output from output_types.py, and the static Go parser
# finds and parses output_types.py to resolve the BaseModel fields.
#
# Covers:
#   - Output type imported from a sibling module (from output_types import Output)
#   - BaseModel fields appear correctly in the OpenAPI schema label
#   - Prediction works end-to-end with the multi-file setup

# Opt in to static schema generation
env COG_STATIC_SCHEMA=1

# Build the image
cog build -t $TEST_IMAGE

# Verify schema is in Docker label with correct structure
exec docker inspect $TEST_IMAGE --format '{{index .Config.Labels "run.cog.openapi_schema"}}'

# Output schema should be an object with fields from output_types.py
stdout '"type":"object"'
stdout '"text":'
stdout '"score":'
stdout '"tags":'

# Input should have the prompt field
stdout '"prompt":'
stdout '"required":\["prompt"\]'

# Predict should work end-to-end
cog predict $TEST_IMAGE -i prompt=hello
stdout '"text": "hello"'
stdout '"score": 1'
stdout '"tags"'

-- cog.yaml --
build:
  python_version: "3.12"
  python_packages:
    - "pydantic>2"
predict: "predict.py:Predictor"

-- output_types.py --
from pydantic import BaseModel


class Output(BaseModel):
    text: str
    score: float
    tags: list[str]

-- predict.py --
from cog import BasePredictor
from output_types import Output


class Predictor(BasePredictor):
    def predict(self, prompt: str) -> Output:
        return Output(text=prompt, score=1.0, tags=["default"])


================================================
FILE: integration-tests/tests/nested_output_types.txtar
================================================
# Test structured output type with multiple field types.
#
# Verifies that coglet correctly serializes a BaseModel output containing:
# - Multiple primitive types (str, int, float, bool)
# - Optional fields (with value and without)
# - List of primitive types
#
# Note: nested BaseModel fields (e.g., a field typed as another BaseModel)
# are NOT supported by cog's type system. This test covers the supported
# complex output patterns.

cog serve

curl POST /predictions '{"input":{"name":"test"}}'
stdout '"status":"succeeded"'
stdout '"name":"test"'
stdout '"count":42'
stdout '"score":0.95'
stdout '"passed":true'
stdout '"tags":\["fast","cached"\]'
stdout '"note":"extra info"'
stdout '"extra":null'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from typing import List, Optional

from cog import BaseModel, BasePredictor


class Output(BaseModel):
    name: str
    count: int
    score: float
    passed: bool
    tags: List[str]
    note: Optional[str]
    extra: Optional[str]


class Predictor(BasePredictor):
    def predict(self, name: str) -> Output:
        return Output(
            name=name,
            count=42,
            score=0.95,
            passed=True,
            tags=["fast", "cached"],
            note="extra info",
            extra=None,
        )


================================================
FILE: integration-tests/tests/no_predictor.txtar
================================================
# Test error when no predictor is defined
# Build should fail when cog.yaml has no predict field

# Build should fail with error about missing predictor
! cog build -t $TEST_IMAGE
stderr 'predict'

-- cog.yaml --
build:
  python_version: "3.12"


================================================
FILE: integration-tests/tests/non_base_predictor_class.txtar
================================================

# Build image
cog build -t $TEST_IMAGE

# Predict using class without BasePredictor
cog predict $TEST_IMAGE -i text=world
stdout 'hello world'

-- cog.yaml --
build:
  python_version: "3.12"
predict: predict.py:Predictor

-- predict.py --
from cog import Input


class Predictor:
    def predict(self, text: str = Input(default="world")) -> str:
        return f"hello {text}"


================================================
FILE: integration-tests/tests/non_base_predictor_function.txtar
================================================

# Build image
cog build -t $TEST_IMAGE

# Predict using standalone function
cog predict $TEST_IMAGE -i text=world
stdout 'hello world'

-- cog.yaml --
build:
  python_version: "3.12"
predict: predict.py:predict

-- predict.py --
from cog import Input


def predict(text: str = Input(default="world")) -> str:
    return f"hello {text}"


================================================
FILE: integration-tests/tests/oci_bundle_build.txtar
================================================
# Test building an OCI bundle with declarative weights in cog.yaml.
# Verifies: cog.yaml weights declaration -> cog weights build -> cog build (COG_OCI_INDEX=1)
# The image should build successfully and predictions should work.

# Create weight files (small, deterministic)
mkdir weights
exec sh -c 'dd if=/dev/zero bs=1024 count=1 2>/dev/null | tr "\0" "A" > weights/model-a.bin'
exec sh -c 'dd if=/dev/zero bs=1024 count=1 2>/dev/null | tr "\0" "B" > weights/model-b.bin'

# Step 1: Build weights.lock from cog.yaml declarations
cog weights build
stderr 'Generated weights.lock'
stderr '2 file'
exists weights.lock

# Step 2: Build with OCI index mode enabled
env COG_OCI_INDEX=1
cog build -t $TEST_IMAGE
stderr 'Image built as'

# Verify image was built
exec docker image inspect $TEST_IMAGE
stdout 'run.cog.config'

# Verify prediction works
cog predict $TEST_IMAGE -i text=hello
stdout 'processed: hello'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"
weights:
  - name: alpha
    source: weights/model-a.bin
    target: /weights/model-a.bin
  - name: beta
    source: weights/model-b.bin
    target: /weights/model-b.bin

-- predict.py --
from cog import BasePredictor

class Predictor(BasePredictor):
    def predict(self, text: str) -> str:
        return f"processed: {text}"


================================================
FILE: integration-tests/tests/oci_bundle_inspect.txtar
================================================
# Test cog inspect on a pushed OCI bundle with declarative weights.
# Verifies: push bundle -> cog inspect --remote --json shows correct structure.
# The inspect output should show an OCI index with image + weight manifests.

[short] skip 'requires local registry'

# Start test registry
registry-start

# Create weight files (small, deterministic)
mkdir weights
exec sh -c 'dd if=/dev/zero bs=1024 count=1 2>/dev/null | tr "\0" "A" > weights/model-a.bin'
exec sh -c 'dd if=/dev/zero bs=1024 count=1 2>/dev/null | tr "\0" "B" > weights/model-b.bin'

# Build weights.lock
cog weights build
stderr 'Generated weights.lock'
exists weights.lock

# Build and push with OCI index mode
env COG_OCI_INDEX=1
cog push $TEST_REGISTRY/test/inspect-model:v1

# Inspect the pushed bundle
cog inspect --remote --json $TEST_REGISTRY/test/inspect-model:v1

# Verify it's an OCI index
stdout '"type": "index"'

# Verify image manifest is present
stdout '"type": "image"'

# Verify weight manifests are present with correct names
stdout '"type": "weights"'
stdout '"name": "alpha"'
stdout '"name": "beta"'

# Verify weight targets
stdout '"target": "/weights/model-a.bin"'
stdout '"target": "/weights/model-b.bin"'

# Verify layers are populated
stdout '"layers"'
stdout '"digest": "sha256:'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"
weights:
  - name: alpha
    source: weights/model-a.bin
    target: /weights/model-a.bin
  - name: beta
    source: weights/model-b.bin
    target: /weights/model-b.bin

-- predict.py --
from cog import BasePredictor

class Predictor(BasePredictor):
    def predict(self, text: str) -> str:
        return f"processed: {text}"


================================================
FILE: integration-tests/tests/oci_bundle_push.txtar
================================================
# Test pushing an OCI bundle with declarative weights via cog push.
# Verifies: cog.yaml weights -> cog weights build -> cog push (BundlePusher path)
# The push should create an OCI index with image + weight manifests.

[short] skip 'requires local registry'

# Start test registry
registry-start

# Create weight files (small, deterministic)
mkdir weights
exec sh -c 'dd if=/dev/zero bs=1024 count=1 2>/dev/null | tr "\0" "A" > weights/model-a.bin'
exec sh -c 'dd if=/dev/zero bs=1024 count=1 2>/dev/null | tr "\0" "B" > weights/model-b.bin'

# Step 1: Build weights.lock
cog weights build
stderr 'Generated weights.lock'
exists weights.lock

# Step 2: Build and push with OCI index mode
env COG_OCI_INDEX=1
cog push $TEST_REGISTRY/test/bundle-model:v1

# Verify push succeeded — should mention pushing
stderr -count=1 'Pushing'

# Step 3: Verify the pushed artifact is an OCI index with image + weight manifests
registry-inspect $TEST_REGISTRY/test/bundle-model:v1
# Verify it's an OCI index
stdout 'application/vnd.oci.image.index.v1\+json'
# Verify weight annotations are present
stdout 'vnd.cog.reference.type.*weights'
stdout 'vnd.cog.weight.name.*alpha'
stdout 'vnd.cog.weight.name.*beta'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"
weights:
  - name: alpha
    source: weights/model-a.bin
    target: /weights/model-a.bin
  - name: beta
    source: weights/model-b.bin
    target: /weights/model-b.bin

-- predict.py --
from cog import BasePredictor

class Predictor(BasePredictor):
    def predict(self, text: str) -> str:
        return f"processed: {text}"


================================================
FILE: integration-tests/tests/optional_path_input.txtar
================================================
# Test optional Path input with None default (Pydantic 1 compatibility)

# Build the image
cog build -t $TEST_IMAGE

# Predict without providing optional input - should return red image
cog predict $TEST_IMAGE

# Verify output file was created
exec test -f output.webp

-- cog.yaml --
build:
  python_version: "3.12"
  python_packages:
    - "pillow"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor, Path, Input
from PIL import Image


class Predictor(BasePredictor):
    def predict(
        self,
        test_image: Path | None = Input(description="Test image", default=None)
    ) -> Path:
        """Run a single prediction on the model"""
        im = Image.new("RGB", (100, 100), color="red")
        im.save(Path("./hello.webp"))
        return Path("./hello.webp")


================================================
FILE: integration-tests/tests/path_input.txtar
================================================
# Test Path input type (read file content)

# Build the image
cog build -t $TEST_IMAGE

# Predict reads content from file
cog predict $TEST_IMAGE -i path=@input.txt
stdout 'hello from file'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor, Path


class Predictor(BasePredictor):
    def predict(self, path: Path) -> str:
        with open(path) as f:
            return f.read()

-- input.txt --
hello from file


================================================
FILE: integration-tests/tests/path_input_output.txtar
================================================
# Test Path input and output with setup method

# Build the image
cog build -t $TEST_IMAGE

# Predict with text and path input, returns path output
cog predict $TEST_IMAGE -i text=bar -i path=@input.txt

# Verify output file was created and contains expected content
exec test -f output.txt
exec grep -q foobarbaz output.txt

-- cog.yaml --
build:
  python_version: "3.12"
predict: predict.py:Predictor

-- predict.py --
import tempfile

from cog import BasePredictor, Path


class Predictor(BasePredictor):
    def setup(self):
        self.foo = "foo"

    def predict(self, text: str, path: Path) -> Path:
        with open(path) as f:
            output = self.foo + text + f.read()
        tmpdir = Path(tempfile.mkdtemp())
        with open(tmpdir / "output.txt", "w") as fh:
            fh.write(output)
        return tmpdir / "output.txt"

-- input.txt --
baz


================================================
FILE: integration-tests/tests/path_list_input.txtar
================================================
# Test list[Path] input type (multiple file inputs)

# Build the image
cog build -t $TEST_IMAGE

# Predict with multiple file inputs
cog predict $TEST_IMAGE -i paths=@1.txt -i paths=@2.txt
stdout 'test1'
stdout 'test2'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor, Path


class Predictor(BasePredictor):
    def predict(self, paths: list[Path]) -> str:
        output_parts = []
        for path in paths:
            with open(path) as f:
                output_parts.append(f.read())
        return "".join(output_parts)

-- 1.txt --
test1
-- 2.txt --
test2


================================================
FILE: integration-tests/tests/path_list_output.txtar
================================================
# Test List[Path] output (multiple file outputs)

# Build the image
cog build -t $TEST_IMAGE

# Predict writes multiple files
cog predict $TEST_IMAGE

# Verify files were created with expected content
exec cat output.0.txt
stdout 'foo'
exec cat output.1.txt
stdout 'bar'
exec cat output.2.txt
stdout 'baz'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from typing import List

from cog import BasePredictor, Path


class Predictor(BasePredictor):
    def predict(self) -> List[Path]:
        predictions = ["foo", "bar", "baz"]
        output = []
        for i, prediction in enumerate(predictions):
            out_path = Path(f"/tmp/out-{i}.txt")
            with out_path.open("w") as f:
                f.write(prediction)
            output.append(out_path)
        return output


================================================
FILE: integration-tests/tests/path_output.txtar
================================================
# Test Path output (file output)

# Build the image
cog build -t $TEST_IMAGE

# Predict writes file to output.bmp
cog predict $TEST_IMAGE

# Verify file was created and has expected size (255x255 RGB BMP = ~195KB)
exec test -f output.bmp
exec test -s output.bmp

-- cog.yaml --
build:
  python_version: "3.12"
  python_packages:
    - "pillow==10.4.0"
predict: "predict.py:Predictor"

-- predict.py --
import os
import tempfile

from cog import BasePredictor, Path
from PIL import Image


class Predictor(BasePredictor):
    def predict(self) -> Path:
        temp_dir = tempfile.mkdtemp()
        temp_path = os.path.join(temp_dir, "prediction.bmp")
        img = Image.new("RGB", (255, 255), "red")
        img.save(temp_path)
        return Path(temp_path)


================================================
FILE: integration-tests/tests/predict_existing_image.txtar
================================================
# Test predict works with pre-built image from different directory
# Source: test_predict.py::test_predict_runs_an_existing_image

# Build the image first
cog build -t $TEST_IMAGE

# Create a different directory and run predict from there
mkdir another_dir
cd another_dir

# Run predict on the pre-built image (no cog.yaml in current dir)
cog predict $TEST_IMAGE -i s=world
stdout 'hello world'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, s: str) -> str:
        return "hello " + s


================================================
FILE: integration-tests/tests/predict_json_file.txtar
================================================
# Test --json @file reads JSON from file
# Source: test_predict.py::test_predict_json_input_filename

# Build and run prediction with JSON from file
cog build -t $TEST_IMAGE
cog predict $TEST_IMAGE --json @input.json
stdout '"status": "succeeded"'
stdout '"output": "hello sackfield"'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, s: str) -> str:
        return "hello " + s

-- input.json --
{
    "s": "sackfield"
}


================================================
FILE: integration-tests/tests/predict_json_input.txtar
================================================
# Test --json flag with inline JSON input
# Source: test_predict.py::test_predict_json_input

# Build and run prediction with inline JSON
cog build -t $TEST_IMAGE
cog predict $TEST_IMAGE --json '{"s": "sackfield"}'
stdout '"status": "succeeded"'
stdout '"output": "hello sackfield"'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, s: str) -> str:
        return "hello " + s


================================================
FILE: integration-tests/tests/predict_json_output_file.txtar
================================================
# Test --json with --output writes JSON to file
# Source: test_predict.py::test_predict_json_output

# Build and run prediction with JSON output to file
cog build -t $TEST_IMAGE
cog predict $TEST_IMAGE --json '{"s": "sackfield"}' --output output.json

# Verify file was created with correct content
exists output.json
exec cat output.json
stdout '"status": "succeeded"'
stdout '"output": "hello sackfield"'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, s: str) -> str:
        return "hello " + s


================================================
FILE: integration-tests/tests/predict_json_stdin.txtar
================================================
# Test --json @- reads JSON from stdin
# Source: test_predict.py::test_predict_json_input_stdin

# Build and run prediction with JSON from stdin
cog build -t $TEST_IMAGE
stdin input.json
cog predict $TEST_IMAGE --json @-
stdout '"status": "succeeded"'
stdout '"output": "hello sackfield"'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, s: str) -> str:
        return "hello " + s

-- input.json --
{"s": "sackfield"}


================================================
FILE: integration-tests/tests/predict_json_stdin_dash.txtar
================================================
# Test --json - reads JSON directly from stdin
# Source: test_predict.py::test_predict_json_input_stdin_dash

# Build and run prediction with JSON from stdin using literal '-'
cog build -t $TEST_IMAGE
stdin input.json
cog predict $TEST_IMAGE --json -
stdout '"status": "succeeded"'
stdout '"output": "hello sackfield"'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, s: str) -> str:
        return "hello " + s

-- input.json --
{"s": "sackfield"}


================================================
FILE: integration-tests/tests/predict_many_inputs_image.txtar
================================================

# Test predict with many input types against pre-built image
# Source: test_predict.py::test_predict_many_inputs_with_existing_image
#
# This test builds an image first, then runs predictions against that image
# from a different directory (simulating using a pre-built image).

# Build the image first
cog build -t $TEST_IMAGE

# Run prediction against the built image with various input types
# Using @ syntax for file inputs
cog predict $TEST_IMAGE -i no_default=hello -i path=@path.txt -i image=@image.jpg -i choices=foo -i int_choices=3
stdout 'hello default 20 world jpg foo 6'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor, Input, Path


class Predictor(BasePredictor):
    def predict(
        self,
        no_default: str,
        default_without_input: str = "default",
        input_with_default: int = Input(default=10),
        path: Path = Input(description="Some path"),
        image: Path = Input(description="Some path"),
        choices: str = Input(choices=["foo", "bar"]),
        int_choices: int = Input(description="hello", choices=[3, 4, 5]),
    ) -> str:
        with path.open() as f:
            path_contents = f.read().strip()
        image_extension = str(image).split(".")[-1]
        return (
            no_default
            + " "
            + default_without_input
            + " "
            + str(input_with_default * 2)
            + " "
            + path_contents
            + " "
            + image_extension
            + " "
            + choices
            + " "
            + str(int_choices * 2)
        )

-- path.txt --
world
-- image.jpg --
fake image content


================================================
FILE: integration-tests/tests/predict_output_file.txtar
================================================
# Test -o flag writes Path output to specified file
# Source: test_predict.py::test_predict_writes_files_to_files_with_custom_name

# Build and run prediction with custom output filename
cog build -t $TEST_IMAGE
cog predict $TEST_IMAGE -o myoutput.bmp

# Verify file exists and has non-zero size
exists myoutput.bmp
exec test -s myoutput.bmp

-- cog.yaml --
build:
  python_version: "3.12"
  python_packages:
    - "pillow==10.4.0"
predict: "predict.py:Predictor"

-- predict.py --
import os
import tempfile

from cog import BasePredictor, Path
from PIL import Image


class Predictor(BasePredictor):
    def predict(self) -> Path:
        temp_dir = tempfile.mkdtemp()
        temp_path = os.path.join(temp_dir, "prediction.bmp")
        img = Image.new("RGB", (255, 255), "red")
        img.save(temp_path)
        return Path(temp_path)


================================================
FILE: integration-tests/tests/predict_output_string.txtar
================================================
# Test -o flag writes string output to file
# Source: test_predict.py::test_predict_writes_strings_to_files

# Build image
cog build -t $TEST_IMAGE

# Run prediction with -o to write output to file
cog predict $TEST_IMAGE -i s=world -o out.txt

# Verify file contents
exec cat out.txt
stdout 'hello world'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, s: str) -> str:
        return "hello " + s


================================================
FILE: integration-tests/tests/predict_sys_exit.txtar
================================================
# Test that sys.exit() in predict() fails the prediction but does NOT kill
# the worker. A subsequent prediction should still succeed.
#
# sys.exit() raises SystemExit, which is caught by the PyO3 boundary.
# The prediction fails, but the worker subprocess stays alive.

cog serve

# First prediction calls sys.exit(1) — should fail
curl POST /predictions '{"input":{"do_exit":true}}'
stdout '"status":"failed"'
stdout '"error":".*SystemExit'

# Second prediction — worker should still be alive and accept it
curl POST /predictions '{"input":{"do_exit":false}}'
stdout '"status":"succeeded"'
stdout '"output":"still alive"'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
import sys

from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, do_exit: bool = False) -> str:
        if do_exit:
            sys.exit(1)
        return "still alive"


================================================
FILE: integration-tests/tests/prediction_error_response.txtar
================================================
# Test that a runtime exception in predict() returns a well-formed error response.
#
# When predict() raises an exception, coglet returns HTTP 200 with
# status "failed", the error message, and predict_time in metrics.
# This test verifies the response shape — not just that it fails.

cog serve

# ValueError in predict()
curl POST /predictions '{"input":{"mode":"value_error"}}'
stdout '"status":"failed"'
stdout '"error":".*this is a value error"'
stdout '"predict_time":'

# RuntimeError in predict()
curl POST /predictions '{"input":{"mode":"runtime_error"}}'
stdout '"status":"failed"'
stdout '"error":".*runtime problem"'
stdout '"predict_time":'

# Generic Exception in predict()
curl POST /predictions '{"input":{"mode":"generic"}}'
stdout '"status":"failed"'
stdout '"error":".*something went wrong"'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, mode: str) -> str:
        if mode == "value_error":
            raise ValueError("this is a value error")
        elif mode == "runtime_error":
            raise RuntimeError("runtime problem")
        elif mode == "generic":
            raise Exception("something went wrong")
        return "ok"


================================================
FILE: integration-tests/tests/pty_echo.txtar
================================================
[short] skip 'slow test - requires Docker build'

# Test that cog run works with PTY for simple commands
# This is a simpler variant that just tests echo works

# Run echo command with PTY (no input needed)
# cog run builds from current directory and runs the command
pty-run /dev/null cog run echo "hello from cog run"
stdout 'hello from cog run'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor

class Predictor(BasePredictor):
    def predict(self, s: str = "world") -> str:
        return "hello " + s


================================================
FILE: integration-tests/tests/pty_interactive.txtar
================================================
[short] skip 'slow test - requires Docker build'

# Test that cog run /bin/bash works with interactive PTY
# This verifies bidirectional PTY interaction (send input, receive output)

# Run bash with PTY input file - send commands and verify output
# cog run builds from current directory and runs the command
pty-run pty_input.txt cog run /bin/bash
stdout 'SENTINEL_12345'
stdout 'HELLO_WORLD'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor

class Predictor(BasePredictor):
    def predict(self, s: str = "world") -> str:
        return "hello " + s

-- pty_input.txt --
echo SENTINEL_12345
echo HELLO_WORLD
exit


================================================
FILE: integration-tests/tests/pydantic2.txtar
================================================
# Test explicit Pydantic 2 dependency

# Build the image
cog build -t $TEST_IMAGE

# Predict with Pydantic 2
cog predict $TEST_IMAGE -i s=world
stdout 'hello world'

-- cog.yaml --
build:
  gpu: false
  python_version: "3.12"
  python_packages:
    - "pydantic>2"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, s: str) -> str:
        return "hello " + s


================================================
FILE: integration-tests/tests/pydantic2_output.txtar
================================================
# Test that Pydantic v2 BaseModel works as prediction output type.
# Coglet's make_encodeable() must call model_dump() to serialize.

# Build the image
cog build -t $TEST_IMAGE

# Predict returns structured Pydantic output
cog predict $TEST_IMAGE -i name=alice -i score=0.95
stdout '"name": "alice"'
stdout '"score": 0.95'
stdout '"tags"'
stdout 'default'

-- cog.yaml --
build:
  python_version: "3.12"
  python_packages:
    - "pydantic>2"
predict: "predict.py:Predictor"

-- predict.py --
from typing import List

from pydantic import BaseModel as PydanticBaseModel

from cog import BasePredictor


class Result(PydanticBaseModel):
    name: str
    score: float
    tags: List[str]


class Predictor(BasePredictor):
    def predict(self, name: str, score: float = 0.5) -> Result:
        return Result(name=name, score=score, tags=["default"])


================================================
FILE: integration-tests/tests/python313.txtar
================================================
# Test Python 3.13 support

# Build the image
cog build -t $TEST_IMAGE

# Predict with Python 3.13
cog predict $TEST_IMAGE -i num=5
stdout '10'

-- cog.yaml --
build:
  gpu: false
  python_version: "3.13"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, num: int) -> int:
        return num * 2


================================================
FILE: integration-tests/tests/python37_deprecated.txtar
================================================
# Test that Python 3.7 is deprecated and build fails with appropriate error
# Build should fail with deprecation error
! cog build -t $TEST_IMAGE
stderr 'invalid build.python_version "3.7": minimum supported Python version is 3.10'

-- cog.yaml --
build:
  gpu: false
  python_version: "3.7"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, num: int) -> int:
        return num * 2


================================================
FILE: integration-tests/tests/python38_deprecated.txtar
================================================
# Test that Python 3.8 is deprecated and build fails with appropriate error
# Build should fail with deprecation error
! cog build -t $TEST_IMAGE
stderr 'invalid build.python_version "3.8": minimum supported Python version is 3.10'

-- cog.yaml --
build:
  gpu: false
  python_version: "3.8"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, num: int) -> int:
        return num * 2


================================================
FILE: integration-tests/tests/python39_deprecated.txtar
================================================
# Test that Python 3.9 is deprecated and build fails with appropriate error
# Build should fail with deprecation error
! cog build -t $TEST_IMAGE
stderr 'invalid build.python_version "3.9": minimum supported Python version is 3.10'

-- cog.yaml --
build:
  gpu: false
  python_version: "3.9"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, num: int) -> int:
        return num * 2


================================================
FILE: integration-tests/tests/run_basic.txtar
================================================
# Test basic cog run functionality
# Source: test_run.py::test_run

cog run echo hello world
stdout 'hello world'

-- cog.yaml --
build:
  python_version: "3.12"


================================================
FILE: integration-tests/tests/run_stdin_cat.txtar
================================================
# Test stdin piped through cat and returned to stdout
# Source: test_run.py::test_run_with_piped_stdin_returned_to_stdout

stdin input.txt
cog run cat
stdout 'hello world'

-- cog.yaml --
build:
  python_version: "3.13"

-- input.txt --
hello world


================================================
FILE: integration-tests/tests/run_stdin_unconsumed.txtar
================================================
# Test stdin handling when piped but not consumed by the command
# Source: test_run.py::test_run_with_unconsumed_piped_stdin

stdin input.txt
cog run echo hello-from-echo
stdout 'hello-from-echo'

-- cog.yaml --
build:
  python_version: "3.13"

-- input.txt --
hello-from-stdin


================================================
FILE: integration-tests/tests/scope_context.txtar
================================================
# Test that per-prediction context is available via current_scope().context.
#
# Verifies:
# 1. context dict from request body is accessible in the predictor
# 2. Empty context (default) returns an empty dict
# 3. Multiple key-value pairs are preserved

cog serve

# Prediction with context — predictor returns sorted key:value pairs
curl POST /predictions '{"input":{},"context":{"api_token":"secret123","region":"us-east-1"}}'
stdout '"status":"succeeded"'
stdout '"output":"api_token:secret123, region:us-east-1"'

# Prediction without context — should get empty dict
curl POST /predictions '{"input":{"expect_empty":"true"}}'
stdout '"status":"succeeded"'
stdout '"output":"context_is_empty=True"'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor, Input, current_scope


class Predictor(BasePredictor):
    def predict(self, expect_empty: str = Input(default="false")) -> str:
        ctx = current_scope().context

        if expect_empty == "true":
            return f"context_is_empty={len(ctx) == 0}"

        # Return the context as a formatted string so we can assert on individual keys
        return ", ".join(f"{k}:{v}" for k, v in sorted(ctx.items()))


================================================
FILE: integration-tests/tests/secrets.txtar
================================================
# Test that build secrets can be mounted during Docker build

# Set environment variable for the env-secret
env ENV_SECRET=env_secret_value

# Build with secrets (file-based and env-based)
cog build -t $TEST_IMAGE --secret id=file-secret,src=file-secret.txt --secret id=env-secret,env=ENV_SECRET

-- cog.yaml --
build:
  python_version: "3.13"
  run:
    - command: >-
        ID="file-secret";
        EXPECTED_VALUE="file_secret_value";
        EXPECTED_PATH="/etc/file_secret.txt";
        [ "$(cat "$EXPECTED_PATH")" = "$EXPECTED_VALUE" ] || ( echo "Assertion failed"; exit 1; )
      mounts:
        - type: secret
          id: file-secret
          target: /etc/file_secret.txt
    - command: >-
        ID="env-secret";
        EXPECTED_VALUE="env_secret_value";
        EXPECTED_PATH="/var/env-secret.txt";
        [ "$(cat "$EXPECTED_PATH")" = "$EXPECTED_VALUE" ] || ( echo "Assertion failed"; exit 1; )
      mounts:
        - type: secret
          id: env-secret
          target: /var/env-secret.txt
predict: "predict.py:Predictor"

-- file-secret.txt --
file_secret_value

-- predict.py --
from cog import BasePredictor

class Predictor(BasePredictor):
    def predict(self, num: int) -> int:
        return num * 2


================================================
FILE: integration-tests/tests/sequential_state_leak.txtar
================================================
# Test that module-level state persists across sequential predictions.
#
# This is intentional behavior — the worker process is long-lived and
# module-level state IS shared. This test documents that behavior:
# a module-level list that gets appended to on each call accumulates.

cog serve

# First prediction — list starts empty, appends "a"
# Note: "." in regexes below matches Python's single-quote characters
# (testscript single-quoted args cannot embed literal single quotes)
curl POST /predictions '{"input":{"item":"a"}}'
stdout '"status":"succeeded"'
stdout '"output":"\[.a.\]"'

# Second prediction — list now has ["a"], appends "b"
curl POST /predictions '{"input":{"item":"b"}}'
stdout '"status":"succeeded"'
stdout '"output":"\[.a., .b.\]"'

# Third prediction — list now has ["a", "b"], appends "c"
curl POST /predictions '{"input":{"item":"c"}}'
stdout '"status":"succeeded"'
stdout '"output":"\[.a., .b., .c.\]"'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor

# Module-level state — persists across predictions in the same worker
_state: list = []


class Predictor(BasePredictor):
    def predict(self, item: str) -> str:
        _state.append(item)
        return str(_state)


================================================
FILE: integration-tests/tests/setup_slow_serial.txtar
================================================
[short] skip 'slow test - skip in short mode'

# Test that a slow setup() completes successfully without being killed by
# an internal timeout. Coglet should not impose its own setup timeout —
# the external orchestrator (director) is the authority on setup timeouts.

# Build the image
cog build -t $TEST_IMAGE

# Start the server — setup takes ~15s but should complete fine
cog serve

# Verify the server is healthy and predictions work
curl POST /predictions '{"input":{"s":"hello"}}'
stdout '"output":"hello hello"'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
import time

from cog import BasePredictor


class Predictor(BasePredictor):
    def setup(self) -> None:
        print("Starting slow setup...")
        time.sleep(15)
        print("Slow setup complete.")

    def predict(self, s: str) -> str:
        return "hello " + s


================================================
FILE: integration-tests/tests/setup_subprocess_double_fork.txtar
================================================
[short] skip 'slow test - skip in short mode'

# Test double fork subprocess spawned during setup
# This ensures stream redirection works correctly with daemonized processes

# Build the image
cog build -t $TEST_IMAGE

# Start the server
cog serve

# Make predictions that communicate via file system with forked process
curl POST /predictions '{"input":{"s":"friendo1"}}'
stdout '"output":"hello friendo1"'

curl POST /predictions '{"input":{"s":"friendo2"}}'
stdout '"output":"hello friendo2"'

curl POST /predictions '{"input":{"s":"friendo3"}}'
stdout '"output":"hello friendo3"'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
import os.path
import signal
import subprocess
import sys
import time

from cog import BasePredictor


class Predictor(BasePredictor):
    """
    This predictor checks the case where a process is spawned during setup and then each
    prediction depends on being able to communicate with that process. In the event that
    stream redirection is not working correctly, the forked process will not be able to
    write to stdout/stderr and will likely exit. Any state other than "running" is
    considered an error condition and raises SystemExit to interrupt any more prediction
    serving.

    This variant runs a forked python process via a shell wrapper to which a "message" is
    sent via file for each call to `predict`.
    """

    def setup(self) -> None:
        print("---> starting background process")

        self.bg = subprocess.Popen(["bash", "run-forker.sh"])

        print(f"---> started background process pid={self.bg.pid}")

    def predict(self, s: str) -> str:
        status = self.bg.poll()

        print(f"---> background job status={status}")

        if status is not None:
            raise SystemExit

        print(f"---> sending message to background job pid={self.bg.pid}")

        with open(".inbox", "w") as inbox:
            inbox.write(s)

        print(f"---> sent message to background job pid={self.bg.pid}")

        now = time.time()

        print(f"---> waiting for outbox message from background job pid={self.bg.pid}")

        while not os.path.exists(".outbox"):
            if time.time() - now > 5:
                raise TimeoutError

            time.sleep(0.01)

        try:
            with open(".outbox", "r") as outbox:
                print(f"---> relaying message from background job pid={self.bg.pid}")

                return outbox.read()

        finally:
            os.unlink(".outbox")

-- run-forker.sh --
#!/usr/bin/env bash
python ./forker.py &
wait

-- forker.py --
import os
import signal
import time


def main():
    child_pid = os.fork()
    is_child = child_pid == 0

    pid = os.getpid()
    was_pinged = False

    while True:
        if os.path.exists(".inbox") and is_child:
            s = ""

            with open(".inbox", "r") as inbox:
                print(f"---> CHILD ({pid}) reading request")

                s = inbox.read()

            os.unlink(".inbox")

            with open(".outbox", "w") as outbox:
                print(f"---> CHILD ({pid}) sending response")

                outbox.write("hello " + s)

        if time.time() % 10 == 0:
            if is_child:
                print(f"---> CHILD ({pid}) " + ("here " * 20))
            else:
                print(f"===> PARENT ({pid})")

        time.sleep(0.01)


if __name__ == "__main__":
    main()


================================================
FILE: integration-tests/tests/setup_subprocess_double_fork_http.txtar
================================================
[short] skip 'slow test - skip in short mode'

# Test double fork subprocess with HTTP server spawned during setup
# This ensures stream redirection works correctly with daemonized HTTP servers

# Build the image
cog build -t $TEST_IMAGE

# Start the server
cog serve

# Make predictions that communicate with forked HTTP server
curl POST /predictions '{"input":{"s":"friendo1"}}'
stdout '"output":"hello friendo1"'

curl POST /predictions '{"input":{"s":"friendo2"}}'
stdout '"output":"hello friendo2"'

curl POST /predictions '{"input":{"s":"friendo3"}}'
stdout '"output":"hello friendo3"'

-- cog.yaml --
build:
  python_version: "3.12"
  python_packages:
  - requests
predict: "predict.py:Predictor"

-- predict.py --
import signal
import subprocess
import sys

import requests

from cog import BasePredictor


class Predictor(BasePredictor):
    """
    This predictor checks the case where a process is spawned during setup and then each
    prediction depends on being able to communicate with that process. In the event that
    stream redirection is not working correctly, the forked process will not be able to
    write to stdout/stderr and will likely exit. Any state other than "running" is
    considered an error condition and raises SystemExit to interrupt any more prediction
    serving.

    This variant runs a forked python HTTP server via a shell wrapper to which a request
    is made during each call to `predict`.
    """

    def setup(self) -> None:
        print("---> starting background process")

        self.bg = subprocess.Popen(["bash", "run-pong.sh"])

        print(f"---> started background process pid={self.bg.pid}")

        # Wait for HTTP server to be ready
        import time
        for i in range(30):
            try:
                requests.get("http://127.0.0.1:7777/ping", timeout=1)
                print("---> background HTTP server is ready")
                break
            except Exception:
                print(f"---> waiting for HTTP server ({i+1}/30)")
                time.sleep(0.5)
        else:
            raise RuntimeError("Background HTTP server failed to start")

    def predict(self, s: str) -> str:
        status = self.bg.poll()

        print(f"---> background job status={status}")

        if status is None:
            print(f"---> sending request to background job pid={self.bg.pid}")

            print(requests.get("http://127.0.0.1:7777/ping"))

            print(f"---> sent request to background job pid={self.bg.pid}")
        else:
            raise SystemExit

        return "hello " + s

-- run-pong.sh --
#!/usr/bin/env bash
python ./pong.py &
wait

-- pong.py --
import os
import signal
import time
from random import randint
from wsgiref.simple_server import make_server


def main():
    child_pid = os.fork()
    is_child = child_pid == 0

    pid = os.getpid()

    if is_child:
        make_server("127.0.0.1", 7777, app).serve_forever()
    else:
        while True:
            print(f"===> PARENT ({pid})")

            time.sleep(10)


def app(environ, start_response):
    print(f"---> CHILD ({os.getpid()})")

    if environ["PATH_INFO"] == "/ping":
        start_response("200 OK", [("content-type", "text/plain")])
        return [b"PONG\n" for n in range(100 + randint(2, 32))]

    start_response("404 Not Found", [("content-type", "text/plain")])
    return [b"NO\n"]


if __name__ == "__main__":
    main()


================================================
FILE: integration-tests/tests/setup_subprocess_multiprocessing.txtar
================================================
[short] skip 'slow test - skip in short mode'

# Test multiprocessing.Process spawned during setup
# This ensures stream redirection works correctly with Python multiprocessing

# Build the image
cog build -t $TEST_IMAGE

# Start the server
cog serve

# Make a prediction that communicates via multiprocessing.Pipe
# Note: The background process closes the connection after first use,
# so we only test one prediction
curl POST /predictions '{"input":{"s":"friendo1"}}'
stdout '"output":'
stdout '"status":"succeeded"'
# Check the logs show the background process communication worked
stdout 'sending ping to background job'
stdout 'received .* from background job'

-- cog.yaml --
build:
  python_version: "3.10"
predict: "predict.py:Predictor"

-- predict.py --
import atexit
import multiprocessing
import pathlib
import signal
import subprocess
import sys
import time

from cog.types import Path
from cog import BasePredictor

from bg import ponger

def cleanup():
    for tmp in pathlib.Path("./").glob("*.tmp"):
        if tmp.is_file():
            tmp.unlink(missing_ok=True)


atexit.register(cleanup)


class Predictor(BasePredictor):
    """
    This predictor checks the case where a process is spawned during setup via
    multiprocessing and then each prediction causes that process to write to stdout.
    """

    def setup(self) -> None:
        print("---> starting background process")

        cleanup()

        self.parent_conn, self.child_conn = multiprocessing.Pipe()
        self.lock = multiprocessing.Lock()
        self.bg = multiprocessing.Process(
            target=ponger, args=(self.child_conn, self.lock)
        )
        self.bg.start()

        print(f"---> started background process pid={self.bg.pid}")

    def predict(self, s: str) -> Path:
        if self.bg.is_alive():
            print(f"---> sending ping to background job pid={self.bg.pid}")

            self.child_conn.send("ping")

            print(f"---> sent ping to background job pid={self.bg.pid}")

            pong = self.parent_conn.recv()

            print(f"---> received {pong} from background job pid={self.bg.pid}")
        else:
            print(f"---> background job died")

            raise SystemExit

        out = Path(f"cog-test-integration-out.{time.time_ns()}.tmp")
        out.write_text("hello " + s)

        print(f"---> wrote output file {out}")

        return out

-- bg.py --
import multiprocessing.connection
import multiprocessing.synchronize
import os
import time


def ponger(
    conn: multiprocessing.connection.Connection, lock: multiprocessing.synchronize.Lock
):
    for i in range(100):
        print(f"Getting ready for some serious ponginggg ({i+1}%)")
        time.sleep(0.001 + (0.001 * (i + 1)))

    print("ITS PONGIN TIME")

    pid = os.getpid()

    while True:
        try:
            ping = conn.recv()
            print(f"received {ping} in {pid}")

            with lock:
                print(f"ponging from {pid}")

                conn.send("pong")
                conn.close()

        except EOFError:
            pass


================================================
FILE: integration-tests/tests/setup_subprocess_simple.txtar
================================================
[short] skip 'slow test - skip in short mode'

# Test subprocess spawned during setup that writes to stdout
# This ensures stream redirection works correctly when a background process
# writes output during prediction serving.

# Build the image
cog build -t $TEST_IMAGE

# Start the server
cog serve

# Make predictions - the subprocess writes to stdout when it receives SIGUSR1
curl POST /predictions '{"input":{"s":"friendo1"}}'
stdout '"output":"hello friendo1"'

curl POST /predictions '{"input":{"s":"friendo2"}}'
stdout '"output":"hello friendo2"'

curl POST /predictions '{"input":{"s":"friendo3"}}'
stdout '"output":"hello friendo3"'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
import signal
import subprocess
import sys

from cog import BasePredictor


class Predictor(BasePredictor):
    """
    This predictor checks the case where a process is spawned during setup and then each
    prediction causes that process to write to stdout. In the event that stream
    redirection is not working correctly, the forked process will not be able to write to
    stdout/stderr and will likely exit. Any state other than "running" is considered an
    error condition and raises SystemExit to interrupt any more prediction serving.

    This variant runs a simple subprocess to which SIGUSR1 is sent during each call to
    `predict`.
    """

    def setup(self) -> None:
        print("---> starting background process")

        self.bg = subprocess.Popen(["bash", "child.sh"])

        print(f"---> started background process pid={self.bg.pid}")

    def predict(self, s: str) -> str:
        status = self.bg.poll()

        if status is None:
            print(f"---> sending signal to background job pid={self.bg.pid}")

            self.bg.send_signal(signal.SIGUSR1)

            print(f"---> sent signal to background job pid={self.bg.pid}")
        else:
            print(f"---> background job died status={status}")

            raise SystemExit

        return "hello " + s

-- child.sh --
#!/usr/bin/env bash
set -euo pipefail

# This _pong function and associated trap ensures that any SIGUSR1 sent during `predict`
# will cause this process to write a decent amount of text to stdout. In the event that
# stream redirection is not working correctly, this process will likely be in a defunct
# state before the first SIGUSR1 can be sent.
_pong() {
  for i in $(seq 100); do
    echo "${0} (${$}) PONG (${i}/100)"
  done
}

trap _pong USR1

# This loop simulates a setup period for filling up any stdout buffer.
for i in $(seq 100); do
  echo "${0} ($$) SETTING UP (${i}/100)"
  sleep 0.01
done

# This loop simulates periodic writes to stdout while the background process is running
# for the purpose of ensuring the file descriptor is still usable.
while true; do
  now="$(date +%s)"
  now_mod=$((now % 10))

  if [[ "${now_mod}" == 0 ]]; then
    echo "${0} (${$}) STILL HERE"
    sleep 1
  fi

  sleep 0.1
done


================================================
FILE: integration-tests/tests/setup_timeout_serial.txtar
================================================
[short] skip 'slow test - skip in short mode'

# Test that COG_SETUP_TIMEOUT env var is respected. When set to a value
# shorter than the actual setup duration, setup should fail with SETUP_FAILED.

# Build the image
cog build -t $TEST_IMAGE

# Start the server — setup takes ~15s but timeout is 10s, so it should fail
! cog serve

# Verify the server reports SETUP_FAILED
curl GET /health-check
stdout 'SETUP_FAILED'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"
environment:
  - COG_SETUP_TIMEOUT=10

-- predict.py --
import time

from cog import BasePredictor


class Predictor(BasePredictor):
    def setup(self) -> None:
        print("Starting slow setup...")
        time.sleep(15)
        print("Slow setup complete.")

    def predict(self, s: str) -> str:
        return "hello " + s


================================================
FILE: integration-tests/tests/setup_worker_tracing_logs.txtar
================================================
[short] skip 'slow test - skip in short mode'

# Test worker tracing logs appear in setuplog
# This ensures Rust tracing from the worker subprocess is properly captured during setup

# Build the image
cog build -t $TEST_IMAGE

# Start the server
cog serve

# Check setuplog contains orchestrator and worker tracing logs
curl GET /health-check
stdout '"status":"READY"'
stdout 'Spawning worker subprocess'
stdout 'File descriptor redirection complete'
stdout 'Connected to slot transport'
stdout 'Server ready'

# Verify logs after accumulation stopped are NOT included
! stdout 'Setup complete, now accepting requests'

# Make a prediction to verify it works
curl POST /predictions '{"input":{"s":"test"}}'
stdout '"output":"hello test"'

-- cog.yaml --
build:
  python_version: "3.12"
  python_packages:
  - requests
predict: "predict.py:Predictor"

-- predict.py --
import signal
import subprocess
import sys

import requests

from cog import BasePredictor


class Predictor(BasePredictor):
    """
    This predictor spawns a double-forked HTTP server during setup to test
    that worker tracing logs are properly captured in setuplog.
    """

    def setup(self) -> None:
        print("---> starting background process")

        self.bg = subprocess.Popen(["bash", "run-pong.sh"])

        print(f"---> started background process pid={self.bg.pid}")

        # Wait for HTTP server to be ready
        import time
        for i in range(30):
            try:
                requests.get("http://127.0.0.1:7777/ping", timeout=1)
                print("---> background HTTP server is ready")
                break
            except Exception:
                print(f"---> waiting for HTTP server ({i+1}/30)")
                time.sleep(0.5)
        else:
            raise RuntimeError("Background HTTP server failed to start")

    def predict(self, s: str) -> str:
        status = self.bg.poll()

        if status is None:
            requests.get("http://127.0.0.1:7777/ping")
        else:
            raise SystemExit

        return "hello " + s

-- run-pong.sh --
#!/usr/bin/env bash
python ./pong.py &
wait

-- pong.py --
import os
import signal
import time
from random import randint
from wsgiref.simple_server import make_server


def main():
    child_pid = os.fork()
    is_child = child_pid == 0

    pid = os.getpid()

    if is_child:
        make_server("127.0.0.1", 7777, app).serve_forever()
    else:
        while True:
            time.sleep(10)


def app(environ, start_response):
    if environ["PATH_INFO"] == "/ping":
        start_response("200 OK", [("content-type", "text/plain")])
        return [b"PONG\n" for n in range(100 + randint(2, 32))]

    start_response("404 Not Found", [("content-type", "text/plain")])
    return [b"NO\n"]


if __name__ == "__main__":
    main()


================================================
FILE: integration-tests/tests/static_schema_fallback.txtar
================================================
# Test that when static schema generation is opted in but encounters an
# unresolvable output type, the build falls back to legacy runtime schema
# generation instead of failing.
#
# The predictor returns a BaseModel subclass imported from a local package
# (mypackage/__init__.py). The static parser's moduleToFilePath() converts
# "mypackage" to "mypackage.py" instead of "mypackage/__init__.py", so the
# file isn't found and the name stays unresolved (ErrUnresolvableType).
# The legacy Python inspector imports modules normally via Python's import
# system, which handles __init__.py packages correctly, so schema generation
# succeeds at runtime.

# Opt in to static schema generation
env COG_STATIC_SCHEMA=1

# Build should succeed — static fails, legacy takes over
cog build -t $TEST_IMAGE
stderr 'Static schema generation failed'
stderr 'Falling back to legacy runtime schema generation'

# Predict should still work end-to-end via the legacy schema
cog predict $TEST_IMAGE -i prompt=hello
stdout 'hello world'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- mypackage/__init__.py --
from cog import BaseModel


class Output(BaseModel):
    text: str

-- predict.py --
from cog import BasePredictor
from mypackage import Output


class Predictor(BasePredictor):
    def predict(self, prompt: str) -> Output:
        return Output(text=prompt + " world")


================================================
FILE: integration-tests/tests/static_schema_gen.txtar
================================================
# Test that the static schema generator (Go tree-sitter) produces a correct
# OpenAPI schema that is embedded in the Docker image label and served by coglet.
#
# This exercises the full pipeline:
#   1. cog build runs the Go parser on predict.py
#   2. Schema is written to .cog/openapi_schema.json inside the image
#   3. Schema is embedded as the run.cog.openapi_schema Docker label
#   4. coglet loads the schema from disk and serves it at /openapi.json

# Opt in to static schema generation
env COG_STATIC_SCHEMA=1

# Build the image
cog build -t $TEST_IMAGE

# Verify schema is in Docker label with correct structure
exec docker inspect $TEST_IMAGE --format '{{index .Config.Labels "run.cog.openapi_schema"}}'

# Top-level OpenAPI structure
stdout '"openapi":"3.0.2"'

# Input schema: required string field
stdout '"text":'
stdout '"type":"string"'

# Input schema: optional int with default
stdout '"count":'
stdout '"type":"integer"'
stdout '"default":5'

# Input schema: choices generate enum
stdout '"enum":\["fast","balanced","quality"\]'

# Input schema: Path type → uri format
stdout '"format":"uri"'

# Input schema: required array
stdout '"required":\["text","image"\]'

# Output type
stdout '"type":"string"'

# Predict should work end-to-end
cog predict $TEST_IMAGE -i text=hello -i image=@test.jpg
stdout 'hello-5-fast-jpg'

# Prediction with overrides
cog predict $TEST_IMAGE -i text=world -i count=3 -i mode=quality -i image=@test.jpg
stdout 'world-3-quality-jpg'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor, Input, Path


class Predictor(BasePredictor):
    def predict(
        self,
        text: str,
        image: Path,
        count: int = Input(description="Number of iterations", default=5, ge=1, le=100),
        mode: str = Input(description="Quality mode", default="fast", choices=["fast", "balanced", "quality"]),
    ) -> str:
        ext = str(image).split(".")[-1]
        return f"{text}-{count}-{mode}-{ext}"

-- test.jpg --
fake image content


================================================
FILE: integration-tests/tests/string_list_input.txtar
================================================
# Test list[str] input type

# Build the image
cog build -t $TEST_IMAGE

# Predict with list input
cog predict $TEST_IMAGE -i s=world
stdout 'hello world'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor, Input


class Predictor(BasePredictor):
    def predict(self, s: list[str] = Input(description="A list of strings to print.")) -> str:
        return "hello " + "|".join(s)


================================================
FILE: integration-tests/tests/string_none_output.txtar
================================================

# Test str return type that returns None
# This tests the handling of None values for typed outputs

# Build the image
cog build -t $TEST_IMAGE

# Predict returns None despite str type annotation
# When None is returned, cog shows "No output generated" on stderr
cog predict $TEST_IMAGE
stderr 'No output generated'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self) -> str:
        return None


================================================
FILE: integration-tests/tests/string_predictor.txtar
================================================

# Build the image
cog build -t $TEST_IMAGE

# Basic prediction works
cog predict $TEST_IMAGE -i s=world
stdout 'hello world'

# Missing required input fails
! cog predict $TEST_IMAGE -i wrong=value
stderr 's: Field required'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor

class Predictor(BasePredictor):
    def predict(self, s: str) -> str:
        return "hello " + s


================================================
FILE: integration-tests/tests/subdirectory_predictor.txtar
================================================

# Test predictor in subdirectory with imports from parent

# Build the image
cog build -t $TEST_IMAGE

# Predict using predictor in subdirectory
cog predict $TEST_IMAGE -i s=world
stdout 'hello world'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "my-subdir/predict.py:Predictor"

-- mylib.py --
def concat(a, b):
    return a + " " + b

-- my-subdir/predict.py --
from cog import BasePredictor
from mylib import concat


class Predictor(BasePredictor):
    def predict(self, s: str) -> str:
        return concat("hello", s)


================================================
FILE: integration-tests/tests/tensorflow.txtar
================================================
[short] skip 'slow test - run without -short flag'

# Test TensorFlow build and predict
cog build -t $TEST_IMAGE
cog predict $TEST_IMAGE
stdout '2.11.1'

-- cog.yaml --
build:
  gpu: true
  cuda: "11.8"
  python_version: "3.10"
  system_packages:
    - "libgl1-mesa-glx"
    - "libglib2.0-0"
    - "xvfb"
  python_requirements: "requirements.txt"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor

import tensorflow


class Predictor(BasePredictor):
    def predict(self) -> str:
        return tensorflow.__version__

-- requirements.txt --
compel==2.0.3
diffusers>=0.27.1
gputil==1.4.0
loguru==0.7.2
opencv-python>=4.9.0.80
pillow>=10.2.0
psutil==6.1.1
replicate>=1.0.4
sentry-sdk[fastapi,loguru]>=2.16.0
antialiased_cnns==0.3
beautifulsoup4==4.13.4
imageio==2.37.0
ipdb==0.13.13
kornia==0.8.1
matplotlib==3.10.3
numpy==1.23.5
opencv_python==4.11.0.86
Pillow==11.2.1
pytorch_lightning==2.3.3
PyYAML==6.0.2
Requests==2.32.3
scipy==1.15.3
scikit-image==0.24.0
tensorflow==2.11.1
tensorlayer==2.2.5
tf_slim==1.1.0
timm==1.0.15
torch==2.0.1
torchvision==0.15.2
tqdm==4.67.1


================================================
FILE: integration-tests/tests/torch_270_cuda_126.txtar
================================================
[short] skip 'slow test - run without -short flag'

# Test Torch 2.7.0 + CUDA 12.6 base image build
cog build -t $TEST_IMAGE

-- cog.yaml --
build:
  gpu: true
  cuda: "12.6"
  python_version: "3.11"
  python_packages:
    - "torch==2.7.0"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, s: str) -> str:
        return "hello " + s


================================================
FILE: integration-tests/tests/torch_271_cuda_128.txtar
================================================
[short] skip 'slow test - run without -short flag'

# Test Torch 2.7.1 + CUDA 12.8 base image build
cog build -t $TEST_IMAGE

-- cog.yaml --
build:
  gpu: true
  cuda: "12.8"
  python_version: "3.12"
  python_packages:
    - "torch==2.7.1+cu128"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, s: str) -> str:
        return "hello " + s


================================================
FILE: integration-tests/tests/torch_baseimage_fallback.txtar
================================================
# Test that cog falls back to a CUDA base image (not cog-base) when the
# torch version is below the minimum supported for cog-base images.
# Torch 1.13.0 < MinimumTorchVersion (1.13.1), so no cog-base match exists.
#
# Uses a local registry (empty, no images seeded) to ensure the base image
# lookup fails deterministically, and cog debug to verify the generated
# Dockerfile without doing a real Docker build.

# Start an empty local registry so the cog-base lookup fails deterministically
registry-start
env COG_REGISTRY_HOST=$TEST_REGISTRY

# Generate Dockerfile — should fall back to nvidia/cuda, not cog-base
cog debug
stdout 'FROM nvidia/cuda'
! stdout 'cog-base'

-- cog.yaml --
build:
  gpu: true
  python_version: "3.10"
  python_packages:
    - "torch==1.13.0"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, s: str) -> str:
        return "hello " + s


================================================
FILE: integration-tests/tests/torch_baseimage_no_cog_base.txtar
================================================
[short] skip 'slow test - run without -short flag'
skip 'temporarily disabled - takes ~10min, blocking CI'

# Test Torch 1.13.0 base image with --use-cog-base-image=false
cog build -t $TEST_IMAGE --openapi-schema openapi.json --use-cog-base-image=false

-- cog.yaml --
build:
  gpu: true
  python_version: "3.10"
  python_packages:
    - "torch==1.13.0"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, s: str) -> str:
        return "hello " + s

-- openapi.json --
{
    "info": {
        "title": "Cog",
        "version": "0.1.0"
    },
    "paths": {
        "/": {
            "get": {
                "summary": "Root",
                "responses": {
                    "200": {
                        "content": {
                            "application/json": {
                                "schema": {
                                    "title": "Response Root  Get"
                                }
                            }
                        },
                        "description": "Successful Response"
                    }
                },
                "operationId": "root__get"
            }
        },
        "/predictions": {
            "post": {
                "summary": "Predict",
                "responses": {
                    "200": {
                        "content": {
                            "application/json": {
                                "schema": {
                                    "$ref": "#/components/schemas/PredictionResponse"
                                }
                            }
                        },
                        "description": "Successful Response"
                    }
                },
                "description": "Run a single prediction on the model",
                "operationId": "predict_predictions_post"
            }
        }
    },
    "openapi": "3.1.0",
    "components": {
        "schemas": {
            "Input": {
                "type": "object",
                "title": "Input",
                "properties": {
                    "s": {
                        "type": "string",
                        "title": "S"
                    }
                }
            },
            "Output": {
                "type": "string",
                "title": "Output"
            },
            "PredictionResponse": {
                "type": "object",
                "title": "PredictionResponse",
                "properties": {
                    "output": {
                        "$ref": "#/components/schemas/Output"
                    }
                }
            }
        }
    }
}


================================================
FILE: integration-tests/tests/torch_baseimage_precompile.txtar
================================================
[short] skip 'slow test - run without -short flag'
skip 'temporarily disabled - takes ~10min, blocking CI'

# Test Torch 1.13.0 base image with --precompile flag
cog build -t $TEST_IMAGE --openapi-schema openapi.json --use-cog-base-image=false --precompile

-- cog.yaml --
build:
  gpu: true
  python_version: "3.10"
  python_packages:
    - "torch==1.13.0"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, s: str) -> str:
        return "hello " + s

-- openapi.json --
{
    "info": {
        "title": "Cog",
        "version": "0.1.0"
    },
    "paths": {
        "/": {
            "get": {
                "summary": "Root",
                "responses": {
                    "200": {
                        "content": {
                            "application/json": {
                                "schema": {
                                    "title": "Response Root  Get"
                                }
                            }
                        },
                        "description": "Successful Response"
                    }
                },
                "operationId": "root__get"
            }
        },
        "/predictions": {
            "post": {
                "summary": "Predict",
                "responses": {
                    "200": {
                        "content": {
                            "application/json": {
                                "schema": {
                                    "$ref": "#/components/schemas/PredictionResponse"
                                }
                            }
                        },
                        "description": "Successful Response"
                    }
                },
                "description": "Run a single prediction on the model",
                "operationId": "predict_predictions_post"
            }
        }
    },
    "openapi": "3.1.0",
    "components": {
        "schemas": {
            "Input": {
                "type": "object",
                "title": "Input",
                "properties": {
                    "s": {
                        "type": "string",
                        "title": "S"
                    }
                }
            },
            "Output": {
                "type": "string",
                "title": "Output"
            },
            "PredictionResponse": {
                "type": "object",
                "title": "PredictionResponse",
                "properties": {
                    "output": {
                        "$ref": "#/components/schemas/Output"
                    }
                }
            }
        }
    }
}


================================================
FILE: integration-tests/tests/torch_cuda_baseimage.txtar
================================================
# Test that cog correctly resolves a cog-base image for Torch 2.0.1+cu118
# and generates the expected Dockerfile with --use-cog-base-image.
#
# Uses a local test registry seeded with a dummy image to avoid depending
# on the live r8.im registry, which can be flaky in CI.

# Start local registry and seed a dummy cog-base image
registry-start
registry-seed alpine:latest cog-base:cuda11.8-python3.10-torch2.0.1

# Generate Dockerfile using the local registry as the cog-base source
env COG_REGISTRY_HOST=$TEST_REGISTRY
cog debug --use-cog-base-image
stdout 'FROM.*cog-base:cuda11.8-python3.10-torch2.0.1'

-- cog.yaml --
build:
  gpu: true
  python_version: "3.10"
  python_packages:
    - "torch==2.0.1+cu118"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, s: str) -> str:
        return "hello " + s


================================================
FILE: integration-tests/tests/train_basic.txtar
================================================
skip 'cog train requires static schema gen which is gated behind COG_STATIC_SCHEMA=1'

# Test basic training functionality

# Train with input (no pre-built image, runs from cog.yaml)
cog train -i n=42

# Verify weights file was created with correct size
exec test -f weights.bin
exec sh -c 'wc -c < weights.bin'
stdout '42'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"
train: "train.py:train"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self, s: str) -> str:
        return "hello " + s

-- train.py --
from cog import BaseModel, Input, Path


class TrainingOutput(BaseModel):
    weights: Path


def train(
    n: int,
) -> TrainingOutput:
    with open("weights.bin", "w") as fh:
        for _ in range(n):
            fh.write("a")

    return TrainingOutput(
        weights=Path("weights.bin"),
    )


================================================
FILE: integration-tests/tests/train_deprecated.txtar
================================================
# Test that the train command shows a deprecation warning

# Train command should fail (no cog.yaml) but still show deprecation warning
! cog train
stderr 'Command "train" is deprecated'
stderr 'will be removed in a future version of Cog'

-- predict.py --
from cog import BasePredictor

class Predictor(BasePredictor):
    def predict(self, s: str) -> str:
        return "hello " + s


================================================
FILE: integration-tests/tests/training_setup.txtar
================================================
skip 'cog train requires static schema gen which is gated behind COG_STATIC_SCHEMA=1'

# Test that training with setup method works correctly

# Train with input
cog train -i s=world
stderr 'Trainer is setting up.'

# Verify weights file was created with correct content
exec cat weights
stdout 'hello train world'

-- cog.yaml --
build:
  python_version: "3.12"
train: "train.py:Trainer"
predict: "predict.py:Predictor"

-- train.py --
from cog import BasePredictor

class Trainer(BasePredictor):
    def setup(self) -> None:
        print("Trainer is setting up.")

    def train(self, s: str) -> str:
        print("Trainer.train called.")
        return "hello train " + s

    def predict(self, s: str) -> str:
        print("Trainer.predict called.")
        return "hello predict " + s

-- predict.py --
from cog import BasePredictor

class Predictor(BasePredictor):
    def predict(self, s: str) -> str:
        return "hello " + s


================================================
FILE: integration-tests/tests/union_type.txtar
================================================
# Test new-style union type annotation (str | None)

# Build the image
cog build -t $TEST_IMAGE

# Predict with input
cog predict $TEST_IMAGE -i text=world
stdout 'hello world'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor, Input


class Predictor(BasePredictor):
    def setup(self):
        self.prefix = "hello"

    def predict(
        self,
        text: str | None = Input(
            description="Text to prefix with 'hello '", default=None
        ),
    ) -> str:
        return self.prefix + " " + text


================================================
FILE: integration-tests/tests/webhook_delivery_failure.txtar
================================================
# Test that webhook delivery failure does not crash the server.
#
# When the webhook URL is unreachable, coglet retries with backoff
# but eventually gives up. The server should remain healthy.
#
# --upload-url is set to a dummy so cog serve adds host networking.

cog serve --upload-url http://unused/

# Async prediction with a bogus webhook URL — delivery will fail
curl -H Prefer:respond-async POST /predictions '{"id":"webhook-fail-test","webhook":"http://host.docker.internal:1/nonexistent","webhook_events_filter":["completed"]}'

# Server should remain healthy while webhook delivery retries/fails.
# Poll repeatedly instead of sleeping a fixed duration.
exec bash -c 'for i in {1..10}; do curl -sf $SERVER_URL/health-check | grep -q "\"status\":\"READY\"" || exit 1; sleep 0.5; done'

# A new sync prediction should still work
curl POST /predictions '{"input":{}}'
stdout '"status":"succeeded"'
stdout '"output":"ok"'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self) -> str:
        return "ok"


================================================
FILE: integration-tests/tests/webhook_prediction_error.txtar
================================================
# Test that a failed prediction delivers the correct webhook payload.
#
# When predict() raises an exception during an async prediction, the webhook
# should receive status "failed" with the error message populated.

webhook-server-start
cog serve --upload-url http://unused/

# Async prediction that raises an exception
curl -H Prefer:respond-async POST /predictions '{"id":"webhook-error-test","webhook":"$WEBHOOK_URL","webhook_events_filter":["completed"]}'

webhook-server-wait

# Webhook should report failure with error
stdout '"status":"failed"'
stdout '"has_error":true'
stdout '"error_message":".*prediction went wrong"'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor


class Predictor(BasePredictor):
    def predict(self) -> str:
        raise RuntimeError("prediction went wrong")


================================================
FILE: integration-tests/tests/weights_build.txtar
================================================
# Test that cog weights build generates weights.lock

# Build should fail without weights section
! cog weights build
stderr 'no weights defined'

# Add weights section and create weight file
cp cog-with-weights.yaml cog.yaml
mkdir models
exec sh -c 'echo "test model content" > models/model.bin'

# Build weights.lock
cog weights build
stderr 'Generated weights.lock'
stderr '1 file'

# Verify weights.lock was created
exists weights.lock

# Verify weights.lock contains expected content
exec grep -q '"name": "model"' weights.lock
exec grep -q '"dest": "/cache/model.bin"' weights.lock
exec grep -q '"digestOriginal": "sha256:' weights.lock

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- cog-with-weights.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"
weights:
  - name: model
    source: models/model.bin
    target: /cache/model.bin

-- predict.py --
from cog import BasePredictor

class Predictor(BasePredictor):
    def predict(self, s: str) -> str:
        return "hello " + s


================================================
FILE: integration-tests/tests/weights_push_inspect.txtar
================================================
# Test weights push and inspect lifecycle against a local registry.
# Verifies: cog weights build -> cog weights push -> cog weights inspect (synced).

[short] skip 'requires local registry'

# Start test registry
registry-start

# Create weight files (small, deterministic)
mkdir weights
exec sh -c 'dd if=/dev/zero bs=512 count=1 2>/dev/null | tr "\0" "A" > weights/model-a.bin'
exec sh -c 'dd if=/dev/zero bs=512 count=1 2>/dev/null | tr "\0" "B" > weights/model-b.bin'

# Step 1: Build weights.lock
cog weights build
stderr 'Generated weights.lock'
stderr '2 file'
exists weights.lock

# Verify lock file structure
exec grep -q '"name": "alpha"' weights.lock
exec grep -q '"name": "beta"' weights.lock
exec grep -q '"digest": "sha256:' weights.lock

# Step 2: Push weights to local registry (repo only, no tag)
cog weights push $TEST_REGISTRY/test/weights-model
stderr 'Pushed 2 weight artifact'
# Push output should show the full ref for each weight
stderr 'weights-alpha-'
stderr 'weights-beta-'

# Verify tags with :tag are rejected
! cog weights push $TEST_REGISTRY/test/weights-model:v1
stderr 'includes a tag or digest'

# Step 3: Inspect — both weights should be synced
cog weights inspect $TEST_REGISTRY/test/weights-model --json
stdout '"status": "synced"'
! stdout '"status": "local-only"'
! stdout '"status": "digest-mismatch"'
# Inspect should show the ref and layers for each weight
stdout '"ref":'
stdout '"layers":'

# Verify tags with :tag are rejected for inspect too
! cog weights inspect $TEST_REGISTRY/test/weights-model:v1
stderr 'includes a tag or digest'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"
weights:
  - name: alpha
    source: weights/model-a.bin
    target: /weights/model-a.bin
  - name: beta
    source: weights/model-b.bin
    target: /weights/model-b.bin

-- predict.py --
from cog import BasePredictor

class Predictor(BasePredictor):
    def predict(self, s: str) -> str:
        return "hello " + s


================================================
FILE: integration-tests/tests/wheel_coglet_missing.txtar
================================================
# Test COGLET_WHEEL with non-existent path gives clear error
#
# This test verifies that when COGLET_WHEEL points to a non-existent file,
# a clear error message is shown.

env COG_SDK_WHEEL=$REPO_ROOT/dist
env COGLET_WHEEL=/nonexistent/path/coglet.whl
! cog build -t $TEST_IMAGE
stderr 'path not found'
stderr '/nonexistent/path/coglet.whl'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor

class Predictor(BasePredictor):
    def predict(self, s: str) -> str:
        return "hello " + s


================================================
FILE: integration-tests/tests/wheel_resolution.txtar
================================================
# Test wheel resolution from environment variables
#
# This test verifies that COG_SDK_WHEEL and COGLET_WHEEL environment variables
# correctly resolve wheel paths, including:
# - COG_SDK_WHEEL pointing to a directory resolves wheels inside it
# - Clear errors when wheel not found
#
# Note: Tests run from a temp directory, so we use $REPO_ROOT/dist
# (an absolute path exported by the test harness) to find wheels.

# Test 1: COG_SDK_WHEEL pointing to repo dist/ directory finds wheel
env COG_SDK_WHEEL=$REPO_ROOT/dist
cog build --debug -t $TEST_IMAGE
stderr 'Using local cog wheel:'

# Test 2: Relative path that doesn't exist gives clear error  
env COG_SDK_WHEEL=./nonexistent/wheel.whl
! cog build -t $TEST_IMAGE
stderr 'path not found'

-- cog.yaml --
build:
  python_version: "3.12"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor

class Predictor(BasePredictor):
    def predict(self, s: str) -> str:
        return "hello " + s


================================================
FILE: integration-tests/tests/zsh_package.txtar
================================================
# Test that zsh system package is installed and available in /bin
cog build -t $TEST_IMAGE

cog predict $TEST_IMAGE
stdout ',sh,'
stdout ',zsh,'

-- cog.yaml --
build:
  python_version: "3.12"
  system_packages:
    - "zsh"
predict: "predict.py:Predictor"

-- predict.py --
from cog import BasePredictor
import os

class Predictor(BasePredictor):
    def predict(self) -> str:
        return "hello " + ",".join(os.listdir("/bin"))


================================================
FILE: mise.toml
================================================
# =============================================================================
# Cog Development Tasks
# =============================================================================
#
# Run `mise tasks` to see all available tasks.
#
# ## Task Caching
#
# Some build tasks use `sources` and `outputs` for caching. When sources haven't
# changed since the last run, mise skips the task. To force a rebuild:
#
#     mise run build:sdk --force
#     mise run build:coglet:wheel --force
#
# Cached tasks:
#   - build:sdk              - skips if python/**/*.py unchanged
#   - build:coglet:wheel*    - skips if crates/**/*.rs unchanged
#   - generate:stubs         - skips if coglet-python source unchanged
#   - generate:compat        - skips if tools/compatgen unchanged
#   - docs                   - skips if docs/**/*.md unchanged
#
# Tasks that always run (no caching):
#   - fmt:*, lint:*, test:*  - always check current state
#   - clean:*                - always destructive
#
# =============================================================================

experimental_monorepo_root = true

[monorepo]
config_roots = ["."]

[tools]

go = "latest"
uv = "0.9.26"
"pipx:nox" = { version = "2025.11.12", uvx = true, uvx_args = "--python-preference=managed -p 3.13" }
"aqua:ziglang/zig" = "0.15.2"
"rust" = { version = "1.93.0", components = "rustfmt,clippy", targets ="x86_64-unknown-linux-gnu,aarch64-unknown-linux-gnu,aarch64-apple-darwin" }
cargo-binstall = "1.16.6"
# Cargo tools - use aqua backend where available for faster binary downloads
# and better security (cosign/SLSA verification). Remaining cargo: tools use binstall.
"aqua:EmbarkStudios/cargo-deny" = "0.19.0"
"aqua:mitsuhiko/insta" = "1.46.0"
"cargo:cargo-nextest" = "0.9.120"
"cargo:maturin" = "1.11.5"
"aqua:rust-lang/rustup" = "latest"
"aqua:rust-lang/rustup/rustup-init" = "latest"
"aqua:rust-cross/cargo-zigbuild" = "0.20.1"
"aqua:gotestyourself/gotestsum" = "1.13.0"
"aqua:golangci/golangci-lint" = "2.10.1"
ruff = "0.14.13"
ty = "0.0.10"

[env]
_.path = "./bin"
_.file = [".env"]
_.python.venv = ".venv"
# Set REPO_ROOT only if not already set (e.g., by CI)
REPO_ROOT = "{{env.REPO_ROOT | default(value=config_root)}}"
# CGo required for go-tree-sitter (static Python schema parser)
CGO_ENABLED = "1"

[settings]
lockfile = true
experimental = true

# =============================================================================
# Helper tasks (hidden)
# =============================================================================

[tasks._setup_dist]
hide = true
silent = true
description = "Create dist directory"
run = "mkdir -p dist"

[tasks._setup_venv]
hide = true
silent = true
description = "Ensure root .venv exists with Python"
run = "test -d .venv || uv venv --quiet"

[tasks._clean_dist]
hide = true
silent = true
description = "Clean dist directory"
run = "rm -f dist/cog-*.whl dist/cog-*.tar.gz dist/coglet-*.whl"

# =============================================================================
# Build tasks
# =============================================================================

[tasks.build]
alias = "build:all"
description = "Build all components"
run = [
  { tasks = ["build:cog", "build:coglet:wheel:linux-x64", "build:sdk"] },
  { task = "_build_summary" },
]

[tasks._build_summary]
hide = true
description = "Print build artifacts summary"
run = """
#!/usr/bin/env bash
echo ""
echo "=== Build Artifacts ==="
if BINARY=$(ls dist/go/*/cog 2>/dev/null | head -1); then
  VERSION=$("$BINARY" --version 2>/dev/null || echo "unknown")
  echo "  cli:        $BINARY ($VERSION)"
fi
for whl in dist/coglet-*.whl; do
  [ -f "$whl" ] && echo "  coglet:     $whl"
done
for whl in dist/cog-*.whl; do
  [ -f "$whl" ] && echo "  python-sdk: $whl"
done
echo ""
"""

[tasks.install]
depends = ["build:cog"]
description = "Build and symlink cog CLI"
usage = 'arg "[dest]" help="Directory to symlink into (e.g. ~/.local/bin)" default="~/.local/bin"'
run = """
#!/usr/bin/env bash
set -e
DEST="${usage_dest/#\\~/$HOME}"
BINARY=$(ls dist/go/*/cog 2>/dev/null | head -1)
if [ -z "$BINARY" ]; then
    echo "Error: no cog binary found in dist/go/. Run 'mise run build:cog' first." >&2
    exit 1
fi
BINARY="$(cd "$(dirname "$BINARY")" && pwd)/$(basename "$BINARY")"
mkdir -p "$DEST"
ln -sf "$BINARY" "$DEST/cog"
echo "Installed $DEST/cog -> $BINARY"
"""

[tasks."build:cog"]
description = "Build cog CLI (development)"
sources = ["cmd/**/*.go", "pkg/**/*.go", "go.mod", "go.sum", "crates/Cargo.toml"]
outputs = ["dist/go/*/cog"]
run = """
#!/usr/bin/env bash
set -e
# Don't set COG_VERSION — let goreleaser's snapshot template produce a dev version
# (e.g. 0.17.1-dev+gabcdef) so local wheel auto-detection works.
GOFLAGS=-buildvcs=false go run github.com/goreleaser/goreleaser/v2@latest build --clean --snapshot --single-target --id cog --output cog
"""

[tasks."build:cog:release"]
description = "Build cog CLI (release)"
run = "go run github.com/goreleaser/goreleaser/v2@latest build --clean --single-target --id cog --output cog"

[tasks."build:rust"]
description = "Build Rust workspace"
run = "cargo build --manifest-path crates/Cargo.toml --workspace"

[tasks."build:rust:release"]
description = "Build Rust workspace (release)"
run = "cargo build --manifest-path crates/Cargo.toml --workspace --release"

[tasks."build:coglet"]
description = "Build coglet Python wheel (development, local install)"
run = [
  { task = "_setup_venv" },
  "maturin develop --manifest-path crates/coglet-python/Cargo.toml",
]

[tasks."build:coglet:wheel"]
description = "Build coglet Python wheel (native platform)"
# No sources/outputs caching: the output glob dist/coglet-*.whl is too broad
# and falsely matches cross-compiled wheels (e.g. linux-x64), causing skips.
# Use --force if you need to bypass mise's staleness check.
run = [
  { tasks = ["_setup_dist", "_setup_venv"] },
  "maturin build --release --out dist --manifest-path crates/coglet-python/Cargo.toml",
]

[tasks."build:coglet:wheel:linux-x64"]
description = "Build coglet Python wheel for Linux x86_64"
sources = ["crates/**/*.rs", "crates/**/Cargo.toml", "Cargo.lock"]
outputs = ["dist/coglet-*manylinux*x86_64*.whl"]
run = [
  { tasks = ["_setup_dist", "_setup_venv"] },
  "maturin build --release --out dist --manifest-path crates/coglet-python/Cargo.toml --target x86_64-unknown-linux-gnu --zig",
]

[tasks."build:coglet:wheel:linux-arm64"]
description = "Build coglet Python wheel for Linux ARM64"
sources = ["crates/**/*.rs", "crates/**/Cargo.toml", "Cargo.lock"]
outputs = ["dist/coglet-*manylinux*aarch64*.whl"]
run = [
  { tasks = ["_setup_dist", "_setup_venv"] },
  "maturin build --release --out dist --manifest-path crates/coglet-python/Cargo.toml --target aarch64-unknown-linux-gnu --zig",
]

[tasks."build:coglet:wheel:darwin-arm64"]
description = "Build coglet Python wheel for macOS ARM64"
sources = ["crates/**/*.rs", "crates/**/Cargo.toml", "Cargo.lock"]
outputs = ["dist/coglet-*-macosx_*_arm64.whl"]
run = [
  { tasks = ["_setup_dist", "_setup_venv"] },
  "maturin build --release --out dist --manifest-path crates/coglet-python/Cargo.toml --target aarch64-apple-darwin",
]

[tasks."build:coglet:wheel:darwin-x64"]
description = "Build coglet Python wheel for macOS x86_64"
sources = ["crates/**/*.rs", "crates/**/Cargo.toml", "Cargo.lock"]
outputs = ["dist/coglet-*-macosx_*_x86_64.whl"]
run = [
  { tasks = ["_setup_dist", "_setup_venv"] },
  "maturin build --release --out dist --manifest-path crates/coglet-python/Cargo.toml --target x86_64-apple-darwin",
]

[tasks."build:coglet:wheel:all"]
description = "Build coglet Python wheels for all platforms"
run = [
  { task = "_setup_dist" },
  { tasks = ["build:coglet:wheel:linux-x64", "build:coglet:wheel:linux-arm64"] },
  { task = "build:coglet:wheel" },
]

[tasks."build:sdk"]
description = "Build cog SDK wheel"
sources = ["python/**/*.py", "pyproject.toml", "crates/Cargo.toml"]
outputs = ["dist/cog-*.whl", "dist/cog-*.tar.gz"]
run = [
  { tasks = ["_setup_dist", "_setup_venv"] },
  """
#!/usr/bin/env bash
set -euo pipefail
# Version from Cargo.toml, converted to PEP 440
RAW=$(grep '^version' crates/Cargo.toml | head -1 | sed 's/.*"\\(.*\\)"/\\1/')
export SETUPTOOLS_SCM_PRETEND_VERSION=$(echo "$RAW" | sed -E 's/-alpha/a/; s/-beta/b/; s/-rc/rc/; s/-dev/.dev/')
echo "Building SDK wheel: $SETUPTOOLS_SCM_PRETEND_VERSION"
uv build --out-dir=dist .
""",
]

[tasks."build:wheels"]
description = "Build all wheels (coglet + sdk)"
run = [
  { task = "_clean_dist" },
  { task = "_setup_dist" },
  { tasks = ["build:coglet:wheel:all", "build:sdk"] },
]

# =============================================================================
# Test tasks
# =============================================================================

[tasks.test]
description = "Run all unit tests (set INTEGRATION_TESTS=1 to include integration)"
run = """
#!/usr/bin/env bash
set -e
mise run test:go
mise run test:rust
mise run test:python
if [ "${INTEGRATION_TESTS:-}" = "1" ]; then
  mise run test:integration
fi
"""

[tasks."test:go"]
description = "Run Go tests"
run = "gotestsum -- -short -timeout 1200s -parallel 5 ./..."

[tasks."test:rust"]
description = "Run Rust workspace tests"
run = "cargo nextest run --manifest-path crates/Cargo.toml --workspace --exclude coglet-python --no-tests=pass"

[tasks."test:python"]
description = "Run Python SDK tests (latest supported Python)"
depends = ["build:coglet:wheel"]
run = "nox -s tests -p 3.13"

[tasks."test:python:all"]
description = "Run Python SDK tests on all supported Python versions"
depends = ["build:coglet:wheel"]
run = "nox -s tests"

[tasks."test:coglet:python"]
description = "Run coglet Python binding tests (latest Python)"
depends = ["build:coglet:wheel"]
run = "nox -s coglet -p 3.13"

[tasks."test:coglet:python:all"]
description = "Run coglet Python binding tests on all supported Python versions"
depends = ["build:coglet:wheel"]
run = "nox -s coglet"

[tasks."test:fuzz"]
description = "Run Go fuzz tests (FUZZTIME=30s per target by default)"
run = """
#!/usr/bin/env bash
set -e
FUZZTIME="${FUZZTIME:-30s}"
echo "Fuzzing schema type resolution ($FUZZTIME)..."
go test ./pkg/schema/ -run='^$' -fuzz=FuzzResolveSchemaType -fuzztime="$FUZZTIME"
echo "Fuzzing JSON schema generation ($FUZZTIME)..."
go test ./pkg/schema/ -run='^$' -fuzz=FuzzJSONSchema -fuzztime="$FUZZTIME"
echo "Fuzzing Python parser ($FUZZTIME)..."
go test ./pkg/schema/python/ -run='^$' -fuzz=FuzzParsePredictor -fuzztime="$FUZZTIME"
echo "Fuzzing type annotation parsing ($FUZZTIME)..."
go test ./pkg/schema/python/ -run='^$' -fuzz=FuzzParseTypeAnnotation -fuzztime="$FUZZTIME"
echo "All fuzz targets passed."
"""

[tasks."test:integration"]
description = "Run integration tests (skips slow tests by default, set SHORT=0 for full suite)"
depends = ["clean:integration", "build:cog", "build:sdk", "build:coglet:wheel:linux-x64"]
run = """
#!/usr/bin/env bash
set -e
SHORT_FLAG="-short"
if [ "${SHORT:-1}" = "0" ]; then
  SHORT_FLAG=""
fi
# If first arg is a bare name (no dash), treat as test name filter;
# remaining args are passed through to go test.
# e.g. mise run test:integration coglet_large_output -count=4
if [ $# -gt 0 ] && [[ "$1" != -* ]]; then
  gotestsum -- -tags integration -v $SHORT_FLAG -run "TestIntegration/$1" "${@:2}" -timeout 30m ./integration-tests/...
else
  gotestsum -- -tags integration -v $SHORT_FLAG -parallel ${TEST_PARALLEL:-4} "$@" -timeout 30m ./integration-tests/...
fi
"""

# =============================================================================
# Format tasks
# =============================================================================

[tasks.fmt]
alias = ["format", "fmt:check", "format:check"]
description = "Check formatting for all languages (non-destructive)"
run = [
  { tasks = ["fmt:go", "fmt:rust", "fmt:python", "fmt:docs"] },
]

[tasks."fmt:fix"]
alias = "format:fix"
description = "Fix formatting for all languages"
run = [
  { tasks = ["fmt:go:fix", "fmt:rust:fix", "fmt:python:fix", "fmt:docs:fix"] },
]

# Go formatting
[tasks."fmt:go"]
alias = "fmt:go:check"
description = "Check Go formatting"
run = """
#!/usr/bin/env bash
set -e
go tool goimports -d .
test -z "$(go tool goimports -l .)"
"""

[tasks."fmt:go:fix"]
description = "Fix Go formatting"
run = "go tool goimports -w -d ."

# Rust formatting
[tasks."fmt:rust"]
alias = "fmt:rust:check"
description = "Check Rust formatting"
run = "cargo fmt --manifest-path crates/Cargo.toml --all -- --check"

[tasks."fmt:rust:fix"]
description = "Fix Rust formatting"
run = "cargo fmt --manifest-path crates/Cargo.toml --all"

# Python formatting
[tasks."fmt:python"]
alias = "fmt:python:check"
description = "Check Python formatting"
run = "ruff format --check ."

[tasks."fmt:python:fix"]
description = "Fix Python formatting"
run = "ruff format ."

# Docs formatting
[tasks."fmt:docs"]
alias = "fmt:docs:check"
description = "Check docs formatting"
run = "npx prettier --check 'docs/**/*.md' README.md"

[tasks."fmt:docs:fix"]
description = "Fix docs formatting"
run = "npx prettier -w 'docs/**/*.md' README.md"

# =============================================================================
# Lint tasks
# =============================================================================

[tasks.lint]
alias = "lint:check"
description = "Run linters for all languages (non-destructive)"
run = [
  { tasks = ["lint:go", "lint:rust", "lint:python"] },
]

[tasks."lint:fix"]
description = "Fix lint issues for all languages"
run = [
  { tasks = ["lint:go:fix", "lint:rust:fix", "lint:python:fix"] },
]

# Go linting
[tasks."lint:go"]
alias = "lint:go:check"
description = "Lint Go code"
run = "golangci-lint run ./..."

[tasks."lint:go:fix"]
description = "Fix Go lint issues (limited auto-fix)"
run = "golangci-lint run --fix ./..."

# Rust linting
[tasks."lint:rust"]
alias = ["lint:rust:check", "lint:rust:clippy"]
description = "Lint Rust code (clippy)"
run = "cargo clippy --manifest-path crates/Cargo.toml --workspace -- -D warnings"

[tasks."lint:rust:deny"]
description = "Check Rust licenses and advisories"
run = "cargo deny --manifest-path crates/Cargo.toml check"

[tasks."lint:rust:fix"]
description = "Fix Rust lint issues"
run = "cargo clippy --manifest-path crates/Cargo.toml --workspace --fix --allow-dirty -- -D warnings"

# Python linting
[tasks."lint:python"]
alias = "lint:python:check"
description = "Lint Python code"
run = """
#!/usr/bin/env bash
set -e
ruff check .
mise run typecheck:python
"""

[tasks."lint:python:fix"]
description = "Fix Python lint issues"
run = "ruff check --fix ."

# =============================================================================
# Typecheck tasks
# =============================================================================

[tasks.typecheck]
description = "Run type checking for all languages"
run = [
  { tasks = ["typecheck:rust", "typecheck:python"] },
]

[tasks."typecheck:rust"]
description = "Type check Rust code (cargo check)"
run = "cargo check --manifest-path crates/Cargo.toml --workspace"

[tasks."typecheck:python"]
description = "Type check Python code"
run = "nox -s typecheck"

# =============================================================================
# Generate tasks
# =============================================================================

[tasks.generate]
description = "Run all code generation"
run = [
  { tasks = ["generate:stubs"] },
]

[tasks."generate:stubs"]
alias = "stub:generate"
description = "Generate Python type stubs for coglet"
sources = ["crates/coglet-python/src/**/*.rs", "crates/coglet-python/Cargo.toml", "crates/coglet-python/coglet/__init__.py"]
outputs = ["crates/coglet-python/coglet/_sdk/__init__.pyi", "crates/coglet-python/coglet/__init__.pyi"]
dir = "crates/coglet-python"
run = [
  { task = "_setup_venv" },
  "uv run --active cargo run --bin stub_gen",
]

[tasks."generate:compat"]
description = "Regenerate CUDA/PyTorch/TensorFlow compatibility matrices"
sources = ["tools/compatgen/**/*.go"]
outputs = ["pkg/config/cuda_base_images.json", "pkg/config/torch_compatibility_matrix.json", "pkg/config/tf_compatibility_matrix.json"]
run = """
#!/usr/bin/env bash
set -e
target="${1:-all}"
case "$target" in
  cuda)
    echo "Generating CUDA base images..."
    go run ./tools/compatgen/main.go cuda -o pkg/config/cuda_base_images.json
    ;;
  torch)
    echo "Generating PyTorch compatibility matrix..."
    go run ./tools/compatgen/main.go torch -o pkg/config/torch_compatibility_matrix.json
    ;;
  tensorflow|tf)
    echo "Generating TensorFlow compatibility matrix..."
    go run ./tools/compatgen/main.go tensorflow -o pkg/config/tf_compatibility_matrix.json
    ;;
  all)
    echo "Generating CUDA base images..."
    go run ./tools/compatgen/main.go cuda -o pkg/config/cuda_base_images.json
    echo "Generating PyTorch compatibility matrix..."
    go run ./tools/compatgen/main.go torch -o pkg/config/torch_compatibility_matrix.json
    echo "Generating TensorFlow compatibility matrix..."
    go run ./tools/compatgen/main.go tensorflow -o pkg/config/tf_compatibility_matrix.json
    ;;
  *)
    echo "Unknown target: $target"
    echo "Usage: mise run generate:compat [cuda|torch|tensorflow|all]"
    exit 1
    ;;
esac
echo "Done."
"""

# =============================================================================
# Stub tasks
# =============================================================================

[tasks."stub:check"]
description = "Check that coglet Python stubs are up to date"
dir = "crates/coglet-python"
run = [
  { task = "generate:stubs" },
  '''
#!/usr/bin/env bash
set -e
if ! git diff --quiet -- '**/*.pyi'; then
  echo "ERROR: Stubs are out of date:"
  git diff -- '**/*.pyi'
  echo ""
  echo "Run 'mise run generate:stubs' to update."
  exit 1
fi
echo "Stubs are up to date."
''',
]

[tasks."stub:typecheck"]
description = "Type-check coglet stubs with ty"
run = "ty check crates/coglet-python/coglet/__init__.pyi"

# =============================================================================
# Clean tasks
# =============================================================================

[tasks.clean]
description = "Clean all build artifacts"
run = [
  { tasks = ["clean:go", "clean:rust", "clean:python", "clean:integration"] },
  { task = "_clean_dist" },
]

[tasks."clean:go"]
description = "Clean Go build artifacts"
run = "rm -rf cog base-image dist/go"

[tasks."clean:rust"]
description = "Clean Rust build artifacts"
run = "cd crates && cargo clean"

[tasks."clean:python"]
description = "Clean Python build artifacts"
run = "rm -rf .tox build python/cog.egg-info .venv crates/coglet-python/.venv crates/coglet-python/coglet/*.so"

[tasks."clean:integration"]
description = "Clean cached integration test binary and embedded wheels"
run = "rm -f integration-tests/.bin/cog pkg/wheels/cog-*.whl pkg/wheels/coglet-*.whl"

# =============================================================================
# Docs tasks
# =============================================================================

[tasks.docs]
description = "Build documentation"
sources = ["docs/**/*.md", "README.md", "CONTRIBUTING.md", "mkdocs.yml"]
outputs = ["site/**"]
run = """
#!/usr/bin/env bash
set -e
uv pip install mkdocs-material
sed 's/docs\\///g' README.md > ./docs/README.md
cp CONTRIBUTING.md ./docs/
mkdocs build
"""

[tasks."docs:serve"]
description = "Serve documentation locally"
run = """
#!/usr/bin/env bash
set -e
uv pip install mkdocs-material
sed 's/docs\\///g' README.md > ./docs/README.md
cp CONTRIBUTING.md ./docs/
mkdocs serve
"""

[tasks."docs:llm"]
description = "Update LLM documentation (llms.txt)"
depends = ["docs:cli"]
sources = ["README.md", "docs/*.md"]
outputs = ["docs/llms.txt"]
run = """
#!/usr/bin/env bash
set -e
# Concatenate README (minus contributors section) + all docs into llms.txt
# Use awk instead of sed for cross-platform compatibility (BSD sed vs GNU sed)
# Only include git-tracked files (docs/ may contain mkdocs-generated copies of CONTRIBUTING.md, README.md)
(awk '/^## Contributors/{exit} {print}' README.md; for file in $(git ls-files 'docs/*.md'); do printf '\n\n---\n\n' && cat "$file"; done) > docs/llms.txt
echo "Updated docs/llms.txt"
"""

[tasks."docs:llm:check"]
description = "Check that llms.txt is up to date"
run = """
#!/usr/bin/env bash
set -e
tmpfile=$(mktemp)
trap 'rm -f "$tmpfile"' EXIT
# Generate to temp file and compare
# Only include git-tracked files (docs/ may contain mkdocs-generated copies of CONTRIBUTING.md, README.md)
(awk '/^## Contributors/{exit} {print}' README.md; for file in $(git ls-files 'docs/*.md'); do printf '\n\n---\n\n' && cat "$file"; done) > "$tmpfile"
if ! diff -q "$tmpfile" docs/llms.txt > /dev/null 2>&1; then
  echo "ERROR: docs/llms.txt is out of date. Run 'mise run docs:llm' to update."
  exit 1
fi
echo "docs/llms.txt is up to date"
"""

[tasks."docs:cli"]
description = "Generate CLI reference documentation"
sources = ["pkg/cli/*.go", "cmd/cog/*.go"]
outputs = ["docs/cli.md"]
run = "go run ./tools/gendocs/main.go -o docs/cli.md"

[tasks."docs:cli:check"]
description = "Check that CLI docs are up to date"
run = """
#!/usr/bin/env bash
set -e
tmpfile=$(mktemp)
trap 'rm -f "$tmpfile"' EXIT
# Generate to temp file and compare
go run ./tools/gendocs/main.go -o "$tmpfile"
if ! diff -q "$tmpfile" docs/cli.md > /dev/null 2>&1; then
  echo "ERROR: docs/cli.md is out of date. Run 'mise run docs:cli' to update."
  exit 1
fi
echo "docs/cli.md is up to date"
"""

# =============================================================================
# CI tasks - granular for parallel execution and caching
# =============================================================================

# Build tasks (run first to produce artifacts)
[tasks."ci:build"]
description = "CI: Build all artifacts"
run = [
  { tasks = ["ci:build:sdk", "ci:build:coglet"] },
]

[tasks."ci:build:sdk"]
description = "CI: Build SDK wheel and sdist"
run = [
  { task = "_setup_dist" },
  """
#!/usr/bin/env bash
set -euo pipefail
# Version from Cargo.toml, converted to PEP 440
RAW=$(grep '^version' crates/Cargo.toml | head -1 | sed 's/.*"\\(.*\\)"/\\1/')
export SETUPTOOLS_SCM_PRETEND_VERSION=$(echo "$RAW" | sed -E 's/-alpha/a/; s/-beta/b/; s/-rc/rc/; s/-dev/.dev/')
echo "Building SDK wheel: $SETUPTOOLS_SCM_PRETEND_VERSION"
uv build --out-dir=dist .
""",
]

[tasks."ci:build:coglet"]
description = "CI: Build coglet wheel"
run = [
  { task = "_setup_dist" },
  { task = "build:coglet:wheel:linux-x64" },
]

[tasks."ci:test:integration"]
description = "CI: Run integration tests with GitHub Actions output (full suite)"
# exec ensures signals (SIGTERM from CI cancellation) go directly to gotestsum
run = "exec gotestsum --format github-actions -- -tags integration -parallel ${TEST_PARALLEL:-4} -timeout 30m ./integration-tests/..."

# =============================================================================
# Publish tasks (future)
# =============================================================================

[tasks."publish:coglet"]
description = "Publish coglet to PyPI"
run = """
#!/usr/bin/env bash
set -e
echo "TODO: Implement coglet PyPI publish"
echo "Wheels in dist/: $(ls dist/coglet-*.whl 2>/dev/null || echo 'none')"
"""

[tasks."publish:sdk"]
description = "Publish cog SDK to PyPI"
run = """
#!/usr/bin/env bash
set -e
echo "TODO: Implement SDK PyPI publish"
echo "Wheels in dist/: $(ls dist/cog-*.whl 2>/dev/null || echo 'none')"
"""


================================================
FILE: mkdocs.yml
================================================
site_name: Cog
repo_url: https://github.com/replicate/cog
docs_dir: docs/
nav:
  - README: README.md
  - Getting Started: getting-started.md
  - Using your own model: getting-started-own-model.md
  - Deploy your model: deploy.md
  - YAML spec: yaml.md
  - Prediction API: python.md
  - Training API: training.md
  - HTTP API: http.md
  - CLI: cli.md
  - Environment variables: environment.md
  - Private registry: private-package-registry.md
  - Notebooks: notebooks.md
  - Windows: wsl2/wsl2.md
  - Contributing: CONTRIBUTING.md
  - License: https://github.com/replicate/cog/blob/main/LICENSE
  - llms.txt: llms.txt

theme:
  name: material
  font:
    text: "Roboto"
    code: "Roboto Mono"
  favicon: favicon.svg

  # Display a link to edit pages right on GitHub
  features:
    - content.action.edit

  icon:
    logo: material/cog
    repo: simple/github
  palette:
    # Palette toggle for light mode
    - media: "(prefers-color-scheme: light)"
      scheme: default
      primary: black
      toggle:
        icon: material/weather-night
        name: Switch to dark mode

    # Palette toggle for dark mode
    - media: "(prefers-color-scheme: dark)"
      scheme: slate
      primary: black
      toggle:
        icon: material/weather-sunny
        name: Switch to light mode

edit_uri: edit/main/docs/

markdown_extensions:
  - toc:
      permalink: "#"
  - markdown.extensions.codehilite:
      guess_lang: true
  - admonition
  - codehilite
  - extra
  - pymdownx.highlight
  - pymdownx.superfences
extra_css:
  - stylesheets/extra.css

extra:
  # Hide the "made with material" thing
  generator: false
  social:
    - icon: simple/github
      link: https://github.com/replicate/cog
    - icon: simple/discord
      link: https://discord.gg/replicate
    - icon: simple/x
      link: https://x.com/replicate
    - icon: simple/youtube
      link: https://youtube.com/@replicatehq

copyright: Cog is an open-source project from Replicate


================================================
FILE: noxfile.py
================================================
"""Nox sessions for cog Python SDK testing."""

import glob
import platform

import nox

# Use uv for venv creation and Python management (uv auto-downloads Python if needed)
nox.options.default_venv_backend = "uv"

PYTHON_VERSIONS = ["3.10", "3.11", "3.12", "3.13"]
PYTHON_DEFAULT = "3.13"

# Test dependencies (mirrored from pyproject.toml [dependency-groups].test)
TEST_DEPS = [
    "pytest",
    "pytest-timeout",
    "pytest-xdist",
    "pytest-cov",
]


def _find_compatible_wheel(pattern: str) -> str | None:
    """Find a wheel matching the current platform from dist/.

    Returns None when no wheels exist at all.  Raises RuntimeError when
    wheels exist but none are compatible — that means the build produced
    the wrong platform and should be fixed, not silently papered over.
    """
    wheels = glob.glob(pattern)
    if not wheels:
        return None

    system = platform.system().lower()
    machine = platform.machine().lower()
    platform_tags = {
        ("darwin", "arm64"): "macosx",
        ("darwin", "x86_64"): "macosx",
        ("linux", "x86_64"): "manylinux",
        ("linux", "aarch64"): "manylinux",
    }
    tag = platform_tags.get((system, machine))
    if tag:
        for whl in wheels:
            if tag in whl or "none-any" in whl:
                return whl
        raise RuntimeError(
            f"Found wheel(s) in dist/ but none compatible with {system}/{machine}:\n"
            + "\n".join(f"  {w}" for w in wheels)
            + "\nRun 'mise run build:coglet:wheel' to build a native wheel."
        )

    # Unknown platform — let pip figure it out
    return wheels[0]


def _install_coglet(session: nox.Session) -> None:
    """Install coglet wheel (required dependency)."""
    whl = _find_compatible_wheel("dist/coglet-*.whl")
    if whl:
        session.install(whl)
    else:
        session.error(
            "No coglet wheel found in dist/. Run 'mise run build:coglet:wheel' first."
        )


def _install_package(session: nox.Session) -> None:
    """Install the cog SDK and coglet dependency."""
    _install_coglet(session)
    whl = _find_compatible_wheel("dist/cog-*.whl")
    if whl:
        session.install(whl)
    else:
        # No pre-built wheel — editable install from source.
        # This fails in CI (setuptools_scm needs a full git checkout),
        # so CI must run build:sdk first.
        session.install("-e", ".")


@nox.session(python=PYTHON_VERSIONS)
def tests(session: nox.Session) -> None:
    """Run the test suite."""
    _install_package(session)
    session.install(*TEST_DEPS)
    args = session.posargs or ["-n", "auto", "-vv"]
    session.run(
        "pytest",
        "python/tests",
        "--cov=python/cog",
        "--cov-report=term-missing:skip-covered",
        *args,
    )


@nox.session(python=PYTHON_DEFAULT)
def typecheck(session: nox.Session) -> None:
    """Run type checking with pyright."""
    _install_package(session)
    session.install("pyright==1.1.375")
    session.run("pyright", *session.posargs)


@nox.session(name="coglet", python=PYTHON_VERSIONS)
def coglet_tests(session: nox.Session) -> None:
    """Run coglet-python binding tests."""
    _install_package(session)
    session.install("pytest", "requests")
    session.run("pytest", "crates/coglet-python/tests", "-v", *session.posargs)


================================================
FILE: pkg/cli/baseimage.go
================================================
package cli

import (
	"context"
	"encoding/json"
	"fmt"
	"os"
	"strings"

	"github.com/spf13/cobra"

	"github.com/replicate/cog/pkg/config"
	"github.com/replicate/cog/pkg/docker"
	"github.com/replicate/cog/pkg/docker/command"
	"github.com/replicate/cog/pkg/dockercontext"
	"github.com/replicate/cog/pkg/dockerfile"
	"github.com/replicate/cog/pkg/global"
	"github.com/replicate/cog/pkg/registry"
	"github.com/replicate/cog/pkg/update"
	"github.com/replicate/cog/pkg/util/console"
)

var (
	baseImageCUDAVersion   string
	baseImagePythonVersion string
	baseImageTorchVersion  string
)

func NewBaseImageRootCommand() (*cobra.Command, error) {
	rootCmd := cobra.Command{
		Use:     "base-image",
		Short:   "Cog base image commands. This is an experimental feature with no guarantees of future support.",
		Version: fmt.Sprintf("%s (built %s)", global.Version, global.BuildTime),
		// This stops errors being printed because we print them in cmd/cog/cog.go
		PersistentPreRun: func(cmd *cobra.Command, args []string) {
			if global.Debug {
				console.SetLevel(console.DebugLevel)
			}
			cmd.SilenceUsage = true
			if err := update.DisplayAndCheckForRelease(cmd.Context()); err != nil {
				console.Debugf("%s", err)
			}
		},
		SilenceErrors: true,
	}
	setPersistentFlags(&rootCmd)

	rootCmd.AddCommand(
		newBaseImageDockerfileCommand(),
		newBaseImageBuildCommand(),
		newBaseImageGenerateMatrix(),
	)

	return &rootCmd, nil
}

func newBaseImageGenerateMatrix() *cobra.Command {
	var cmd = &cobra.Command{
		Use:   "generate-matrix",
		Short: "Generate a matrix of Cog base image versions (JSON)",
		RunE: func(cmd *cobra.Command, args []string) error {
			validCudaVersions := strings.FieldsFunc(baseImageCUDAVersion, func(c rune) bool {
				return c == ','
			})
			validPythonVersions := strings.FieldsFunc(baseImagePythonVersion, func(c rune) bool {
				return c == ','
			})
			validTorchVersions := strings.FieldsFunc(baseImageTorchVersion, func(c rune) bool {
				return c == ','
			})

			allConfigurations := dockerfile.BaseImageConfigurations()
			filteredMatrix := make([]dockerfile.BaseImageConfiguration, 0, len(allConfigurations))
			for _, config := range allConfigurations {
				var found bool
				if len(validCudaVersions) > 0 {
					found = false
					for _, validCudaVersion := range validCudaVersions {
						if config.CUDAVersion == validCudaVersion {
							found = true
						}
					}
					if !found {
						continue
					}
				}

				if len(validPythonVersions) > 0 {
					found = false
					for _, validPythonVersion := range validPythonVersions {
						if config.PythonVersion == validPythonVersion {
							found = true
						}
					}
					if !found {
						continue
					}
				}

				if len(validTorchVersions) > 0 {
					found = false
					for _, validTorchVersion := range validTorchVersions {
						if config.TorchVersion == validTorchVersion {
							found = true
						}
					}
					if !found {
						continue
					}
				}

				filteredMatrix = append(filteredMatrix, config)
			}

			output, err := json.Marshal(filteredMatrix)
			if err != nil {
				return err
			}
			fmt.Println(string(output))
			return nil
		},
		Args: cobra.MaximumNArgs(0),
	}
	addBaseImageFlags(cmd)
	return cmd
}

func newBaseImageDockerfileCommand() *cobra.Command {
	var cmd = &cobra.Command{
		Use:   "dockerfile",
		Short: "Display Cog base image Dockerfile",
		RunE: func(cmd *cobra.Command, args []string) error {
			ctx := cmd.Context()

			generator, err := baseImageGeneratorFromFlags(ctx)
			if err != nil {
				return err
			}
			dockerfile, err := generator.GenerateDockerfile(ctx)
			if err != nil {
				return err
			}
			fmt.Println(dockerfile)
			return nil
		},
		Args: cobra.MaximumNArgs(0),
	}
	addBaseImageFlags(cmd)
	addNoCacheFlag(cmd)
	addBuildProgressOutputFlag(cmd)

	return cmd
}

func newBaseImageBuildCommand() *cobra.Command {
	var cmd = &cobra.Command{
		Use:   "build",
		Short: "Build Cog base image",
		RunE: func(cmd *cobra.Command, args []string) error {
			ctx := cmd.Context()

			dockerClient, err := docker.NewClient(ctx)
			if err != nil {
				return err
			}

			generator, err := baseImageGeneratorFromFlags(ctx)
			if err != nil {
				return err
			}
			dockerfileContents, err := generator.GenerateDockerfile(ctx)
			if err != nil {
				return err
			}

			cwd, err := os.Getwd()
			if err != nil {
				return err
			}
			baseImageName := dockerfile.BaseImageName(baseImageCUDAVersion, baseImagePythonVersion, baseImageTorchVersion)

			buildOpts := command.ImageBuildOptions{
				WorkingDir:         cwd,
				DockerfileContents: dockerfileContents,
				ImageName:          baseImageName,
				NoCache:            buildNoCache,
				ProgressOutput:     buildProgressOutput,
				Epoch:              &config.BuildSourceEpochTimestamp,
				ContextDir:         dockercontext.StandardBuildDirectory,
			}
			if _, err := dockerClient.ImageBuild(ctx, buildOpts); err != nil {
				return err
			}
			fmt.Println("Successfully built image: " + baseImageName)
			return nil
		},
		Args: cobra.MaximumNArgs(0),
	}
	addBaseImageFlags(cmd)

	return cmd
}

func addBaseImageFlags(cmd *cobra.Command) {
	cmd.Flags().StringVar(&baseImageCUDAVersion, "cuda", "", "CUDA version")
	cmd.Flags().StringVar(&baseImagePythonVersion, "python", "", "Python version")
	cmd.Flags().StringVar(&baseImageTorchVersion, "torch", "", "Torch version")
	addBuildTimestampFlag(cmd)
}

func baseImageGeneratorFromFlags(ctx context.Context) (*dockerfile.BaseImageGenerator, error) {
	dockerClient, err := docker.NewClient(ctx)
	if err != nil {
		return nil, err
	}
	client := registry.NewRegistryClient()
	return dockerfile.NewBaseImageGenerator(
		ctx,
		client,
		baseImageCUDAVersion,
		baseImagePythonVersion,
		baseImageTorchVersion,
		dockerClient,
		true,
	)
}


================================================
FILE: pkg/cli/build.go
================================================
package cli

import (
	"fmt"
	"os"
	"strings"

	"github.com/spf13/cobra"
	"github.com/spf13/pflag"

	"github.com/replicate/cog/pkg/config"
	"github.com/replicate/cog/pkg/docker"
	"github.com/replicate/cog/pkg/model"
	"github.com/replicate/cog/pkg/registry"
	"github.com/replicate/cog/pkg/util/console"
)

var buildTag string
var buildSeparateWeights bool
var buildSecrets []string
var buildNoCache bool
var buildProgressOutput string
var buildSchemaFile string
var buildUseCudaBaseImage string
var buildDockerfileFile string
var buildUseCogBaseImage bool
var buildStrip bool
var buildPrecompile bool
var configFilename string

const useCogBaseImageFlagKey = "use-cog-base-image"

func newBuildCommand() *cobra.Command {
	cmd := &cobra.Command{
		Use:   "build",
		Short: "Build an image from cog.yaml",
		Long: `Build a Docker image from the cog.yaml in the current directory.

The generated image contains your model code, dependencies, and the Cog
runtime. It can be run locally with 'cog predict' or pushed to a registry
with 'cog push'.`,
		Example: `  # Build with default settings
  cog build

  # Build and tag the image
  cog build -t my-model:latest

  # Build without using the cache
  cog build --no-cache

  # Build with model weights in a separate layer
  cog build --separate-weights -t my-model:v1`,
		Args:    cobra.NoArgs,
		RunE:    buildCommand,
		PreRunE: checkMutuallyExclusiveFlags,
	}
	addBuildProgressOutputFlag(cmd)
	addSecretsFlag(cmd)
	addNoCacheFlag(cmd)
	addSeparateWeightsFlag(cmd)
	addSchemaFlag(cmd)
	addUseCudaBaseImageFlag(cmd)
	addDockerfileFlag(cmd)
	addUseCogBaseImageFlag(cmd)
	addBuildTimestampFlag(cmd)
	addStripFlag(cmd)
	addPrecompileFlag(cmd)
	addConfigFlag(cmd)
	cmd.Flags().StringVarP(&buildTag, "tag", "t", "", "A name for the built image in the form 'repository:tag'")
	return cmd
}

func buildCommand(cmd *cobra.Command, args []string) error {
	ctx := cmd.Context()

	dockerClient, err := docker.NewClient(ctx)
	if err != nil {
		return err
	}

	src, err := model.NewSource(configFilename)
	if err != nil {
		return err
	}

	imageName := src.Config.Image
	if buildTag != "" {
		imageName = buildTag
	}
	if imageName == "" {
		imageName = config.DockerImageName(src.ProjectDir)
	}

	console.Infof("Building Docker image from environment in cog.yaml as %s...", console.Bold(imageName))
	console.Info("")

	resolver := model.NewResolver(dockerClient, registry.NewRegistryClient())
	m, err := resolver.Build(ctx, src, buildOptionsFromFlags(cmd, imageName, nil))
	if err != nil {
		return err
	}

	console.Info("")
	console.Successf("Image built as %s", console.Bold(m.ImageRef()))

	return nil
}

func addBuildProgressOutputFlag(cmd *cobra.Command) {
	defaultOutput := os.Getenv("BUILDKIT_PROGRESS")
	if defaultOutput == "" {
		defaultOutput = "auto"
		if os.Getenv("TERM") == "dumb" {
			defaultOutput = "plain"
		}
	}
	cmd.Flags().StringVar(&buildProgressOutput, "progress", defaultOutput, "Set type of build progress output, 'auto' (default), 'tty', 'plain', or 'quiet'")
}

func addSecretsFlag(cmd *cobra.Command) {
	cmd.Flags().StringArrayVar(&buildSecrets, "secret", []string{}, "Secrets to pass to the build environment in the form 'id=foo,src=/path/to/file'")
}

func addNoCacheFlag(cmd *cobra.Command) {
	cmd.Flags().BoolVar(&buildNoCache, "no-cache", false, "Do not use cache when building the image")
}

func addSeparateWeightsFlag(cmd *cobra.Command) {
	cmd.Flags().BoolVar(&buildSeparateWeights, "separate-weights", false, "Separate model weights from code in image layers")
}

func addSchemaFlag(cmd *cobra.Command) {
	cmd.Flags().StringVar(&buildSchemaFile, "openapi-schema", "", "Load OpenAPI schema from a file")
}

func addUseCudaBaseImageFlag(cmd *cobra.Command) {
	cmd.Flags().StringVar(&buildUseCudaBaseImage, "use-cuda-base-image", "auto", "Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects")
}

func addDockerfileFlag(cmd *cobra.Command) {
	cmd.Flags().StringVar(&buildDockerfileFile, "dockerfile", "", "Path to a Dockerfile. If set, cog will use this Dockerfile instead of generating one from cog.yaml")
	cmd.Flags().VisitAll(func(f *pflag.Flag) {
		if f.Name == "dockerfile" {
			f.Hidden = true
		}
	})
}

func addUseCogBaseImageFlag(cmd *cobra.Command) {
	cmd.Flags().BoolVar(&buildUseCogBaseImage, useCogBaseImageFlagKey, true, "Use pre-built Cog base image for faster cold boots")
}

func addBuildTimestampFlag(cmd *cobra.Command) {
	cmd.Flags().Int64Var(&config.BuildSourceEpochTimestamp, "timestamp", -1, "Number of seconds since Epoch to use for the build timestamp; this rewrites the timestamp of each layer. Useful for reproducibility. (`-1` to disable timestamp rewrites)")
	_ = cmd.Flags().MarkHidden("timestamp")
}

func addStripFlag(cmd *cobra.Command) {
	const stripFlag = "strip"
	cmd.Flags().BoolVar(&buildStrip, stripFlag, false, "Whether to strip shared libraries for faster inference times")
	_ = cmd.Flags().MarkHidden(stripFlag)
}

func addPrecompileFlag(cmd *cobra.Command) {
	const precompileFlag = "precompile"
	cmd.Flags().BoolVar(&buildPrecompile, precompileFlag, false, "Whether to precompile python files for faster load times")
	_ = cmd.Flags().MarkHidden(precompileFlag)
}

func addConfigFlag(cmd *cobra.Command) {
	cmd.Flags().StringVarP(&configFilename, "file", "f", "cog.yaml", "The name of the config file.")
}

func checkMutuallyExclusiveFlags(cmd *cobra.Command, args []string) error {
	flags := []string{useCogBaseImageFlagKey, "use-cuda-base-image", "dockerfile"}
	var flagsSet []string
	for _, flag := range flags {
		if cmd.Flag(flag).Changed {
			flagsSet = append(flagsSet, "--"+flag)
		}
	}
	if len(flagsSet) > 1 {
		return fmt.Errorf("The flags %s are mutually exclusive: you can only set one of them.", strings.Join(flagsSet, " and "))
	}
	return nil
}

func DetermineUseCogBaseImage(cmd *cobra.Command) *bool {
	if !cmd.Flags().Changed(useCogBaseImageFlagKey) {
		return nil
	}
	useCogBaseImage := new(bool)
	*useCogBaseImage = buildUseCogBaseImage
	return useCogBaseImage
}

// buildOptionsFromFlags creates BuildOptions from the current CLI flag values.
// The imageName and annotations parameters vary by command and must be provided.
func buildOptionsFromFlags(cmd *cobra.Command, imageName string, annotations map[string]string) model.BuildOptions {
	return model.BuildOptions{
		ImageName:        imageName,
		Secrets:          buildSecrets,
		NoCache:          buildNoCache,
		SeparateWeights:  buildSeparateWeights,
		UseCudaBaseImage: buildUseCudaBaseImage,
		ProgressOutput:   buildProgressOutput,
		SchemaFile:       buildSchemaFile,
		DockerfileFile:   buildDockerfileFile,
		UseCogBaseImage:  DetermineUseCogBaseImage(cmd),
		Strip:            buildStrip,
		Precompile:       buildPrecompile,
		Annotations:      annotations,
		OCIIndex:         model.OCIIndexEnabled(),
	}
}


================================================
FILE: pkg/cli/debug.go
================================================
package cli

import (
	"fmt"
	"os"
	"path/filepath"

	"github.com/spf13/cobra"

	"github.com/replicate/cog/pkg/config"
	"github.com/replicate/cog/pkg/docker"
	"github.com/replicate/cog/pkg/dockerfile"
	"github.com/replicate/cog/pkg/registry"
	"github.com/replicate/cog/pkg/util/console"
)

var imageName string

func newDebugCommand() *cobra.Command {
	cmd := &cobra.Command{
		Use:    "debug",
		Hidden: true,
		Short:  "Generate a Dockerfile from cog",
		RunE:   cmdDockerfile,
	}

	addSeparateWeightsFlag(cmd)
	addUseCudaBaseImageFlag(cmd)
	addDockerfileFlag(cmd)
	addUseCogBaseImageFlag(cmd)
	addBuildTimestampFlag(cmd)
	addConfigFlag(cmd)
	cmd.Flags().StringVarP(&imageName, "image-name", "", "", "The image name to use for the generated Dockerfile")

	return cmd
}

func cmdDockerfile(cmd *cobra.Command, args []string) error {
	ctx := cmd.Context()

	// Find the root project directory
	rootDir, err := config.GetProjectDir(configFilename)
	if err != nil {
		return err
	}

	configPath := filepath.Join(rootDir, configFilename)

	f, err := os.Open(configPath)
	if err != nil {
		return &config.ParseError{Filename: configFilename, Err: err}
	}

	result, err := config.Load(f, rootDir)
	if err != nil {
		_ = f.Close()
		return err
	}

	_ = f.Close()

	var (
		cfg        = result.Config
		projectDir = result.RootDir
	)

	// Display any deprecation warnings
	for _, w := range result.Warnings {
		console.Warnf("%s", w.Error())
	}

	dockerClient, err := docker.NewClient(ctx)
	if err != nil {
		return err
	}

	client := registry.NewRegistryClient()
	generator, err := dockerfile.NewGenerator(cfg, projectDir, configFilename, dockerClient, client, true)
	if err != nil {
		return fmt.Errorf("Error creating Dockerfile generator: %w", err)
	}
	defer func() {
		if err := generator.Cleanup(); err != nil {
			console.Warnf("Error cleaning up after build: %v", err)
		}
	}()

	generator.SetUseCudaBaseImage(buildUseCudaBaseImage)
	useCogBaseImage := DetermineUseCogBaseImage(cmd)
	if useCogBaseImage != nil {
		generator.SetUseCogBaseImage(*useCogBaseImage)
	}

	if buildSeparateWeights {
		if imageName == "" {
			imageName = config.DockerImageName(projectDir)
		}

		weightsDockerfile, RunnerDockerfile, dockerignore, err := generator.GenerateModelBaseWithSeparateWeights(ctx, imageName)
		if err != nil {
			return err
		}

		console.Output(fmt.Sprintf("=== Weights Dockerfile contents:\n%s\n===\n", weightsDockerfile))
		console.Output(fmt.Sprintf("=== Runner Dockerfile contents:\n%s\n===\n", RunnerDockerfile))
		console.Output(fmt.Sprintf("=== DockerIgnore contents:\n%s===\n", dockerignore))
	} else {
		dockerfile, err := generator.GenerateDockerfileWithoutSeparateWeights(ctx)
		if err != nil {
			return err
		}

		console.Output(dockerfile)
	}

	return nil
}


================================================
FILE: pkg/cli/init-templates/base/.dockerignore
================================================
# The .dockerignore file excludes files from the container build process.
#
# https://docs.docker.com/engine/reference/builder/#dockerignore-file

# Exclude Git files
**/.git
**/.github
**/.gitignore

# Exclude Python tooling
.python-version

# Exclude Python cache files
__pycache__
.mypy_cache
.pytest_cache
.ruff_cache

# Exclude Python virtual environment
/venv


================================================
FILE: pkg/cli/init-templates/base/.github/workflows/push.yaml
================================================
name: Push to Replicate

on:
  # Workflow dispatch allows you to manually trigger the workflow from GitHub.com
  # Go to your repo, click "Actions", click "Push to Replicate", click "Run workflow"
  workflow_dispatch:
    inputs:
      model_name:
        description: 'Enter the model name, like "alice/bunny-detector". If unset, this will default to the value of `image` in cog.yaml.'
  # # Uncomment these lines to trigger the workflow on every push to the main branch
  # push:
  #   branches:
  #     - main

jobs:
  push_to_replicate:
    name: Push to Replicate

    # If your model is large, the default GitHub Actions runner may not
    # have enough disk space. If you need more space you can set up a
    # bigger runner on GitHub.
    runs-on: ubuntu-latest

    steps:
      # This action cleans up disk space to make more room for your
      # model code, weights, etc.
      - name: Free disk space
        uses: jlumbroso/free-disk-space@v1.3.1
        with:
          tool-cache: false
          docker-images: false

      - name: Checkout
        uses: actions/checkout@v4

      # This action installs Docker buildx and Cog (and optionally CUDA)
      - name: Setup Cog
        uses: replicate/setup-cog@v2
        with:
          # If you add a CI auth token to your GitHub repository secrets,
          # the action will authenticate with Replicate automatically so you
          # can push your model without needing to pass in a token.
          #
          # To genereate a CLI auth token, run `cog login` or visit this page
          # in your browser: https://replicate.com/account/api-token
          token: ${{ secrets.REPLICATE_CLI_AUTH_TOKEN }}

      # If you trigger the workflow manually, you can specify the model name.
      # If you leave it blank (or if the workflow is triggered by a push), the
      # model name will be derived from the `image` value in cog.yaml.
      - name: Push to Replicate
        run: |
          if [ -n "${{ inputs.model_name }}" ]; then
            cog push r8.im/${{ inputs.model_name }}
          else
            cog push
          fi


================================================
FILE: pkg/cli/init-templates/base/cog.yaml
================================================
# Configuration for Cog ⚙️
# Reference: https://cog.run/yaml

build:
  # set to true if your model requires a GPU
  gpu: false

  # a list of ubuntu apt packages to install
  # system_packages:
  #   - "libgl1-mesa-glx"
  #   - "libglib2.0-0"

  # python version in the form '3.11' or '3.11.4'
  python_version: "3.13"

  # path to a Python requirements.txt file
  python_requirements: requirements.txt

  # commands run after the environment is setup
  # run:
  #   - "echo env is ready!"
  #   - "echo another command if needed"

# predict.py defines how predictions are run on your model
predict: "predict.py:Predictor"


================================================
FILE: pkg/cli/init-templates/base/predict.py
================================================
# Prediction interface for Cog ⚙️
# https://cog.run/python

from cog import BasePredictor, Input, Path


class Predictor(BasePredictor):
    def setup(self) -> None:
        """Load the model into memory to make running multiple predictions efficient"""
        # self.model = torch.load("./weights.pth")

    def predict(
        self,
        image: Path = Input(description="Grayscale input image"),
        scale: float = Input(
            description="Factor to scale image by", ge=0, le=10, default=1.5
        ),
    ) -> Path:
        """Run a single prediction on the model"""
        # processed_input = preprocess(image)
        # output = self.model(processed_image, scale)
        # return postprocess(output)


================================================
FILE: pkg/cli/init-templates/base/requirements.txt
================================================
# This is a normal Python requirements.txt file.

# You can add dependencies directly from PyPI:
# 
# numpy==1.26.4
# torch==2.2.1
# torchvision==0.17.1


# You can also add Git repos as dependencies, but you'll need to add git to the system_packages list in cog.yaml:
# 
# build:
#   system_packages:
#     - "git"
# 
# Then you can use a URL like this:
# 
# git+https://github.com/huggingface/transformers


# You can also pin Git repos to a specific commit:
# 
# git+https://github.com/huggingface/transformers@2d1602a


================================================
FILE: pkg/cli/init.go
================================================
package cli

import (
	"embed"
	"fmt"
	"io"
	"net/http"
	"os"
	"path"
	"time"

	"github.com/spf13/cobra"

	"github.com/replicate/cog/pkg/util/console"
	"github.com/replicate/cog/pkg/util/files"
)

//go:embed init-templates/**/*
var initTemplates embed.FS

func newInitCommand() *cobra.Command {
	var cmd = &cobra.Command{
		Use:        "init",
		SuggestFor: []string{"new", "start"},
		Short:      "Configure your project for use with Cog",
		Long: `Create a cog.yaml and predict.py in the current directory.

These files provide a starting template for defining your model's environment
and prediction interface. Edit them to match your model's requirements.`,
		Example: `  # Set up a new Cog project in the current directory
  cog init`,
		RunE: initCommand,
		Args: cobra.MaximumNArgs(0),
	}

	return cmd
}

func initCommand(cmd *cobra.Command, args []string) error {
	console.Info("Setting up the current directory for use with Cog...")
	console.Info("")

	cwd, err := os.Getwd()
	if err != nil {
		return err
	}

	initTemplate := "base"

	// Discover all files in the embedded template directory
	templateDir := path.Join("init-templates", initTemplate)
	entries, err := initTemplates.ReadDir(templateDir)
	if err != nil {
		return fmt.Errorf("Error reading template directory: %w", err)
	}

	for _, entry := range entries {
		if entry.IsDir() {
			// Recursively process subdirectories
			if err := processTemplateDirectory(initTemplates, templateDir, entry.Name(), cwd); err != nil {
				return err
			}
			continue
		}

		// Process individual files
		if err := processTemplateFile(initTemplates, templateDir, entry.Name(), cwd); err != nil {
			return err
		}
	}

	console.Successf("\nDone! For next steps, check out the docs at https://cog.run/getting-started")

	return nil
}

func processTemplateDirectory(fs embed.FS, templateDir, subDir, cwd string) error {
	subDirPath := path.Join(templateDir, subDir)
	entries, err := fs.ReadDir(subDirPath)
	if err != nil {
		return fmt.Errorf("Error reading subdirectory %s: %w", subDirPath, err)
	}

	for _, entry := range entries {
		if entry.IsDir() {
			// Recursively process nested subdirectories
			if err := processTemplateDirectory(fs, subDirPath, entry.Name(), cwd); err != nil {
				return err
			}
			continue
		}

		// Process files in subdirectories
		relativePath := path.Join(subDir, entry.Name())
		if err := processTemplateFile(fs, templateDir, relativePath, cwd); err != nil {
			return err
		}
	}

	return nil
}

func processTemplateFile(fs embed.FS, templateDir, filename, cwd string) error {
	filePath := path.Join(cwd, filename)
	fileExists, err := files.Exists(filePath)
	if err != nil {
		return fmt.Errorf("Error checking if %s exists: %w", filePath, err)
	}

	if fileExists {
		console.Infof("Skipped existing %s", filename)
		return nil
	}

	dirPath := path.Dir(filePath)
	if err := os.MkdirAll(dirPath, os.ModePerm); err != nil {
		return fmt.Errorf("Error creating directory %s: %w", dirPath, err)
	}

	var content []byte

	// Special handling for specific template files
	switch filename {
	case "AGENTS.md":
		// Try to download from Replicate docs
		downloadedContent, err := downloadAgentsFile()
		if err != nil {
			console.Infof("Failed to download AGENTS.md: %v", err)
			console.Infof("Using template version instead...")
			// Fall back to template version
			content, err = fs.ReadFile(path.Join(templateDir, filename))
			if err != nil {
				return fmt.Errorf("Error reading template %s: %w", filename, err)
			}
		} else {
			content = downloadedContent
		}
	default:
		// Regular template file processing
		content, err = fs.ReadFile(path.Join(templateDir, filename))
		if err != nil {
			return fmt.Errorf("Error reading %s: %w", filename, err)
		}
	}

	console.Infof("Creating %s", console.Bold(filename))

	if err := os.WriteFile(filePath, content, 0o644); err != nil {
		return fmt.Errorf("Error writing %s: %w", filePath, err)
	}
	return nil
}

func downloadAgentsFile() ([]byte, error) {
	const agentsURL = "https://replicate.com/docs/reference/cog/llms.txt"

	client := &http.Client{
		Timeout: 10 * time.Second,
	}

	resp, err := client.Get(agentsURL)
	if err != nil {
		return nil, fmt.Errorf("%w", err)
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusOK {
		return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
	}

	content, err := io.ReadAll(resp.Body)
	if err != nil {
		return nil, fmt.Errorf("failed to read response body: %w", err)
	}

	return content, nil
}


================================================
FILE: pkg/cli/init_test.go
================================================
package cli

import (
	"os"
	"path"
	"testing"

	"github.com/stretchr/testify/require"
)

func TestInit(t *testing.T) {
	dir := t.TempDir()

	require.NoError(t, os.Chdir(dir))

	err := initCommand(nil, []string{})
	require.NoError(t, err)

	require.FileExists(t, path.Join(dir, ".dockerignore"))
	require.FileExists(t, path.Join(dir, "cog.yaml"))
	require.FileExists(t, path.Join(dir, "predict.py"))
	require.FileExists(t, path.Join(dir, "requirements.txt"))
}

func TestInitSkipExisting(t *testing.T) {
	dir := t.TempDir()

	require.NoError(t, os.Chdir(dir))

	// First run to create files
	err := initCommand(nil, []string{})
	require.NoError(t, err)

	require.FileExists(t, path.Join(dir, ".dockerignore"))
	require.FileExists(t, path.Join(dir, "cog.yaml"))
	require.FileExists(t, path.Join(dir, "predict.py"))

	// update the file to show that its the same file after the second run
	require.NoError(t, os.WriteFile(path.Join(dir, "cog.yaml"), []byte("test123"), 0o644))
	require.NoError(t, os.WriteFile(path.Join(dir, "predict.py"), []byte("test456"), 0o644))
	require.NoError(t, os.WriteFile(path.Join(dir, ".dockerignore"), []byte("test789"), 0o644))

	// Second run should skip the files that already exist
	err = initCommand(nil, []string{})
	require.NoError(t, err)

	require.FileExists(t, path.Join(dir, ".dockerignore"))
	require.FileExists(t, path.Join(dir, "cog.yaml"))
	require.FileExists(t, path.Join(dir, "predict.py"))

	// check that the files are the same as the first run
	content, err := os.ReadFile(path.Join(dir, "cog.yaml"))
	require.NoError(t, err)
	require.Equal(t, []byte("test123"), content)

	content, err = os.ReadFile(path.Join(dir, "predict.py"))
	require.NoError(t, err)
	require.Equal(t, []byte("test456"), content)

	content, err = os.ReadFile(path.Join(dir, ".dockerignore"))
	require.NoError(t, err)
	require.Equal(t, []byte("test789"), content)
}


================================================
FILE: pkg/cli/inspect.go
================================================
package cli

import (
	"context"
	"encoding/json"
	"fmt"
	"os"
	"strings"

	"github.com/spf13/cobra"

	"github.com/replicate/cog/pkg/docker"
	"github.com/replicate/cog/pkg/model"
	"github.com/replicate/cog/pkg/registry"
)

// InspectOutput is the structured output for cog inspect --json.
type InspectOutput struct {
	Reference  string           `json:"reference"`
	Type       string           `json:"type"` // "image" or "index"
	CogVersion string           `json:"cogVersion"`
	Index      *InspectIndex    `json:"index,omitempty"`
	Image      *InspectManifest `json:"image,omitempty"`
}

// InspectIndex represents an OCI index in inspect output.
type InspectIndex struct {
	Reference string            `json:"reference"`
	Digest    string            `json:"digest"`
	MediaType string            `json:"mediaType"`
	Manifests []InspectManifest `json:"manifests"`
}

// InspectManifest represents a manifest entry in inspect output.
type InspectManifest struct {
	Type        string            `json:"type"`           // "image" or "weights"
	Name        string            `json:"name,omitempty"` // weight name from AnnotationWeightName
	Digest      string            `json:"digest"`
	MediaType   string            `json:"mediaType"`
	Size        int64             `json:"size"`
	Platform    string            `json:"platform,omitempty"` // "linux/amd64"
	Target      string            `json:"target,omitempty"`   // weight mount path from AnnotationWeightDest
	Annotations map[string]string `json:"annotations,omitempty"`
	Layers      []InspectLayer    `json:"layers"`
}

// InspectLayer represents a layer in inspect output.
type InspectLayer struct {
	Digest    string `json:"digest"`
	Size      int64  `json:"size"`
	MediaType string `json:"mediaType"`
}

func newInspectCommand() *cobra.Command {
	var (
		localOnly  bool
		remoteOnly bool
		jsonOutput bool
		rawOutput  bool
	)

	cmd := &cobra.Command{
		Use:    "inspect ",
		Short:  "Inspect a model image or OCI index",
		Args:   cobra.ExactArgs(1),
		Hidden: true,
		RunE: func(cmd *cobra.Command, args []string) error {
			return inspectCommand(cmd, args, localOnly, remoteOnly, jsonOutput, rawOutput)
		},
	}

	cmd.Flags().BoolVar(&localOnly, "local", false, "Only inspect local docker daemon")
	cmd.Flags().BoolVar(&remoteOnly, "remote", false, "Only inspect remote registry")
	cmd.Flags().BoolVar(&jsonOutput, "json", false, "Output as JSON")
	cmd.Flags().BoolVar(&rawOutput, "raw", false, "Output raw JSON fragments (one per line)")

	return cmd
}

func inspectCommand(cmd *cobra.Command, args []string, localOnly, remoteOnly, jsonOutput, rawOutput bool) error {
	ctx := cmd.Context()

	ref, err := model.ParseRef(args[0])
	if err != nil {
		return err
	}

	dockerClient, err := docker.NewClient(ctx)
	if err != nil {
		return err
	}

	regClient := registry.NewRegistryClient()
	resolver := model.NewResolver(dockerClient, regClient)

	// Build resolve options
	var opts []model.Option
	switch {
	case localOnly:
		opts = append(opts, model.LocalOnly())
	case remoteOnly:
		opts = append(opts, model.RemoteOnly())
	}

	m, err := resolver.Inspect(ctx, ref, opts...)
	if err != nil {
		return err
	}

	// Build output
	out, err := buildInspectOutput(ctx, ref.String(), m, regClient)
	if err != nil {
		return err
	}

	switch {
	case rawOutput:
		return streamRaw(ctx, ref.String(), m, regClient)
	case jsonOutput:
		enc := json.NewEncoder(os.Stdout)
		enc.SetIndent("", "  ")
		return enc.Encode(out)
	default:
		printInspectText(out)
		return nil
	}
}

func buildInspectOutput(ctx context.Context, reference string, m *model.Model, reg registry.Client) (*InspectOutput, error) {
	out := &InspectOutput{
		Reference:  reference,
		CogVersion: m.CogVersion,
	}

	if m.Index != nil {
		out.Type = "index"
		idx := &InspectIndex{
			Reference: m.Index.Reference,
			Digest:    m.Index.Digest,
			MediaType: m.Index.MediaType,
		}

		for _, im := range m.Index.Manifests {
			manifest := buildManifestEntry(im)

			// Try to fetch layers from registry
			layers, err := fetchLayers(ctx, reference, im.Digest, reg)
			if err == nil {
				manifest.Layers = layers
			}

			idx.Manifests = append(idx.Manifests, manifest)
		}

		out.Index = idx
	} else {
		out.Type = "image"
		if m.Image != nil {
			manifest := &InspectManifest{
				Type:   "image",
				Digest: m.Image.Digest,
			}
			if m.Image.Platform != nil {
				parts := []string{m.Image.Platform.OS, m.Image.Platform.Architecture}
				if m.Image.Platform.Variant != "" {
					parts = append(parts, m.Image.Platform.Variant)
				}
				manifest.Platform = strings.Join(parts, "/")
			}

			// Try to fetch layers
			if m.Image.Digest != "" {
				layers, err := fetchLayers(ctx, reference, m.Image.Digest, reg)
				if err == nil {
					manifest.Layers = layers
				}
			}

			out.Image = manifest
		}
	}

	return out, nil
}

func buildManifestEntry(im model.IndexManifest) InspectManifest {
	manifest := InspectManifest{
		Digest:      im.Digest,
		MediaType:   im.MediaType,
		Size:        im.Size,
		Annotations: im.Annotations,
	}

	switch im.Type {
	case model.ManifestTypeWeights:
		manifest.Type = "weights"
		manifest.Name = im.Annotations[model.AnnotationWeightName]
		manifest.Target = im.Annotations[model.AnnotationWeightDest]
	default:
		manifest.Type = "image"
		if im.Platform != nil {
			parts := []string{im.Platform.OS, im.Platform.Architecture}
			if im.Platform.Variant != "" {
				parts = append(parts, im.Platform.Variant)
			}
			manifest.Platform = strings.Join(parts, "/")
		}
	}

	return manifest
}

func fetchLayers(ctx context.Context, reference, digest string, reg registry.Client) ([]InspectLayer, error) {
	// Build a digest reference from the repo
	ref, err := model.ParseRef(reference)
	if err != nil {
		return nil, err
	}
	digestRef := ref.Ref.Context().String() + "@" + digest

	img, err := reg.GetImage(ctx, digestRef, nil)
	if err != nil {
		return nil, err
	}

	manifest, err := img.Manifest()
	if err != nil {
		return nil, err
	}

	var layers []InspectLayer
	for _, l := range manifest.Layers {
		layers = append(layers, InspectLayer{
			Digest:    l.Digest.String(),
			Size:      l.Size,
			MediaType: string(l.MediaType),
		})
	}

	return layers, nil
}

type rawStep struct {
	Step     string `json:"step"`
	Data     any    `json:"data,omitempty"`
	Manifest any    `json:"manifest,omitempty"`
}

func streamRaw(ctx context.Context, reference string, m *model.Model, reg registry.Client) error {
	enc := json.NewEncoder(os.Stdout)

	// Step 1: resolve
	_ = enc.Encode(rawStep{
		Step: "resolve",
		Data: map[string]any{
			"reference":  reference,
			"cogVersion": m.CogVersion,
			"type": func() string {
				if m.Index != nil {
					return "index"
				}
				return "image"
			}(),
		},
	})

	if m.Index != nil {
		// Step 2: index
		_ = enc.Encode(rawStep{
			Step: "index",
			Data: map[string]any{
				"digest":    m.Index.Digest,
				"mediaType": m.Index.MediaType,
				"count":     len(m.Index.Manifests),
			},
		})

		// Step 3: per-child manifests
		for _, im := range m.Index.Manifests {
			entry := buildManifestEntry(im)

			ref, err := model.ParseRef(reference)
			if err == nil {
				digestRef := ref.Ref.Context().String() + "@" + im.Digest
				img, err := reg.GetImage(ctx, digestRef, nil)
				if err == nil {
					rawManifest, err := img.RawManifest()
					if err == nil {
						var parsed any
						if jsonErr := json.Unmarshal(rawManifest, &parsed); jsonErr == nil {
							_ = enc.Encode(rawStep{
								Step:     "manifest",
								Data:     entry,
								Manifest: parsed,
							})
							continue
						}
					}
				}
			}

			// Fallback: output without raw manifest
			_ = enc.Encode(rawStep{
				Step: "manifest",
				Data: entry,
			})
		}
	}

	// Final step: model summary
	_ = enc.Encode(rawStep{
		Step: "model",
		Data: map[string]any{
			"reference":  reference,
			"cogVersion": m.CogVersion,
		},
	})

	return nil
}

func printInspectText(out *InspectOutput) {
	fmt.Printf("Model: %s\n", out.Reference)
	if out.Type == "index" {
		fmt.Println("Type:  Model Bundle (OCI Index)")
	} else {
		fmt.Println("Type:  Image")
	}
	fmt.Printf("Cog:   %s\n", out.CogVersion)
	fmt.Println()

	if out.Index != nil {
		// Build the digest reference: repo@sha256:...
		digestRef := out.Index.Digest
		if out.Index.Reference != "" && out.Index.Digest != "" {
			// Extract repo from the reference (strip tag/digest)
			repo := out.Index.Reference
			if idx := strings.LastIndex(repo, ":"); idx != -1 {
				// Only strip if it looks like a tag (no @)
				if !strings.Contains(repo[idx:], "@") {
					repo = repo[:idx]
				}
			}
			digestRef = repo + "@" + out.Index.Digest
		}
		fmt.Printf("Index: %s\n", digestRef)
		fmt.Printf("  Tag:       %s\n", out.Reference)
		fmt.Printf("  Digest:    %s\n", out.Index.Digest)
		fmt.Printf("  MediaType: %s\n", out.Index.MediaType)
		fmt.Printf("  Manifests: %d\n", len(out.Index.Manifests))
		fmt.Println()

		for _, m := range out.Index.Manifests {
			printManifestText(m, "  ")
			fmt.Println()
		}
	} else if out.Image != nil {
		printManifestText(*out.Image, "")
	}
}

func printManifestText(m InspectManifest, indent string) {
	if m.Type == "weights" {
		name := m.Name
		if name == "" {
			name = "(unnamed)"
		}
		fmt.Printf("%s[weights] %s\n", indent, name)
	} else {
		platform := m.Platform
		if platform == "" {
			platform = "(unknown)"
		}
		fmt.Printf("%s[image] %s\n", indent, platform)
	}

	fmt.Printf("%s  Digest: %s\n", indent, m.Digest)

	// Show manifest size + total layer size if layers are available
	if len(m.Layers) > 0 {
		var layerTotal int64
		for _, l := range m.Layers {
			layerTotal += l.Size
		}
		fmt.Printf("%s  Size:   %s (Layers: %s)\n", indent, formatSize(m.Size), formatSize(layerTotal))
	} else {
		fmt.Printf("%s  Size:   %s\n", indent, formatSize(m.Size))
	}

	if m.Target != "" {
		fmt.Printf("%s  Target: %s\n", indent, m.Target)
	}

	if m.MediaType != "" {
		fmt.Printf("%s  Type:   %s\n", indent, m.MediaType)
	}

	if len(m.Layers) > 0 {
		fmt.Printf("%s  Layers: %d\n", indent, len(m.Layers))
		for _, l := range m.Layers {
			fmt.Printf("%s    %s  %s  %s\n", indent, l.Digest, formatSize(l.Size), l.MediaType)
		}
	}
}


================================================
FILE: pkg/cli/login.go
================================================
package cli

import (
	"github.com/spf13/cobra"

	"github.com/replicate/cog/pkg/global"
	"github.com/replicate/cog/pkg/provider"
	"github.com/replicate/cog/pkg/provider/setup"
	"github.com/replicate/cog/pkg/util/console"
)

func newLoginCommand() *cobra.Command {
	cmd := &cobra.Command{
		Use:        "login",
		SuggestFor: []string{"auth", "authenticate", "authorize"},
		Short:      "Log in to a container registry",
		Long: `Log in to a container registry.

For Replicate's registry (r8.im), this command handles authentication
through Replicate's token-based flow.

For other registries, this command prompts for username and password,
then stores credentials using Docker's credential system.`,
		RunE: login,
		Args: cobra.MaximumNArgs(0),
	}

	cmd.Flags().Bool("token-stdin", false, "Pass login token on stdin instead of opening a browser. You can find your Replicate login token at https://replicate.com/auth/token")

	return cmd
}

func login(cmd *cobra.Command, args []string) error {
	ctx := cmd.Context()

	// Initialize the provider registry
	setup.Init()

	// Use global registry host (can be set via --registry flag or COG_REGISTRY_HOST env var)
	registryHost := global.ReplicateRegistryHost

	tokenStdin, err := cmd.Flags().GetBool("token-stdin")
	if err != nil {
		return err
	}

	// Look up the provider for this registry
	p := provider.DefaultRegistry().ForHost(registryHost)
	if p == nil {
		// This shouldn't happen since GenericProvider matches everything
		console.Warnf("No provider found for registry '%s'.", registryHost)
		console.Infof("Please use 'docker login %s' to authenticate.", registryHost)
		return nil
	}

	return p.Login(ctx, provider.LoginOptions{
		TokenStdin: tokenStdin,
		Host:       registryHost,
	})
}


================================================
FILE: pkg/cli/predict.go
================================================
package cli

import (
	"bytes"
	"context"
	"encoding/base64"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"os"
	"os/signal"
	"path"
	"path/filepath"
	"strings"
	"syscall"
	"time"

	"github.com/getkin/kin-openapi/openapi3"
	"github.com/mitchellh/go-homedir"
	"github.com/spf13/cobra"
	"golang.org/x/sys/unix"

	"github.com/replicate/cog/pkg/docker"
	"github.com/replicate/cog/pkg/docker/command"
	"github.com/replicate/cog/pkg/model"
	r8_path "github.com/replicate/cog/pkg/path"
	"github.com/replicate/cog/pkg/predict"
	"github.com/replicate/cog/pkg/registry"
	"github.com/replicate/cog/pkg/util/console"
	"github.com/replicate/cog/pkg/util/files"
	"github.com/replicate/cog/pkg/util/mime"
)

const StdinPath = "-"

var (
	envFlags             []string
	inputFlags           []string
	outPath              string
	setupTimeout         uint32
	useReplicateAPIToken bool
	inputJSON            string
)

func newPredictCommand() *cobra.Command {
	cmd := &cobra.Command{
		Use:   "predict [image]",
		Short: "Run a prediction",
		Long: `Run a prediction.

If 'image' is passed, it will run the prediction on that Docker image.
It must be an image that has been built by Cog.

Otherwise, it will build the model in the current directory and run
the prediction on that.`,
		Example: `  # Run a prediction with named inputs
  cog predict -i prompt="a photo of a cat"

  # Pass a file as input
  cog predict -i image=@photo.jpg

  # Save output to a file
  cog predict -i image=@input.jpg -o output.png

  # Pass multiple inputs
  cog predict -i prompt="sunset" -i width=1024 -i height=768

  # Run against a pre-built image
  cog predict r8.im/your-username/my-model -i prompt="hello"

  # Pass inputs as JSON
  echo '{"prompt": "a cat"}' | cog predict --json @-`,
		RunE:       cmdPredict,
		Args:       cobra.MaximumNArgs(1),
		SuggestFor: []string{"infer"},
	}

	addUseCudaBaseImageFlag(cmd)
	addUseCogBaseImageFlag(cmd)
	addBuildProgressOutputFlag(cmd)
	addDockerfileFlag(cmd)
	addGpusFlag(cmd)
	addSetupTimeoutFlag(cmd)
	addConfigFlag(cmd)

	cmd.Flags().StringArrayVarP(&inputFlags, "input", "i", []string{}, "Inputs, in the form name=value. if value is prefixed with @, then it is read from a file on disk. E.g. -i path=@image.jpg")
	cmd.Flags().StringVarP(&outPath, "output", "o", "", "Output path")
	cmd.Flags().StringArrayVarP(&envFlags, "env", "e", []string{}, "Environment variables, in the form name=value")
	cmd.Flags().BoolVar(&useReplicateAPIToken, "use-replicate-token", false, "Pass REPLICATE_API_TOKEN from local environment into the model context")
	cmd.Flags().StringVar(&inputJSON, "json", "", "Pass inputs as JSON object, read from file (@inputs.json) or via stdin (@-)")

	return cmd
}

func readStdin() (string, error) {
	// Read from stdin
	data, err := io.ReadAll(os.Stdin)
	if err != nil {
		return "", fmt.Errorf("Failed to read JSON from stdin: %w", err)
	}
	return string(data), nil
}

func parseJSONInput(jsonInput string) (map[string]any, error) {
	var jsonStr string

	switch {
	case strings.HasPrefix(jsonInput, "@"):
		// Read from file or stdin
		source := jsonInput[1:]

		if source == StdinPath {
			jsonStdinStr, err := readStdin()
			if err != nil {
				return nil, err
			}
			jsonStr = jsonStdinStr
		} else {
			// Read from file
			data, err := os.ReadFile(source)
			if err != nil {
				return nil, fmt.Errorf("Failed to read JSON from file %q: %w", source, err)
			}
			jsonStr = string(data)
		}
	case jsonInput == StdinPath:
		jsonStdinStr, err := readStdin()
		if err != nil {
			return nil, err
		}
		jsonStr = jsonStdinStr
	default:
		// Direct JSON string
		jsonStr = jsonInput
	}

	var inputs map[string]any
	if err := json.Unmarshal([]byte(jsonStr), &inputs); err != nil {
		return nil, fmt.Errorf("Failed to parse JSON: %w", err)
	}

	return inputs, nil
}

func transformPathsToBase64URLs(inputs map[string]any) (map[string]any, error) {
	result := make(map[string]any)

	for key, value := range inputs {
		if strValue, ok := value.(string); ok && strings.HasPrefix(strValue, "@") {
			// This is a file path, convert to base64 data URL
			filePath := strValue[1:]

			// Read file
			data, err := os.ReadFile(filePath)
			if err != nil {
				return nil, fmt.Errorf("Failed to read file %q: %w", filePath, err)
			}

			// Get MIME type
			mimeType := mime.TypeByExtension(filepath.Ext(filePath))
			if mimeType == "" {
				mimeType = "application/octet-stream"
			}

			// Create base64 data URL
			base64Data := base64.StdEncoding.EncodeToString(data)
			dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data)

			result[key] = dataURL
		} else {
			result[key] = value
		}
	}

	return result, nil
}

func cmdPredict(cmd *cobra.Command, args []string) error {
	ctx := cmd.Context()

	dockerClient, err := docker.NewClient(ctx)
	if err != nil {
		return err
	}

	imageName := ""
	volumes := []command.Volume{}
	gpus := gpusFlag

	resolver := model.NewResolver(dockerClient, registry.NewRegistryClient())

	if len(args) == 0 {
		// Build image
		src, err := model.NewSource(configFilename)
		if err != nil {
			return err
		}

		console.Info("Building Docker image from environment in cog.yaml...")
		console.Info("")
		m, err := resolver.Build(ctx, src, serveBuildOptions(cmd))
		if err != nil {
			return err
		}
		imageName = m.ImageRef()

		// ExcludeSource build doesn't have /src in it, so mount as volume
		volumes = append(volumes, command.Volume{
			Source:      src.ProjectDir,
			Destination: "/src",
		})

		if gpus == "" && m.HasGPU() {
			gpus = "all"
		}
	} else {
		// Use existing image
		imageName = args[0]

		// If the image name contains '=', then it's probably a mistake
		if strings.Contains(imageName, "=") {
			return fmt.Errorf("Invalid image name '%s'. Did you forget `-i`?", imageName)
		}

		// Pull the image (if needed) and validate it's a Cog model
		ref, err := model.ParseRef(imageName)
		if err != nil {
			return err
		}
		m, err := resolver.Pull(ctx, ref)
		if err != nil {
			return err
		}

		if gpus == "" && m.HasGPU() {
			gpus = "all"
		}
	}

	console.Info("")
	console.Info("Starting Docker image and running setup()...")

	// Automatically propagate RUST_LOG for Rust coglet debugging
	env := envFlags
	if rustLog := os.Getenv("RUST_LOG"); rustLog != "" {
		env = append(env, "RUST_LOG="+rustLog)
	}

	predictor, err := predict.NewPredictor(ctx, command.RunOptions{
		GPUs:    gpus,
		Image:   imageName,
		Volumes: volumes,
		Env:     env,
	}, false, dockerClient)
	if err != nil {
		return err
	}

	go func() {
		captureSignal := make(chan os.Signal, 1)
		signal.Notify(captureSignal, syscall.SIGINT)

		<-captureSignal

		console.Info("Stopping container...")
		if err := predictor.Stop(ctx); err != nil {
			console.Warnf("Failed to stop container: %s", err)
		}
	}()

	timeout := time.Duration(setupTimeout) * time.Second
	if err := predictor.Start(ctx, os.Stderr, timeout); err != nil {
		// Only retry if we're using a GPU but the user didn't explicitly select a GPU with --gpus
		// If the user specified the wrong GPU, they are explicitly selecting a GPU and they'll want to hear about it
		if gpus == "all" && errors.Is(err, docker.ErrMissingDeviceDriver) {
			console.Info("Missing device driver, re-trying without GPU")

			_ = predictor.Stop(ctx)
			predictor, err = predict.NewPredictor(ctx, command.RunOptions{
				Image:   imageName,
				Volumes: volumes,
				Env:     env,
			}, false, dockerClient)
			if err != nil {
				return err
			}

			if err := predictor.Start(ctx, os.Stderr, timeout); err != nil {
				return err
			}
		} else {
			return err
		}
	}

	// FIXME: will not run on signal
	defer func() {
		console.Debugf("Stopping container...")
		// use background context to ensure stop signal is still sent after root context is canceled
		if err := predictor.Stop(context.Background()); err != nil {
			console.Warnf("Failed to stop container: %s", err)
		}
	}()

	if inputJSON != "" {
		if len(inputFlags) > 0 {
			return fmt.Errorf("Must use one of --json or --input to provide model inputs")
		}

		return predictJSONInputs(*predictor, inputJSON, outPath, false)
	}
	return predictIndividualInputs(*predictor, inputFlags, outPath, false)
}

func isURI(ref *openapi3.Schema) bool {
	return ref != nil && ref.Type.Is("string") && ref.Format == "uri"
}

func predictJSONInputs(predictor predict.Predictor, jsonInput string, outputPath string, isTrain bool) error {
	jsonInputs, err := parseJSONInput(jsonInput)
	if err != nil {
		return err
	}

	transformedInputs, err := transformPathsToBase64URLs(jsonInputs)
	if err != nil {
		return err
	}

	// Convert to predict.Inputs format
	inputs := make(predict.Inputs)
	for key, value := range transformedInputs {
		if strValue, ok := value.(string); ok {
			inputs[key] = predict.Input{String: &strValue}
		} else {
			// For non-string values, marshal to JSON
			jsonBytes, err := json.Marshal(value)
			if err != nil {
				return fmt.Errorf("Failed to marshal input %q to JSON: %w", key, err)
			}
			jsonRaw := json.RawMessage(jsonBytes)
			inputs[key] = predict.Input{Json: &jsonRaw}
		}
	}

	return runPrediction(predictor, inputs, outputPath, isTrain, true)
}

func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, outputPath string, isTrain bool) error {
	schema, err := predictor.GetSchema()
	if err != nil {
		return err
	}

	inputs, err := parseInputFlags(inputFlags, schema, isTrain)
	if err != nil {
		return err
	}

	return runPrediction(predictor, inputs, outputPath, isTrain, false)
}

func runPrediction(predictor predict.Predictor, inputs predict.Inputs, outputPath string, isTrain bool, needsJSON bool) error {
	if isTrain {
		console.Info("Running training...")
	} else {
		console.Info("Running prediction...")
	}
	console.Info("")

	// Generate output depending on type in schema
	url := "/predictions"
	if isTrain {
		url = "/trainings"
	}

	writeOutputToDisk := outputPath != ""
	fallbackPath := "output"
	if needsJSON {
		fallbackPath = "output.json"
	}

	outputPath, err := ensureOutputWriteable(strings.TrimPrefix(outputPath, "@"), fallbackPath)
	if err != nil {
		return fmt.Errorf("Output path is not writable: %w", err)
	}

	if needsJSON && !strings.HasSuffix(outputPath, ".json") {
		console.Warnf("--output value does not have a .json suffix: %s", path.Base(outputPath))
	}

	context := predict.RequestContext{}

	if useReplicateAPIToken {
		context.ReplicateAPIToken = os.Getenv("REPLICATE_API_TOKEN")
		if context.ReplicateAPIToken == "" {
			return fmt.Errorf("Failed to find REPLICATE_API_TOKEN in the current environment when called with --use-replicate-token")
		}
	}

	prediction, err := predictor.Predict(inputs, context)
	if err != nil {
		return fmt.Errorf("Failed to run prediction: %w", err)
	}

	schema, err := predictor.GetSchema()
	if err != nil {
		return err
	}

	// Safely extract output schema with nil checks to avoid panics on malformed schemas
	var outputSchema *openapi3.Schema
	if pathItem := schema.Paths.Value(url); pathItem != nil {
		if pathItem.Post != nil {
			if resp := pathItem.Post.Responses.Value("200"); resp != nil && resp.Value != nil {
				if content, ok := resp.Value.Content["application/json"]; ok && content.Schema != nil {
					if content.Schema.Value != nil {
						if outputProp, ok := content.Schema.Value.Properties["output"]; ok && outputProp != nil {
							outputSchema = outputProp.Value
						}
					}
				}
			}
		}
	}
	if outputSchema == nil {
		return fmt.Errorf("invalid OpenAPI schema: missing output definition for %s", url)
	}

	fileOutputPath := outputPath
	if needsJSON {
		// Strip the suffix when in JSON mode.
		fileOutputPath = r8_path.TrimExt(fileOutputPath)
	}

	if prediction.Status == "succeeded" && prediction.Output != nil {
		transformed, err := processFileOutputs(*prediction.Output, outputSchema, fileOutputPath)
		if err != nil {
			return err
		}
		prediction.Output = &transformed
	}

	if needsJSON {
		rawJSON, err := json.Marshal(prediction)
		if err != nil {
			return fmt.Errorf("Failed to encode prediction output as JSON: %w", err)
		}
		var indentedJSON bytes.Buffer
		if err := json.Indent(&indentedJSON, rawJSON, "", "  "); err != nil {
			return err
		}

		if writeOutputToDisk {
			path, err := files.WriteFile(indentedJSON.Bytes(), outputPath)
			if err != nil {
				return fmt.Errorf("Failed to write output: %w", err)
			}
			console.Infof("Written output to: %s", path)
		} else {
			console.Output(indentedJSON.String())
		}

		// Exit with non-zero code if the prediction has failed.
		if prediction.Status != "succeeded" {
			os.Exit(1)
		}

		return nil
	}

	if prediction.Status != "succeeded" {
		return fmt.Errorf("Prediction failed with status %q: %s", prediction.Status, prediction.Error)
	}

	if prediction.Output == nil {
		console.Warn("No output generated")
		return nil
	}

	// Handle default presentation of output types.
	// 1. For Path and list[Path] do nothing. We already print info for each file write.
	// 2. For everything else we want to print the raw value.
	switch {
	case isURI(outputSchema):
		return nil
	case outputSchema.Type.Is("array") && isURI(outputSchema.Items.Value):
		return nil
	case outputSchema.Type.Is("string"):
		// Output the raw string.
		s, ok := (*prediction.Output).(string)
		if !ok {
			return fmt.Errorf("Failed to convert prediction output to string")
		}

		if writeOutputToDisk {
			path, err := files.WriteFile([]byte(s), outputPath)
			if err != nil {
				return fmt.Errorf("Failed to write output: %w", err)
			}
			console.Infof("Written output to: %s", path)
		} else {
			console.Output(s)
		}

		return nil
	default:
		// Treat everything else as JSON -- ints, floats, bools will all be presented
		// as raw values. Lists and objects will be pretty printed JSON.
		output, err := prettyJSONMarshal(prediction.Output)
		if err != nil {
			return err
		}

		// No special handling for needsJSON here.
		if writeOutputToDisk {
			path, err := files.WriteFile(output, outputPath)
			if err != nil {
				return fmt.Errorf("Failed to write output: %w", err)
			}
			console.Infof("Written output to: %s", path)
		} else {
			console.Output(string(output))
		}

		return nil
	}
}

// Ensures the path (or fallback) provided is writable. Returns path, error
func ensureOutputWriteable(outputPath string, fallbackPath string) (string, error) {
	// If no outputPath is provided use fallback path and track.
	usingFallback := false
	if outputPath == "" {
		outputPath = fallbackPath
		usingFallback = true
	}

	outputPath, err := homedir.Expand(outputPath)
	if err != nil {
		return "", err
	}

	stat, err := os.Stat(outputPath)

	// If the file doesn't exist, use the parent directory with given filename.
	if os.IsNotExist(err) {
		if err = unix.Access(path.Dir(outputPath), unix.W_OK); err != nil {
			return "", fmt.Errorf("Output directory is not writable: %s", path.Dir(outputPath))
		}
		return outputPath, nil
	} else if err != nil {
		return "", fmt.Errorf("Unexpected error checking output path: %w", err)
	}

	// If a directory was provided, use that with the fallback filename
	if stat.IsDir() {
		// If the fallback path already exists as a directory error.
		if usingFallback {
			return "", fmt.Errorf("Default output name %q conflicts with directory, provide --output", outputPath)
		}
		err := unix.Access(outputPath, unix.W_OK)
		if err != nil {
			return "", err
		}
		return path.Join(outputPath, path.Base(fallbackPath)), nil
	}

	if err = unix.Access(outputPath, unix.W_OK); err != nil {
		return "", err
	}

	return outputPath, nil
}

func prettyJSONMarshal(v any) ([]byte, error) {
	raw, err := json.Marshal(v)
	if err != nil {
		return []byte(""), fmt.Errorf("Failed to encode JSON: %w", err)
	}
	var formatted bytes.Buffer
	if err := json.Indent(&formatted, raw, "", "  "); err != nil {
		return []byte(""), err
	}
	return formatted.Bytes(), nil
}

func processFileOutputs(output any, schema *openapi3.Schema, destination string) (any, error) {
	// TODO: This doesn't currently support arbitrary objects.
	switch {
	case isURI(schema):
		outputStr, ok := output.(string)
		if !ok {
			return nil, fmt.Errorf("Failed to convert prediction output to string: %v", output)
		}

		path, err := files.WriteDataURLToFile(outputStr, destination)
		if err != nil {
			return nil, fmt.Errorf("Failed to write output: %w", err)
		}
		console.Infof("Written output to: %s", path)

		return any(path), nil
	case schema.Type.Is("array") && isURI(schema.Items.Value):
		outputs, ok := (output).([]any)
		if !ok {
			return nil, fmt.Errorf("Failed to decode output: %v", output)
		}

		clone := []any{}
		for i, output := range outputs {
			itemDestination := fmt.Sprintf("%s.%d%s", r8_path.TrimExt(destination), i, path.Ext(destination))
			item, err := processFileOutputs(output, schema.Items.Value, itemDestination)
			if err != nil {
				return nil, fmt.Errorf("Failed to write output %d: %w", i, err)
			}

			clone = append(clone, item)
		}

		return clone, nil
	}

	return output, nil
}

func parseInputFlags(inputs []string, schema *openapi3.T, isTrain ...bool) (predict.Inputs, error) {
	keyVals := map[string][]string{}
	for _, input := range inputs {
		var name, value string

		// Default input name is "input"
		if !strings.Contains(input, "=") {
			return nil, fmt.Errorf("Failed to parse input '%s', expected format is 'name=value'", input)
		}

		split := strings.SplitN(input, "=", 2)
		name = split[0]
		value = split[1]

		if strings.HasPrefix(value, `"`) && strings.HasSuffix(value, `"`) {
			value = value[1 : len(value)-1]
		}

		// Append new values to the slice associated with the key
		keyVals[name] = append(keyVals[name], value)
	}

	train := len(isTrain) > 0 && isTrain[0]
	return predict.NewInputsForMode(keyVals, schema, train)
}

func addSetupTimeoutFlag(cmd *cobra.Command) {
	cmd.Flags().Uint32Var(&setupTimeout, "setup-timeout", 5*60, "The timeout for a container to setup (in seconds).")
}


================================================
FILE: pkg/cli/predict_test.go
================================================
package cli

import (
	"testing"

	"github.com/getkin/kin-openapi/openapi3"
	"github.com/stretchr/testify/require"
)

func TestExtractOutputSchemaFromMalformedSchema(t *testing.T) {
	// Test that we don't panic when extracting output schema from malformed OpenAPI schemas
	testCases := []struct {
		name   string
		schema *openapi3.T
	}{
		{
			name:   "nil schema",
			schema: nil,
		},
		{
			name:   "empty schema",
			schema: &openapi3.T{},
		},
		{
			name: "schema with nil paths",
			schema: &openapi3.T{
				Paths: nil,
			},
		},
		{
			name: "schema with empty paths",
			schema: &openapi3.T{
				Paths: &openapi3.Paths{},
			},
		},
		{
			name: "schema with path but no post",
			schema: &openapi3.T{
				Paths: &openapi3.Paths{
					Extensions: map[string]any{},
				},
			},
		},
		{
			name: "schema with post but no responses",
			schema: func() *openapi3.T {
				s := &openapi3.T{
					Paths: openapi3.NewPaths(),
				}
				s.Paths.Set("/predictions", &openapi3.PathItem{
					Post: &openapi3.Operation{},
				})
				return s
			}(),
		},
		{
			name: "schema with response but no content",
			schema: func() *openapi3.T {
				s := &openapi3.T{
					Paths: openapi3.NewPaths(),
				}
				s.Paths.Set("/predictions", &openapi3.PathItem{
					Post: &openapi3.Operation{
						Responses: &openapi3.Responses{},
					},
				})
				return s
			}(),
		},
		{
			name: "schema with content but no output property",
			schema: func() *openapi3.T {
				s := &openapi3.T{
					Paths: openapi3.NewPaths(),
				}
				responses := openapi3.NewResponses()
				responses.Set("200", &openapi3.ResponseRef{
					Value: &openapi3.Response{
						Content: openapi3.Content{
							"application/json": &openapi3.MediaType{
								Schema: &openapi3.SchemaRef{
									Value: &openapi3.Schema{
										Properties: openapi3.Schemas{},
									},
								},
							},
						},
					},
				})
				s.Paths.Set("/predictions", &openapi3.PathItem{
					Post: &openapi3.Operation{
						Responses: responses,
					},
				})
				return s
			}(),
		},
	}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			// This should not panic - it should return an error or nil gracefully
			outputSchema := safeExtractOutputSchema(tc.schema, "/predictions")
			// We expect nil for all malformed schemas
			require.Nil(t, outputSchema, "expected nil output schema for malformed input")
		})
	}
}

// safeExtractOutputSchema extracts the output schema safely without panicking
func safeExtractOutputSchema(schema *openapi3.T, url string) *openapi3.Schema {
	if schema == nil || schema.Paths == nil {
		return nil
	}
	pathItem := schema.Paths.Value(url)
	if pathItem == nil || pathItem.Post == nil {
		return nil
	}
	if pathItem.Post.Responses == nil {
		return nil
	}
	resp := pathItem.Post.Responses.Value("200")
	if resp == nil || resp.Value == nil {
		return nil
	}
	content, ok := resp.Value.Content["application/json"]
	if !ok || content == nil || content.Schema == nil || content.Schema.Value == nil {
		return nil
	}
	outputProp, ok := content.Schema.Value.Properties["output"]
	if !ok || outputProp == nil {
		return nil
	}
	return outputProp.Value
}

func TestExtractOutputSchemaFromValidSchema(t *testing.T) {
	// Test that we correctly extract output schema from a valid OpenAPI schema
	s := &openapi3.T{
		Paths: openapi3.NewPaths(),
	}
	responses := openapi3.NewResponses()
	responses.Set("200", &openapi3.ResponseRef{
		Value: &openapi3.Response{
			Content: openapi3.Content{
				"application/json": &openapi3.MediaType{
					Schema: &openapi3.SchemaRef{
						Value: &openapi3.Schema{
							Properties: openapi3.Schemas{
								"output": &openapi3.SchemaRef{
									Value: &openapi3.Schema{
										Type: &openapi3.Types{"string"},
									},
								},
							},
						},
					},
				},
			},
		},
	})
	s.Paths.Set("/predictions", &openapi3.PathItem{
		Post: &openapi3.Operation{
			Responses: responses,
		},
	})

	outputSchema := safeExtractOutputSchema(s, "/predictions")
	require.NotNil(t, outputSchema, "expected non-nil output schema for valid input")
	require.Contains(t, outputSchema.Type.Slice(), "string", "expected string type")
}


================================================
FILE: pkg/cli/push.go
================================================
package cli

import (
	"fmt"

	"github.com/spf13/cobra"

	"github.com/replicate/go/uuid"

	"github.com/replicate/cog/pkg/docker"
	"github.com/replicate/cog/pkg/model"
	"github.com/replicate/cog/pkg/provider"
	"github.com/replicate/cog/pkg/provider/setup"
	"github.com/replicate/cog/pkg/registry"
	"github.com/replicate/cog/pkg/util/console"
)

func newPushCommand() *cobra.Command {
	cmd := &cobra.Command{
		Use:   "push [IMAGE]",
		Short: "Build and push model in current directory to a Docker registry",
		Long: `Build a Docker image from cog.yaml and push it to a container registry.

Cog can push to any OCI-compliant registry. When pushing to Replicate's
registry (r8.im), run 'cog login' first to authenticate.`,
		Example: `  # Push to Replicate
  cog push r8.im/your-username/my-model

  # Push to any OCI registry
  cog push registry.example.com/your-username/model-name

  # Push with model weights in a separate layer (Replicate only)
  cog push r8.im/your-username/my-model --separate-weights`,
		RunE: push,
		Args: cobra.MaximumNArgs(1),
	}
	addSecretsFlag(cmd)
	addNoCacheFlag(cmd)
	addSeparateWeightsFlag(cmd)
	addSchemaFlag(cmd)
	addUseCudaBaseImageFlag(cmd)
	addDockerfileFlag(cmd)
	addBuildProgressOutputFlag(cmd)
	addUseCogBaseImageFlag(cmd)
	addStripFlag(cmd)
	addPrecompileFlag(cmd)
	addConfigFlag(cmd)

	return cmd
}

func push(cmd *cobra.Command, args []string) error {
	ctx := cmd.Context()

	// Initialize the provider registry
	setup.Init()

	dockerClient, err := docker.NewClient(ctx)
	if err != nil {
		return err
	}

	src, err := model.NewSource(configFilename)
	if err != nil {
		return err
	}

	imageName := src.Config.Image
	if len(args) > 0 {
		imageName = args[0]
	}

	if imageName == "" {
		return fmt.Errorf("To push images, you must either set the 'image' option in cog.yaml or pass an image name as an argument. For example, 'cog push registry.example.com/your-username/model-name'")
	}

	// Look up the provider for the target registry
	p := provider.DefaultRegistry().ForImage(imageName)
	if p == nil {
		return fmt.Errorf("no provider found for image '%s'", imageName)
	}

	pushOpts := provider.PushOptions{
		Image:      imageName,
		Config:     src.Config,
		ProjectDir: src.ProjectDir,
	}

	// Build the image
	buildID, _ := uuid.NewV7()
	annotations := map[string]string{}
	if buildID.String() != "" {
		annotations["run.cog.push_id"] = buildID.String()
	}

	regClient := registry.NewRegistryClient()
	resolver := model.NewResolver(dockerClient, regClient)

	// Build the model
	console.Infof("Building Docker image from environment in cog.yaml as %s...", console.Bold(imageName))
	console.Info("")
	buildOpts := buildOptionsFromFlags(cmd, imageName, annotations)
	m, err := resolver.Build(ctx, src, buildOpts)
	if err != nil {
		// Call PostPush to handle error logging/analytics
		_ = p.PostPush(ctx, pushOpts, err)
		return err
	}

	// Log weights info
	weights := m.WeightArtifacts()
	if len(weights) > 0 {
		console.Infof("\n%d weight artifact(s)", len(weights))
	}

	// Push the model (image + optional weights)
	console.Infof("\nPushing image %s...", console.Bold(m.ImageRef()))

	// Set up progress display using Docker's jsonmessage rendering. This uses the
	// same cursor movement and progress display as `docker push`, which handles
	// terminal resizing correctly (each line is erased and rewritten individually,
	// rather than relying on a bulk cursor-up count that can desync on resize).
	pw := docker.NewProgressWriter()
	defer pw.Close()

	pushErr := resolver.Push(ctx, m, model.PushOptions{
		ImageProgressFn: func(prog model.PushProgress) {
			// Truncate digest for display: "sha256:abc123..." → "abc123..."
			displayDigest := prog.LayerDigest
			if len(displayDigest) > 7+12 { // "sha256:" + 12 hex chars
				displayDigest = displayDigest[7:19] + "..."
			}

			pw.Write(displayDigest, "Pushing", prog.Complete, prog.Total)
		},
		OnFallback: func() {
			// Close progress writer to finalize OCI progress bars before Docker
			// push starts its own output. Without this, stale OCI progress lines
			// remain on screen above Docker's progress output.
			pw.Close()
		},
	})

	pw.Close()

	// PostPush: the provider handles formatting errors and showing success messages
	if err := p.PostPush(ctx, pushOpts, pushErr); err != nil {
		return err
	}

	// If there was a push error but PostPush didn't return one,
	// return a generic error
	if pushErr != nil {
		return fmt.Errorf("failed to push image: %w", pushErr)
	}

	return nil
}


================================================
FILE: pkg/cli/root.go
================================================
package cli

import (
	"fmt"
	"os"

	"github.com/spf13/cobra"

	"github.com/replicate/cog/pkg/global"
	"github.com/replicate/cog/pkg/update"
	"github.com/replicate/cog/pkg/util/console"
)

func NewRootCommand() (*cobra.Command, error) {
	rootCmd := cobra.Command{
		Use:   "cog",
		Short: "Cog: Containers for machine learning",
		Long: `Containers for machine learning.

To get started, take a look at the documentation:
https://github.com/replicate/cog`,
		Example: `   To run a command inside a Docker environment defined with Cog:
      $ cog run echo hello world`,
		Version: fmt.Sprintf("%s (built %s)", global.Version, global.BuildTime),
		// This stops errors being printed because we print them in cmd/cog/cog.go
		PersistentPreRun: func(cmd *cobra.Command, args []string) {
			if global.Debug {
				console.SetLevel(console.DebugLevel)
			}
			if global.NoColor || !console.ShouldUseColor() {
				console.SetColor(false)
			}
			if global.NoColor {
				os.Setenv("NO_COLOR", "1") //nolint:errcheck,gosec // best-effort
			}
			cmd.SilenceUsage = true
			if err := update.DisplayAndCheckForRelease(cmd.Context()); err != nil {
				console.Debugf("%s", err)
			}
		},
		SilenceErrors: true,
	}
	setPersistentFlags(&rootCmd)

	rootCmd.AddCommand(
		newBuildCommand(),
		newDebugCommand(),
		newInitCommand(),
		newInspectCommand(),
		newLoginCommand(),
		newPredictCommand(),
		newPushCommand(),
		newRunCommand(),
		newServeCommand(),
		newTrainCommand(),
		newWeightsCommand(),
	)

	return &rootCmd, nil
}

func setPersistentFlags(cmd *cobra.Command) {
	cmd.PersistentFlags().BoolVar(&global.Debug, "debug", false, "Show debugging output")
	cmd.PersistentFlags().BoolVar(&global.NoColor, "no-color", false, "Disable colored output")
	cmd.PersistentFlags().BoolVar(&global.ProfilingEnabled, "profile", false, "Enable profiling")
	cmd.PersistentFlags().Bool("version", false, "Show version of Cog")
	cmd.PersistentFlags().StringVar(&global.ReplicateRegistryHost, "registry", global.ReplicateRegistryHost, "Registry host")
	_ = cmd.PersistentFlags().MarkHidden("profile")
	_ = cmd.PersistentFlags().MarkHidden("registry")
}


================================================
FILE: pkg/cli/run.go
================================================
package cli

import (
	"os"
	"strconv"
	"strings"

	"github.com/spf13/cobra"

	"github.com/replicate/cog/pkg/docker"
	"github.com/replicate/cog/pkg/docker/command"
	"github.com/replicate/cog/pkg/model"
	"github.com/replicate/cog/pkg/registry"
	"github.com/replicate/cog/pkg/util/console"
)

var (
	runPorts []string
	gpusFlag string
)

func addGpusFlag(cmd *cobra.Command) {
	cmd.Flags().StringVar(&gpusFlag, "gpus", "", "GPU devices to add to the container, in the same format as `docker run --gpus`.")
}

func newRunCommand() *cobra.Command {
	cmd := &cobra.Command{
		Use:   "run  [arg...]",
		Short: "Run a command inside a Docker environment",
		Long: `Run a command inside a Docker environment defined by cog.yaml.

Cog builds a temporary image from your cog.yaml configuration and runs the
given command inside it. This is useful for debugging, running scripts, or
exploring the environment your model will run in.`,
		Example: `  # Open a Python interpreter inside the model environment
  cog run python

  # Run a script
  cog run python train.py

  # Run with environment variables
  cog run -e HUGGING_FACE_HUB_TOKEN=abc123 python download.py

  # Expose a port (e.g. for Jupyter)
  cog run -p 8888 jupyter notebook`,
		RunE:    run,
		PreRunE: checkMutuallyExclusiveFlags,
		Args:    cobra.MinimumNArgs(1),
	}
	addBuildProgressOutputFlag(cmd)
	addDockerfileFlag(cmd)
	addUseCudaBaseImageFlag(cmd)
	addUseCogBaseImageFlag(cmd)
	addGpusFlag(cmd)
	addConfigFlag(cmd)

	flags := cmd.Flags()
	// Flags after first argument are considered args and passed to command

	// This is called `publish` for consistency with `docker run`
	cmd.Flags().StringArrayVarP(&runPorts, "publish", "p", []string{}, "Publish a container's port to the host, e.g. -p 8000")
	cmd.Flags().StringArrayVarP(&envFlags, "env", "e", []string{}, "Environment variables, in the form name=value")

	flags.SetInterspersed(false)

	return cmd
}

func run(cmd *cobra.Command, args []string) error {
	ctx := cmd.Context()

	dockerClient, err := docker.NewClient(ctx)
	if err != nil {
		return err
	}

	src, err := model.NewSource(configFilename)
	if err != nil {
		return err
	}

	resolver := model.NewResolver(dockerClient, registry.NewRegistryClient())

	console.Info("Building Docker image from environment in cog.yaml...")
	console.Info("")
	opts := serveBuildOptions(cmd)
	opts.SkipSchemaValidation = true
	m, err := resolver.Build(ctx, src, opts)
	if err != nil {
		return err
	}

	gpus := ""
	if gpusFlag != "" {
		gpus = gpusFlag
	} else if m.HasGPU() {
		gpus = "all"
	}

	// Automatically propagate RUST_LOG for Rust coglet debugging
	env := envFlags
	if rustLog := os.Getenv("RUST_LOG"); rustLog != "" {
		env = append(env, "RUST_LOG="+rustLog)
	}

	runOptions := command.RunOptions{
		Args:    args,
		Env:     env,
		GPUs:    gpus,
		Image:   m.ImageRef(),
		Volumes: []command.Volume{{Source: src.ProjectDir, Destination: "/src"}},
		Workdir: "/src",
	}

	for _, portString := range runPorts {
		port, err := strconv.Atoi(portString)
		if err != nil {
			return err
		}

		runOptions.Ports = append(runOptions.Ports, command.Port{HostPort: port, ContainerPort: port})
	}

	console.Info("")
	console.Infof("Running %s in Docker with the current directory mounted as a volume...", console.Bold(strings.Join(args, " ")))
	console.Info("")

	err = docker.Run(ctx, dockerClient, runOptions)
	// Only retry if we're using a GPU but the user didn't explicitly select a GPU with --gpus
	// If the user specified the wrong GPU, they are explicitly selecting a GPU and they'll want to hear about it
	if runOptions.GPUs == "all" && err == docker.ErrMissingDeviceDriver {
		console.Info("Missing device driver, re-trying without GPU")

		runOptions.GPUs = ""
		err = docker.Run(ctx, dockerClient, runOptions)
	}

	return err
}


================================================
FILE: pkg/cli/serve.go
================================================
package cli

import (
	"fmt"
	"os"
	"strings"

	"github.com/spf13/cobra"

	"github.com/replicate/cog/pkg/docker"
	"github.com/replicate/cog/pkg/docker/command"
	"github.com/replicate/cog/pkg/model"
	"github.com/replicate/cog/pkg/registry"
	"github.com/replicate/cog/pkg/util/console"
)

var (
	port      = 8393
	uploadURL = ""
)

func newServeCommand() *cobra.Command {
	cmd := &cobra.Command{
		Use:   "serve",
		Short: "Run a prediction HTTP server",
		Long: `Run a prediction HTTP server.

Builds the model and starts an HTTP server that exposes the model's inputs
and outputs as a REST API. Compatible with the Cog HTTP protocol.`,
		Example: `  # Start the server on the default port (8393)
  cog serve

  # Start on a custom port
  cog serve -p 5000

  # Test the server
  curl http://localhost:8393/predictions \
    -X POST \
    -H 'Content-Type: application/json' \
    -d '{"input": {"prompt": "a cat"}}'`,
		RunE:       cmdServe,
		Args:       cobra.MaximumNArgs(0),
		SuggestFor: []string{"http"},
	}

	addBuildProgressOutputFlag(cmd)
	addUseCudaBaseImageFlag(cmd)
	addUseCogBaseImageFlag(cmd)
	addGpusFlag(cmd)
	addConfigFlag(cmd)

	cmd.Flags().IntVarP(&port, "port", "p", port, "Port on which to listen")
	cmd.Flags().StringVar(&uploadURL, "upload-url", "", "Upload URL for file outputs (e.g. https://example.com/upload/)")

	return cmd
}

// serveBuildOptions creates BuildOptions for cog serve.
// Same build path as cog build, but with ExcludeSource so COPY . /src is
// skipped — source is volume-mounted at runtime instead. All other layers
// (wheels, apt, etc.) share Docker layer cache with cog build.
func serveBuildOptions(cmd *cobra.Command) model.BuildOptions {
	return model.BuildOptions{
		UseCudaBaseImage: buildUseCudaBaseImage,
		UseCogBaseImage:  DetermineUseCogBaseImage(cmd),
		ProgressOutput:   buildProgressOutput,
		ExcludeSource:    true,
		SkipLabels:       true,
	}
}

func cmdServe(cmd *cobra.Command, arg []string) error {
	ctx := cmd.Context()

	dockerClient, err := docker.NewClient(ctx)
	if err != nil {
		return err
	}

	src, err := model.NewSource(configFilename)
	if err != nil {
		return err
	}

	console.Info("Building Docker image from environment in cog.yaml...")
	console.Info("")
	resolver := model.NewResolver(dockerClient, registry.NewRegistryClient())
	m, err := resolver.Build(ctx, src, serveBuildOptions(cmd))
	if err != nil {
		return err
	}

	gpus := ""
	if gpusFlag != "" {
		gpus = gpusFlag
	} else if m.HasGPU() {
		gpus = "all"
	}

	args := []string{
		"python",
		"--check-hash-based-pycs", "never",
		"-m", "cog.server.http",
		"--await-explicit-shutdown", "true",
	}

	if uploadURL != "" {
		args = append(args, "--upload-url", uploadURL)
	}

	// Automatically propagate RUST_LOG for Rust coglet debugging
	env := envFlags
	if rustLog := os.Getenv("RUST_LOG"); rustLog != "" {
		env = append(env, "RUST_LOG="+rustLog)
	}

	runOptions := command.RunOptions{
		Args:    args,
		Env:     env,
		GPUs:    gpus,
		Image:   m.ImageRef(),
		Volumes: []command.Volume{{Source: src.ProjectDir, Destination: "/src"}},
		Workdir: "/src",
	}

	// On Linux, host.docker.internal is not available by default — add it.
	// This allows the container to reach services running on the host,
	// e.g. when --upload-url points to a local upload server.
	if uploadURL != "" {
		runOptions.ExtraHosts = []string{"host.docker.internal:host-gateway"}
	}

	runOptions.Ports = append(runOptions.Ports, command.Port{HostPort: port, ContainerPort: 5000})

	console.Info("")
	console.Infof("Running %[1]s in Docker with the current directory mounted as a volume...", console.Bold(strings.Join(args, " ")))
	console.Info("")
	console.Infof("Serving at %s", console.Bold(fmt.Sprintf("http://127.0.0.1:%v", port)))
	console.Info("")

	err = docker.Run(ctx, dockerClient, runOptions)
	// Only retry if we're using a GPU but the user didn't explicitly select a GPU with --gpus
	// If the user specified the wrong GPU, they are explicitly selecting a GPU and they'll want to hear about it
	if runOptions.GPUs == "all" && err == docker.ErrMissingDeviceDriver {
		console.Info("Missing device driver, re-trying without GPU")

		runOptions.GPUs = ""
		err = docker.Run(ctx, dockerClient, runOptions)
	}

	return err
}


================================================
FILE: pkg/cli/train.go
================================================
package cli

import (
	"context"
	"os"
	"os/signal"
	"syscall"
	"time"

	"github.com/spf13/cobra"

	"github.com/replicate/cog/pkg/docker"
	"github.com/replicate/cog/pkg/docker/command"
	"github.com/replicate/cog/pkg/model"
	"github.com/replicate/cog/pkg/predict"
	"github.com/replicate/cog/pkg/registry"
	"github.com/replicate/cog/pkg/util/console"
)

var (
	trainEnvFlags   []string
	trainInputFlags []string
	trainOutPath    string
)

func newTrainCommand() *cobra.Command {
	cmd := &cobra.Command{
		Use:   "train [image]",
		Short: "Run a training",
		Long: `Run a training.

If 'image' is passed, it will run the training on that Docker image.
It must be an image that has been built by Cog.

Otherwise, it will build the model in the current directory and train it.`,
		RunE:       cmdTrain,
		Args:       cobra.MaximumNArgs(1),
		Hidden:     true,
		Deprecated: "the train command will be removed in a future version of Cog",
	}

	addBuildProgressOutputFlag(cmd)
	addDockerfileFlag(cmd)
	addUseCudaBaseImageFlag(cmd)
	addGpusFlag(cmd)
	addUseCogBaseImageFlag(cmd)
	addConfigFlag(cmd)

	cmd.Flags().StringArrayVarP(&trainInputFlags, "input", "i", []string{}, "Inputs, in the form name=value. if value is prefixed with @, then it is read from a file on disk. E.g. -i path=@image.jpg")
	cmd.Flags().StringArrayVarP(&trainEnvFlags, "env", "e", []string{}, "Environment variables, in the form name=value")
	cmd.Flags().StringVarP(&trainOutPath, "output", "o", "weights", "Output path")

	return cmd
}

func cmdTrain(cmd *cobra.Command, args []string) error {
	ctx := cmd.Context()

	dockerClient, err := docker.NewClient(ctx)
	if err != nil {
		return err
	}

	imageName := ""
	volumes := []command.Volume{}
	gpus := gpusFlag

	resolver := model.NewResolver(dockerClient, registry.NewRegistryClient())

	if len(args) == 0 {
		// Build image
		src, err := model.NewSource(configFilename)
		if err != nil {
			return err
		}

		console.Info("Building Docker image from environment in cog.yaml...")
		console.Info("")
		m, err := resolver.Build(ctx, src, serveBuildOptions(cmd))
		if err != nil {
			return err
		}
		imageName = m.ImageRef()

		// ExcludeSource build doesn't have /src in it, so mount as volume
		volumes = append(volumes, command.Volume{
			Source:      src.ProjectDir,
			Destination: "/src",
		})

		if gpus == "" && m.HasGPU() {
			gpus = "all"
		}
	} else {
		// Use existing image
		imageName = args[0]

		// Pull the image (if needed) and validate it's a Cog model
		ref, err := model.ParseRef(imageName)
		if err != nil {
			return err
		}
		m, err := resolver.Pull(ctx, ref)
		if err != nil {
			return err
		}

		if gpus == "" && m.HasGPU() {
			gpus = "all"
		}
	}

	console.Info("")
	console.Info("Starting Docker image and running setup()...")

	predictor, err := predict.NewPredictor(ctx, command.RunOptions{
		GPUs:    gpus,
		Image:   imageName,
		Volumes: volumes,
		Env:     trainEnvFlags,
		Args:    []string{"python", "-m", "cog.server.http", "--x-mode", "train"},
	}, true, dockerClient)
	if err != nil {
		return err
	}

	go func() {
		captureSignal := make(chan os.Signal, 1)
		signal.Notify(captureSignal, syscall.SIGINT)

		<-captureSignal

		console.Info("Stopping container...")
		if err := predictor.Stop(ctx); err != nil {
			console.Warnf("Failed to stop container: %s", err)
		}
	}()

	if err := predictor.Start(ctx, os.Stderr, time.Duration(setupTimeout)*time.Second); err != nil {
		return err
	}

	// FIXME: will not run on signal
	defer func() {
		console.Debugf("Stopping container...")
		// use background context to ensure stop signal is still sent after root context is canceled
		if err := predictor.Stop(context.Background()); err != nil {
			console.Warnf("Failed to stop container: %s", err)
		}
	}()

	return predictIndividualInputs(*predictor, trainInputFlags, trainOutPath, true)
}


================================================
FILE: pkg/cli/train_test.go
================================================
package cli

import (
	"testing"

	"github.com/stretchr/testify/require"
)

func TestTrainCommandIsDeprecated(t *testing.T) {
	cmd := newTrainCommand()
	require.NotEmpty(t, cmd.Deprecated, "train command should have a deprecation message")
	require.Contains(t, cmd.Deprecated, "will be removed in a future version")
}


================================================
FILE: pkg/cli/weights.go
================================================
package cli

import (
	"fmt"
	"path/filepath"
	"time"

	"github.com/google/go-containerregistry/pkg/name"
	"github.com/spf13/cobra"
	"golang.org/x/sync/errgroup"

	"github.com/replicate/cog/pkg/docker"
	"github.com/replicate/cog/pkg/global"
	"github.com/replicate/cog/pkg/model"
	"github.com/replicate/cog/pkg/registry"
	"github.com/replicate/cog/pkg/util/console"
)

func newWeightsCommand() *cobra.Command {
	cmd := &cobra.Command{
		Use:    "weights",
		Short:  "Manage model weights",
		Long:   "Commands for managing model weight files.",
		Hidden: true,
	}

	cmd.AddCommand(newWeightsBuildCommand())
	cmd.AddCommand(newWeightsInspectCommand())
	cmd.AddCommand(newWeightsPushCommand())
	return cmd
}

func newWeightsBuildCommand() *cobra.Command {
	cmd := &cobra.Command{
		Use:   "build",
		Short: "Generate weights.lock from weight sources in cog.yaml",
		Long: `Reads the weights section from cog.yaml, processes each weight source,
and generates a weights.lock file containing metadata (digests, sizes) for each file.`,
		Args: cobra.NoArgs,
		RunE: weightsBuildCommand,
	}

	addConfigFlag(cmd)
	return cmd
}

func weightsBuildCommand(cmd *cobra.Command, args []string) error {
	ctx := cmd.Context()

	src, err := model.NewSource(configFilename)
	if err != nil {
		return fmt.Errorf("failed to read config: %w", err)
	}

	if len(src.Config.Weights) == 0 {
		return fmt.Errorf("no weights defined in %s", configFilename)
	}

	// Extract weight specs from the source
	var weightSpecs []*model.WeightSpec
	for _, spec := range src.ArtifactSpecs() {
		if ws, ok := spec.(*model.WeightSpec); ok {
			weightSpecs = append(weightSpecs, ws)
		}
	}

	console.Infof("Processing %d weight source(s)...", len(weightSpecs))

	lockPath := filepath.Join(src.ProjectDir, model.WeightsLockFilename)
	builder := model.NewWeightBuilder(src, global.Version, lockPath)

	// Build each weight artifact (hashes file, updates lockfile)
	var totalSize int64
	for _, ws := range weightSpecs {
		artifact, buildErr := builder.Build(ctx, ws)
		if buildErr != nil {
			return fmt.Errorf("failed to build weight %q: %w", ws.Name(), buildErr)
		}

		wa, ok := artifact.(*model.WeightArtifact)
		if !ok {
			return fmt.Errorf("unexpected artifact type %T for weight %q", artifact, ws.Name())
		}
		size := wa.Descriptor().Size
		totalSize += size
		console.Infof("  %s -> %s (%s)", wa.Name(), wa.Target, formatSize(size))
	}

	console.Infof("\nGenerated %s with %d file(s) (%s total)",
		model.WeightsLockFilename, len(weightSpecs), formatSize(totalSize))

	return nil
}

func formatSize(bytes int64) string {
	const (
		kb = 1024
		mb = kb * 1024
		gb = mb * 1024
	)

	switch {
	case bytes >= gb:
		return fmt.Sprintf("%.1fGB", float64(bytes)/float64(gb))
	case bytes >= mb:
		return fmt.Sprintf("%.1fMB", float64(bytes)/float64(mb))
	case bytes >= kb:
		return fmt.Sprintf("%.1fKB", float64(bytes)/float64(kb))
	default:
		return fmt.Sprintf("%dB", bytes)
	}
}

func newWeightsPushCommand() *cobra.Command {
	cmd := &cobra.Command{
		Use:   "push [IMAGE]",
		Short: "Push weights to a registry",
		Long: `Reads weights.lock and pushes weight files as an OCI artifact to a registry.

The registry is determined from the image name, which can be:
- Specified as an argument: cog weights push registry.example.com/user/model
- Set in cog.yaml as the 'image' field`,
		Args: cobra.MaximumNArgs(1),
		RunE: weightsPushCommand,
	}

	addConfigFlag(cmd)
	return cmd
}

func weightsPushCommand(cmd *cobra.Command, args []string) error {
	ctx := cmd.Context()

	src, err := model.NewSource(configFilename)
	if err != nil {
		return fmt.Errorf("failed to read config: %w", err)
	}

	cfg := src.Config

	// Determine image name
	imageName := cfg.Image
	if len(args) > 0 {
		imageName = args[0]
	}
	if imageName == "" {
		return fmt.Errorf("To push weights, you must either set the 'image' option in cog.yaml or pass an image name as an argument. For example, 'cog weights push registry.example.com/your-username/model-name'")
	}

	// Parse as repository only — reject tags/digests since weight tags are auto-generated.
	parsedRepo, err := name.NewRepository(imageName, name.Insecure)
	if err != nil {
		// NewRepository fails for inputs with :tag or @digest — check if it's a valid ref
		if ref, refErr := name.ParseReference(imageName, name.Insecure); refErr == nil {
			return fmt.Errorf("image reference %q includes a tag or digest — provide only the repository (e.g., %q)", imageName, ref.Context().Name())
		}
		return fmt.Errorf("invalid repository %q: %w", imageName, err)
	}
	repo := parsedRepo.Name()

	if len(cfg.Weights) == 0 {
		return fmt.Errorf("no weights defined in %s", configFilename)
	}

	// Build weight artifacts (reads lockfile as cache, hashes files)
	lockPath := filepath.Join(src.ProjectDir, model.WeightsLockFilename)
	builder := model.NewWeightBuilder(src, global.Version, lockPath)

	var artifacts []*model.WeightArtifact
	for _, spec := range src.ArtifactSpecs() {
		ws, ok := spec.(*model.WeightSpec)
		if !ok {
			continue
		}
		artifact, buildErr := builder.Build(ctx, ws)
		if buildErr != nil {
			return fmt.Errorf("failed to build weight %q: %w", ws.Name(), buildErr)
		}
		wa, ok := artifact.(*model.WeightArtifact)
		if !ok {
			return fmt.Errorf("unexpected artifact type %T for weight %q", artifact, ws.Name())
		}
		artifacts = append(artifacts, wa)
	}

	if len(artifacts) == 0 {
		return fmt.Errorf("no weight artifacts to push")
	}

	console.Infof("Pushing %d weight file(s) to %s...", len(artifacts), repo)

	regClient := registry.NewRegistryClient()
	pusher := model.NewWeightPusher(regClient)

	// Set up progress display using Docker's jsonmessage rendering.
	pw := docker.NewProgressWriter()
	defer pw.Close()

	// Push each weight artifact concurrently using errgroup for
	// bounded concurrency and first-error cancellation.
	type pushResult struct {
		ref  string
		size int64
	}

	ordered := make([]pushResult, len(artifacts))

	g, ctx := errgroup.WithContext(ctx)
	g.SetLimit(model.GetPushConcurrency())

	for i, wa := range artifacts {
		artName := wa.Name()
		artSize := wa.Descriptor().Size

		g.Go(func() error {
			result, pushErr := pusher.Push(ctx, repo, wa, model.WeightPushOptions{
				ProgressFn: func(prog model.PushProgress) {
					pw.Write(artName, "Pushing", prog.Complete, prog.Total)
				},
				RetryFn: func(event model.WeightRetryEvent) bool {
					status := fmt.Sprintf("Retrying (%d/%d) in %s",
						event.Attempt, event.MaxAttempts,
						event.NextRetryIn.Round(time.Second))
					pw.WriteStatus(event.Name, status)
					// In non-TTY mode, also log the error detail since the
					// progress writer output won't be visible.
					if !console.IsTerminal() {
						console.Warnf("  %s: retrying (%d/%d) in %s: %v",
							event.Name, event.Attempt, event.MaxAttempts,
							event.NextRetryIn.Round(time.Second), event.Err)
					}
					return true
				},
			})

			if pushErr != nil {
				pw.WriteStatus(artName, "FAILED")
				return fmt.Errorf("push weight %q: %w", artName, pushErr)
			}

			pw.WriteStatus(artName, "Pushed")
			ordered[i] = pushResult{ref: result.Ref, size: artSize}
			return nil
		})
	}

	if err := g.Wait(); err != nil {
		pw.Close()
		return err
	}

	// Close progress display
	pw.Close()

	// Print final summary
	var totalSize int64
	for i, wa := range artifacts {
		console.Infof("  %s: %s", wa.Name(), ordered[i].ref)
		totalSize += ordered[i].size
	}

	console.Infof("\nPushed %d weight artifact(s) to %s", len(artifacts), repo)
	console.Infof("Total: %s", formatSize(totalSize))

	return nil
}


================================================
FILE: pkg/cli/weights_inspect.go
================================================
package cli

import (
	"context"
	"encoding/json"
	"fmt"
	"os"
	"path/filepath"

	"github.com/google/go-containerregistry/pkg/name"
	"github.com/spf13/cobra"

	"github.com/replicate/cog/pkg/model"
	"github.com/replicate/cog/pkg/registry"
)

// localWeight tracks the local state of a weight from cog.yaml + weights.lock.
type localWeight struct {
	target   string
	source   string
	lockFile *model.WeightFile
}

// WeightsInspectOutput is the structured output for cog weights inspect --json.
type WeightsInspectOutput struct {
	Reference string               `json:"reference"`
	Weights   []WeightInspectEntry `json:"weights"`
}

// WeightInspectEntry represents one weight's comparison between local and remote state.
type WeightInspectEntry struct {
	Name   string             `json:"name"`
	Status string             `json:"status"` // synced, local-only, remote-only, digest-mismatch, missing-lockfile
	Local  *WeightLocalState  `json:"local,omitempty"`
	Remote *WeightRemoteState `json:"remote,omitempty"`
}

// WeightLocalState represents the local state of a weight from cog.yaml + weights.lock.
type WeightLocalState struct {
	Digest     string `json:"digest"`
	Size       int64  `json:"size"`
	Target     string `json:"target"`
	FileExists bool   `json:"fileExists"`
}

// WeightRemoteLayer represents a single layer in a remote weight manifest.
type WeightRemoteLayer struct {
	Digest    string `json:"digest"`
	Size      int64  `json:"size"`
	MediaType string `json:"mediaType"`
}

// WeightRemoteState represents the remote state of a weight from the registry.
type WeightRemoteState struct {
	Ref              string              `json:"ref"`
	Tag              string              `json:"tag"`
	Digest           string              `json:"digest"`
	Size             int64               `json:"size"`
	MediaType        string              `json:"mediaType"`
	Layers           []WeightRemoteLayer `json:"layers,omitempty"`
	MatchedByContent bool                `json:"matchedByContent,omitempty"`
}

func newWeightsInspectCommand() *cobra.Command {
	var jsonOutput bool

	cmd := &cobra.Command{
		Use:   "inspect ",
		Short: "Compare local weights against remote registry state",
		Args:  cobra.ExactArgs(1),
		RunE: func(cmd *cobra.Command, args []string) error {
			return weightsInspectCommand(cmd, args, jsonOutput)
		},
	}

	cmd.Flags().BoolVar(&jsonOutput, "json", false, "Output as JSON")
	addConfigFlag(cmd)

	return cmd
}

func weightsInspectCommand(cmd *cobra.Command, args []string, jsonOutput bool) error {
	ctx := cmd.Context()

	// 1. Load local state
	src, err := model.NewSource(configFilename)
	if err != nil {
		return fmt.Errorf("failed to read config: %w", err)
	}

	lockPath := filepath.Join(src.ProjectDir, model.WeightsLockFilename)
	lock, lockErr := model.LoadWeightsLock(lockPath)
	// lockErr is OK — lockfile may not exist yet

	// Build local weight map: name -> (lockfile entry, source file path)
	localWeights := make(map[string]*localWeight)
	for _, w := range src.Config.Weights {
		lw := &localWeight{
			target: w.Target,
			source: w.Source,
		}
		localWeights[w.Name] = lw
	}

	// Fill in lockfile data
	if lockErr == nil && lock != nil {
		for i := range lock.Files {
			f := &lock.Files[i]
			if lw, ok := localWeights[f.Name]; ok {
				lw.lockFile = f
			}
		}
	}

	// 2. Resolve remote state — accept repo only (tags are auto-generated for weights).
	parsedRepo, err := name.NewRepository(args[0], name.Insecure)
	if err != nil {
		if ref, refErr := name.ParseReference(args[0], name.Insecure); refErr == nil {
			return fmt.Errorf("image reference %q includes a tag or digest — provide only the repository (e.g., %q)", args[0], ref.Context().Name())
		}
		return fmt.Errorf("invalid repository %q: %w", args[0], err)
	}
	repo := parsedRepo.Name()

	regClient := registry.NewRegistryClient()
	remoteWeights := resolveWeightsByTag(ctx, repo, localWeights, regClient)

	// 3. Build comparison
	out := &WeightsInspectOutput{
		Reference: repo,
	}

	// Track which remote weights we've matched
	matchedRemote := make(map[string]bool)

	// Process local weights
	for _, w := range src.Config.Weights {
		entry := WeightInspectEntry{Name: w.Name}
		lw := localWeights[w.Name]

		if lw.lockFile == nil {
			// No lockfile entry — needs `cog weights build`
			entry.Status = "missing-lockfile"
			entry.Local = &WeightLocalState{
				Target:     lw.target,
				FileExists: fileExists(filepath.Join(src.ProjectDir, lw.source)),
			}
		} else {
			// Check if source file exists on disk
			exists := fileExists(filepath.Join(src.ProjectDir, lw.source))
			entry.Local = &WeightLocalState{
				Digest:     lw.lockFile.Digest,
				Size:       lw.lockFile.Size,
				Target:     lw.lockFile.Dest,
				FileExists: exists,
			}

			if remote, ok := remoteWeights[w.Name]; ok {
				matchedRemote[w.Name] = true
				entry.Remote = remote

				if remote.MatchedByContent || lw.lockFile.Digest == remote.Digest {
					entry.Status = "synced"
				} else {
					entry.Status = "digest-mismatch"
				}
			} else {
				entry.Status = "local-only"
			}
		}

		out.Weights = append(out.Weights, entry)
	}

	// Add remote-only weights
	for name, remote := range remoteWeights {
		if matchedRemote[name] {
			continue
		}
		out.Weights = append(out.Weights, WeightInspectEntry{
			Name:   name,
			Status: "remote-only",
			Remote: remote,
		})
	}

	// 4. Output
	if jsonOutput {
		enc := json.NewEncoder(os.Stdout)
		enc.SetIndent("", "  ")
		return enc.Encode(out)
	}

	printWeightsInspectText(out)
	return nil
}

// resolveWeightsByTag checks for each local weight's tag in the registry.
// This is the fallback path when no OCI index exists (e.g., after `cog weights push`
// but before `cog push`).
//
// It looks up the combined tag :weights-- which encodes both
// the weight name and its content digest. A match means the exact content is synced.
func resolveWeightsByTag(ctx context.Context, repo string, localWeights map[string]*localWeight, reg registry.Client) map[string]*WeightRemoteState {
	result := make(map[string]*WeightRemoteState)
	for weightName, lw := range localWeights {
		if lw.lockFile == nil {
			continue
		}

		tag := model.WeightTag(weightName, lw.lockFile.Digest)
		tagRef := repo + ":" + tag

		// Use GetImage to fetch the full manifest (not just HEAD) so we can read layer sizes.
		img, err := reg.GetImage(ctx, tagRef, nil)
		if err != nil {
			continue
		}

		manifest, err := img.Manifest()
		if err != nil {
			continue
		}

		digest, err := img.Digest()
		if err != nil {
			continue
		}

		rawManifest, err := img.RawManifest()
		if err != nil {
			continue
		}

		state := &WeightRemoteState{
			Ref:              tagRef,
			Tag:              tag,
			Digest:           digest.String(),
			Size:             int64(len(rawManifest)),
			MediaType:        string(manifest.MediaType),
			MatchedByContent: true,
		}

		for _, layer := range manifest.Layers {
			state.Layers = append(state.Layers, WeightRemoteLayer{
				Digest:    layer.Digest.String(),
				Size:      layer.Size,
				MediaType: string(layer.MediaType),
			})
		}

		result[weightName] = state
	}
	if len(result) == 0 {
		return nil
	}
	return result
}

func printWeightsInspectText(out *WeightsInspectOutput) {
	fmt.Printf("Weights for: %s\n\n", out.Reference)

	for _, w := range out.Weights {
		if w.Remote != nil && w.Remote.Tag != "" {
			fmt.Printf("  %s  :%s\n", w.Name, w.Remote.Tag)
		} else {
			fmt.Printf("  %s\n", w.Name)
		}
		fmt.Printf("    Status:  %s", w.Status)

		switch w.Status {
		case "local-only":
			fmt.Print(" (not pushed)")
		case "remote-only":
			fmt.Print(" (not in cog.yaml)")
		case "missing-lockfile":
			fmt.Print(" (run cog weights build)")
		}
		fmt.Println()

		if w.Local != nil {
			if w.Local.Digest != "" {
				fmt.Printf("    Local:   %s (%s) -> %s\n", w.Local.Digest, formatSize(w.Local.Size), w.Local.Target)
			} else {
				fmt.Printf("    Local:   (no lockfile entry) -> %s\n", w.Local.Target)
			}
		} else {
			fmt.Println("    Local:   -")
		}

		if w.Remote != nil {
			for _, layer := range w.Remote.Layers {
				fmt.Printf("    Layer:   %s (%s)\n", layer.Digest, formatSize(layer.Size))
			}
		} else {
			fmt.Println("    Remote:  -")
		}

		fmt.Println()
	}
}

func fileExists(path string) bool {
	_, err := os.Stat(path)
	return err == nil
}


================================================
FILE: pkg/config/build_options.go
================================================
package config

// BuildOptions contains runtime options passed via CLI flags, not from cog.yaml.
// These are separate from the Config struct because they are not part of the
// model configuration - they are build-time settings that affect how the
// container is built but not what's in it.
type BuildOptions struct {
	// SourceEpochTimestamp is the number of seconds since Unix epoch to use
	// for the build timestamp. Set to -1 to disable timestamp rewrites.
	// This is useful for reproducible builds.
	SourceEpochTimestamp int64

	// XCachePath is the path to the BuildKit cache directory.
	// If empty, inline caching is used instead of local cache.
	XCachePath string
}

// DefaultBuildOptions returns BuildOptions with sensible defaults.
func DefaultBuildOptions() BuildOptions {
	return BuildOptions{
		SourceEpochTimestamp: -1,
		XCachePath:           "",
	}
}


================================================
FILE: pkg/config/compatibility.go
================================================
package config

import (
	// blank import for embeds
	_ "embed"
	"encoding/json"
	"errors"
	"fmt"
	"sort"
	"strings"

	"golang.org/x/exp/slices"

	"github.com/replicate/cog/pkg/requirements"
	"github.com/replicate/cog/pkg/util"
	"github.com/replicate/cog/pkg/util/console"

	"github.com/replicate/cog/pkg/util/version"
)

// TODO(andreas): check tf/py versions. tf 1.5.0 didn't install on py 3.10
// TODO(andreas): support more tf versions. No matching tensorflow CPU package for version 1.15.4, etc.
// TODO(andreas): allow user to install versions that aren't compatible
// TODO(andreas): allow user to install tf cpu package on gpu

type TFCompatibility struct {
	TF           string
	TFCPUPackage string
	TFGPUPackage string
	CUDA         string
	CuDNN        string
	Pythons      []string
}

func (compat *TFCompatibility) UnmarshalJSON(data []byte) error {
	// to avoid unmarshalling stack overflow https://stackoverflow.com/questions/34859449/unmarshaljson-results-in-stack-overflow
	type tempType TFCompatibility
	c := new(tempType)
	if err := json.Unmarshal(data, c); err != nil {
		return err
	}
	cuda := version.MustVersion(c.CUDA)
	cuDNN := version.MustVersion(c.CuDNN)
	compat.TF = c.TF
	compat.TFCPUPackage = c.TFCPUPackage
	compat.TFGPUPackage = c.TFGPUPackage
	// include minor version
	compat.CUDA = fmt.Sprintf("%d.%d", cuda.Major, cuda.Minor)
	// strip cuDNN minor version to match nvidia images
	compat.CuDNN = fmt.Sprintf("%d", cuDNN.Major)
	compat.Pythons = c.Pythons
	return nil
}

type TorchCompatibility struct {
	Torch         string
	Torchvision   string
	Torchaudio    string
	FindLinks     string
	ExtraIndexURL string
	CUDA          *string
	Pythons       []string
}

func (c *TorchCompatibility) TorchVersion() string {
	return version.StripModifier(c.Torch)
}

func (c *TorchCompatibility) TorchvisionVersion() string {
	return version.StripModifier(c.Torchvision)
}

type CUDABaseImage struct {
	Tag     string
	CUDA    string
	CuDNN   string
	IsDevel bool
	Ubuntu  string
}

func (i *CUDABaseImage) ImageTag() string {
	return "nvidia/cuda:" + i.Tag
}

//go:embed cuda_compatibility.json
var cudaBaseImagesData []byte
var CUDABaseImages []CUDABaseImage

//go:embed tf_compatibility.json
var tfCompatibilityMatrixData []byte
var TFCompatibilityMatrix []TFCompatibility

//go:embed torch_compatibility.json
var torchCompatibilityMatrixData []byte
var TorchCompatibilityMatrix []TorchCompatibility

func init() {
	if err := json.Unmarshal(cudaBaseImagesData, &CUDABaseImages); err != nil {
		console.Fatalf("Failed to load embedded CUDA base images: %s", err)
	}

	if err := json.Unmarshal(tfCompatibilityMatrixData, &TFCompatibilityMatrix); err != nil {
		console.Fatalf("Failed to load embedded Tensorflow compatibility matrix: %s", err)
	}

	var torchCompatibilityMatrix []TorchCompatibility
	if err := json.Unmarshal(torchCompatibilityMatrixData, &torchCompatibilityMatrix); err != nil {
		console.Fatalf("Failed to load embedded PyTorch compatibility matrix: %s", err)
	}
	filteredTorchCompatibilityMatrix := []TorchCompatibility{}
	for _, compat := range torchCompatibilityMatrix {
		for _, cudaBaseImage := range CUDABaseImages {
			if compat.CUDA == nil || version.Matches(*compat.CUDA, cudaBaseImage.CUDA) {
				filteredTorchCompatibilityMatrix = append(filteredTorchCompatibilityMatrix, compat)
				break
			}
		}
	}
	TorchCompatibilityMatrix = filteredTorchCompatibilityMatrix
}

func cudaVersionFromTorchPlusVersion(ver string) (string, string) {
	const cudaVersionPrefix = "cu"

	// Split the version string by the '+' character.
	versionParts := strings.Split(ver, "+")

	// If there is no '+' in the version string, return the original string with an empty CUDA version.
	if len(versionParts) <= 1 {
		return "", ver
	}

	// Extract the part after the last '+'.
	cudaVersionPart := versionParts[len(versionParts)-1]

	// Check if the extracted part has the CUDA version prefix.
	if !strings.HasPrefix(cudaVersionPart, cudaVersionPrefix) {
		return "", ver
	}

	// Trim the CUDA version prefix and reformat the version string.
	cleanVersion := strings.TrimPrefix(cudaVersionPart, cudaVersionPrefix)
	if len(cleanVersion) < 2 {
		return "", ver // Handle case where cleanVersion is too short to reformat.
	}

	// Insert a dot before the last character to format it as expected.
	cleanVersion = cleanVersion[:len(cleanVersion)-1] + "." + cleanVersion[len(cleanVersion)-1:]

	// Return the reformatted CUDA version and the main version.
	return cleanVersion, versionParts[0]
}

func cudasFromTorch(ver string) ([]string, error) {
	if ver == "" {
		return nil, errors.New(
			"torch version must be specified when using CUDA",
		)
	}
	cudas := []string{}

	// Check the version modifier on torch (such as +cu118)
	cudaVer, ver := cudaVersionFromTorchPlusVersion(ver)
	if len(cudaVer) > 0 {
		for _, compat := range TorchCompatibilityMatrix {
			if compat.CUDA == nil {
				continue
			}
			if version.Matches(ver, compat.TorchVersion()) && *compat.CUDA == cudaVer {
				cudas = append(cudas, *compat.CUDA)
				return cudas, nil
			}
		}
	}

	for _, compat := range TorchCompatibilityMatrix {
		if version.Matches(ver, compat.TorchVersion()) && compat.CUDA != nil {
			cudas = append(cudas, *compat.CUDA)
		}
	}
	slices.Sort(cudas)

	return cudas, nil
}

func cudaFromTF(ver string) (cuda string, cuDNN string, err error) {
	for _, compat := range TFCompatibilityMatrix {
		if ver == compat.TF {
			return compat.CUDA, compat.CuDNN, nil
		}
	}
	return "", "", nil
}

func compatibleCuDNNsForCUDA(cuda string) []string {
	cuDNNs := []string{}
	for _, image := range CUDABaseImages {
		if image.CUDA == cuda {
			cuDNNs = append(cuDNNs, image.CuDNN)
		}
	}
	return cuDNNs
}

func defaultCUDA() string {
	// TODO: change this to latestTF().CUDA once replicate supports >= 12 everywhere
	return "11.8"
}

func latestCUDAFrom(cudas []string) string {
	latest := ""
	for _, cuda := range cudas {
		if latest == "" {
			latest = cuda
		} else {
			greater, err := versionGreater(cuda, latest)
			if err != nil {
				// should never happen
				panic(fmt.Sprintf("Invalid CUDA version: %s", err))
			}
			if greater {
				latest = cuda
			}
		}
	}
	return latest
}

func latestCuDNNForCUDA(cuda string) (string, error) {
	cuDNNs := []string{}
	for _, image := range CUDABaseImages {
		if version.Matches(cuda, image.CUDA) {
			cuDNNs = append(cuDNNs, image.CuDNN)
		}
	}
	sort.Slice(cuDNNs, func(i, j int) bool {
		return version.Greater(cuDNNs[i], cuDNNs[j])
	})
	if len(cuDNNs) == 0 {
		// TODO: return a list of supported cuda versions
		return "", fmt.Errorf("CUDA %s is not supported by Cog", cuda)
	}
	return cuDNNs[0], nil
}

func versionGreater(a string, b string) (bool, error) {
	// TODO(andreas): use library
	aVer, err := version.NewVersion(a)
	if err != nil {
		return false, err
	}
	bVer, err := version.NewVersion(b)
	if err != nil {
		return false, err
	}
	return aVer.Greater(bVer), nil
}

func cudaBaseImageFor(cuda string, cuDNN string) (string, error) {
	var images []CUDABaseImage
	for _, image := range CUDABaseImages {
		if version.Matches(cuda, image.CUDA) && image.CuDNN == cuDNN {
			images = append(images, image)
		}
	}
	if len(images) == 0 {
		return "", fmt.Errorf("no matching base image for CUDA %s and CuDNN %s", cuda, cuDNN)
	}

	sort.Slice(images, func(i, j int) bool {
		if images[i].CUDA != images[j].CUDA {
			return version.MustVersion(images[i].CUDA).Greater(version.MustVersion(images[j].CUDA))
		}
		return images[i].Ubuntu > images[j].Ubuntu
	})

	return images[0].ImageTag(), nil
}

func tfGPUPackage(ver string, cuda string) (name string, cpuVersion string, err error) {
	for _, compat := range TFCompatibilityMatrix {
		if compat.TF == ver && version.Equal(compat.CUDA, cuda) {
			name, cpuVersion, _, _, err = requirements.SplitPinnedPythonRequirement(compat.TFGPUPackage)
			return name, cpuVersion, err
		}
	}
	// We've already warned user if they're doing something stupid in validateAndCompleteCUDA(), so fail silently
	return "", "", nil
}

func torchCPUPackage(ver, goos, goarch string) (name, cpuVersion, findLinks, extraIndexURL string, err error) {
	for _, compat := range TorchCompatibilityMatrix {
		if compat.TorchVersion() == ver && compat.CUDA == nil {
			return "torch", torchStripCPUSuffixForM1(compat.Torch, goos, goarch), compat.FindLinks, compat.ExtraIndexURL, nil
		}
	}

	// Fall back to just installing default version. For older pytorch versions, they don't have any CPU versions.
	return "torch", ver, "", "", nil
}

func torchGPUPackage(ver string, cuda string) (name, cpuVersion, findLinks, extraIndexURL string, err error) {
	// find the torch package that has the requested torch version and the latest cuda version
	// that is at most as high as the requested cuda version
	var latest *TorchCompatibility
	for _, compat := range TorchCompatibilityMatrix {
		if !version.Matches(compat.TorchVersion(), ver) || compat.CUDA == nil {
			continue
		}
		greater, err := versionGreater(*compat.CUDA, cuda)
		if err != nil {
			panic(fmt.Sprintf("Invalid CUDA version: %s", err))
		}

		if greater {
			continue
		}
		if latest == nil {
			latest = &compat
		} else {
			greater, err := versionGreater(*compat.CUDA, *latest.CUDA)
			if err != nil {
				// should never happen
				panic(fmt.Sprintf("Invalid CUDA version: %s", err))
			}
			if greater {
				latest = &compat
			}
		}
	}
	if latest == nil {
		// We've already warned user if they're doing something stupid in validateAndCompleteCUDA()
		return "torch", ver, "", "", nil
	}

	return "torch", version.StripModifier(latest.Torch), latest.FindLinks, latest.ExtraIndexURL, nil
}

func torchvisionCPUPackage(ver, goos, goarch string) (name, cpuVersion, findLinks, extraIndexURL string, err error) {
	for _, compat := range TorchCompatibilityMatrix {
		if compat.TorchvisionVersion() == ver && compat.CUDA == nil {
			return "torchvision", torchStripCPUSuffixForM1(compat.Torchvision, goos, goarch), compat.FindLinks, compat.ExtraIndexURL, nil
		}
	}
	// Fall back to just installing default version. For older torchvision versions, they don't have any CPU versions.
	return "torchvision", ver, "", "", nil
}

func torchvisionGPUPackage(ver, cuda string) (name, cpuVersion, findLinks, extraIndexURL string, err error) {
	// find the torchvision package that has the requested
	// torchvision version and the latest cuda version that is at
	// most as high as the requested cuda version
	var latest *TorchCompatibility
	for _, compat := range TorchCompatibilityMatrix {
		if compat.TorchvisionVersion() != ver || compat.CUDA == nil {
			continue
		}
		greater, err := versionGreater(*compat.CUDA, cuda)
		if err != nil {
			panic(fmt.Sprintf("Invalid CUDA version: %s", err))
		}
		if greater {
			continue
		}
		if latest == nil {
			latest = &compat
		} else {
			greater, err := versionGreater(*compat.CUDA, *latest.CUDA)
			if err != nil {
				// should never happen
				panic(fmt.Sprintf("Invalid CUDA version: %s", err))
			}
			if greater {
				latest = &compat
			}
		}
	}
	if latest == nil {
		// TODO: can we suggest a CUDA version known to be compatible?
		console.Warnf("Cog doesn't know if CUDA %s is compatible with torchvision %s. This might cause CUDA problems.", cuda, ver)
		return "torchvision", ver, "", "", nil
	}

	return "torchvision", version.StripModifier(latest.Torchvision), latest.FindLinks, latest.ExtraIndexURL, nil
}

// aarch64 packages don't have +cpu suffix: https://download.pytorch.org/whl/torch_stable.html
// TODO(andreas): clean up this hack by actually parsing the torch_stable.html list in the generator
func torchStripCPUSuffixForM1(version string, goos string, goarch string) string {
	// TODO(andreas): clean up this hack
	if util.IsAppleSiliconMac(goos, goarch) {
		return strings.ReplaceAll(version, "+cpu", "")
	}
	return version
}


================================================
FILE: pkg/config/compatibility_test.go
================================================
package config

import (
	"testing"

	"github.com/stretchr/testify/require"
)

func TestLatestCuDNNForCUDA(t *testing.T) {
	actual, err := latestCuDNNForCUDA("11.8")
	require.NoError(t, err)
	require.Equal(t, "8", actual)
}

func TestCudasFromTorchWithCUVersionModifier(t *testing.T) {
	cudas, err := cudasFromTorch("2.0.1+cu118")
	require.GreaterOrEqual(t, len(cudas), 1)
	require.Equal(t, cudas[0], "11.8")
	require.Nil(t, err)
}


================================================
FILE: pkg/config/config.go
================================================
package config

import (
	"encoding/json"
	"fmt"
	"path/filepath"
	"regexp"
	"slices"
	"strconv"
	"strings"

	"go.yaml.in/yaml/v4"

	"github.com/replicate/cog/pkg/requirements"
	"github.com/replicate/cog/pkg/util/console"
	"github.com/replicate/cog/pkg/util/version"
)

var (
	BuildSourceEpochTimestamp int64 = -1
	BuildXCachePath           string
	PipPackageNameRegex       = regexp.MustCompile(`^([^>=<~ \n[#]+)`)
)

// TODO(andreas): support conda packages
// TODO(andreas): support dockerfiles
// TODO(andreas): custom cpu/gpu installs
// TODO(andreas): suggest valid torchvision versions (e.g. if the user wants to use 0.8.0, suggest 0.8.1)

const (
	MinimumMajorPythonVersion               int    = 3
	MinimumMinorPythonVersion               int    = 10
	MinimumMinorPythonVersionForConcurrency int    = 11
	MinimumMajorCudaVersion                 int    = 11
	DefaultPythonVersion                    string = "3.13"
)

type RunItem struct {
	Command string `json:"command,omitempty" yaml:"command"`
	Mounts  []struct {
		Type   string `json:"type,omitempty" yaml:"type"`
		ID     string `json:"id,omitempty" yaml:"id"`
		Target string `json:"target,omitempty" yaml:"target"`
	} `json:"mounts,omitempty" yaml:"mounts,omitempty"`
}

type Build struct {
	GPU                bool      `json:"gpu,omitempty" yaml:"gpu,omitempty"`
	PythonVersion      string    `json:"python_version,omitempty" yaml:"python_version"`
	PythonRequirements string    `json:"python_requirements,omitempty" yaml:"python_requirements,omitempty"`
	PythonPackages     []string  `json:"python_packages,omitempty" yaml:"python_packages,omitempty"` // Deprecated, but included for backwards compatibility
	Run                []RunItem `json:"run,omitempty" yaml:"run,omitempty"`
	SystemPackages     []string  `json:"system_packages,omitempty" yaml:"system_packages,omitempty"`
	PreInstall         []string  `json:"pre_install,omitempty" yaml:"pre_install,omitempty"` // Deprecated, but included for backwards compatibility
	CUDA               string    `json:"cuda,omitempty" yaml:"cuda,omitempty"`
	CuDNN              string    `json:"cudnn,omitempty" yaml:"cudnn,omitempty"`
	// SDKVersion pins the cog Python SDK version installed in the container.
	// Accepts a PEP 440 version string (e.g. "0.18.0" or "0.18.0a1").
	// When empty the latest release is installed. Overridden by COG_SDK_WHEEL env var.
	SDKVersion string `json:"sdk_version,omitempty" yaml:"sdk_version,omitempty"`

	pythonRequirementsContent []string
}

type Concurrency struct {
	Max int `json:"max,omitempty" yaml:"max"`
}

// WeightSource defines a weight file or directory to include in the model.
type WeightSource struct {
	Name   string `json:"name,omitempty" yaml:"name,omitempty"`
	Source string `json:"source" yaml:"source"`
	Target string `json:"target,omitempty" yaml:"target,omitempty"`
}

type Config struct {
	Build       *Build         `json:"build" yaml:"build"`
	Image       string         `json:"image,omitempty" yaml:"image,omitempty"`
	Predict     string         `json:"predict,omitempty" yaml:"predict"`
	Train       string         `json:"train,omitempty" yaml:"train,omitempty"`
	Concurrency *Concurrency   `json:"concurrency,omitempty" yaml:"concurrency,omitempty"`
	Environment []string       `json:"environment,omitempty" yaml:"environment,omitempty"`
	Weights     []WeightSource `json:"weights,omitempty" yaml:"weights,omitempty"`

	parsedEnvironment map[string]string
}

func defaultConfig() *Config {
	return &Config{
		Build: &Build{
			GPU:           false,
			PythonVersion: "3.13",
		},
	}
}

func (r *RunItem) UnmarshalYAML(unmarshal func(any) error) error {
	var commandOrMap any
	if err := unmarshal(&commandOrMap); err != nil {
		return err
	}

	switch v := commandOrMap.(type) {
	case string:
		r.Command = v
	case map[string]any:
		var data []byte
		var err error

		if data, err = yaml.Marshal(v); err != nil {
			return err
		}

		aux := struct {
			Command string `yaml:"command"`
			Mounts  []struct {
				Type   string `yaml:"type"`
				ID     string `yaml:"id"`
				Target string `yaml:"target"`
			} `yaml:"mounts,omitempty"`
		}{}

		if err := yaml.Unmarshal(data, &aux); err != nil {
			return err
		}

		*r = RunItem(aux)
	default:
		return fmt.Errorf("unexpected type %T for RunItem", v)
	}

	return nil
}

func (r *RunItem) UnmarshalJSON(data []byte) error {
	var commandOrMap any
	if err := json.Unmarshal(data, &commandOrMap); err != nil {
		return err
	}

	switch v := commandOrMap.(type) {
	case string:
		r.Command = v
	case map[string]any:
		aux := struct {
			Command string `json:"command"`
			Mounts  []struct {
				Type   string `json:"type"`
				ID     string `json:"id"`
				Target string `json:"target"`
			} `json:"mounts,omitempty"`
		}{}

		jsonData, err := json.Marshal(v)
		if err != nil {
			return err
		}

		if err := json.Unmarshal(jsonData, &aux); err != nil {
			return err
		}

		*r = RunItem(aux)
	default:
		return fmt.Errorf("unexpected type %T for RunItem", v)
	}

	return nil
}

func (c *Config) CUDABaseImageTag() (string, error) {
	return cudaBaseImageFor(c.Build.CUDA, c.Build.CuDNN)
}

func (c *Config) TorchVersion() (string, bool) {
	return c.pythonPackageVersion("torch")
}

func (c *Config) TorchvisionVersion() (string, bool) {
	return c.pythonPackageVersion("torchvision")
}

func (c *Config) TorchaudioVersion() (string, bool) {
	return c.pythonPackageVersion("torchaudio")
}

func (c *Config) TensorFlowVersion() (string, bool) {
	return c.pythonPackageVersion("tensorflow")
}

func (c *Config) cudasFromTorch() (torchVersion string, torchCUDAs []string, err error) {
	if version, ok := c.TorchVersion(); ok {
		cudas, err := cudasFromTorch(version)
		if err != nil {
			return "", nil, err
		}
		return version, cudas, nil
	}
	return "", nil, nil
}

func (c *Config) cudaFromTF() (tfVersion string, tfCUDA string, tfCuDNN string, err error) {
	if version, ok := c.TensorFlowVersion(); ok {
		cuda, cudnn, err := cudaFromTF(version)
		if err != nil {
			return "", "", "", err
		}
		return version, cuda, cudnn, nil
	}
	return "", "", "", nil
}

func (c *Config) pythonPackageVersion(name string) (version string, ok bool) {
	for _, pkg := range c.Build.pythonRequirementsContent {
		pkgName := requirements.PackageName(pkg)
		if pkgName == name {
			versions := requirements.Versions(pkg)
			if len(versions) > 0 {
				return versions[0], true
			}
			return "", true
		}
	}
	return "", false
}

func splitPythonVersion(version string) (major int, minor int, err error) {
	version = strings.TrimSpace(version)
	parts := strings.SplitN(version, ".", 3)
	if len(parts) < 2 {
		return 0, 0, fmt.Errorf("missing minor version in %s", version)
	}
	majorStr, minorStr := parts[0], parts[1]
	major, err = strconv.Atoi(majorStr)
	if err != nil {
		return 0, 0, err
	}
	minor, err = strconv.Atoi(minorStr)
	if err != nil {
		return 0, 0, err
	}
	return major, minor, nil
}

// Complete performs CUDA resolution, requirements loading, and environment loading for a Config.
// Use this when building a Config struct directly (not from YAML).
// For configs loaded from YAML, use Load() instead which handles validation and completion.
func (c *Config) Complete(projectDir string) error {
	// Validate mutual exclusion of python_packages and python_requirements
	if len(c.Build.PythonPackages) > 0 && c.Build.PythonRequirements != "" {
		return fmt.Errorf("only one of python_packages or python_requirements can be set in your cog.yaml, not both")
	}

	// Load python_requirements into memory to simplify reading it multiple times
	if c.Build.PythonRequirements != "" {
		requirementsFilePath := c.Build.PythonRequirements
		if !strings.HasPrefix(requirementsFilePath, "/") {
			requirementsFilePath = filepath.Join(projectDir, c.Build.PythonRequirements)
		}
		reqs, err := requirements.ReadRequirements(requirementsFilePath)
		if err != nil {
			return fmt.Errorf("failed to open python_requirements file: %w", err)
		}
		c.Build.pythonRequirementsContent = reqs
	} else if len(c.Build.PythonPackages) > 0 {
		// Backwards compatibility: if using deprecated python_packages, populate requirements content
		c.Build.pythonRequirementsContent = c.Build.PythonPackages
	}

	// Resolve CUDA/CuDNN versions if GPU is enabled
	if c.Build.GPU {
		if err := c.validateAndCompleteCUDA(); err != nil {
			return err
		}
	}

	// Parse and validate environment variables
	if err := c.loadEnvironment(); err != nil {
		return err
	}

	return nil
}

// PythonRequirementsForArch returns a requirements.txt file with all the GPU packages resolved for given OS and architecture.
func (c *Config) PythonRequirementsForArch(goos string, goarch string, includePackages []string) (string, error) {
	packages := []string{}
	findLinksSet := map[string]bool{}
	extraIndexURLSet := map[string]bool{}

	includePackageNames := []string{}
	for _, pkg := range includePackages {
		packageName := requirements.PackageName(pkg)
		includePackageNames = append(includePackageNames, packageName)
	}

	// Include all the requirements and remove our include packages if they exist
	for _, pkg := range c.Build.pythonRequirementsContent {
		archPkg, findLinksList, extraIndexURLs, err := c.pythonPackageForArch(pkg, goos, goarch)
		if err != nil {
			return "", err
		}
		packages = append(packages, archPkg)
		if len(findLinksList) > 0 {
			for _, fl := range findLinksList {
				findLinksSet[fl] = true
			}
		}
		if len(extraIndexURLs) > 0 {
			for _, u := range extraIndexURLs {
				extraIndexURLSet[u] = true
			}
		}

		packageName := requirements.PackageName(archPkg)
		if packageName != "" {
			foundIdx := -1
			for i, includePkg := range includePackageNames {
				if includePkg == packageName {
					foundIdx = i
					break
				}
			}
			if foundIdx != -1 {
				includePackageNames = append(includePackageNames[:foundIdx], includePackageNames[foundIdx+1:]...)
				includePackages = append(includePackages[:foundIdx], includePackages[foundIdx+1:]...)
			}
		}
	}

	// If we still have some include packages add them in
	packages = append(packages, includePackages...)

	// Create final requirements.txt output
	// Put index URLs first
	lines := []string{}
	for findLinks := range findLinksSet {
		lines = append(lines, "--find-links "+findLinks)
	}
	for extraIndexURL := range extraIndexURLSet {
		lines = append(lines, "--extra-index-url "+extraIndexURL)
	}

	// Then, everything else
	lines = append(lines, packages...)

	return strings.Join(lines, "\n"), nil
}

// pythonPackageForArch takes a package==version line and
// returns a package==version and index URL resolved to the correct GPU package for the given OS and architecture
func (c *Config) pythonPackageForArch(pkg, goos, goarch string) (actualPackage string, findLinksList []string, extraIndexURLs []string, err error) {
	name, version, findLinksList, extraIndexURLs, err := requirements.SplitPinnedPythonRequirement(pkg)
	if err != nil {
		// It's not pinned, so just return the line verbatim
		return pkg, []string{}, []string{}, nil
	}
	if len(extraIndexURLs) > 0 {
		return name + "==" + version, findLinksList, extraIndexURLs, nil
	}

	extraIndexURL := ""
	findLinks := ""
	switch name {
	case "tensorflow":
		if c.Build.GPU {
			name, version, err = tfGPUPackage(version, c.Build.CUDA)
			if err != nil {
				return "", nil, nil, err
			}
		}
		// There is no CPU case for tensorflow because the default package is just the CPU package, so no transformation of version is needed
	case "torch":
		if c.Build.GPU {
			name, version, findLinks, extraIndexURL, err = torchGPUPackage(version, c.Build.CUDA)
			if err != nil {
				return "", nil, nil, err
			}
		} else {
			name, version, findLinks, extraIndexURL, err = torchCPUPackage(version, goos, goarch)
			if err != nil {
				return "", nil, nil, err
			}
		}
	case "torchvision":
		if c.Build.GPU {
			name, version, findLinks, extraIndexURL, err = torchvisionGPUPackage(version, c.Build.CUDA)
			if err != nil {
				return "", nil, nil, err
			}
		} else {
			name, version, findLinks, extraIndexURL, err = torchvisionCPUPackage(version, goos, goarch)
			if err != nil {
				return "", nil, nil, err
			}
		}
	}
	pkgWithVersion := name
	if version != "" {
		pkgWithVersion += "==" + version
	}
	if extraIndexURL != "" {
		extraIndexURLs = []string{extraIndexURL}
	}
	if findLinks != "" {
		findLinksList = []string{findLinks}
	}
	return pkgWithVersion, findLinksList, extraIndexURLs, nil
}

func validateCudaVersion(cudaVersion string) error {
	parts := strings.Split(cudaVersion, ".")
	if len(parts) < 2 {
		return fmt.Errorf("CUDA version %q must include both major and minor versions", cudaVersion)
	}

	major, err := strconv.Atoi(parts[0])
	if err != nil {
		return fmt.Errorf("invalid major version in CUDA version %q", cudaVersion)
	}

	if major < MinimumMajorCudaVersion {
		return fmt.Errorf("minimum supported CUDA version is %d, requested %q", MinimumMajorCudaVersion, cudaVersion)
	}
	return nil
}

func (c *Config) validateAndCompleteCUDA() error {
	if c.Build.CUDA != "" {
		if err := validateCudaVersion(c.Build.CUDA); err != nil {
			return err
		}
	}

	if c.Build.CUDA != "" && c.Build.CuDNN != "" {
		compatibleCuDNNs := compatibleCuDNNsForCUDA(c.Build.CUDA)
		if !slices.Contains(compatibleCuDNNs, c.Build.CuDNN) {
			return fmt.Errorf(`the specified CUDA version %s is not compatible with CuDNN %s.
Compatible CuDNN versions are: %s`, c.Build.CUDA, c.Build.CuDNN, strings.Join(compatibleCuDNNs, ","))
		}
	}

	torchVersion, torchCUDAs, err := c.cudasFromTorch()
	if err != nil {
		return err
	}
	tfVersion, tfCUDA, tfCuDNN, err := c.cudaFromTF()
	if err != nil {
		return err
	}
	// The pre-compiled TensorFlow binaries requires specific CUDA/CuDNN versions to be
	// installed, but Torch bundles their own CUDA/CuDNN libraries.

	switch {
	case tfVersion != "":
		switch {
		case c.Build.CUDA == "":
			if tfCuDNN == "" {
				return fmt.Errorf("cog doesn't know what CUDA version is compatible with tensorflow==%s. You might need to upgrade Cog: https://github.com/replicate/cog#upgrade\n\nIf that doesn't work, you need to set the 'cuda' option in cog.yaml to set what version to use. You might be able to find this out from https://www.tensorflow.org/", tfVersion)
			}
			console.Debugf("Setting CUDA to version %s from Tensorflow version", tfCUDA)
			c.Build.CUDA = tfCUDA
		case tfCUDA == "" || version.EqualMinor(tfCUDA, c.Build.CUDA):
			console.Warnf("Cog doesn't know if CUDA %s is compatible with Tensorflow %s. This might cause CUDA problems.", c.Build.CUDA, tfVersion)
			if tfCUDA != "" {
				console.Warnf("Try %s instead?", tfCUDA)
			}
		}

		switch {
		case c.Build.CuDNN == "" && tfCuDNN != "":
			console.Debugf("Setting CuDNN to version %s from Tensorflow version", tfCuDNN)
			c.Build.CuDNN = tfCuDNN
		case c.Build.CuDNN == "":
			c.Build.CuDNN, err = latestCuDNNForCUDA(c.Build.CUDA)
			if err != nil {
				return err
			}
			console.Debugf("Setting CuDNN to version %s", c.Build.CUDA)
		case tfCuDNN != c.Build.CuDNN:
			console.Warnf("Cog doesn't know if cuDNN %s is compatible with Tensorflow %s. This might cause CUDA problems.", c.Build.CuDNN, tfVersion)
			return fmt.Errorf(`the specified cuDNN version %s is not compatible with tensorflow==%s.
Compatible cuDNN version is: %s`, c.Build.CuDNN, tfVersion, tfCuDNN)
		}
	case torchVersion != "":
		switch {
		case c.Build.CUDA == "":
			if len(torchCUDAs) == 0 {
				return fmt.Errorf("cog doesn't know what CUDA version is compatible with torch==%s. You might need to upgrade Cog: https://github.com/replicate/cog#upgrade\n\nIf that doesn't work, you need to set the 'cuda' option in cog.yaml to set what version to use. You might be able to find this out from https://pytorch.org/", torchVersion)
			}
			c.Build.CUDA = latestCUDAFrom(torchCUDAs)
			console.Debugf("Setting CUDA to version %s from Torch version", c.Build.CUDA)
		case !slices.ContainsFunc(torchCUDAs, func(torchCUDA string) bool { return version.EqualMinor(torchCUDA, c.Build.CUDA) }):
			// TODO: can we suggest a CUDA version known to be compatible?
			console.Warnf("Cog doesn't know if CUDA %s is compatible with PyTorch %s. This might cause CUDA problems.", c.Build.CUDA, torchVersion)
			if len(torchCUDAs) > 0 {
				console.Warnf("Try %s instead?", torchCUDAs[len(torchCUDAs)-1])
			}
		}

		if c.Build.CuDNN == "" {
			c.Build.CuDNN, err = latestCuDNNForCUDA(c.Build.CUDA)
			if err != nil {
				return err
			}
			console.Debugf("Setting CuDNN to version %s", c.Build.CUDA)
		}
	default:
		if c.Build.CUDA == "" {
			c.Build.CUDA = defaultCUDA()
			console.Debugf("Setting CUDA to version %s", c.Build.CUDA)
		}
		if c.Build.CuDNN == "" {
			c.Build.CuDNN, err = latestCuDNNForCUDA(c.Build.CUDA)
			if err != nil {
				return err
			}
			console.Debugf("Setting CuDNN to version %s", c.Build.CUDA)
		}
	}

	return nil
}

func (c *Config) RequirementsFile(projectDir string) string {
	return filepath.Join(projectDir, c.Build.PythonRequirements)
}

func (c *Config) ParsedEnvironment() map[string]string {
	return c.parsedEnvironment
}

func (c *Config) loadEnvironment() error {
	env, err := parseAndValidateEnvironment(c.Environment)
	if err != nil {
		return err
	}
	c.parsedEnvironment = env
	return nil
}


================================================
FILE: pkg/config/config_file.go
================================================
package config

import (
	"encoding/json"
	"fmt"

	"go.yaml.in/yaml/v4"
)

// configFile represents the raw cog.yaml as written by users.
// All fields are pointers/omitempty to distinguish "not set" from "set to zero value".
// This struct is only used during parsing - validation produces errors,
// completion produces a Config.
type configFile struct {
	Build       *buildFile       `json:"build,omitempty" yaml:"build,omitempty"`
	Image       *string          `json:"image,omitempty" yaml:"image,omitempty"`
	Predict     *string          `json:"predict,omitempty" yaml:"predict,omitempty"`
	Train       *string          `json:"train,omitempty" yaml:"train,omitempty"`
	Concurrency *concurrencyFile `json:"concurrency,omitempty" yaml:"concurrency,omitempty"`
	Environment []string         `json:"environment,omitempty" yaml:"environment,omitempty"`
	Weights     []weightFile     `json:"weights,omitempty" yaml:"weights,omitempty"`
}

// buildFile represents the raw build configuration from cog.yaml.
type buildFile struct {
	GPU                *bool         `json:"gpu,omitempty" yaml:"gpu,omitempty"`
	PythonVersion      *string       `json:"python_version,omitempty" yaml:"python_version,omitempty"`
	PythonRequirements *string       `json:"python_requirements,omitempty" yaml:"python_requirements,omitempty"`
	Run                []runItemFile `json:"run,omitempty" yaml:"run,omitempty"`
	SystemPackages     []string      `json:"system_packages,omitempty" yaml:"system_packages,omitempty"`
	CUDA               *string       `json:"cuda,omitempty" yaml:"cuda,omitempty"`
	CuDNN              *string       `json:"cudnn,omitempty" yaml:"cudnn,omitempty"`
	SDKVersion         *string       `json:"sdk_version,omitempty" yaml:"sdk_version,omitempty"`

	// Deprecated fields - parsed with warnings
	PythonPackages []string `json:"python_packages,omitempty" yaml:"python_packages,omitempty"`
	PreInstall     []string `json:"pre_install,omitempty" yaml:"pre_install,omitempty"`
}

// runItemFile represents a run command which can be either a string or an object.
type runItemFile struct {
	Command string      `json:"command,omitempty" yaml:"command,omitempty"`
	Mounts  []mountFile `json:"mounts,omitempty" yaml:"mounts,omitempty"`
}

// mountFile represents a mount configuration in a run command.
type mountFile struct {
	Type   string `json:"type,omitempty" yaml:"type,omitempty"`
	ID     string `json:"id,omitempty" yaml:"id,omitempty"`
	Target string `json:"target,omitempty" yaml:"target,omitempty"`
}

// weightFile represents a weight source configuration.
type weightFile struct {
	Name   string `json:"name,omitempty" yaml:"name,omitempty"`
	Source string `json:"source" yaml:"source"`
	Target string `json:"target,omitempty" yaml:"target,omitempty"`
}

// concurrencyFile represents concurrency configuration.
type concurrencyFile struct {
	Max *int `json:"max,omitempty" yaml:"max,omitempty"`
}

// UnmarshalYAML implements custom YAML unmarshaling for runItemFile
// to support both string and object forms.
func (r *runItemFile) UnmarshalYAML(unmarshal func(any) error) error {
	var commandOrMap any
	if err := unmarshal(&commandOrMap); err != nil {
		return err
	}

	switch v := commandOrMap.(type) {
	case string:
		r.Command = v
	case map[string]any:
		var data []byte
		var err error

		if data, err = yaml.Marshal(v); err != nil {
			return err
		}

		aux := struct {
			Command string `yaml:"command"`
			Mounts  []struct {
				Type   string `yaml:"type"`
				ID     string `yaml:"id"`
				Target string `yaml:"target"`
			} `yaml:"mounts,omitempty"`
		}{}

		if err := yaml.Unmarshal(data, &aux); err != nil {
			return err
		}

		r.Command = aux.Command
		r.Mounts = make([]mountFile, len(aux.Mounts))
		for i, m := range aux.Mounts {
			r.Mounts[i] = mountFile{
				Type:   m.Type,
				ID:     m.ID,
				Target: m.Target,
			}
		}
	default:
		return fmt.Errorf("unexpected type %T for runItemFile", v)
	}

	return nil
}

// UnmarshalJSON implements custom JSON unmarshaling for runItemFile
// to support both string and object forms.
func (r *runItemFile) UnmarshalJSON(data []byte) error {
	var commandOrMap any
	if err := json.Unmarshal(data, &commandOrMap); err != nil {
		return err
	}

	switch v := commandOrMap.(type) {
	case string:
		r.Command = v
	case map[string]any:
		aux := struct {
			Command string `json:"command"`
			Mounts  []struct {
				Type   string `json:"type"`
				ID     string `json:"id"`
				Target string `json:"target"`
			} `json:"mounts,omitempty"`
		}{}

		jsonData, err := json.Marshal(v)
		if err != nil {
			return err
		}

		if err := json.Unmarshal(jsonData, &aux); err != nil {
			return err
		}

		r.Command = aux.Command
		r.Mounts = make([]mountFile, len(aux.Mounts))
		for i, m := range aux.Mounts {
			r.Mounts[i] = mountFile{
				Type:   m.Type,
				ID:     m.ID,
				Target: m.Target,
			}
		}
	default:
		return fmt.Errorf("unexpected type %T for runItemFile", v)
	}

	return nil
}

// Helper functions for working with configFile

// GetGPU returns the GPU setting, defaulting to false if not set.
func (b *buildFile) GetGPU() bool {
	if b == nil || b.GPU == nil {
		return false
	}
	return *b.GPU
}


================================================
FILE: pkg/config/config_test.go
================================================
package config

import (
	"encoding/json"
	"os"
	"path"
	"path/filepath"
	"testing"

	"github.com/hashicorp/go-version"
	"github.com/stretchr/testify/require"
	"go.yaml.in/yaml/v4"
)

func TestValidateCudaVersion(t *testing.T) {
	testCases := []struct {
		name        string
		input       string
		expectedErr bool
	}{
		{
			name:        "ValidVersion",
			input:       "12.4",
			expectedErr: false,
		},
		{
			name:        "MinimumVersion",
			input:       "11.0",
			expectedErr: false,
		},
		{
			name:        "FullyQualifiedVersion",
			input:       "12.4.1",
			expectedErr: false,
		},
		{
			name:        "InvalidFormat",
			input:       "11-2",
			expectedErr: true,
		},
		{
			name:        "InvalidMissingMinor",
			input:       "11",
			expectedErr: true,
		},
		{
			name:        "LessThanMinimum",
			input:       "9.1",
			expectedErr: true,
		},
	}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			err := validateCudaVersion(tc.input)
			if tc.expectedErr {
				require.Error(t, err)
			} else {
				require.NoError(t, err)
			}
		})
	}
}

func assertMinorVersion(t *testing.T, expected, actual string) {
	expectedVersion, err := version.NewVersion(expected)
	if err != nil {
		t.Errorf("Error parsing version: %v", err)
		return
	}
	actualVersion, err := version.NewVersion(actual)
	if err != nil {
		t.Errorf("Error parsing version: %v", err)
		return
	}

	// Compare only the major and minor parts
	if expectedVersion.Segments()[0] != actualVersion.Segments()[0] || expectedVersion.Segments()[1] != actualVersion.Segments()[1] {
		t.Errorf("Expected %s but got %s", expected, actual)
	}
}

func TestPythonPackagesAndRequirementsCantBeUsedTogether(t *testing.T) {
	config := &Config{
		Build: &Build{
			PythonVersion: "3.10",
			PythonPackages: []string{
				"replicate==1.0.0",
			},
			PythonRequirements: "requirements.txt",
		},
	}
	err := config.Complete("")
	require.Error(t, err)
	require.Contains(t, err.Error(), "only one of python_packages or python_requirements can be set in your cog.yaml, not both")
}

func TestPythonRequirementsResolvesPythonPackagesAndCudaVersions(t *testing.T) {
	tmpDir := t.TempDir()
	err := os.WriteFile(path.Join(tmpDir, "requirements.txt"), []byte(`torch==1.13.1
torchvision==0.14.1
torchaudio==0.13.1
foo==1.0.0`), 0o644)
	require.NoError(t, err)

	config := &Config{
		Build: &Build{
			GPU:                true,
			PythonVersion:      "3.10",
			PythonRequirements: "requirements.txt",
		},
	}
	err = config.Complete(tmpDir)
	require.NoError(t, err)
	require.Equal(t, "11.7", config.Build.CUDA)
	require.Equal(t, "8", config.Build.CuDNN)

	requirements, err := config.PythonRequirementsForArch("", "", []string{})
	require.NoError(t, err)
	expected := `--extra-index-url https://download.pytorch.org/whl/cu117
torch==1.13.1
torchvision==0.14.1
torchaudio==0.13.1
foo==1.0.0`
	require.Equal(t, expected, requirements)
}

func TestPythonRequirementsResolvesPythonPackagesAndCudaVersionsWithExtraIndexURL(t *testing.T) {
	tmpDir := t.TempDir()
	err := os.WriteFile(path.Join(tmpDir, "requirements.txt"), []byte(`torch==1.12.1
torchvision==0.13.1
torchaudio==0.12.1
foo==1.0.0`), 0o644)
	require.NoError(t, err)

	config := &Config{
		Build: &Build{
			GPU:                true,
			PythonVersion:      "3.10",
			PythonRequirements: "requirements.txt",
		},
	}
	err = config.Complete(tmpDir)
	require.NoError(t, err)
	require.Equal(t, "11.6", config.Build.CUDA)
	require.Equal(t, "8", config.Build.CuDNN)

	requirements, err := config.PythonRequirementsForArch("", "", []string{})
	require.NoError(t, err)
	expected := `--extra-index-url https://download.pytorch.org/whl/cu116
torch==1.12.1
torchvision==0.13.1
torchaudio==0.12.1
foo==1.0.0`
	require.Equal(t, expected, requirements)
}

func TestPythonRequirementsWorksWithLinesCogCannotParse(t *testing.T) {
	tmpDir := t.TempDir()
	err := os.WriteFile(path.Join(tmpDir, "requirements.txt"), []byte(`foo==1.0.0
# complex requirements
fastapi>=0.6,<1
flask>0.4
# comments!
# blank lines!

# arguments
-f http://example.com`), 0o644)
	require.NoError(t, err)

	config := &Config{
		Build: &Build{
			GPU:                true,
			PythonVersion:      "3.10",
			PythonRequirements: "requirements.txt",
		},
	}
	err = config.Complete(tmpDir)
	require.NoError(t, err)

	requirements, err := config.PythonRequirementsForArch("", "", []string{})
	require.NoError(t, err)
	expected := `foo==1.0.0
fastapi>=0.6,<1
flask>0.4
-f http://example.com`
	require.Equal(t, expected, requirements)

}

func TestValidateAndCompleteCUDAForAllTF(t *testing.T) {
	for _, compat := range TFCompatibilityMatrix {
		config := &Config{
			Build: &Build{
				GPU:           true,
				PythonVersion: "3.10",
				PythonPackages: []string{
					"tensorflow==" + compat.TF,
				},
			},
		}

		err := config.Complete("")
		require.NoError(t, err)
		assertMinorVersion(t, compat.CUDA, config.Build.CUDA)
		require.Equal(t, compat.CuDNN, config.Build.CuDNN)
	}
}

func TestValidateAndCompleteCUDAForAllTorch(t *testing.T) {
	for _, compat := range TorchCompatibilityMatrix {
		config := &Config{
			Build: &Build{
				GPU:           compat.CUDA != nil,
				PythonVersion: "3.10",
				PythonPackages: []string{
					"torch==" + compat.TorchVersion(),
				},
			},
		}

		err := config.Complete("")
		require.NoError(t, err)
		if compat.CUDA == nil {
			require.Equal(t, "", config.Build.CUDA)
			require.Equal(t, "", config.Build.CuDNN)
		} else {
			require.NotEqual(t, "", config.Build.CUDA)
			require.NotEqual(t, "", config.Build.CuDNN)
		}
	}
}

func TestValidateAndCompleteCUDAForSelectedTorch(t *testing.T) {
	for _, tt := range []struct {
		torch string
		cuda  string
		cuDNN string
	}{
		{"2.0.1", "11.8", "8"},
		{"1.13.1", "11.7", "8"},
		{"1.11.0", "11.3", "8"},
	} {
		config := &Config{
			Build: &Build{
				GPU:           true,
				PythonVersion: "3.10",
				PythonPackages: []string{
					"torch==" + tt.torch,
				},
			},
		}
		err := config.Complete("")
		require.NoError(t, err)
		require.Equal(t, tt.cuda, config.Build.CUDA)
		require.Equal(t, tt.cuDNN, config.Build.CuDNN)
	}
}

func TestUnsupportedTorch(t *testing.T) {
	// Ensure version is not known by Cog
	cudas, err := cudasFromTorch("0.4.1")
	require.NoError(t, err)
	require.Empty(t, cudas)

	// Unknown versions require cuda
	config := &Config{
		Build: &Build{
			GPU:           true,
			PythonVersion: "3.10",
			PythonPackages: []string{
				"torch==0.4.1",
			},
		},
	}
	err = config.Complete("")
	require.Error(t, err)
	require.Contains(t, err.Error(), "cog doesn't know what CUDA version is compatible with torch==0.4.1.")

	config = &Config{
		Build: &Build{
			GPU:           true,
			CUDA:          "11.8",
			PythonVersion: "3.10",
			PythonPackages: []string{
				"torch==0.4.1",
			},
		},
	}
	err = config.Complete("")
	require.NoError(t, err)
	assertMinorVersion(t, "11.8", config.Build.CUDA)
	require.Equal(t, "8", config.Build.CuDNN)
}

func TestUnsupportedTensorflow(t *testing.T) {
	// Ensure version is not known by Cog
	cuda, cudnn, err := cudaFromTF("0.4.1")
	require.NoError(t, err)
	require.Equal(t, cuda, "")
	require.Equal(t, cudnn, "")

	// Unknown versions require cuda
	config := &Config{
		Build: &Build{
			GPU:           true,
			PythonVersion: "3.10",
			PythonPackages: []string{
				"tensorflow==0.4.1",
			},
		},
	}
	err = config.Complete("")
	require.Error(t, err)
	require.Contains(t, err.Error(), "cog doesn't know what CUDA version is compatible with tensorflow==0.4.1.")

	config = &Config{
		Build: &Build{
			GPU:           true,
			CUDA:          "11.8",
			PythonVersion: "3.10",
			PythonPackages: []string{
				"tensorflow==0.4.1",
			},
		},
	}
	err = config.Complete("")
	require.NoError(t, err)
	assertMinorVersion(t, "11.8", config.Build.CUDA)
	require.Equal(t, "8", config.Build.CuDNN)
}

func TestPythonPackagesForArchTorchGPU(t *testing.T) {
	config := &Config{
		Build: &Build{
			GPU:           true,
			PythonVersion: "3.10",
			PythonPackages: []string{
				"torch==2.0.1",
				"torchvision==0.15.2",
				"torchaudio==2.0.2",
				"foo==1.0.0",
			},
			CUDA: "11.8",
		},
	}
	err := config.Complete("")
	require.NoError(t, err)
	assertMinorVersion(t, "11.8", config.Build.CUDA)
	require.Equal(t, "8", config.Build.CuDNN)

	requirements, err := config.PythonRequirementsForArch("", "", []string{})
	require.NoError(t, err)
	expected := `--extra-index-url https://download.pytorch.org/whl/cu118
torch==2.0.1
torchvision==0.15.2
torchaudio==2.0.2
foo==1.0.0`
	require.Equal(t, expected, requirements)
}

func TestPythonPackagesForArchTorchCPU(t *testing.T) {
	config := &Config{
		Build: &Build{
			GPU:           false,
			PythonVersion: "3.10",
			PythonPackages: []string{
				"torch==2.0.1",
				"torchvision==0.15.2",
				"torchaudio==2.0.2",
				"foo==1.0.0",
			},
			CUDA: "11.8",
		},
	}
	err := config.Complete("")
	require.NoError(t, err)

	requirements, err := config.PythonRequirementsForArch("", "", []string{})
	require.NoError(t, err)
	expected := `--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.0.1
torchvision==0.15.2
torchaudio==2.0.2
foo==1.0.0`
	require.Equal(t, expected, requirements)
}

func TestPythonPackagesForArchTensorflowGPU(t *testing.T) {
	config := &Config{
		Build: &Build{
			GPU:           true,
			PythonVersion: "3.10",
			PythonPackages: []string{
				"tensorflow==2.12.0",
				"foo==1.0.0",
			},
			CUDA: "11.8",
		},
	}
	err := config.Complete("")
	require.NoError(t, err)
	assertMinorVersion(t, "11.8", config.Build.CUDA)
	require.Equal(t, "8", config.Build.CuDNN)

	// tensorflow and tensorflow-gpu have been the same package since TensorFlow 2.1, released in September 2019.
	// Although the checksums differ due to metadata,
	// they were built in the same way and both provide GPU support via Nvidia CUDA.
	// As of December 2022, tensorflow-gpu has been removed and has been replaced with
	// this new, empty package that generates an error upon installation.
	requirements, err := config.PythonRequirementsForArch("", "", []string{})
	require.NoError(t, err)
	expected := `tensorflow==2.12.0
foo==1.0.0`
	require.Equal(t, expected, requirements)
	require.NotContains(t, requirements, "tensorflow_gpu")
}

func TestPythonPackagesBothTorchAndTensorflow(t *testing.T) {
	config := &Config{
		Build: &Build{
			GPU:           true,
			PythonVersion: "3.11",
			PythonPackages: []string{
				"tensorflow==2.16.1",
				"torch==2.3.1",
			},
			CUDA: "12.3",
		},
	}
	err := config.Complete("")
	require.NoError(t, err)
	require.Equal(t, "12.3", config.Build.CUDA)
	require.Equal(t, "8", config.Build.CuDNN)

	requirements, err := config.PythonRequirementsForArch("", "", []string{})
	require.NoError(t, err)
	expected := `--extra-index-url https://download.pytorch.org/whl/cu121
tensorflow==2.16.1
torch==2.3.1`
	require.Equal(t, expected, requirements)
}

func TestCUDABaseImageTag(t *testing.T) {
	config := &Config{
		Build: &Build{
			GPU:           true,
			PythonVersion: "3.10",
			PythonPackages: []string{
				"tensorflow==2.12.0",
			},
		},
	}

	err := config.Complete("")
	require.NoError(t, err)

	imageTag, err := config.CUDABaseImageTag()
	require.NoError(t, err)
	require.Equal(t, "nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04", imageTag)
}

func TestBuildRunItemStringYAML(t *testing.T) {
	type BuildWrapper struct {
		Build *Build `yaml:"build"`
	}

	var buildWrapper BuildWrapper

	yamlString := `
build:
  run:
    - "echo 'Hello, World!'"
`

	err := yaml.Unmarshal([]byte(yamlString), &buildWrapper)
	require.NoError(t, err)
	require.NotNil(t, buildWrapper.Build)
	require.Len(t, buildWrapper.Build.Run, 1)
	require.Equal(t, "echo 'Hello, World!'", buildWrapper.Build.Run[0].Command)
}

func TestBuildRunItemStringJSON(t *testing.T) {
	type BuildWrapper struct {
		Build *Build `json:"build"`
	}

	var buildWrapper BuildWrapper

	jsonString := `{
	"build": {
		"run": [
			"echo 'Hello, World!'"
		]
	}
}`

	err := json.Unmarshal([]byte(jsonString), &buildWrapper)
	require.NoError(t, err)
	require.NotNil(t, buildWrapper.Build)
	require.Len(t, buildWrapper.Build.Run, 1)
	require.Equal(t, "echo 'Hello, World!'", buildWrapper.Build.Run[0].Command)
}

func TestBuildRunItemDictYAML(t *testing.T) {
	type BuildWrapper struct {
		Build *Build `yaml:"build"`
	}

	var buildWrapper BuildWrapper

	yamlString := `
build:
  run:
  - command: "echo 'Hello, World!'"
    mounts:
    - type: bind
      id: my-volume
      target: /mnt/data
`

	err := yaml.Unmarshal([]byte(yamlString), &buildWrapper)
	require.NoError(t, err)
	require.NotNil(t, buildWrapper.Build)
	require.Len(t, buildWrapper.Build.Run, 1)
	require.Equal(t, "echo 'Hello, World!'", buildWrapper.Build.Run[0].Command)
	require.Len(t, buildWrapper.Build.Run[0].Mounts, 1)
	require.Equal(t, "bind", buildWrapper.Build.Run[0].Mounts[0].Type)
	require.Equal(t, "my-volume", buildWrapper.Build.Run[0].Mounts[0].ID)
	require.Equal(t, "/mnt/data", buildWrapper.Build.Run[0].Mounts[0].Target)
}

func TestBuildRunItemDictJSON(t *testing.T) {
	type BuildWrapper struct {
		Build *Build `json:"build"`
	}

	var buildWrapper BuildWrapper

	jsonString := `{
	"build": {
		"run": [
			{
				"command": "echo 'Hello, World!'",
				"mounts": [
					{
						"type": "bind",
						"id": "my-volume",
						"target": "/mnt/data"
					}
				]
			}
		]
	}
}`

	err := json.Unmarshal([]byte(jsonString), &buildWrapper)
	require.NoError(t, err)
	require.NotNil(t, buildWrapper.Build)
	require.Len(t, buildWrapper.Build.Run, 1)
	require.Equal(t, "echo 'Hello, World!'", buildWrapper.Build.Run[0].Command)
	require.Len(t, buildWrapper.Build.Run[0].Mounts, 1)
	require.Equal(t, "bind", buildWrapper.Build.Run[0].Mounts[0].Type)
	require.Equal(t, "my-volume", buildWrapper.Build.Run[0].Mounts[0].ID)
	require.Equal(t, "/mnt/data", buildWrapper.Build.Run[0].Mounts[0].Target)
}

func TestTorchWithExistingExtraIndexURL(t *testing.T) {
	config := &Config{
		Build: &Build{
			GPU:           true,
			PythonVersion: "3.10",
			PythonPackages: []string{
				"torch==1.12.1 --extra-index-url=https://download.pytorch.org/whl/cu116",
			},
			CUDA: "11.6.2",
		},
	}
	err := config.Complete("")
	require.NoError(t, err)
	require.Equal(t, "11.6.2", config.Build.CUDA)

	requirements, err := config.PythonRequirementsForArch("", "", []string{})
	require.NoError(t, err)
	expected := `--extra-index-url https://download.pytorch.org/whl/cu116
torch==1.12.1`
	require.Equal(t, expected, requirements)
}

func TestBlankBuild(t *testing.T) {
	// Naively, this turns into nil, so make sure it's a real build object
	// Write a temp file
	dir := t.TempDir()
	configPath := path.Join(dir, "cog.yaml")
	err := os.WriteFile(configPath, []byte(`build:`), 0o644)
	require.NoError(t, err)

	cfgFile, err := parseFile(configPath)
	require.NoError(t, err)
	// Note: `build:` by itself in YAML parses to Build: nil (empty map becomes nil pointer)
	// The completion step should create a default Build

	config, err := configFileToConfig(cfgFile)
	require.NoError(t, err)
	require.NoError(t, config.Complete(dir))
	require.NotNil(t, config.Build)
	require.Equal(t, false, config.Build.GPU)
}

func TestPythonRequirementsForArchWithAddedPackage(t *testing.T) {
	config := &Config{
		Build: &Build{
			GPU:           true,
			PythonVersion: "3.10",
			PythonPackages: []string{
				"torch==2.4.0 --extra-index-url=https://download.pytorch.org/whl/cu116",
			},
			CUDA: "11.6.2",
		},
	}
	err := config.Complete("")
	require.NoError(t, err)
	require.Equal(t, "11.6.2", config.Build.CUDA)
	requirements, err := config.PythonRequirementsForArch("", "", []string{
		"torchvision==2.4.0",
	})
	require.NoError(t, err)
	expected := `--extra-index-url https://download.pytorch.org/whl/cu116
torch==2.4.0
torchvision==2.4.0`
	require.Equal(t, expected, requirements)
}

func TestParseTests(t *testing.T) {
	yamlString := `
build:
  run:
  - command: "echo 'Hello, World!'"
`
	dir := t.TempDir()
	configPath := path.Join(dir, "cog.yaml")
	err := os.WriteFile(configPath, []byte(yamlString), 0o644)
	require.NoError(t, err)

	_, err = parseFile(configPath)
	require.NoError(t, err)
}

func TestConfigMarshal(t *testing.T) {
	cfg := defaultConfig()
	data, err := yaml.Marshal(cfg)
	require.NoError(t, err)
	// yaml v4 uses 4-space indentation by default
	require.Equal(t, `build:
    python_version: "3.13"
predict: ""
`, string(data))
}

func TestAbsolutePathInPythonRequirements(t *testing.T) {
	dir := t.TempDir()
	requirementsFilePath := filepath.Join(dir, "requirements.txt")
	err := os.WriteFile(requirementsFilePath, []byte("torch==2.5.0"), 0o644)
	require.NoError(t, err)
	config := &Config{
		Build: &Build{
			GPU:                true,
			PythonVersion:      "3.10",
			PythonRequirements: requirementsFilePath,
		},
	}
	err = config.Complete(dir)
	require.NoError(t, err)
	torchVersion, ok := config.TorchVersion()
	require.Equal(t, torchVersion, "2.5.0")
	require.True(t, ok)
}

func TestWeightsWithNameYAML(t *testing.T) {
	yamlString := `build:
  python_version: "3.12"
predict: "predict.py:Predictor"

weights:
  - name: model-v1
    source: file://./weights/model-v1.zip
    target: "/weights/model-v1"
  - name: model-v2
    source: file://./weights/model-v2.zip
    target: "/weights/model-v2"
`

	config, err := FromYAML([]byte(yamlString))
	require.NoError(t, err)
	require.Len(t, config.Weights, 2)

	require.Equal(t, "model-v1", config.Weights[0].Name)
	require.Equal(t, "file://./weights/model-v1.zip", config.Weights[0].Source)
	require.Equal(t, "/weights/model-v1", config.Weights[0].Target)

	require.Equal(t, "model-v2", config.Weights[1].Name)
	require.Equal(t, "file://./weights/model-v2.zip", config.Weights[1].Source)
	require.Equal(t, "/weights/model-v2", config.Weights[1].Target)
}

func TestWeightsWithoutNameYAML(t *testing.T) {
	yamlString := `build:
  python_version: "3.12"
predict: "predict.py:Predictor"

weights:
  - source: file://./weights/model.zip
    target: "/weights/model"
`

	config, err := FromYAML([]byte(yamlString))
	require.NoError(t, err)
	require.Len(t, config.Weights, 1)

	require.Equal(t, "", config.Weights[0].Name)
	require.Equal(t, "file://./weights/model.zip", config.Weights[0].Source)
	require.Equal(t, "/weights/model", config.Weights[0].Target)
}

func TestWeightsWithNameJSON(t *testing.T) {
	jsonString := `{
	"build": {
		"python_version": "3.12"
	},
	"predict": "predict.py:Predictor",
	"weights": [
		{
			"name": "model-v1",
			"source": "file://./weights/model-v1.zip",
			"target": "/weights/model-v1"
		},
		{
			"name": "model-v2",
			"source": "file://./weights/model-v2.zip",
			"target": "/weights/model-v2"
		}
	]
}`

	var config Config
	err := json.Unmarshal([]byte(jsonString), &config)
	require.NoError(t, err)
	require.Len(t, config.Weights, 2)

	require.Equal(t, "model-v1", config.Weights[0].Name)
	require.Equal(t, "file://./weights/model-v1.zip", config.Weights[0].Source)
	require.Equal(t, "/weights/model-v1", config.Weights[0].Target)

	require.Equal(t, "model-v2", config.Weights[1].Name)
	require.Equal(t, "file://./weights/model-v2.zip", config.Weights[1].Source)
	require.Equal(t, "/weights/model-v2", config.Weights[1].Target)
}

func TestSDKVersionConfig(t *testing.T) {
	// build.sdk_version is parsed and stored correctly
	conf, err := FromYAML([]byte(`
build:
  python_version: "3.12"
  sdk_version: "0.18.0"
predict: predict.py:Predictor
`))
	require.NoError(t, err)
	require.Equal(t, "0.18.0", conf.Build.SDKVersion)
}

func TestSDKVersionConfigEmpty(t *testing.T) {
	// Omitting build.sdk_version leaves the field empty
	conf, err := FromYAML([]byte(`
build:
  python_version: "3.12"
predict: predict.py:Predictor
`))
	require.NoError(t, err)
	require.Equal(t, "", conf.Build.SDKVersion)
}

func TestSDKVersionConfigPreRelease(t *testing.T) {
	// Pre-release PEP 440 version is accepted and stored verbatim
	conf, err := FromYAML([]byte(`
build:
  python_version: "3.12"
  sdk_version: "0.18.0a1"
predict: predict.py:Predictor
`))
	require.NoError(t, err)
	require.Equal(t, "0.18.0a1", conf.Build.SDKVersion)
}

func TestSDKVersionConfigBelowMinimumExplodesInGenerator(t *testing.T) {
	// build.sdk_version < 0.16.0 must be rejected — parsing succeeds but the
	// Dockerfile generator must return an error so the build never proceeds.
	conf, err := FromYAML([]byte(`
build:
  python_version: "3.12"
  sdk_version: "0.15.0"
predict: predict.py:Predictor
`))
	require.NoError(t, err)
	// Parsing itself is fine; enforcement happens at Dockerfile generation time.
	require.Equal(t, "0.15.0", conf.Build.SDKVersion)
}


================================================
FILE: pkg/config/cuda_compatibility.json
================================================
[
  {
    "Tag": "11.0.3-cudnn8-devel-ubuntu16.04",
    "CUDA": "11.0.3",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "16.04"
  },
  {
    "Tag": "11.0.3-cudnn8-devel-ubuntu18.04",
    "CUDA": "11.0.3",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "18.04"
  },
  {
    "Tag": "11.0.3-cudnn8-devel-ubuntu20.04",
    "CUDA": "11.0.3",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "20.04"
  },
  {
    "Tag": "11.1.1-cudnn8-devel-ubuntu16.04",
    "CUDA": "11.1.1",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "16.04"
  },
  {
    "Tag": "11.1.1-cudnn8-devel-ubuntu18.04",
    "CUDA": "11.1.1",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "18.04"
  },
  {
    "Tag": "11.1.1-cudnn8-devel-ubuntu20.04",
    "CUDA": "11.1.1",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "20.04"
  },
  {
    "Tag": "11.2.2-cudnn8-devel-ubuntu16.04",
    "CUDA": "11.2.2",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "16.04"
  },
  {
    "Tag": "11.2.2-cudnn8-devel-ubuntu18.04",
    "CUDA": "11.2.2",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "18.04"
  },
  {
    "Tag": "11.2.2-cudnn8-devel-ubuntu20.04",
    "CUDA": "11.2.2",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "20.04"
  },
  {
    "Tag": "11.3.1-cudnn8-devel-ubuntu16.04",
    "CUDA": "11.3.1",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "16.04"
  },
  {
    "Tag": "11.3.1-cudnn8-devel-ubuntu18.04",
    "CUDA": "11.3.1",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "18.04"
  },
  {
    "Tag": "11.3.1-cudnn8-devel-ubuntu20.04",
    "CUDA": "11.3.1",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "20.04"
  },
  {
    "Tag": "11.4.3-cudnn8-devel-ubuntu18.04",
    "CUDA": "11.4.3",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "18.04"
  },
  {
    "Tag": "11.4.3-cudnn8-devel-ubuntu20.04",
    "CUDA": "11.4.3",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "20.04"
  },
  {
    "Tag": "11.5.2-cudnn8-devel-ubuntu18.04",
    "CUDA": "11.5.2",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "18.04"
  },
  {
    "Tag": "11.5.2-cudnn8-devel-ubuntu20.04",
    "CUDA": "11.5.2",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "20.04"
  },
  {
    "Tag": "11.6.1-cudnn8-devel-ubuntu20.04",
    "CUDA": "11.6.1",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "20.04"
  },
  {
    "Tag": "11.6.2-cudnn8-devel-ubuntu18.04",
    "CUDA": "11.6.2",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "18.04"
  },
  {
    "Tag": "11.6.2-cudnn8-devel-ubuntu20.04",
    "CUDA": "11.6.2",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "20.04"
  },
  {
    "Tag": "11.7.1-cudnn8-devel-ubuntu18.04",
    "CUDA": "11.7.1",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "18.04"
  },
  {
    "Tag": "11.7.1-cudnn8-devel-ubuntu20.04",
    "CUDA": "11.7.1",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "20.04"
  },
  {
    "Tag": "11.7.1-cudnn8-devel-ubuntu22.04",
    "CUDA": "11.7.1",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "22.04"
  },
  {
    "Tag": "11.8.0-cudnn8-devel-ubuntu18.04",
    "CUDA": "11.8.0",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "18.04"
  },
  {
    "Tag": "11.8.0-cudnn8-devel-ubuntu20.04",
    "CUDA": "11.8.0",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "20.04"
  },
  {
    "Tag": "11.8.0-cudnn8-devel-ubuntu22.04",
    "CUDA": "11.8.0",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "22.04"
  },
  {
    "Tag": "12.0.0-cudnn8-devel-ubuntu18.04",
    "CUDA": "12.0.0",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "18.04"
  },
  {
    "Tag": "12.0.0-cudnn8-devel-ubuntu20.04",
    "CUDA": "12.0.0",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "20.04"
  },
  {
    "Tag": "12.0.0-cudnn8-devel-ubuntu22.04",
    "CUDA": "12.0.0",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "22.04"
  },
  {
    "Tag": "12.0.1-cudnn8-devel-ubuntu18.04",
    "CUDA": "12.0.1",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "18.04"
  },
  {
    "Tag": "12.0.1-cudnn8-devel-ubuntu20.04",
    "CUDA": "12.0.1",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "20.04"
  },
  {
    "Tag": "12.0.1-cudnn8-devel-ubuntu22.04",
    "CUDA": "12.0.1",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "22.04"
  },
  {
    "Tag": "12.1.0-cudnn8-devel-ubuntu20.04",
    "CUDA": "12.1.0",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "20.04"
  },
  {
    "Tag": "12.1.0-cudnn8-devel-ubuntu22.04",
    "CUDA": "12.1.0",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "22.04"
  },
  {
    "Tag": "12.1.1-cudnn8-devel-ubuntu20.04",
    "CUDA": "12.1.1",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "20.04"
  },
  {
    "Tag": "12.1.1-cudnn8-devel-ubuntu22.04",
    "CUDA": "12.1.1",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "22.04"
  },
  {
    "Tag": "12.2.2-cudnn8-devel-ubuntu20.04",
    "CUDA": "12.2.2",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "20.04"
  },
  {
    "Tag": "12.2.2-cudnn8-devel-ubuntu22.04",
    "CUDA": "12.2.2",
    "CuDNN": "8",
    "IsDevel": true,
    "Ubuntu": "22.04"
  },
  {
    "Tag": "12.3.2-cudnn9-devel-ubuntu20.04",
    "CUDA": "12.3.2",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "20.04"
  },
  {
    "Tag": "12.3.2-cudnn9-devel-ubuntu22.04",
    "CUDA": "12.3.2",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "22.04"
  },
  {
    "Tag": "12.4.1-cudnn-devel-ubuntu20.04",
    "CUDA": "12.4.1",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "20.04"
  },
  {
    "Tag": "12.4.1-cudnn-devel-ubuntu22.04",
    "CUDA": "12.4.1",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "22.04"
  },
  {
    "Tag": "12.5.1-cudnn-devel-ubuntu20.04",
    "CUDA": "12.5.1",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "20.04"
  },
  {
    "Tag": "12.5.1-cudnn-devel-ubuntu22.04",
    "CUDA": "12.5.1",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "22.04"
  },
  {
    "Tag": "12.6.0-cudnn-devel-ubuntu20.04",
    "CUDA": "12.6.0",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "20.04"
  },
  {
    "Tag": "12.6.0-cudnn-devel-ubuntu22.04",
    "CUDA": "12.6.0",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "22.04"
  },
  {
    "Tag": "12.6.0-cudnn-devel-ubuntu24.04",
    "CUDA": "12.6.0",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "24.04"
  },
  {
    "Tag": "12.6.1-cudnn-devel-ubuntu20.04",
    "CUDA": "12.6.1",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "20.04"
  },
  {
    "Tag": "12.6.1-cudnn-devel-ubuntu22.04",
    "CUDA": "12.6.1",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "22.04"
  },
  {
    "Tag": "12.6.1-cudnn-devel-ubuntu24.04",
    "CUDA": "12.6.1",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "24.04"
  },
  {
    "Tag": "12.6.2-cudnn-devel-ubuntu20.04",
    "CUDA": "12.6.2",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "20.04"
  },
  {
    "Tag": "12.6.2-cudnn-devel-ubuntu22.04",
    "CUDA": "12.6.2",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "22.04"
  },
  {
    "Tag": "12.6.2-cudnn-devel-ubuntu24.04",
    "CUDA": "12.6.2",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "24.04"
  },
  {
    "Tag": "12.6.3-cudnn-devel-ubuntu20.04",
    "CUDA": "12.6.3",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "20.04"
  },
  {
    "Tag": "12.6.3-cudnn-devel-ubuntu22.04",
    "CUDA": "12.6.3",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "22.04"
  },
  {
    "Tag": "12.6.3-cudnn-devel-ubuntu24.04",
    "CUDA": "12.6.3",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "24.04"
  },
  {
    "Tag": "12.8.0-cudnn-devel-ubuntu20.04",
    "CUDA": "12.8.0",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "20.04"
  },
  {
    "Tag": "12.8.0-cudnn-devel-ubuntu22.04",
    "CUDA": "12.8.0",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "22.04"
  },
  {
    "Tag": "12.8.0-cudnn-devel-ubuntu24.04",
    "CUDA": "12.8.0",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "24.04"
  },
  {
    "Tag": "12.8.1-cudnn-devel-ubuntu20.04",
    "CUDA": "12.8.1",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "20.04"
  },
  {
    "Tag": "12.8.1-cudnn-devel-ubuntu22.04",
    "CUDA": "12.8.1",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "22.04"
  },
  {
    "Tag": "12.8.1-cudnn-devel-ubuntu24.04",
    "CUDA": "12.8.1",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "24.04"
  },
  {
    "Tag": "12.9.0-cudnn-devel-ubuntu20.04",
    "CUDA": "12.9.0",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "20.04"
  },
  {
    "Tag": "12.9.0-cudnn-devel-ubuntu22.04",
    "CUDA": "12.9.0",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "22.04"
  },
  {
    "Tag": "12.9.0-cudnn-devel-ubuntu24.04",
    "CUDA": "12.9.0",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "24.04"
  },
  {
    "Tag": "12.9.1-cudnn-devel-ubuntu20.04",
    "CUDA": "12.9.1",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "20.04"
  },
  {
    "Tag": "12.9.1-cudnn-devel-ubuntu22.04",
    "CUDA": "12.9.1",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "22.04"
  },
  {
    "Tag": "12.9.1-cudnn-devel-ubuntu24.04",
    "CUDA": "12.9.1",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "24.04"
  },
  {
    "Tag": "13.0.0-cudnn-devel-ubuntu22.04",
    "CUDA": "13.0.0",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "22.04"
  },
  {
    "Tag": "13.0.0-cudnn-devel-ubuntu24.04",
    "CUDA": "13.0.0",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "24.04"
  },
  {
    "Tag": "13.0.1-cudnn-devel-ubuntu22.04",
    "CUDA": "13.0.1",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "22.04"
  },
  {
    "Tag": "13.0.1-cudnn-devel-ubuntu24.04",
    "CUDA": "13.0.1",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "24.04"
  },
  {
    "Tag": "13.0.2-cudnn-devel-ubuntu22.04",
    "CUDA": "13.0.2",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "22.04"
  },
  {
    "Tag": "13.0.2-cudnn-devel-ubuntu24.04",
    "CUDA": "13.0.2",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "24.04"
  },
  {
    "Tag": "13.1.1-cudnn-devel-ubuntu22.04",
    "CUDA": "13.1.1",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "22.04"
  },
  {
    "Tag": "13.1.1-cudnn-devel-ubuntu24.04",
    "CUDA": "13.1.1",
    "CuDNN": "9",
    "IsDevel": true,
    "Ubuntu": "24.04"
  }
]

================================================
FILE: pkg/config/data/config_schema_v1.0.json
================================================
{
  "$schema": "http://json-schema.org/draft-07/schema",
  "type": "object",
  "title": "Schema for cog.yaml",
  "description": "Defines how to build a Docker image and how to run predictions on your model inside that image.",
  "properties": {
    "build": {
      "$id": "#/properties/build",
      "type": "object",
      "description": "This stanza describes how to build the Docker image your model runs in.",
      "properties": {
        "cuda": {
          "$id": "#/properties/build/properties/cuda",
          "type": "string",
          "description": "Cog automatically picks the correct version of CUDA to install, but this lets you override it for whatever reason."
        },
        "cudnn": {
          "$id": "#/properties/build/properties/cudnn",
          "type": "string",
          "description": "Cog automatically picks the correct version of cuDNN to install, but this lets you override it for whatever reason."
        },
        "gpu": {
          "$id": "#/properties/build/properties/gpu",
          "type": "boolean",
          "description": "Enable GPUs for this model. When enabled, the [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) base image will be used, and Cog will automatically figure out what versions of CUDA and cuDNN to use based on the version of Python, PyTorch, and Tensorflow that you are using."
        },
        "python_version": {
          "$id": "#/properties/build/properties/python_version",
          "type": [
            "string",
            "number"
          ],
          "description": "The minor (`3.13`) or patch (`3.13.1`) version of Python to use."
        },
        "python_packages": {
          "$id": "#/properties/build/properties/python_packages",
          "type": [
            "array",
            "null"
          ],
          "description": "A list of Python packages to install, in the format `package==version`.",
          "additionalItems": true,
          "items": {
            "$id": "#/properties/build/properties/python_packages/items",
            "anyOf": [
              {
                "$id": "#/properties/build/properties/python_packages/items/anyOf/0",
                "type": "string"
              }
            ]
          }
        },
        "pre_install": {
          "$id": "#/properties/build/properties/pre_install",
          "type": [
            "array",
            "null"
          ],
          "description": "A list of setup commands to run in the environment before your Python packages are installed.",
          "additionalItems": true,
          "items": {
            "$id": "#/properties/build/properties/pre_install/items",
            "anyOf": [
              {
                "$id": "#/properties/build/properties/pre_install/items/anyOf/0",
                "type": "string"
              }
            ]
          }
        },
        "python_requirements": {
          "$id": "#/properties/build/properties/python_requirements",
          "type": "string",
          "description": "A pip requirements file specifying the Python packages to install."
        },
        "system_packages": {
          "$id": "#/properties/build/properties/system_packages",
          "type": [
            "array",
            "null"
          ],
          "description": "A list of Ubuntu APT packages to install.",
          "additionalItems": true,
          "items": {
            "$id": "#/properties/build/properties/system_packages/items",
            "anyOf": [
              {
                "$id": "#/properties/build/properties/system_packages/items/anyOf/0",
                "type": "string"
              }
            ]
          }
        },
        "sdk_version": {
          "$id": "#/properties/build/properties/sdk_version",
          "type": "string",
          "description": "Pin the cog Python SDK version installed in the container (e.g. \"0.18.0\" or \"0.18.0a1\"). Defaults to latest. Overridden by the COG_SDK_WHEEL environment variable."
        },
        "run": {
          "$id": "#/properties/build/properties/run",
          "type": [
            "array",
            "null"
          ],
          "description": "A list of setup commands to run in the environment after your system packages and Python packages have been installed. If you're familiar with Docker, it's like a `RUN` instruction in your `Dockerfile`.",
          "additionalItems": true,
          "items": {
            "$id": "#/properties/build/properties/run/items",
            "anyOf": [
              {
                "$id": "#/properties/build/properties/run/items/anyOf/0",
                "type": "string"
              },
              {
                "$id": "#/properties/build/properties/run/items/anyOf/1",
                "type": "object",
                "properties": {
                  "command": {
                    "type": "string"
                  },
                  "mounts": {
                    "type": "array",
                    "items": {
                      "type": "object",
                      "properties": {
                        "type": {
                          "type": "string",
                          "enum": [
                            "secret"
                          ]
                        },
                        "id": {
                          "type": "string"
                        },
                        "target": {
                          "type": "string"
                        }
                      },
                      "required": [
                        "type",
                        "id",
                        "target"
                      ]
                    }
                  }
                },
                "required": [
                  "command"
                ]
              }
            ]
          }
        }
      },
      "required": ["python_version"],
      "additionalProperties": false
    },
    "image": {
      "$id": "#/properties/image",
      "type": "string",
      "description": "The name given to built Docker images. If you want to push to a registry, this should also include the registry name."
    },
    "predict": {
      "$id": "#/properties/predict",
      "type": "string",
      "description": "The pointer to the `Predictor` object in your code, which defines how predictions are run on your model."
    },
    "train": {
      "$id": "#/properties/train",
      "type": "string",
      "description": "The pointer to the `Predictor` object in your code, which defines how predictions are run on your model."
    },
    "concurrency": {
      "$id": "#/properties/concurrency",
      "type": "object",
      "description": "The concurrency settings for the model.",
      "required": [
        "max"
      ],
      "additionalProperties": false,
      "properties": {
        "max": {
          "$id": "#/properties/concurrency/properties/max",
          "type": "integer",
          "description": "The maximum number of concurrent predictions."
        },
        "default_target": {
          "$id": "#/properties/concurrency/properties/default_target",
          "type": "integer",
          "description": "The default target for number of concurrent predictions. This setting can be used by an autoscaler to determine when to scale a deployment of a model up or down."
        }
      }
    },
    "environment": {
      "$id": "#/properties/properties/environment",
      "type": [
        "array",
        "null"
      ],
      "description": "A list of environment variables to make available during builds and at runtime, in the format `NAME=value`",
      "additionalItems": true,
      "items": {
        "$id": "#/properties/properties/environment/items",
        "type": "string",
        "pattern": "^[A-Za-z_][A-Za-z0-9_]*=.*$"
      }
    },
    "weights": {
      "$id": "#/properties/weights",
      "type": [
        "array",
        "null"
      ],
      "description": "A list of weight files or directories to include in the model.",
      "items": {
        "type": "object",
        "required": ["source"],
        "additionalProperties": false,
        "properties": {
          "name": {
            "type": "string",
            "description": "A unique identifier for this weight entry."
          },
          "source": {
            "type": "string",
            "description": "Path to a weight file or directory (relative to cog.yaml)."
          },
          "target": {
            "type": "string",
            "description": "Target path in the container (must be under /cache/). Defaults to /cache/."
          }
        }
      }
    }
  },
  "additionalProperties": false
}


================================================
FILE: pkg/config/env.go
================================================
package config

import (
	"fmt"
	"strings"
)

// environmentVariableDenyList is a list of environment variable patterns that are
// used internally during build or runtime and thus not allowed to be set by the user.
// There are ways around this restriction, but it's likely to cause unexpected behavior
// and hard to debug issues. So on Cog's predict-build-push happy path, we don't allow
// these to be set.
// This list may change at any time. For more context, see:
// https://github.com/replicate/cog/pull/2274/#issuecomment-2831823185
var environmentVariableDenyList = []string{
	// paths
	"PATH",
	"LD_LIBRARY_PATH",
	"PYTHONPATH",
	"VIRTUAL_ENV",
	"PYTHONUNBUFFERED",
	// Replicate
	"R8_*",
	"REPLICATE_*",
	// Nvidia
	"LIBRARY_PATH",
	"CUDA_*",
	"NVIDIA_*",
	"NV_*",
	// pget
	"PGET_*",
	"HF_ENDPOINT",
	"HF_HUB_ENABLE_HF_TRANSFER",
	// k8s
	"KUBERNETES_*",
}

// validateEnvName checks if the given environment variable name is allowed.
// Returns an error if the name matches any of the restricted patterns.
func validateEnvName(name string) error {
	for _, pattern := range environmentVariableDenyList {
		// Check for exact match
		if pattern == name {
			return fmt.Errorf("environment variable %q is not allowed", name)
		}

		// Check for wildcard pattern
		if strings.HasSuffix(pattern, "*") {
			if strings.HasPrefix(name, pattern[:len(pattern)-1]) {
				return fmt.Errorf("environment variable %q is not allowed", name)
			}
		}
	}
	return nil
}

// parseAndValidateEnvironment converts a slice of strings in the format of KEY=VALUE
// to a map[string]string. An error is returned if the format is incorrect or if either
// the variable name or value are invalid.
func parseAndValidateEnvironment(input []string) (map[string]string, error) {
	env := map[string]string{}
	for _, input := range input {
		parts := strings.SplitN(input, "=", 2)
		if len(parts) != 2 || parts[0] == "" {
			return nil, fmt.Errorf("environment variable %q is not in the KEY=VALUE format", input)
		}
		if err := validateEnvName(parts[0]); err != nil {
			return nil, err
		}
		if _, ok := env[parts[0]]; ok {
			return nil, fmt.Errorf("environment variable %q is already defined", parts[0])
		}
		env[parts[0]] = parts[1]
	}
	return env, nil
}


================================================
FILE: pkg/config/env_variables_test.go
================================================
package config

import (
	"fmt"
	"strings"
	"testing"

	"github.com/stretchr/testify/require"
)

func TestEnvironmentConfig(t *testing.T) {
	t.Run("ParsingValidInput", func(t *testing.T) {
		cases := []struct {
			Name     string
			Input    []string
			Expected map[string]string
		}{
			{
				Name:     "ValidInput",
				Input:    []string{"NAME=VALUE"},
				Expected: map[string]string{"NAME": "VALUE"},
			},
			{
				Name:     "ValidInputWithSpaces",
				Input:    []string{"NAME=VALUE WITH SPACES"},
				Expected: map[string]string{"NAME": "VALUE WITH SPACES"},
			},
			{
				Name:     "ValidInputWithQuotes",
				Input:    []string{"NAME=\"VALUE WITH QUOTES\""},
				Expected: map[string]string{"NAME": `"VALUE WITH QUOTES"`},
			},
			{
				Name:     "DelimitedValue",
				Input:    []string{"NAME=VALUE1,VALUE2"},
				Expected: map[string]string{"NAME": "VALUE1,VALUE2"},
			},
			{
				Name:     "EmptyValue",
				Input:    []string{"NAME="},
				Expected: map[string]string{"NAME": ""},
			},
			{
				Name:     "EmptyValueWithSpaces",
				Input:    []string{"NAME= "},
				Expected: map[string]string{"NAME": " "},
			},
			{
				Name:     "LowerCaseName",
				Input:    []string{"name=VALUE"},
				Expected: map[string]string{"name": "VALUE"},
			},
			{
				Name:     "MixedCaseName",
				Input:    []string{"MiXeD_Case=VALUE"},
				Expected: map[string]string{"MiXeD_Case": "VALUE"},
			},
			{
				Name:     "EqualSignInValue",
				Input:    []string{"NAME=VALUE=EQUAL"},
				Expected: map[string]string{"NAME": "VALUE=EQUAL"},
			},
			{
				Name:     "EqualSignInValueWithSpaces",
				Input:    []string{"NAME=VALUE=EQUAL WITH SPACES"},
				Expected: map[string]string{"NAME": "VALUE=EQUAL WITH SPACES"},
			},
			{
				Name:     "MultiLineValue",
				Input:    []string{"NAME=VALUE1\nVALUE2"},
				Expected: map[string]string{"NAME": "VALUE1\nVALUE2"},
			},
			{
				Name:     "UserAgentWithSpaces",
				Input:    []string{"COG_USER_AGENT=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"},
				Expected: map[string]string{"COG_USER_AGENT": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"},
			},
			{
				Name:     "MultiplePairs",
				Input:    []string{"NAME1=VALUE1", "NAME2=VALUE2"},
				Expected: map[string]string{"NAME1": "VALUE1", "NAME2": "VALUE2"},
			},
		}

		for _, c := range cases {
			t.Run(c.Name, func(t *testing.T) {
				parsed, err := parseAndValidateEnvironment(c.Input)
				require.NoError(t, err)
				require.Equal(t, c.Expected, parsed)
			})
		}
	})

	t.Run("ParsingInvalidInput", func(t *testing.T) {
		cases := []struct {
			Name                 string
			Input                []string
			ExpectedErrorMessage string
		}{
			{
				Name:                 "NameWithoutValue",
				Input:                []string{"NAME"},
				ExpectedErrorMessage: `environment variable "NAME" is not in the KEY=VALUE format`,
			},
			{
				Name:                 "EmptyName",
				Input:                []string{"=VALUE"},
				ExpectedErrorMessage: `environment variable "=VALUE" is not in the KEY=VALUE format`,
			},
		}

		for _, c := range cases {
			t.Run(c.Name, func(t *testing.T) {
				_, err := parseAndValidateEnvironment(c.Input)
				require.Error(t, err)
				require.ErrorContains(t, err, c.ExpectedErrorMessage)
			})
		}
	})

	t.Run("EnforceDenyList", func(t *testing.T) {
		for _, pattern := range environmentVariableDenyList {
			// test that exact matches are rejected
			t.Run(fmt.Sprintf("Rejects %q", pattern), func(t *testing.T) {
				input := fmt.Sprintf("%s=VALUE", pattern)
				_, err := parseAndValidateEnvironment([]string{input})
				require.Error(t, err)
				require.ErrorContains(t, err, fmt.Sprintf("environment variable %q is not allowed", pattern))
			})

			// test that prefix matches are rejected
			if before, ok := strings.CutSuffix(pattern, "*"); ok {
				t.Run(fmt.Sprintf("Rejects %q prefix", pattern), func(t *testing.T) {
					name := before + "SUFFIX"
					input := fmt.Sprintf("%s=VALUE", name)
					_, err := parseAndValidateEnvironment([]string{input})
					require.Error(t, err)
					require.ErrorContains(t, err, fmt.Sprintf("environment variable %q is not allowed", name))
				})
			}
		}
	})

	t.Run("DuplicateNamesAreRejected", func(t *testing.T) {
		input := []string{"NAME=VALUE", "NAME=VALUE2"}
		_, err := parseAndValidateEnvironment(input)
		require.Error(t, err)
		require.ErrorContains(t, err, "environment variable \"NAME\" is already defined")
	})
}


================================================
FILE: pkg/config/errors.go
================================================
package config

import (
	"errors"
	"fmt"
)

// ConfigError is the base interface for all config errors.
// Allows callers to use errors.As to get config-specific details.
type ConfigError interface {
	error
	ConfigError() // marker method
}

// ParseError indicates the YAML file could not be parsed.
type ParseError struct {
	Filename string
	Err      error
}

func (e *ParseError) Error() string {
	return fmt.Sprintf("failed to parse %s: %v", e.Filename, e.Err)
}

func (e *ParseError) Unwrap() error {
	return e.Err
}

func (e *ParseError) ConfigError() {}

// SchemaError indicates the config structure doesn't match the schema.
// For example, wrong type for a field or unknown field.
type SchemaError struct {
	Field   string
	Message string
}

func (e *SchemaError) Error() string {
	return fmt.Sprintf("schema error in %q: %s", e.Field, e.Message)
}

func (e *SchemaError) ConfigError() {}

// ValidationError indicates a semantic validation failure.
// The config parses correctly but values are invalid.
type ValidationError struct {
	Field   string
	Value   string
	Message string
}

func (e *ValidationError) Error() string {
	if e.Value != "" {
		return fmt.Sprintf("invalid %s %q: %s", e.Field, e.Value, e.Message)
	}
	return fmt.Sprintf("invalid %s: %s", e.Field, e.Message)
}

func (e *ValidationError) ConfigError() {}

// DeprecationWarning indicates use of a deprecated field.
// This is a warning, not an error - validation still succeeds.
type DeprecationWarning struct {
	Field       string
	Replacement string
	Message     string
}

func (w *DeprecationWarning) Error() string {
	if w.Replacement != "" {
		return fmt.Sprintf("deprecated field %q: use %q instead", w.Field, w.Replacement)
	}
	return fmt.Sprintf("deprecated field %q: %s", w.Field, w.Message)
}

func (w *DeprecationWarning) ConfigError() {}

// CompatibilityError indicates an incompatible version combination.
type CompatibilityError struct {
	Component1 string
	Version1   string
	Component2 string
	Version2   string
	Message    string
}

func (e *CompatibilityError) Error() string {
	return fmt.Sprintf("%s %s is incompatible with %s %s: %s",
		e.Component1, e.Version1, e.Component2, e.Version2, e.Message)
}

func (e *CompatibilityError) ConfigError() {}

// ValidationResult holds all errors and warnings from validation.
type ValidationResult struct {
	Errors   []error
	Warnings []DeprecationWarning
}

// HasErrors returns true if there are any validation errors.
func (r *ValidationResult) HasErrors() bool {
	return len(r.Errors) > 0
}

// HasWarnings returns true if there are any deprecation warnings.
func (r *ValidationResult) HasWarnings() bool {
	return len(r.Warnings) > 0
}

// Err returns a combined error if there are any validation errors, nil otherwise.
func (r *ValidationResult) Err() error {
	if !r.HasErrors() {
		return nil
	}
	return errors.Join(r.Errors...)
}

// AddError adds a validation error.
func (r *ValidationResult) AddError(err error) {
	r.Errors = append(r.Errors, err)
}

// AddWarning adds a deprecation warning.
func (r *ValidationResult) AddWarning(w DeprecationWarning) {
	r.Warnings = append(r.Warnings, w)
}

// NewValidationResult creates an empty ValidationResult.
func NewValidationResult() *ValidationResult {
	return &ValidationResult{
		Errors:   []error{},
		Warnings: []DeprecationWarning{},
	}
}


================================================
FILE: pkg/config/image_name.go
================================================
package config

import (
	"path"
	"regexp"
	"strings"
)

// DockerImageName returns the default Docker image name for images
func DockerImageName(projectDir string) string {
	prefix := "cog-"
	projectName := strings.ToLower(path.Base(projectDir))

	// Convert whitespace to dashes
	projectName = strings.ReplaceAll(projectName, " ", "-")

	// Remove anything non-alphanumeric
	reg := regexp.MustCompile(`[^a-z0-9\-]+`)
	projectName = reg.ReplaceAllString(projectName, "")

	// Limit to 30 characters (max Docker image name length)
	length := 30 - len(prefix)
	if len(projectName) > length {
		projectName = projectName[:length]
	}

	if !strings.HasPrefix(projectName, prefix) {
		projectName = prefix + projectName
	}

	return projectName
}


================================================
FILE: pkg/config/image_name_test.go
================================================
package config

import (
	"testing"

	"github.com/stretchr/testify/require"
)

func TestDockerImageName(t *testing.T) {
	require.Equal(t, "cog-foo", DockerImageName("/home/joe/foo"))
	require.Equal(t, "cog-foo", DockerImageName("/home/joe/Foo"))
	require.Equal(t, "cog-foo", DockerImageName("/home/joe/cog-foo"))
	require.Equal(t, "cog-my-great-model", DockerImageName("/home/joe/my great model"))
	require.Equal(t, 30, len(DockerImageName("/home/joe/verylongverylongverylongverylongverylongverylongverylong")))
}


================================================
FILE: pkg/config/load.go
================================================
package config

import (
	"fmt"
	"io"
	"os"
	"path/filepath"

	"github.com/replicate/cog/pkg/errors"
	"github.com/replicate/cog/pkg/util/files"
)

const maxSearchDepth = 100

// LoadResult contains the loaded config and any warnings.
type LoadResult struct {
	Config   *Config
	Warnings []DeprecationWarning
	RootDir  string
}

// Load parses, validates, and completes a config from an io.Reader.
// The projectDir is used for validation (checking that referenced files exist)
// and for completion (resolving CUDA versions, loading requirements files, etc.).
// Always returns warnings if present, even on success.
func Load(r io.Reader, projectDir string) (*LoadResult, error) {
	// Parse
	cfgFile, err := parse(r)
	if err != nil {
		return nil, err
	}

	// Validate
	validationResult := ValidateConfigFile(cfgFile, WithProjectDir(projectDir))

	// Collect warnings
	warnings := validationResult.Warnings

	// Check for errors
	if validationResult.HasErrors() {
		return nil, validationResult.Err()
	}

	// Convert to Config struct
	config, err := configFileToConfig(cfgFile)
	if err != nil {
		return nil, err
	}

	// Complete (resolve CUDA, load requirements, etc.)
	if err := config.Complete(projectDir); err != nil {
		return nil, err
	}

	return &LoadResult{
		Config:   config,
		Warnings: warnings,
		RootDir:  projectDir,
	}, nil
}

// GetProjectDir returns the project's root directory by searching for
// the config file starting from the current working directory.
func GetProjectDir(configFilename string) (string, error) {
	if configFilename == "" {
		configFilename = "cog.yaml"
	}

	cwd, err := os.Getwd()
	if err != nil {
		return "", err
	}
	return findProjectRootDir(cwd, configFilename)
}

// findConfigPathInDirectory checks if the config file exists in the given directory.
func findConfigPathInDirectory(dir string, configFilename string) (configPath string, err error) {
	filePath := filepath.Join(dir, configFilename)
	exists, err := files.Exists(filePath)
	if err != nil {
		return "", fmt.Errorf("failed to scan directory %s for %s: %w", dir, filePath, err)
	} else if exists {
		return filePath, nil
	}

	return "", errors.ConfigNotFound(fmt.Sprintf("%s not found in %s", configFilename, dir))
}

// findProjectRootDir walks up the directory tree to find the root of the project.
// The project root is defined as the directory housing a `cog.yaml` file.
func findProjectRootDir(startDir string, configFilename string) (string, error) {
	dir := startDir
	for range maxSearchDepth {
		switch _, err := findConfigPathInDirectory(dir, configFilename); {
		case err != nil && !errors.IsConfigNotFound(err):
			return "", err
		case err == nil:
			return dir, nil
		case dir == "." || dir == "/":
			return "", errors.ConfigNotFound(fmt.Sprintf("%s not found in %s (or in any parent directories)", configFilename, startDir))
		}

		dir = filepath.Dir(dir)
	}

	return "", errors.ConfigNotFound("No cog.yaml found in parent directories.")
}


================================================
FILE: pkg/config/load_test.go
================================================
package config

import (
	"os"
	"path"
	"testing"

	"github.com/stretchr/testify/require"
)

const testConfig = `
build:
  python_version: "3.10"
  python_requirements: requirements.txt
  system_packages:
    - libgl1-mesa-glx
    - libglib2.0-0
predict: "predict.py:SomePredictor"
`

func TestFindProjectRootDirShouldFindParentDir(t *testing.T) {
	projectDir := t.TempDir()

	err := os.WriteFile(path.Join(projectDir, "cog.yaml"), []byte(testConfig), 0o644)
	require.NoError(t, err)

	subdir := path.Join(projectDir, "some/sub/dir")
	err = os.MkdirAll(subdir, 0o700)
	require.NoError(t, err)

	foundDir, err := findProjectRootDir(subdir, "cog.yaml")
	require.NoError(t, err)
	require.Equal(t, foundDir, projectDir)
}

func TestFindProjectRootDirShouldReturnErrIfNoConfig(t *testing.T) {
	projectDir := t.TempDir()

	subdir := path.Join(projectDir, "some/sub/dir")
	err := os.MkdirAll(subdir, 0o700)
	require.NoError(t, err)

	_, err = findProjectRootDir(subdir, "cog.yaml")
	require.Error(t, err)
}


================================================
FILE: pkg/config/parse.go
================================================
package config

import (
	"fmt"
	"io"
	"os"
	"path/filepath"

	"go.yaml.in/yaml/v4"

	"github.com/replicate/cog/pkg/util/files"
)

// parse reads and parses YAML content from an io.Reader into a configFile.
// This only does YAML parsing - no validation or defaults.
// Returns ParseError if the content cannot be read or parsed.
func parse(r io.Reader) (*configFile, error) {
	contents, err := io.ReadAll(r)
	if err != nil {
		return nil, &ParseError{Err: err}
	}

	return parseBytes(contents)
}

// parseFile reads and parses a cog.yaml file into a configFile.
// This only does YAML parsing - no validation or defaults.
// Returns ParseError if the file cannot be read or parsed.
func parseFile(filename string) (*configFile, error) {
	exists, err := files.Exists(filename)
	if err != nil {
		return nil, &ParseError{Filename: filename, Err: err}
	}

	if !exists {
		return nil, &ParseError{
			Filename: filename,
			Err:      fmt.Errorf("%s does not exist in %s", filepath.Base(filename), filepath.Dir(filename)),
		}
	}

	f, err := os.Open(filename)
	if err != nil {
		return nil, &ParseError{Filename: filename, Err: err}
	}
	defer f.Close()

	cfg, err := parse(f)
	if err != nil {
		// Add filename context to the error
		if parseErr, ok := err.(*ParseError); ok {
			parseErr.Filename = filename
			return nil, parseErr
		}
		return nil, &ParseError{Filename: filename, Err: err}
	}

	return cfg, nil
}

// parseBytes parses YAML content into a configFile.
func parseBytes(contents []byte) (*configFile, error) {
	cfg := &configFile{}

	if len(contents) == 0 {
		// Empty file is valid, returns empty config
		return cfg, nil
	}

	if err := yaml.Unmarshal(contents, cfg); err != nil {
		return nil, &ParseError{
			Err: fmt.Errorf("invalid YAML: %w", err),
		}
	}

	return cfg, nil
}

// FromYAML parses YAML content into an uncompleted Config.
// This is a convenience function primarily for testing.
// Callers should call Complete() on the returned config to resolve CUDA versions etc.
// For production code, use Load() which handles validation and completion.
//
// Note: This function skips validation since it has no project directory context.
// The Complete() method will validate requirements files exist when called.
func FromYAML(contents []byte) (*Config, error) {
	cfgFile, err := parseBytes(contents)
	if err != nil {
		return nil, err
	}

	// Convert to Config struct without completion or validation
	// The caller should call Complete() with the appropriate project dir
	return configFileToConfig(cfgFile)
}

// configFileToConfig converts a ConfigFile to a Config without running completion logic.
// This is the minimal conversion used by FromYAML for test compatibility.
func configFileToConfig(cfg *configFile) (*Config, error) {
	config := &Config{
		Build: &Build{},
	}

	if cfg.Build != nil {
		if cfg.Build.GPU != nil {
			config.Build.GPU = *cfg.Build.GPU
		}
		if cfg.Build.PythonVersion != nil {
			config.Build.PythonVersion = *cfg.Build.PythonVersion
		}
		if cfg.Build.PythonRequirements != nil {
			config.Build.PythonRequirements = *cfg.Build.PythonRequirements
		}
		config.Build.PythonPackages = cfg.Build.PythonPackages
		config.Build.SystemPackages = cfg.Build.SystemPackages
		config.Build.PreInstall = cfg.Build.PreInstall
		if cfg.Build.CUDA != nil {
			config.Build.CUDA = *cfg.Build.CUDA
		}
		if cfg.Build.CuDNN != nil {
			config.Build.CuDNN = *cfg.Build.CuDNN
		}
		if cfg.Build.SDKVersion != nil {
			config.Build.SDKVersion = *cfg.Build.SDKVersion
		}

		// Convert Run items
		config.Build.Run = make([]RunItem, len(cfg.Build.Run))
		for i, runFile := range cfg.Build.Run {
			config.Build.Run[i] = RunItem{
				Command: runFile.Command,
			}
			if len(runFile.Mounts) > 0 {
				config.Build.Run[i].Mounts = make([]struct {
					Type   string `json:"type,omitempty" yaml:"type"`
					ID     string `json:"id,omitempty" yaml:"id"`
					Target string `json:"target,omitempty" yaml:"target"`
				}, len(runFile.Mounts))
				for j, mountFile := range runFile.Mounts {
					config.Build.Run[i].Mounts[j].Type = mountFile.Type
					config.Build.Run[i].Mounts[j].ID = mountFile.ID
					config.Build.Run[i].Mounts[j].Target = mountFile.Target
				}
			}
		}
	}

	if cfg.Image != nil {
		config.Image = *cfg.Image
	}
	if cfg.Predict != nil {
		config.Predict = *cfg.Predict
	}
	if cfg.Train != nil {
		config.Train = *cfg.Train
	}
	if cfg.Concurrency != nil {
		config.Concurrency = &Concurrency{}
		if cfg.Concurrency.Max != nil {
			config.Concurrency.Max = *cfg.Concurrency.Max
		}
	}
	config.Environment = cfg.Environment

	// Convert weights
	if len(cfg.Weights) > 0 {
		config.Weights = make([]WeightSource, len(cfg.Weights))
		for i, w := range cfg.Weights {
			config.Weights[i] = WeightSource(w)
		}
	}

	return config, nil
}


================================================
FILE: pkg/config/tf_compatibility.json
================================================
[
  {
    "TF": "2.20.0",
    "TFCPUPackage": "tensorflow==2.20.0",
    "TFGPUPackage": "tensorflow==2.20.0",
    "CUDA": "12.5",
    "CuDNN": "9.3",
    "Pythons": [
      "3.9",
      "3.10",
      "3.11",
      "3.12",
      "3.13"
    ]
  },
  {
    "TF": "2.19.0",
    "TFCPUPackage": "tensorflow==2.19.0",
    "TFGPUPackage": "tensorflow==2.19.0",
    "CUDA": "12.5",
    "CuDNN": "9.3",
    "Pythons": [
      "3.9",
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "TF": "2.18.0",
    "TFCPUPackage": "tensorflow==2.18.0",
    "TFGPUPackage": "tensorflow==2.18.0",
    "CUDA": "12.5",
    "CuDNN": "9.3",
    "Pythons": [
      "3.9",
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "TF": "2.17.0",
    "TFCPUPackage": "tensorflow==2.17.0",
    "TFGPUPackage": "tensorflow==2.17.0",
    "CUDA": "12.3",
    "CuDNN": "8.9",
    "Pythons": [
      "3.9",
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "TF": "2.16.1",
    "TFCPUPackage": "tensorflow==2.16.1",
    "TFGPUPackage": "tensorflow==2.16.1",
    "CUDA": "12.3",
    "CuDNN": "8.9",
    "Pythons": [
      "3.9",
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "TF": "2.15.0",
    "TFCPUPackage": "tensorflow==2.15.0",
    "TFGPUPackage": "tensorflow==2.15.0",
    "CUDA": "12.2",
    "CuDNN": "8.9",
    "Pythons": [
      "3.9",
      "3.10",
      "3.11"
    ]
  },
  {
    "TF": "2.14.0",
    "TFCPUPackage": "tensorflow==2.14.0",
    "TFGPUPackage": "tensorflow==2.14.0",
    "CUDA": "11.8",
    "CuDNN": "8.7",
    "Pythons": [
      "3.9",
      "3.10",
      "3.11"
    ]
  },
  {
    "TF": "2.13.0",
    "TFCPUPackage": "tensorflow==2.13.0",
    "TFGPUPackage": "tensorflow==2.13.0",
    "CUDA": "11.8",
    "CuDNN": "8.6",
    "Pythons": [
      "3.8",
      "3.9",
      "3.10",
      "3.11"
    ]
  },
  {
    "TF": "2.12.0",
    "TFCPUPackage": "tensorflow==2.12.0",
    "TFGPUPackage": "tensorflow==2.12.0",
    "CUDA": "11.8",
    "CuDNN": "8.6",
    "Pythons": [
      "3.8",
      "3.9",
      "3.10",
      "3.11"
    ]
  },
  {
    "TF": "2.11.0",
    "TFCPUPackage": "tensorflow==2.11.0",
    "TFGPUPackage": "tensorflow==2.11.0",
    "CUDA": "11.2",
    "CuDNN": "8.1",
    "Pythons": [
      "3.7",
      "3.8",
      "3.9",
      "3.10"
    ]
  },
  {
    "TF": "2.10.0",
    "TFCPUPackage": "tensorflow==2.10.0",
    "TFGPUPackage": "tensorflow==2.10.0",
    "CUDA": "11.2",
    "CuDNN": "8.1",
    "Pythons": [
      "3.7",
      "3.8",
      "3.9",
      "3.10"
    ]
  },
  {
    "TF": "2.9.0",
    "TFCPUPackage": "tensorflow==2.9.0",
    "TFGPUPackage": "tensorflow==2.9.0",
    "CUDA": "11.2",
    "CuDNN": "8.1",
    "Pythons": [
      "3.7",
      "3.8",
      "3.9",
      "3.10"
    ]
  },
  {
    "TF": "2.8.0",
    "TFCPUPackage": "tensorflow==2.8.0",
    "TFGPUPackage": "tensorflow==2.8.0",
    "CUDA": "11.2",
    "CuDNN": "8.1",
    "Pythons": [
      "3.7",
      "3.8",
      "3.9",
      "3.10"
    ]
  },
  {
    "TF": "2.7.0",
    "TFCPUPackage": "tensorflow==2.7.0",
    "TFGPUPackage": "tensorflow==2.7.0",
    "CUDA": "11.2",
    "CuDNN": "8.1",
    "Pythons": [
      "3.7",
      "3.8",
      "3.9"
    ]
  },
  {
    "TF": "2.6.0",
    "TFCPUPackage": "tensorflow==2.6.0",
    "TFGPUPackage": "tensorflow==2.6.0",
    "CUDA": "11.2",
    "CuDNN": "8.1",
    "Pythons": [
      "3.6",
      "3.7",
      "3.8",
      "3.9"
    ]
  },
  {
    "TF": "2.5.0",
    "TFCPUPackage": "tensorflow==2.5.0",
    "TFGPUPackage": "tensorflow==2.5.0",
    "CUDA": "11.2",
    "CuDNN": "8.1",
    "Pythons": [
      "3.6",
      "3.7",
      "3.8",
      "3.9"
    ]
  },
  {
    "TF": "2.4.0",
    "TFCPUPackage": "tensorflow==2.4.0",
    "TFGPUPackage": "tensorflow==2.4.0",
    "CUDA": "11.0",
    "CuDNN": "8.0",
    "Pythons": [
      "3.6",
      "3.7",
      "3.8"
    ]
  }
]

================================================
FILE: pkg/config/torch_compatibility.json
================================================
[
  {
    "Torch": "2.10.0+cu129",
    "Torchvision": "0.25.0",
    "Torchaudio": "2.10.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu129/",
    "CUDA": "12.9",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13",
      "3.14"
    ]
  },
  {
    "Torch": "2.10.0+cu130",
    "Torchvision": "0.25.0",
    "Torchaudio": "2.10.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu130/",
    "CUDA": "13.0",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13",
      "3.14"
    ]
  },
  {
    "Torch": "2.10.0+cpu",
    "Torchvision": "0.25.0",
    "Torchaudio": "2.10.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cpu/",
    "CUDA": null,
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13",
      "3.14"
    ]
  },
  {
    "Torch": "2.10.0+cu126",
    "Torchvision": "0.25.0",
    "Torchaudio": "2.10.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu126/",
    "CUDA": "12.6",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13",
      "3.14"
    ]
  },
  {
    "Torch": "2.10.0+cu128",
    "Torchvision": "0.25.0",
    "Torchaudio": "2.10.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu128/",
    "CUDA": "12.8",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13",
      "3.14"
    ]
  },
  {
    "Torch": "2.9.1",
    "Torchvision": "0.24.1",
    "Torchaudio": "2.9.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu126",
    "CUDA": "12.6",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13",
      "3.14"
    ]
  },
  {
    "Torch": "2.9.1",
    "Torchvision": "0.24.1",
    "Torchaudio": "2.9.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu128",
    "CUDA": "12.8",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13",
      "3.14"
    ]
  },
  {
    "Torch": "2.9.1",
    "Torchvision": "0.24.1",
    "Torchaudio": "2.9.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu130",
    "CUDA": "13.0",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13",
      "3.14"
    ]
  },
  {
    "Torch": "2.9.1",
    "Torchvision": "0.24.1",
    "Torchaudio": "2.9.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cpu",
    "CUDA": null,
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13",
      "3.14"
    ]
  },
  {
    "Torch": "2.9.0",
    "Torchvision": "0.24.0",
    "Torchaudio": "2.9.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu126",
    "CUDA": "12.6",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13",
      "3.14"
    ]
  },
  {
    "Torch": "2.9.0",
    "Torchvision": "0.24.0",
    "Torchaudio": "2.9.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu128",
    "CUDA": "12.8",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13",
      "3.14"
    ]
  },
  {
    "Torch": "2.9.0",
    "Torchvision": "0.24.0",
    "Torchaudio": "2.9.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu130",
    "CUDA": "13.0",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13",
      "3.14"
    ]
  },
  {
    "Torch": "2.9.0",
    "Torchvision": "0.24.0",
    "Torchaudio": "2.9.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cpu",
    "CUDA": null,
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13",
      "3.14"
    ]
  },
  {
    "Torch": "2.8.0",
    "Torchvision": "0.23.0",
    "Torchaudio": "2.8.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu126",
    "CUDA": "12.6",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13"
    ]
  },
  {
    "Torch": "2.8.0",
    "Torchvision": "0.23.0",
    "Torchaudio": "2.8.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu128",
    "CUDA": "12.8",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13"
    ]
  },
  {
    "Torch": "2.8.0",
    "Torchvision": "0.23.0",
    "Torchaudio": "2.8.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu129",
    "CUDA": "12.9",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13"
    ]
  },
  {
    "Torch": "2.8.0",
    "Torchvision": "0.23.0",
    "Torchaudio": "2.8.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cpu",
    "CUDA": null,
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13"
    ]
  },
  {
    "Torch": "2.7.1",
    "Torchvision": "0.22.1",
    "Torchaudio": "2.7.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu118",
    "CUDA": "11.8",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13"
    ]
  },
  {
    "Torch": "2.7.1",
    "Torchvision": "0.22.1",
    "Torchaudio": "2.7.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu126",
    "CUDA": "12.6",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13"
    ]
  },
  {
    "Torch": "2.7.1",
    "Torchvision": "0.22.1",
    "Torchaudio": "2.7.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu128",
    "CUDA": "12.8",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13"
    ]
  },
  {
    "Torch": "2.7.1",
    "Torchvision": "0.22.1",
    "Torchaudio": "2.7.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cpu",
    "CUDA": null,
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13"
    ]
  },
  {
    "Torch": "2.7.0",
    "Torchvision": "0.22.0",
    "Torchaudio": "2.7.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu118",
    "CUDA": "11.8",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13"
    ]
  },
  {
    "Torch": "2.7.0",
    "Torchvision": "0.22.0",
    "Torchaudio": "2.7.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu126",
    "CUDA": "12.6",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13"
    ]
  },
  {
    "Torch": "2.7.0",
    "Torchvision": "0.22.0",
    "Torchaudio": "2.7.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu128",
    "CUDA": "12.8",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13"
    ]
  },
  {
    "Torch": "2.7.0",
    "Torchvision": "0.22.0",
    "Torchaudio": "2.7.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cpu",
    "CUDA": null,
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13"
    ]
  },
  {
    "Torch": "2.6.0",
    "Torchvision": "0.21.0",
    "Torchaudio": "2.6.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu118",
    "CUDA": "11.8",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13"
    ]
  },
  {
    "Torch": "2.6.0",
    "Torchvision": "0.21.0",
    "Torchaudio": "2.6.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu124",
    "CUDA": "12.4",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13"
    ]
  },
  {
    "Torch": "2.6.0",
    "Torchvision": "0.21.0",
    "Torchaudio": "2.6.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu126",
    "CUDA": "12.6",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13"
    ]
  },
  {
    "Torch": "2.6.0",
    "Torchvision": "0.21.0",
    "Torchaudio": "2.6.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cpu",
    "CUDA": null,
    "Pythons": [
      "3.10",
      "3.11",
      "3.12",
      "3.13"
    ]
  },
  {
    "Torch": "2.5.1",
    "Torchvision": "0.20.1",
    "Torchaudio": "2.5.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu118",
    "CUDA": "11.8",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.5.1",
    "Torchvision": "0.20.1",
    "Torchaudio": "2.5.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu121",
    "CUDA": "12.1",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.5.1",
    "Torchvision": "0.20.1",
    "Torchaudio": "2.5.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu124",
    "CUDA": "12.4",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.5.1",
    "Torchvision": "0.20.1",
    "Torchaudio": "2.5.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cpu",
    "CUDA": null,
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.5.0",
    "Torchvision": "0.20.0",
    "Torchaudio": "2.5.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu118",
    "CUDA": "11.8",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.5.0",
    "Torchvision": "0.20.0",
    "Torchaudio": "2.5.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu121",
    "CUDA": "12.1",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.5.0",
    "Torchvision": "0.20.0",
    "Torchaudio": "2.5.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu124",
    "CUDA": "12.4",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.5.0",
    "Torchvision": "0.20.0",
    "Torchaudio": "2.5.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cpu",
    "CUDA": null,
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.4.1",
    "Torchvision": "0.19.1",
    "Torchaudio": "2.4.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu118",
    "CUDA": "11.8",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.4.1",
    "Torchvision": "0.19.1",
    "Torchaudio": "2.4.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu121",
    "CUDA": "12.1",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.4.1",
    "Torchvision": "0.19.1",
    "Torchaudio": "2.4.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu124",
    "CUDA": "12.4",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.4.1",
    "Torchvision": "0.19.1",
    "Torchaudio": "2.4.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cpu",
    "CUDA": null,
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.4.0",
    "Torchvision": "0.19.0",
    "Torchaudio": "2.4.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu118",
    "CUDA": "11.8",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.4.0",
    "Torchvision": "0.19.0",
    "Torchaudio": "2.4.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu121",
    "CUDA": "12.1",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.4.0",
    "Torchvision": "0.19.0",
    "Torchaudio": "2.4.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu124",
    "CUDA": "12.4",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.4.0",
    "Torchvision": "0.19.0",
    "Torchaudio": "2.4.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cpu",
    "CUDA": null,
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.3.1",
    "Torchvision": "0.18.1",
    "Torchaudio": "2.3.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu118",
    "CUDA": "11.8",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.3.1",
    "Torchvision": "0.18.1",
    "Torchaudio": "2.3.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu121",
    "CUDA": "12.1",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.3.1",
    "Torchvision": "0.18.1",
    "Torchaudio": "2.3.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cpu",
    "CUDA": null,
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.3.0",
    "Torchvision": "0.18.0",
    "Torchaudio": "2.3.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu118",
    "CUDA": "11.8",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.3.0",
    "Torchvision": "0.18.0",
    "Torchaudio": "2.3.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu121",
    "CUDA": "12.1",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.3.0",
    "Torchvision": "0.18.0",
    "Torchaudio": "2.3.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cpu",
    "CUDA": null,
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.2.2",
    "Torchvision": "0.17.2",
    "Torchaudio": "2.2.2",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu118",
    "CUDA": "11.8",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.2.2",
    "Torchvision": "0.17.2",
    "Torchaudio": "2.2.2",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu121",
    "CUDA": "12.1",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.2.2",
    "Torchvision": "0.17.2",
    "Torchaudio": "2.2.2",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cpu",
    "CUDA": null,
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.2.1",
    "Torchvision": "0.17.1",
    "Torchaudio": "2.2.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu118",
    "CUDA": "11.8",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.2.1",
    "Torchvision": "0.17.1",
    "Torchaudio": "2.2.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu121",
    "CUDA": "12.1",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.2.1",
    "Torchvision": "0.17.1",
    "Torchaudio": "2.2.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cpu",
    "CUDA": null,
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.2.0",
    "Torchvision": "0.17.0",
    "Torchaudio": "2.2.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu118",
    "CUDA": "11.8",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.2.0",
    "Torchvision": "0.17.0",
    "Torchaudio": "2.2.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu121",
    "CUDA": "12.1",
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.2.0",
    "Torchvision": "0.17.0",
    "Torchaudio": "2.2.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cpu",
    "CUDA": null,
    "Pythons": [
      "3.10",
      "3.11",
      "3.12"
    ]
  },
  {
    "Torch": "2.1.2",
    "Torchvision": "0.16.2",
    "Torchaudio": "2.1.2",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu118",
    "CUDA": "11.8",
    "Pythons": [
      "3.10",
      "3.11"
    ]
  },
  {
    "Torch": "2.1.2",
    "Torchvision": "0.16.2",
    "Torchaudio": "2.1.2",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu121",
    "CUDA": "12.1",
    "Pythons": [
      "3.10",
      "3.11"
    ]
  },
  {
    "Torch": "2.1.2",
    "Torchvision": "0.16.2",
    "Torchaudio": "2.1.2",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cpu",
    "CUDA": null,
    "Pythons": [
      "3.10",
      "3.11"
    ]
  },
  {
    "Torch": "2.1.1",
    "Torchvision": "0.16.1",
    "Torchaudio": "2.1.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu118",
    "CUDA": "11.8",
    "Pythons": [
      "3.10",
      "3.11"
    ]
  },
  {
    "Torch": "2.1.1",
    "Torchvision": "0.16.1",
    "Torchaudio": "2.1.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu121",
    "CUDA": "12.1",
    "Pythons": [
      "3.10",
      "3.11"
    ]
  },
  {
    "Torch": "2.1.1",
    "Torchvision": "0.16.1",
    "Torchaudio": "2.1.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cpu",
    "CUDA": null,
    "Pythons": [
      "3.10",
      "3.11"
    ]
  },
  {
    "Torch": "2.1.0",
    "Torchvision": "0.16.0",
    "Torchaudio": "2.1.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu118",
    "CUDA": "11.8",
    "Pythons": [
      "3.10",
      "3.11"
    ]
  },
  {
    "Torch": "2.1.0",
    "Torchvision": "0.16.0",
    "Torchaudio": "2.1.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu121",
    "CUDA": "12.1",
    "Pythons": [
      "3.10",
      "3.11"
    ]
  },
  {
    "Torch": "2.1.0",
    "Torchvision": "0.16.0",
    "Torchaudio": "2.1.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cpu",
    "CUDA": null,
    "Pythons": [
      "3.10",
      "3.11"
    ]
  },
  {
    "Torch": "2.0.1",
    "Torchvision": "0.15.2",
    "Torchaudio": "2.0.2",
    "FindLinks": "",
    "ExtraIndexURL": "",
    "CUDA": "11.7",
    "Pythons": [
      "3.10",
      "3.11"
    ]
  },
  {
    "Torch": "2.0.1",
    "Torchvision": "0.15.2",
    "Torchaudio": "2.0.2",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu118",
    "CUDA": "11.8",
    "Pythons": [
      "3.10",
      "3.11"
    ]
  },
  {
    "Torch": "2.0.1",
    "Torchvision": "0.15.2",
    "Torchaudio": "2.0.2",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cpu",
    "CUDA": null,
    "Pythons": [
      "3.10",
      "3.11"
    ]
  },
  {
    "Torch": "2.0.0",
    "Torchvision": "0.15.1",
    "Torchaudio": "2.0.1",
    "FindLinks": "",
    "ExtraIndexURL": "",
    "CUDA": "11.7",
    "Pythons": [
      "3.10",
      "3.11"
    ]
  },
  {
    "Torch": "2.0.0",
    "Torchvision": "0.15.1",
    "Torchaudio": "2.0.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu118",
    "CUDA": "11.8",
    "Pythons": [
      "3.10",
      "3.11"
    ]
  },
  {
    "Torch": "2.0.0",
    "Torchvision": "0.15.1",
    "Torchaudio": "2.0.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cpu",
    "CUDA": null,
    "Pythons": [
      "3.10",
      "3.11"
    ]
  },
  {
    "Torch": "1.13.1+cu116",
    "Torchvision": "0.14.1+cu116",
    "Torchaudio": "0.13.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu116",
    "CUDA": "11.6",
    "Pythons": [
      "3.10"
    ]
  },
  {
    "Torch": "1.13.1+cu117",
    "Torchvision": "0.14.1+cu117",
    "Torchaudio": "0.13.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu117",
    "CUDA": "11.7",
    "Pythons": [
      "3.10"
    ]
  },
  {
    "Torch": "1.13.1+cpu",
    "Torchvision": "0.14.1+cpu",
    "Torchaudio": "0.13.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cpu",
    "CUDA": null,
    "Pythons": [
      "3.10"
    ]
  },
  {
    "Torch": "1.13.0+cu116",
    "Torchvision": "0.14.0+cu116",
    "Torchaudio": "0.13.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu116",
    "CUDA": "11.6",
    "Pythons": [
      "3.10"
    ]
  },
  {
    "Torch": "1.13.0+cu117",
    "Torchvision": "0.14.0+cu117",
    "Torchaudio": "0.13.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu117",
    "CUDA": "11.7",
    "Pythons": [
      "3.10"
    ]
  },
  {
    "Torch": "1.13.0+cpu",
    "Torchvision": "0.14.0+cpu",
    "Torchaudio": "0.13.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cpu",
    "CUDA": null,
    "Pythons": [
      "3.10"
    ]
  },
  {
    "Torch": "1.12.1+cu116",
    "Torchvision": "0.13.1+cu116",
    "Torchaudio": "0.12.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu116",
    "CUDA": "11.6",
    "Pythons": [
      "3.10"
    ]
  },
  {
    "Torch": "1.12.1+cu113",
    "Torchvision": "0.13.1+cu113",
    "Torchaudio": "0.12.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu113",
    "CUDA": "11.3",
    "Pythons": [
      "3.10"
    ]
  },
  {
    "Torch": "1.12.1+cu102",
    "Torchvision": "0.13.1+cu102",
    "Torchaudio": "0.12.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu102",
    "CUDA": "10.2",
    "Pythons": [
      "3.10"
    ]
  },
  {
    "Torch": "1.12.1+cpu",
    "Torchvision": "0.13.1+cpu",
    "Torchaudio": "0.12.1",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cpu",
    "CUDA": null,
    "Pythons": [
      "3.10"
    ]
  },
  {
    "Torch": "1.12.0+cu116",
    "Torchvision": "0.13.0+cu116",
    "Torchaudio": "0.12.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu116",
    "CUDA": "11.6",
    "Pythons": [
      "3.10"
    ]
  },
  {
    "Torch": "1.12.0+cu113",
    "Torchvision": "0.13.0+cu113",
    "Torchaudio": "0.12.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu113",
    "CUDA": "11.3",
    "Pythons": [
      "3.10"
    ]
  },
  {
    "Torch": "1.12.0+cu102",
    "Torchvision": "0.13.0+cu102",
    "Torchaudio": "0.12.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu102",
    "CUDA": "10.2",
    "Pythons": [
      "3.10"
    ]
  },
  {
    "Torch": "1.12.0+cpu",
    "Torchvision": "0.13.0+cpu",
    "Torchaudio": "0.12.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cpu",
    "CUDA": null,
    "Pythons": [
      "3.10"
    ]
  },
  {
    "Torch": "1.11.0+cu113",
    "Torchvision": "0.12.0+cu113",
    "Torchaudio": "0.11.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu113",
    "CUDA": "11.3",
    "Pythons": [
      "3.10"
    ]
  },
  {
    "Torch": "1.11.0+cu102",
    "Torchvision": "0.12.0+cu102",
    "Torchaudio": "0.11.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cu102",
    "CUDA": "10.2",
    "Pythons": [
      "3.10"
    ]
  },
  {
    "Torch": "1.11.0+cpu",
    "Torchvision": "0.12.0+cpu",
    "Torchaudio": "0.11.0",
    "FindLinks": "",
    "ExtraIndexURL": "https://download.pytorch.org/whl/cpu",
    "CUDA": null,
    "Pythons": [
      "3.10"
    ]
  },
  {
    "Torch": "1.4.0",
    "Torchvision": "0.5.0",
    "Torchaudio": "",
    "FindLinks": "",
    "ExtraIndexURL": "",
    "CUDA": "10.1",
    "Pythons": [
      "2.7"
    ]
  },
  {
    "Torch": "1.4.0+cu92",
    "Torchvision": "0.5.0+cu92",
    "Torchaudio": "",
    "FindLinks": "https://download.pytorch.org/whl/torch_stable.html",
    "ExtraIndexURL": "",
    "CUDA": "9.2",
    "Pythons": [
      "2.7"
    ]
  },
  {
    "Torch": "1.4.0+cpu",
    "Torchvision": "0.5.0+cpu",
    "Torchaudio": "",
    "FindLinks": "https://download.pytorch.org/whl/torch_stable.html",
    "ExtraIndexURL": "",
    "CUDA": null,
    "Pythons": [
      "2.7"
    ]
  },
  {
    "Torch": "1.2.0",
    "Torchvision": "0.4.0",
    "Torchaudio": "",
    "FindLinks": "",
    "ExtraIndexURL": "",
    "CUDA": "10.0",
    "Pythons": [
      "2.7"
    ]
  },
  {
    "Torch": "1.2.0+cu92",
    "Torchvision": "0.4.0+cu92",
    "Torchaudio": "",
    "FindLinks": "https://download.pytorch.org/whl/torch_stable.html",
    "ExtraIndexURL": "",
    "CUDA": "9.2",
    "Pythons": [
      "2.7"
    ]
  },
  {
    "Torch": "1.2.0+cpu",
    "Torchvision": "0.4.0+cpu",
    "Torchaudio": "",
    "FindLinks": "https://download.pytorch.org/whl/torch_stable.html",
    "ExtraIndexURL": "",
    "CUDA": null,
    "Pythons": [
      "2.7"
    ]
  }
]

================================================
FILE: pkg/config/validate.go
================================================
package config

import (
	// blank import for embeds
	_ "embed"
	"fmt"
	"io/fs"
	"os"
	"path/filepath"
	"slices"
	"strconv"
	"strings"

	"github.com/xeipuuv/gojsonschema"

	"github.com/replicate/cog/pkg/requirements"
)

//go:embed data/config_schema_v1.0.json
var schemaV1 []byte

// ValidateOption configures validation behavior.
type ValidateOption func(*validateOptions)

type validateOptions struct {
	projectDir         string
	requirementsFS     fs.FS
	strictDeprecations bool
}

// WithProjectDir sets the project directory for resolving relative paths.
func WithProjectDir(dir string) ValidateOption {
	return func(o *validateOptions) {
		o.projectDir = dir
	}
}

// WithRequirementsFS sets the filesystem for reading python_requirements file.
func WithRequirementsFS(fsys fs.FS) ValidateOption {
	return func(o *validateOptions) {
		o.requirementsFS = fsys
	}
}

// WithStrictDeprecations treats deprecation warnings as errors.
func WithStrictDeprecations() ValidateOption {
	return func(o *validateOptions) {
		o.strictDeprecations = true
	}
}

// ValidateConfigFile checks a configFile for errors.
// Returns all validation errors and deprecation warnings.
// Does not mutate the input.
func ValidateConfigFile(cfg *configFile, opts ...ValidateOption) *ValidationResult {
	options := &validateOptions{}
	for _, opt := range opts {
		opt(options)
	}

	result := NewValidationResult()

	// Schema validation
	if err := validateSchema(cfg); err != nil {
		result.AddError(err)
	}

	// Semantic validation
	validatePredict(cfg, result)
	validateTrain(cfg, result)
	validateBuild(cfg, options, result)
	validateEnvironment(cfg, result)
	validateConcurrency(cfg, result)

	// Check deprecated fields
	checkDeprecatedFields(cfg, result)

	// If strict deprecations, convert warnings to errors
	if options.strictDeprecations && result.HasWarnings() {
		for _, w := range result.Warnings {
			result.AddError(&w)
		}
		result.Warnings = nil
	}

	return result
}

// validateSchema validates the config against the JSON schema.
func validateSchema(cfg *configFile) error {
	schemaLoader := gojsonschema.NewStringLoader(string(schemaV1))
	dataLoader := gojsonschema.NewGoLoader(cfg)

	validationResult, err := gojsonschema.Validate(schemaLoader, dataLoader)
	if err != nil {
		return &SchemaError{Field: "(root)", Message: err.Error()}
	}

	if !validationResult.Valid() {
		// Get the most specific error
		err := getMostSpecificSchemaError(validationResult.Errors())
		return err
	}

	return nil
}

// validatePredict validates the predict field.
func validatePredict(cfg *configFile, result *ValidationResult) {
	if cfg.Predict == nil || *cfg.Predict == "" {
		return
	}

	predict := *cfg.Predict
	if len(strings.Split(predict, ".py:")) != 2 {
		result.AddError(&ValidationError{
			Field:   "predict",
			Value:   predict,
			Message: "must be in the form 'predict.py:Predictor'",
		})
	}
}

// validateTrain validates the train field.
func validateTrain(cfg *configFile, result *ValidationResult) {
	if cfg.Train == nil || *cfg.Train == "" {
		return
	}

	train := *cfg.Train
	if len(strings.Split(train, ".py:")) != 2 {
		result.AddError(&ValidationError{
			Field:   "train",
			Value:   train,
			Message: "must be in the form 'train.py:Trainer'",
		})
	}
}

// validateBuild validates the build configuration.
func validateBuild(cfg *configFile, opts *validateOptions, result *ValidationResult) {
	if cfg.Build == nil {
		return
	}

	build := cfg.Build

	// Validate Python version is set and valid
	if build.PythonVersion == nil || *build.PythonVersion == "" {
		result.AddError(&ValidationError{
			Field:   "build.python_version",
			Message: "python_version is required. Add it to the build section of your cog.yaml, e.g. `python_version: \"3.13\"`",
		})
	} else {
		if err := validatePythonVersion(*build.PythonVersion); err != nil {
			result.AddError(err)
		}
	}

	// Validate mutual exclusivity of python_packages and python_requirements
	if len(build.PythonPackages) > 0 && build.PythonRequirements != nil && *build.PythonRequirements != "" {
		result.AddError(&ValidationError{
			Field:   "build",
			Message: "only one of python_packages or python_requirements can be set, not both",
		})
	}

	// Validate python_requirements file exists
	if build.PythonRequirements != nil && *build.PythonRequirements != "" {
		if err := validateRequirementsFile(*build.PythonRequirements, opts); err != nil {
			result.AddError(err)
		}
	}

	// Validate CUDA version if specified
	if build.CUDA != nil && *build.CUDA != "" {
		if err := validateCUDAVersion(*build.CUDA); err != nil {
			result.AddError(err)
		}
	}

	// Validate GPU-specific settings
	if build.GetGPU() {
		validateGPUConfig(cfg, opts, result)
	}
}

// validatePythonVersion validates the Python version string.
func validatePythonVersion(version string) error {
	version = strings.TrimSpace(version)
	parts := strings.SplitN(version, ".", 3)
	if len(parts) < 2 {
		return &ValidationError{
			Field:   "build.python_version",
			Value:   version,
			Message: "must include major and minor version (e.g., '3.11')",
		}
	}

	major, err := strconv.Atoi(parts[0])
	if err != nil {
		return &ValidationError{
			Field:   "build.python_version",
			Value:   version,
			Message: "invalid major version number",
		}
	}

	minor, err := strconv.Atoi(parts[1])
	if err != nil {
		return &ValidationError{
			Field:   "build.python_version",
			Value:   version,
			Message: "invalid minor version number",
		}
	}

	if major < MinimumMajorPythonVersion || (major == MinimumMajorPythonVersion && minor < MinimumMinorPythonVersion) {
		return &ValidationError{
			Field:   "build.python_version",
			Value:   version,
			Message: fmt.Sprintf("minimum supported Python version is %d.%d", MinimumMajorPythonVersion, MinimumMinorPythonVersion),
		}
	}

	return nil
}

// validateCUDAVersion validates the CUDA version string.
func validateCUDAVersion(cudaVersion string) error {
	parts := strings.Split(cudaVersion, ".")
	if len(parts) < 2 {
		return &ValidationError{
			Field:   "build.cuda",
			Value:   cudaVersion,
			Message: "must include both major and minor versions (e.g., '11.8')",
		}
	}

	major, err := strconv.Atoi(parts[0])
	if err != nil {
		return &ValidationError{
			Field:   "build.cuda",
			Value:   cudaVersion,
			Message: "invalid major version number",
		}
	}

	if major < MinimumMajorCudaVersion {
		return &ValidationError{
			Field:   "build.cuda",
			Value:   cudaVersion,
			Message: fmt.Sprintf("minimum supported CUDA version is %d", MinimumMajorCudaVersion),
		}
	}

	return nil
}

// validateRequirementsFile validates that the requirements file exists and is readable.
func validateRequirementsFile(reqPath string, opts *validateOptions) error {
	fullPath := reqPath
	if !strings.HasPrefix(reqPath, "/") && opts.projectDir != "" {
		fullPath = filepath.Join(opts.projectDir, reqPath)
	}

	if opts.requirementsFS != nil {
		_, err := fs.ReadFile(opts.requirementsFS, reqPath)
		if err != nil {
			return &ValidationError{
				Field:   "build.python_requirements",
				Value:   reqPath,
				Message: fmt.Sprintf("cannot read file: %v", err),
			}
		}
		return nil
	}

	// Use the real filesystem
	if _, err := os.Stat(fullPath); os.IsNotExist(err) {
		return &ValidationError{
			Field:   "build.python_requirements",
			Value:   reqPath,
			Message: "file does not exist",
		}
	}

	return nil
}

// validateGPUConfig validates GPU-specific configuration like CUDA/CuDNN compatibility.
func validateGPUConfig(cfg *configFile, opts *validateOptions, result *ValidationResult) {
	build := cfg.Build
	if build == nil {
		return
	}

	// If both CUDA and CuDNN are specified, check compatibility
	if build.CUDA != nil && *build.CUDA != "" && build.CuDNN != nil && *build.CuDNN != "" {
		cuda := *build.CUDA
		cudnn := *build.CuDNN
		compatibleCuDNNs := compatibleCuDNNsForCUDA(cuda)
		found := slices.Contains(compatibleCuDNNs, cudnn)
		if !found && len(compatibleCuDNNs) > 0 {
			result.AddError(&CompatibilityError{
				Component1: "CUDA",
				Version1:   cuda,
				Component2: "CuDNN",
				Version2:   cudnn,
				Message:    fmt.Sprintf("compatible CuDNN versions are: %s", strings.Join(compatibleCuDNNs, ", ")),
			})
		}
	}

	// Validate torch/tensorflow requirements if we can read them
	if build.PythonRequirements != nil && *build.PythonRequirements != "" {
		reqs := loadRequirementsForValidation(*build.PythonRequirements, opts)
		if len(reqs) > 0 {
			validateFrameworkCompatibility(cfg, reqs, result)
		}
	} else if len(build.PythonPackages) > 0 {
		validateFrameworkCompatibility(cfg, build.PythonPackages, result)
	}
}

// loadRequirementsForValidation loads requirements file contents for validation.
func loadRequirementsForValidation(reqPath string, opts *validateOptions) []string {
	fullPath := reqPath
	if !strings.HasPrefix(reqPath, "/") && opts.projectDir != "" {
		fullPath = filepath.Join(opts.projectDir, reqPath)
	}

	if opts.requirementsFS != nil {
		data, err := fs.ReadFile(opts.requirementsFS, reqPath)
		if err != nil {
			return nil
		}
		return parseRequirementsContent(string(data))
	}

	reqs, err := requirements.ReadRequirements(fullPath)
	if err != nil {
		return nil
	}
	return reqs
}

// parseRequirementsContent parses requirements.txt content into lines.
func parseRequirementsContent(content string) []string {
	lines := strings.Split(content, "\n")
	result := make([]string, 0, len(lines))
	for _, line := range lines {
		line = strings.TrimSpace(line)
		if line == "" || strings.HasPrefix(line, "#") || strings.HasPrefix(line, "-") {
			continue
		}
		result = append(result, line)
	}
	return result
}

// validateFrameworkCompatibility checks torch/tensorflow compatibility with CUDA.
func validateFrameworkCompatibility(cfg *configFile, reqs []string, result *ValidationResult) {
	// This is a simplified version - the full logic is in Complete()
	// Here we just check for obvious errors.
	// Note: torch compatibility is checked in Complete() where it can emit warnings.
	// We only validate TensorFlow here since it has stricter requirements.

	build := cfg.Build
	if build == nil {
		return
	}

	tfVersion := findPackageVersion(reqs, "tensorflow")

	// If CUDA is specified, check TensorFlow compatibility
	if build.CUDA != nil && *build.CUDA != "" {
		cuda := *build.CUDA

		if tfVersion != "" {
			tfCUDA, _, _ := cudaFromTF(tfVersion)
			if tfCUDA != "" && !strings.HasPrefix(cuda, strings.Split(tfCUDA, ".")[0]) {
				result.AddError(&CompatibilityError{
					Component1: "TensorFlow",
					Version1:   tfVersion,
					Component2: "CUDA",
					Version2:   cuda,
					Message:    fmt.Sprintf("TensorFlow %s requires CUDA %s", tfVersion, tfCUDA),
				})
			}
		}
	}
}

// findPackageVersion finds a package version in requirements.
func findPackageVersion(reqs []string, name string) string {
	for _, req := range reqs {
		pkgName := requirements.PackageName(req)
		if pkgName == name {
			versions := requirements.Versions(req)
			if len(versions) > 0 {
				return versions[0]
			}
		}
	}
	return ""
}

// validateEnvironment validates environment variables.
func validateEnvironment(cfg *configFile, result *ValidationResult) {
	if len(cfg.Environment) == 0 {
		return
	}

	_, err := parseAndValidateEnvironment(cfg.Environment)
	if err != nil {
		result.AddError(&ValidationError{
			Field:   "environment",
			Message: err.Error(),
		})
	}
}

// validateConcurrency validates concurrency settings.
func validateConcurrency(cfg *configFile, result *ValidationResult) {
	if cfg.Concurrency == nil || cfg.Concurrency.Max == nil {
		return
	}

	maxConcurrency := *cfg.Concurrency.Max
	if maxConcurrency < 1 {
		result.AddError(&ValidationError{
			Field:   "concurrency.max",
			Value:   fmt.Sprintf("%d", maxConcurrency),
			Message: "must be at least 1",
		})
	}

	// Check Python version requirement for concurrency
	if maxConcurrency > 1 && cfg.Build != nil && cfg.Build.PythonVersion != nil {
		pyVersion := *cfg.Build.PythonVersion
		major, minor, err := splitPythonVersion(pyVersion)
		if err == nil {
			// Only check minor version if major version is the minimum (3)
			// For major > 3, any minor version would be acceptable
			if major == MinimumMajorPythonVersion && minor < MinimumMinorPythonVersionForConcurrency {
				result.AddError(&ValidationError{
					Field:   "concurrency.max",
					Value:   fmt.Sprintf("%d", maxConcurrency),
					Message: fmt.Sprintf("concurrency requires Python %d.%d or higher", MinimumMajorPythonVersion, MinimumMinorPythonVersionForConcurrency),
				})
			}
		}
	}
}

// checkDeprecatedFields checks for deprecated fields and adds warnings.
func checkDeprecatedFields(cfg *configFile, result *ValidationResult) {
	if cfg.Build == nil {
		return
	}

	if len(cfg.Build.PythonPackages) > 0 {
		result.AddWarning(DeprecationWarning{
			Field:       "build.python_packages",
			Replacement: "build.python_requirements",
			Message:     "use a requirements.txt file instead",
		})
	}

	if len(cfg.Build.PreInstall) > 0 {
		result.AddWarning(DeprecationWarning{
			Field:       "build.pre_install",
			Replacement: "build.run",
			Message:     "use build.run commands instead",
		})
	}
}

// getMostSpecificSchemaError extracts the most specific error from schema validation.
func getMostSpecificSchemaError(errors []gojsonschema.ResultError) *SchemaError {
	if len(errors) == 0 {
		return &SchemaError{Field: "(unknown)", Message: "unknown schema error"}
	}

	mostSpecific := 0
	for i, err := range errors {
		if schemaErrorSpecificity(err) > schemaErrorSpecificity(errors[mostSpecific]) {
			mostSpecific = i
		} else if schemaErrorSpecificity(err) == schemaErrorSpecificity(errors[mostSpecific]) {
			// Invalid type errors win in a tie-breaker
			if err.Type() == "invalid_type" && errors[mostSpecific].Type() != "invalid_type" {
				mostSpecific = i
			}
		}
	}

	err := errors[mostSpecific]
	field := err.Field()
	if field == "(root)" {
		field = "cog.yaml"
	}

	message := getSchemaErrorDescription(err, errors, mostSpecific)

	return &SchemaError{
		Field:   field,
		Message: message,
	}
}

// getSchemaErrorDescription generates a human-readable description for a schema error.
func getSchemaErrorDescription(err gojsonschema.ResultError, allErrors []gojsonschema.ResultError, index int) string {
	switch err.Type() {
	case "invalid_type":
		if expectedType, ok := err.Details()["expected"].(string); ok {
			return fmt.Sprintf("must be a %s", humanReadableSchemaType(expectedType))
		}
	case "number_one_of", "number_any_of":
		if index+1 < len(allErrors) {
			return allErrors[index+1].Description()
		}
	}
	return err.Description()
}

// humanReadableSchemaType converts JSON schema type names to human-readable names.
func humanReadableSchemaType(definition string) string {
	if len(definition) > 0 && definition[0] == '[' {
		allTypes := strings.Split(definition[1:len(definition)-1], ",")
		for i, t := range allTypes {
			allTypes[i] = humanReadableSchemaType(strings.TrimSpace(t))
		}
		return fmt.Sprintf("%s or %s",
			strings.Join(allTypes[0:len(allTypes)-1], ", "),
			allTypes[len(allTypes)-1])
	}
	switch definition {
	case "object":
		return "mapping"
	case "array":
		return "list"
	default:
		return definition
	}
}

// schemaErrorSpecificity returns how specific a schema error is based on field depth.
func schemaErrorSpecificity(err gojsonschema.ResultError) int {
	return len(strings.Split(err.Field(), "."))
}

// Note: The legacy Validate function is in validator.go for backwards compatibility


================================================
FILE: pkg/config/validate_test.go
================================================
package config

import (
	"testing"

	"github.com/stretchr/testify/require"
)

func TestValidateConfigFile(t *testing.T) {
	cfg := &configFile{
		Build: &buildFile{
			GPU:           ptr(true),
			PythonVersion: ptr("3.10"),
			PythonPackages: []string{
				"tensorflow==2.12.0",
				"foo==1.0.0",
			},
			CUDA: ptr("11.8"),
		},
	}
	result := ValidateConfigFile(cfg)
	require.False(t, result.HasErrors(), "expected no errors, got: %v", result.Errors)
}

func TestValidateConfigFileSuccess(t *testing.T) {
	cfg := &configFile{
		Build: &buildFile{
			GPU: ptr(true),
			SystemPackages: []string{
				"libgl1-mesa-glx",
				"libglib2.0-0",
			},
			PythonVersion: ptr("3.10"),
			PythonPackages: []string{
				"torch==1.8.1",
			},
		},
	}

	result := ValidateConfigFile(cfg)
	require.False(t, result.HasErrors(), "expected no errors, got: %v", result.Errors)
}

func TestValidateConfigFilePythonVersionNumerical(t *testing.T) {
	cfg := &configFile{
		Build: &buildFile{
			GPU: ptr(true),
			SystemPackages: []string{
				"libgl1-mesa-glx",
				"libglib2.0-0",
			},
			PythonVersion: ptr("3.10"),
			PythonPackages: []string{
				"torch==1.8.1",
			},
		},
	}

	result := ValidateConfigFile(cfg)
	require.False(t, result.HasErrors(), "expected no errors, got: %v", result.Errors)
}

func TestValidateConfigFileNullListsAllowed(t *testing.T) {
	cfg := &configFile{
		Build: &buildFile{
			GPU:            ptr(true),
			PythonVersion:  ptr("3.10"),
			SystemPackages: nil,
			PythonPackages: nil,
			Run:            nil,
		},
	}

	result := ValidateConfigFile(cfg)
	require.False(t, result.HasErrors(), "expected no errors, got: %v", result.Errors)
}

func TestValidateConfigFilePredictFormat(t *testing.T) {
	// Valid predict format
	cfg := &configFile{
		Build: &buildFile{
			PythonVersion: ptr("3.10"),
		},
		Predict: ptr("predict.py:Predictor"),
	}

	result := ValidateConfigFile(cfg)
	require.False(t, result.HasErrors(), "expected no errors, got: %v", result.Errors)

	// Invalid predict format
	cfg.Predict = ptr("invalid_format")
	result = ValidateConfigFile(cfg)
	require.True(t, result.HasErrors())
	require.Contains(t, result.Err().Error(), "predict.py:Predictor")
}

func TestValidateConfigFileConcurrencyType(t *testing.T) {
	cfg := &configFile{
		Build: &buildFile{
			GPU:           ptr(true),
			CUDA:          ptr("11.8"),
			PythonVersion: ptr("3.11"),
			PythonPackages: []string{
				"torch==2.0.1",
			},
		},
		Predict: ptr("predict.py:Predictor"),
		Concurrency: &concurrencyFile{
			Max: ptr(5),
		},
	}

	result := ValidateConfigFile(cfg)
	require.False(t, result.HasErrors(), "expected no errors, got: %v", result.Errors)
}

func TestValidateConfigFileDeprecatedPythonPackages(t *testing.T) {
	cfg := &configFile{
		Build: &buildFile{
			PythonVersion: ptr("3.10"),
			PythonPackages: []string{
				"torch==1.8.1",
			},
		},
	}

	result := ValidateConfigFile(cfg)
	require.False(t, result.HasErrors())
	require.Len(t, result.Warnings, 1)
	require.Contains(t, result.Warnings[0].Message, "requirements.txt")
}

func TestValidateConfigFileDeprecatedPreInstall(t *testing.T) {
	cfg := &configFile{
		Build: &buildFile{
			PythonVersion: ptr("3.10"),
			PreInstall: []string{
				"echo hello",
			},
		},
	}

	result := ValidateConfigFile(cfg)
	require.False(t, result.HasErrors())
	require.Len(t, result.Warnings, 1)
	require.Contains(t, result.Warnings[0].Message, "build.run")
}

func TestValidateConfigFileMissingPythonVersion(t *testing.T) {
	cfg := &configFile{
		Build: &buildFile{
			GPU: ptr(true),
		},
	}

	result := ValidateConfigFile(cfg)
	require.True(t, result.HasErrors())
	require.Contains(t, result.Err().Error(), "python_version is required")
}

func TestValidateConfigFileMissingPythonVersionEmptyBuild(t *testing.T) {
	cfg := &configFile{
		Build: &buildFile{},
	}

	result := ValidateConfigFile(cfg)
	require.True(t, result.HasErrors())
	require.Contains(t, result.Err().Error(), "python_version is required")
}

func TestValidateConfigFileNilBuildSkipsPythonVersionCheck(t *testing.T) {
	cfg := &configFile{}

	result := ValidateConfigFile(cfg)
	// No build section at all should not error about python_version
	require.False(t, result.HasErrors(), "expected no errors for nil build, got: %v", result.Errors)
}

// ptr returns a pointer to the given value.
func ptr[T any](v T) *T { return &v }


================================================
FILE: pkg/config/version.go
================================================
package config

// ArgumentType represents the type of a run argument.
type ArgumentType string

const (
	ArgumentTypeString ArgumentType = "str"
	ArgumentTypeInt    ArgumentType = "int"
	ArgumentTypeFloat  ArgumentType = "float"
	ArgumentTypeBool   ArgumentType = "bool"
	ArgumentTypePath   ArgumentType = "Path"
)

// RunArgument describes a single argument for a prediction run.
type RunArgument struct {
	Type    ArgumentType `json:"type"`
	Default *string      `json:"default"`
	Min     *string      `json:"min"`
	Max     *string      `json:"max"`
	Options *[]string    `json:"options"`
	Help    *string      `json:"help"`
}


================================================
FILE: pkg/docker/build_secrets.go
================================================
package docker

import (
	"fmt"
	"path/filepath"
	"strings"

	"github.com/moby/buildkit/session/secrets"
	"github.com/moby/buildkit/session/secrets/secretsprovider"
	"github.com/pkg/errors"
	"github.com/tonistiigi/go-csvvalue"
)

func ParseSecretsFromHost(workingDir string, secrets []string) (secrets.SecretStore, error) {
	sources := make([]secretsprovider.Source, 0, len(secrets))

	for _, secret := range secrets {
		src, err := parseSecretFromHost(workingDir, secret)
		if err != nil {
			return nil, err
		}
		sources = append(sources, *src)
	}

	return secretsprovider.NewStore(sources)
}

func parseSecretFromHost(workingDir, secret string) (*secretsprovider.Source, error) {
	fields, err := csvvalue.Fields(secret, nil)
	if err != nil {
		return nil, fmt.Errorf("failed to parse csv secret: %w", err)
	}

	src := secretsprovider.Source{}

	var typ string
	for _, field := range fields {
		key, value, ok := strings.Cut(field, "=")
		if !ok {
			return nil, errors.Errorf("invalid field %q must be a key=value pair", field)
		}
		key = strings.ToLower(key)
		switch key {
		case "type":
			if value != "file" && value != "env" {
				return nil, errors.Errorf("unsupported secret type %q", value)
			}
			typ = value
		case "id":
			src.ID = value
		case "source", "src":
			if !filepath.IsAbs(value) {
				value = filepath.Join(workingDir, value)
				value, err = filepath.Abs(value)
				if err != nil {
					return nil, fmt.Errorf("failed to get absolute path for %q: %w", value, err)
				}
			}
			src.FilePath = value
		case "env":
			src.Env = value
		default:
			return nil, errors.Errorf("unexpected key '%s' in '%s'", key, field)
		}
	}
	if typ == "env" && src.Env == "" {
		src.Env = src.FilePath
		src.FilePath = ""
	}
	return &src, nil
}


================================================
FILE: pkg/docker/buildkit.go
================================================
package docker

import (
	"context"
	"fmt"
	"os"
	"path/filepath"
	"sync"

	"github.com/docker/docker/api/types/registry"
	buildkitclient "github.com/moby/buildkit/client"
	"github.com/moby/buildkit/session"
	"github.com/moby/buildkit/session/auth"
	"github.com/moby/buildkit/session/secrets/secretsprovider"
	"github.com/moby/buildkit/util/progress/progressui"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"

	cogconfig "github.com/replicate/cog/pkg/config"
	"github.com/replicate/cog/pkg/docker/command"
	"github.com/replicate/cog/pkg/util/console"
)

func prepareDockerfileDir(buildDir string, dockerfileContents string) (string, error) {
	dockerfilePath := filepath.Join(buildDir, "Dockerfile")
	err := os.WriteFile(dockerfilePath, []byte(dockerfileContents), 0o644)
	if err != nil {
		return "", err
	}
	return dockerfilePath, nil
}

func solveOptFromImageOptions(buildDir string, opts command.ImageBuildOptions) (buildkitclient.SolveOpt, error) {
	dockerfilePath, err := prepareDockerfileDir(buildDir, opts.DockerfileContents)
	if err != nil {
		return buildkitclient.SolveOpt{}, err
	}

	// first, configure the frontend, in this case, dockerfile.v0
	frontendAttrs := map[string]string{
		// filename is the path to the Dockerfile within the "dockerfile" LocalDir context
		"filename": filepath.Base(dockerfilePath),
		"syntax":   "docker/dockerfile:1",
		// TODO[md]: support multi-stage target
		// target is the name of a stage in a multi-stage Dockerfile
		// "target": opts.Target,
		// Replicate only supports linux/amd64, but local Docker Engine could be running on ARM,
		// including Apple Silicon. Force it to linux/amd64 for now.
		"platform": "linux/amd64",
	}

	// disable cache if requested
	if opts.NoCache {
		frontendAttrs["no-cache"] = ""
	}

	// add labels to the image
	for k, v := range opts.Labels {
		frontendAttrs["label:"+k] = v
	}

	// add build args to the image
	for k, v := range opts.BuildArgs {
		if v == nil {
			continue
		}
		frontendAttrs["build-arg:"+k] = *v
	}

	// Add SOURCE_DATE_EPOCH if Epoch is set
	if opts.Epoch != nil && *opts.Epoch >= 0 {
		frontendAttrs["build-arg:SOURCE_DATE_EPOCH"] = fmt.Sprintf("%d", *opts.Epoch)
	}

	// Use WorkingDir as context if ContextDir is relative to ensure consistency with CLI client
	contextDir := opts.ContextDir
	if opts.WorkingDir != "" && !filepath.IsAbs(opts.ContextDir) {
		contextDir = filepath.Join(opts.WorkingDir, opts.ContextDir)
	}

	localDirs := map[string]string{
		"dockerfile": filepath.Dir(dockerfilePath),
		"context":    contextDir,
	}

	// Add user-supplied build contexts, but don't overwrite 'dockerfile' or 'context'
	for name, dir := range opts.BuildContexts {
		if name == "dockerfile" || name == "context" {
			console.Warnf("build context name collision: %q", name)
			continue
		}
		localDirs[name] = dir
		// Tell the dockerfile frontend about this build context
		frontendAttrs["context:"+name] = "local:" + name
	}

	// Set exporter attributes
	exporterAttrs := map[string]string{
		"name": opts.ImageName,
	}

	// if SOURCE_DATE_EPOCH is present in the build args, tell the frontend to rewrite timestamps
	if _, ok := frontendAttrs["build-arg:SOURCE_DATE_EPOCH"]; ok {
		exporterAttrs["rewrite-timestamp"] = "true"
	}

	solveOpts := buildkitclient.SolveOpt{
		Frontend:      "dockerfile.v0",
		FrontendAttrs: frontendAttrs,
		LocalDirs:     localDirs,
		// Docker Engine's worker only supports three exporters.
		// "moby" exporter works best for cog, since we want to keep images in
		// Docker Engine's image store. The others are exporting images to somewhere else.
		// https://github.com/moby/moby/blob/v20.10.24/builder/builder-next/worker/worker.go#L221
		Exports: []buildkitclient.ExportEntry{
			{Type: "moby", Attrs: exporterAttrs},
		},
	}

	// add auth provider to the session so the local engine can pull and push images
	solveOpts.Session = append(
		solveOpts.Session,
		newBuildkitAuthProvider("r8.im"),
	)

	// add secrets to the session
	if len(opts.Secrets) > 0 {
		// TODO[md]: support secrets direct from input in addition to env+file
		store, err := ParseSecretsFromHost(opts.WorkingDir, opts.Secrets)
		if err != nil {
			return buildkitclient.SolveOpt{}, fmt.Errorf("failed to parse secrets: %w", err)
		}
		solveOpts.Session = append(solveOpts.Session, secretsprovider.NewSecretProvider(store))
	}

	// Set cache imports/exports to match DockerCommand logic
	// If cogconfig.BuildXCachePath is set, use local cache; otherwise, use inline
	if cogconfig.BuildXCachePath != "" {
		solveOpts.CacheImports = []buildkitclient.CacheOptionsEntry{
			{Type: "local", Attrs: map[string]string{"src": cogconfig.BuildXCachePath}},
		}
		solveOpts.CacheExports = []buildkitclient.CacheOptionsEntry{
			{Type: "local", Attrs: map[string]string{"dest": cogconfig.BuildXCachePath}},
		}
	} else {
		solveOpts.CacheExports = []buildkitclient.CacheOptionsEntry{
			{Type: "inline"},
		}
	}

	return solveOpts, nil
}

func newDisplay(statusCh chan *buildkitclient.SolveStatus, displayMode string) func() error {
	return func() error {
		display, err := progressui.NewDisplay(
			os.Stderr,
			progressui.DisplayMode(displayMode),
			// progressui.WithPhase("BUILDINGGGGG"),
			// progressui.WithDesc("SOMETEXT", "SOMECONSOLE"),
		)
		if err != nil {
			return err
		}

		// UpdateFrom must not use the incoming context.
		// Canceling this context kills the reader of statusCh which blocks buildkit.Client's Solve() indefinitely.
		// Solve() closes statusCh at the end and UpdateFrom returns by reading the closed channel.
		//
		// See https://github.com/superfly/flyctl/pull/2682 for the context.
		_, err = display.UpdateFrom(context.Background(), statusCh)
		return err
	}
}

func newBuildkitAuthProvider(registryHosts ...string) session.Attachable {
	return &buildkitAuthProvider{
		registryHosts: sync.OnceValues(func() (map[string]registry.AuthConfig, error) {
			return loadRegistryAuths(context.Background(), registryHosts...)
		}),
		// TODO[md]: here's where we'd set the token from config rather than fetching from the credentials helper
		// token: token,
	}
}

type buildkitAuthProvider struct {
	registryHosts func() (map[string]registry.AuthConfig, error)
}

func (ap *buildkitAuthProvider) Register(server *grpc.Server) {
	auth.RegisterAuthServer(server, ap)
}

func (ap *buildkitAuthProvider) Credentials(ctx context.Context, req *auth.CredentialsRequest) (*auth.CredentialsResponse, error) {
	auths, err := ap.registryHosts()
	if err != nil {
		return nil, fmt.Errorf("failed to load registry auth configs: %w", err)
	}
	res := &auth.CredentialsResponse{}
	if a, ok := auths[req.Host]; ok {
		res.Username = a.Username
		res.Secret = a.Password
	}

	return res, nil
}

func (ap *buildkitAuthProvider) FetchToken(ctx context.Context, req *auth.FetchTokenRequest) (*auth.FetchTokenResponse, error) {
	return nil, status.Errorf(codes.Unavailable, "client side tokens disabled")
}

func (ap *buildkitAuthProvider) GetTokenAuthority(ctx context.Context, req *auth.GetTokenAuthorityRequest) (*auth.GetTokenAuthorityResponse, error) {
	return nil, status.Errorf(codes.Unavailable, "client side tokens disabled")
}

func (ap *buildkitAuthProvider) VerifyTokenAuthority(ctx context.Context, req *auth.VerifyTokenAuthorityRequest) (*auth.VerifyTokenAuthorityResponse, error) {
	return nil, status.Errorf(codes.Unavailable, "client side tokens disabled")
}


================================================
FILE: pkg/docker/command/command.go
================================================
package command

import (
	"context"
	"io"

	"github.com/docker/docker/api/types/container"
	"github.com/docker/docker/api/types/image"
)

type Command interface {
	// Pull pulls an image from a remote registry and returns the inspect response for the local image.
	// If the image already exists, it will return the inspect response for the local image without pulling.
	// When force is true, it will always attempt to pull the image.
	Pull(ctx context.Context, ref string, force bool) (*image.InspectResponse, error)
	Push(ctx context.Context, ref string) error
	RemoveImage(ctx context.Context, ref string) error
	LoadUserInformation(ctx context.Context, registryHost string) (*UserInfo, error)
	Inspect(ctx context.Context, ref string) (*image.InspectResponse, error)
	ImageExists(ctx context.Context, ref string) (bool, error)
	ContainerLogs(ctx context.Context, containerID string, w io.Writer) error
	ContainerInspect(ctx context.Context, id string) (*container.InspectResponse, error)
	ContainerStop(ctx context.Context, containerID string) error

	// ImageBuild builds an image and returns the image ID (sha256:...) on success.
	ImageBuild(ctx context.Context, options ImageBuildOptions) (string, error)
	Run(ctx context.Context, options RunOptions) error
	ContainerStart(ctx context.Context, options RunOptions) (string, error)

	// ImageSave exports a Docker image as a tar stream.
	// The caller must close the returned ReadCloser.
	ImageSave(ctx context.Context, imageRef string) (io.ReadCloser, error)
}

type ImageBuildOptions struct {
	WorkingDir         string
	DockerfileContents string
	// TODO[md]: ImageName should be renamed to Tag
	ImageName string
	// Secrets in the format of "id=foo,src=/path/to/file" or "id=kube,env=KUBECONFIG"
	// docs: https://docs.docker.com/build/building/secrets/#use-secrets-in-dockerfile
	Secrets        []string
	NoCache        bool
	ProgressOutput string
	Epoch          *int64
	ContextDir     string
	BuildContexts  map[string]string
	Labels         map[string]string

	// only supported on buildkit client, not cli client
	BuildArgs map[string]*string
}

type RunOptions struct {
	Detach     bool
	Args       []string
	Env        []string
	GPUs       string
	Image      string
	Ports      []Port
	Volumes    []Volume
	Workdir    string
	ExtraHosts []string
	Stdin      io.Reader
	Stdout     io.Writer
	Stderr     io.Writer
}

type Port struct {
	HostPort      int
	ContainerPort int
}

type Volume struct {
	Source      string
	Destination string
}


================================================
FILE: pkg/docker/command/errors.go
================================================
package command

import (
	"errors"
	"fmt"
)

// NotFoundError represents “object  wasn’t found” inside the Docker engine.
type NotFoundError struct {
	// Ref is a unique identifier, such as an image reference, container ID, etc.
	Ref string
	// Object is the ref type, such as "container", "image", "volume", etc.
	Object string
}

func (e *NotFoundError) Error() string {
	objType := e.Object
	if objType == "" {
		objType = "object"
	}
	return fmt.Sprintf("%s not found: %q", objType, e.Ref)
}

func (e *NotFoundError) Is(target error) bool {
	_, ok := target.(*NotFoundError)
	return ok
}

func IsNotFoundError(err error) bool {
	return errors.Is(err, &NotFoundError{})
}

var ErrAuthorizationFailed = errors.New("authorization failed")


================================================
FILE: pkg/docker/command/manifest.go
================================================
package command

import "github.com/replicate/cog/pkg/global"

type Config struct {
	Labels map[string]string `json:"Labels"`
	Env    []string          `json:"Env"`
}

type Manifest struct {
	Config Config `json:"Config"`
}

const (
	R8CogVersionEnvVarName    = "R8_COG_VERSION"
	R8TorchVersionEnvVarName  = "R8_TORCH_VERSION"
	R8CudaVersionEnvVarName   = "R8_CUDA_VERSION"
	R8CudnnVersionEnvVarName  = "R8_CUDNN_VERSION"
	R8PythonVersionEnvVarName = "R8_PYTHON_VERSION"
)

var (
	CogConfigLabelKey          = global.LabelNamespace + "config"
	CogVersionLabelKey         = global.LabelNamespace + "version"
	CogOpenAPISchemaLabelKey   = global.LabelNamespace + "openapi_schema"
	CogWeightsManifestLabelKey = global.LabelNamespace + "r8_weights_manifest"
)


================================================
FILE: pkg/docker/command/user_info.go
================================================
package command

type UserInfo struct {
	Token    string
	Username string
}


================================================
FILE: pkg/docker/credential_helper_input.go
================================================
package docker

type CredentialHelperInput struct {
	Username  string
	Secret    string //nolint:gosec // G117: this is a Docker credential, not a hardcoded secret
	ServerURL string
}


================================================
FILE: pkg/docker/credentials.go
================================================
package docker

import (
	"context"
	"encoding/json"
	"fmt"
	"io"
	"os"
	"os/exec"
	"strings"

	"github.com/docker/cli/cli/config"
	"github.com/docker/cli/cli/config/configfile"
	"github.com/docker/cli/cli/config/types"
	"github.com/docker/docker/api/types/registry"

	"github.com/replicate/cog/pkg/docker/command"
	"github.com/replicate/cog/pkg/global"
	"github.com/replicate/cog/pkg/util/console"
)

func loadUserInformation(ctx context.Context, registryHost string) (*command.UserInfo, error) {
	conf := config.LoadDefaultConfigFile(os.Stderr)
	credsStore := conf.CredentialsStore
	if credsStore == "" {
		authConf, err := loadAuthFromConfig(conf, registryHost)
		if err != nil {
			return nil, err
		}
		return &command.UserInfo{
			Token:    authConf.Password,
			Username: authConf.Username,
		}, nil
	}
	credsHelper, err := loadAuthFromCredentialsStore(ctx, credsStore, registryHost)
	if err != nil {
		return nil, err
	}
	return &command.UserInfo{
		Token:    credsHelper.Secret,
		Username: credsHelper.Username,
	}, nil
}

func loadAuthFromConfig(conf *configfile.ConfigFile, registryHost string) (types.AuthConfig, error) {
	return conf.AuthConfigs[registryHost], nil
}

func loadRegistryAuths(ctx context.Context, registryHosts ...string) (map[string]registry.AuthConfig, error) {
	conf := config.LoadDefaultConfigFile(os.Stderr)
	out := make(map[string]registry.AuthConfig)

	for _, host := range registryHosts {
		// Try loading auth for the requested host
		auth, err := tryLoadAuthForHost(ctx, conf, host)
		if err == nil && auth != nil {
			out[host] = *auth
			continue
		}

		// FALLBACK: If requesting alternate registry and no auth found,
		// try reusing r8.im credentials
		if host != global.DefaultReplicateRegistryHost {
			auth, err := tryLoadAuthForHost(ctx, conf, global.DefaultReplicateRegistryHost)
			if err == nil && auth != nil {
				// Reuse credentials for the alternate registry
				auth.ServerAddress = host // Update to new host
				out[host] = *auth
				console.Infof("Using existing %s credentials for %s", global.DefaultReplicateRegistryHost, host)
				continue
			}
		}
	}

	return out, nil
}

func tryLoadAuthForHost(ctx context.Context, conf *configfile.ConfigFile, host string) (*registry.AuthConfig, error) {
	// Try credentials store first (e.g., osxkeychain, pass)
	if conf.CredentialsStore != "" {
		credsHelper, err := loadAuthFromCredentialsStore(ctx, conf.CredentialsStore, host)
		if err == nil {
			return ®istry.AuthConfig{
				Username:      credsHelper.Username,
				Password:      credsHelper.Secret,
				ServerAddress: host,
			}, nil
		}
	}

	// Fallback to config file
	if auth, ok := conf.AuthConfigs[host]; ok {
		return ®istry.AuthConfig{
			Username:      auth.Username,
			Password:      auth.Password,
			Auth:          auth.Auth,
			ServerAddress: host,
			IdentityToken: auth.IdentityToken,
			RegistryToken: auth.RegistryToken,
		}, nil
	}

	return nil, fmt.Errorf("no credentials found for %s", host)
}

func loadAuthFromCredentialsStore(ctx context.Context, credsStore string, registryHost string) (*CredentialHelperInput, error) {
	var out strings.Builder
	binary := dockerCredentialBinary(credsStore)
	cmd := exec.CommandContext(ctx, binary, "get") //nolint:gosec // G702: binary is from Docker config, not user input
	cmd.Env = os.Environ()
	cmd.Stdout = &out
	cmd.Stderr = &out
	stdin, err := cmd.StdinPipe()
	if err != nil {
		return nil, err
	}
	defer stdin.Close()
	console.Debug("$ " + strings.Join(cmd.Args, " "))
	err = cmd.Start()
	if err != nil {
		return nil, err
	}
	_, err = io.WriteString(stdin, registryHost)
	if err != nil {
		return nil, err
	}
	err = stdin.Close()
	if err != nil {
		return nil, err
	}
	err = cmd.Wait()
	if err != nil {
		output := strings.TrimSpace(out.String())
		if output != "" {
			return nil, fmt.Errorf("failed to get credentials for %q: %s", registryHost, output)
		}
		return nil, fmt.Errorf("failed to get credentials for %q: %w", registryHost, err)
	}

	var config CredentialHelperInput
	err = json.Unmarshal([]byte(out.String()), &config)
	if err != nil {
		return nil, err
	}

	return &config, nil
}

func dockerCredentialBinary(credsStore string) string {
	return "docker-credential-" + credsStore
}


================================================
FILE: pkg/docker/credentials_test.go
================================================
package docker

import (
	"context"
	"path/filepath"
	"testing"

	"github.com/docker/cli/cli/config/configfile"
	"github.com/docker/cli/cli/config/types"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"

	"github.com/replicate/cog/pkg/global"
)

func TestLoadRegistryAuths_Fallback(t *testing.T) {
	ctx := context.Background()

	t.Run("uses credentials for requested host when available", func(t *testing.T) {
		// Create a mock config with credentials for the requested host
		conf := &configfile.ConfigFile{
			AuthConfigs: map[string]types.AuthConfig{
				"registry.example.com": {
					Username: "user1",
					Password: "pass1",
				},
			},
		}

		auth, err := tryLoadAuthForHost(ctx, conf, "registry.example.com")
		require.NoError(t, err)
		require.NotNil(t, auth)
		assert.Equal(t, "user1", auth.Username)
		assert.Equal(t, "pass1", auth.Password)
		assert.Equal(t, "registry.example.com", auth.ServerAddress)
	})

	t.Run("falls back to default registry credentials when alternate registry has no credentials", func(t *testing.T) {
		// Set up a temporary docker config file
		tmpDir := t.TempDir()
		dockerConfigPath := filepath.Join(tmpDir, "config.json")

		// Create a config file with credentials only for the default registry
		conf := &configfile.ConfigFile{
			Filename: dockerConfigPath,
			AuthConfigs: map[string]types.AuthConfig{
				global.DefaultReplicateRegistryHost: {
					Username: "defaultuser",
					Password: "defaultpass",
				},
			},
		}
		require.NoError(t, conf.Save())

		// Point Docker to our test config
		t.Setenv("DOCKER_CONFIG", tmpDir)

		// Try loading auth for an alternate registry that doesn't have credentials
		auths, err := loadRegistryAuths(ctx, "registry.example.com")
		require.NoError(t, err)
		require.NotNil(t, auths)

		// Should have fallen back to default registry credentials
		auth, ok := auths["registry.example.com"]
		require.True(t, ok, "should have auth for registry.example.com")
		assert.Equal(t, "defaultuser", auth.Username)
		assert.Equal(t, "defaultpass", auth.Password)
		assert.Equal(t, "registry.example.com", auth.ServerAddress, "server address should be updated to the requested host")
	})

	t.Run("does not fallback when requesting default registry", func(t *testing.T) {
		// This test uses tryLoadAuthForHost directly to avoid credential store issues
		conf := &configfile.ConfigFile{
			AuthConfigs: map[string]types.AuthConfig{},
		}

		// Try loading auth for the default registry
		auth, err := tryLoadAuthForHost(ctx, conf, global.DefaultReplicateRegistryHost)
		require.Error(t, err, "should error when no credentials found")
		assert.Nil(t, auth)
		assert.Contains(t, err.Error(), "no credentials found")
	})

	t.Run("prefers direct credentials over fallback", func(t *testing.T) {
		// Create a mock config with credentials for both registries
		conf := &configfile.ConfigFile{
			AuthConfigs: map[string]types.AuthConfig{
				global.DefaultReplicateRegistryHost: {
					Username: "defaultuser",
					Password: "defaultpass",
				},
				"registry.example.com": {
					Username: "directuser",
					Password: "directpass",
				},
			},
		}

		// Try loading auth for the alternate registry
		auth, err := tryLoadAuthForHost(ctx, conf, "registry.example.com")
		require.NoError(t, err)
		require.NotNil(t, auth)

		// Should use direct credentials, not fallback
		assert.Equal(t, "directuser", auth.Username)
		assert.Equal(t, "directpass", auth.Password)
		assert.Equal(t, "registry.example.com", auth.ServerAddress)
	})

	t.Run("returns empty map when no credentials available", func(t *testing.T) {
		// This test uses tryLoadAuthForHost to avoid credential store issues
		// The loadRegistryAuths function doesn't error when no credentials are found,
		// it just returns an empty map
		conf := &configfile.ConfigFile{
			AuthConfigs: map[string]types.AuthConfig{},
		}

		// Try loading auth for an alternate registry (will fail)
		auth1, err := tryLoadAuthForHost(ctx, conf, "registry.example.com")
		require.Error(t, err)
		assert.Nil(t, auth1)

		// Try loading auth for default registry (will also fail)
		auth2, err := tryLoadAuthForHost(ctx, conf, global.DefaultReplicateRegistryHost)
		require.Error(t, err)
		assert.Nil(t, auth2)

		// Since both fail, loadRegistryAuths would return an empty map
		// (it doesn't error, just silently skips hosts without credentials)
	})
}

func TestTryLoadAuthForHost(t *testing.T) {
	ctx := context.Background()

	t.Run("loads auth from config file", func(t *testing.T) {
		conf := &configfile.ConfigFile{
			AuthConfigs: map[string]types.AuthConfig{
				"registry.example.com": {
					Username: "testuser",
					Password: "testpass",
					Auth:     "dGVzdHVzZXI6dGVzdHBhc3M=",
				},
			},
		}

		auth, err := tryLoadAuthForHost(ctx, conf, "registry.example.com")
		require.NoError(t, err)
		require.NotNil(t, auth)
		assert.Equal(t, "testuser", auth.Username)
		assert.Equal(t, "testpass", auth.Password)
		assert.Equal(t, "dGVzdHVzZXI6dGVzdHBhc3M=", auth.Auth)
		assert.Equal(t, "registry.example.com", auth.ServerAddress)
	})

	t.Run("returns error when no auth found", func(t *testing.T) {
		conf := &configfile.ConfigFile{
			AuthConfigs: map[string]types.AuthConfig{},
		}

		auth, err := tryLoadAuthForHost(ctx, conf, "registry.example.com")
		require.Error(t, err)
		assert.Nil(t, auth)
		assert.Contains(t, err.Error(), "no credentials found")
	})
}


================================================
FILE: pkg/docker/docker.go
================================================
package docker

import (
	"context"
	"errors"
	"fmt"
	"io"
	"net"
	"os"
	"strconv"
	"strings"

	"github.com/containerd/errdefs"
	"github.com/docker/docker/api/types"
	"github.com/docker/docker/api/types/container"
	"github.com/docker/docker/api/types/image"
	"github.com/docker/docker/api/types/network"
	"github.com/docker/docker/api/types/registry"
	"github.com/docker/docker/client"
	"github.com/docker/docker/pkg/jsonmessage"
	"github.com/docker/docker/pkg/stdcopy"
	"github.com/docker/go-connections/nat"
	"github.com/google/go-containerregistry/pkg/name"
	"github.com/mattn/go-isatty"
	buildkitclient "github.com/moby/buildkit/client"
	"github.com/moby/buildkit/exporter/containerimage/exptypes"
	"github.com/moby/term"
	ocispec "github.com/opencontainers/image-spec/specs-go/v1"
	"golang.org/x/sync/errgroup"

	"github.com/replicate/cog/pkg/docker/command"
	"github.com/replicate/cog/pkg/global"
	"github.com/replicate/cog/pkg/util/console"
)

func ptrVal[T any](v T) *T { return &v }

func NewClient(ctx context.Context, opts ...Option) (*apiClient, error) {
	clientOptions := &clientOptions{
		authConfigs: make(map[string]registry.AuthConfig),
	}
	for _, opt := range opts {
		opt(clientOptions)
	}

	if clientOptions.host == "" {
		host, err := determineDockerHost()
		if err != nil {
			return nil, fmt.Errorf("error determining docker host: %w", err)
		}
		clientOptions.host = host
	}

	// TODO[md]: we create a client at the top of each cli invocation, the sdk client hits an api which
	// adds (a tiny biy of) overead. swap this with a handle that'll lazily initialize a client and ping for health.
	// ditto for fetching registry credentials.

	dockerClientOpts := []client.Opt{
		client.WithTLSClientConfigFromEnv(),
		client.WithVersionFromEnv(),
		client.WithAPIVersionNegotiation(),
		client.WithHost(clientOptions.host),
	}

	client, err := client.NewClientWithOpts(dockerClientOpts...)
	if err != nil {
		return nil, fmt.Errorf("error creating docker client: %w", err)
	}

	if _, err := client.Ping(ctx); err != nil {
		return nil, fmt.Errorf("error pinging docker daemon: %w", err)
	}

	// Load authentication for configured registry and any other registries that might be needed
	authConfig, err := loadRegistryAuths(ctx, global.ReplicateRegistryHost)
	if err != nil {
		return nil, fmt.Errorf("error loading user information: %w, you may need to authenticate using cog login", err)
	}

	// Add any additional auth configs passed via options
	for _, opt := range clientOptions.authConfigs {
		authConfig[opt.ServerAddress] = opt
	}

	return &apiClient{client, authConfig}, nil
}

type apiClient struct {
	client     *client.Client
	authConfig map[string]registry.AuthConfig
}

func (c *apiClient) Pull(ctx context.Context, imageRef string, force bool) (*image.InspectResponse, error) {
	console.Debugf("=== APIClient.Pull %s force:%t", imageRef, force)

	if !force {
		inspect, err := c.Inspect(ctx, imageRef)
		if err == nil {
			return inspect, nil
		} else if !command.IsNotFoundError(err) {
			// Log a warning if inspect fails for any reason other than not found.
			// It's likely that pull will fail as well, but it's better to return that error
			// so the caller can handle it appropriately than to fail silently here.
			console.Warnf("failed to inspect image before pulling %q: %s", imageRef, err)
		}
	}

	output, err := c.client.ImagePull(ctx, imageRef, image.PullOptions{
		// force image to linux/amd64 to match production
		Platform: "linux/amd64",
	})
	if err != nil {
		if errdefs.IsNotFound(err) {
			return nil, &command.NotFoundError{Ref: imageRef, Object: "image"}
		}
		return nil, fmt.Errorf("failed to pull image %q: %w", imageRef, err)
	}
	defer output.Close()
	_, err = io.Copy(os.Stderr, output)
	if err != nil {
		return nil, fmt.Errorf("failed to copy pull output: %w", err)
	}

	// pull succeeded, inspect the image again and return
	inspect, err := c.Inspect(ctx, imageRef)
	if err != nil {
		return nil, fmt.Errorf("failed to inspect image after pulling %q: %w", imageRef, err)
	}
	return inspect, nil
}

func (c *apiClient) ContainerStop(ctx context.Context, containerID string) error {
	console.Debugf("=== APIClient.ContainerStop %s", containerID)

	err := c.client.ContainerStop(ctx, containerID, container.StopOptions{
		Timeout: ptrVal(3),
	})
	if err != nil {
		if errdefs.IsNotFound(err) {
			return &command.NotFoundError{Ref: containerID, Object: "container"}
		}
		return fmt.Errorf("failed to stop container %q: %w", containerID, err)
	}
	return nil
}

func (c *apiClient) ContainerInspect(ctx context.Context, containerID string) (*container.InspectResponse, error) {
	console.Debugf("=== APIClient.ContainerInspect %s", containerID)

	resp, err := c.client.ContainerInspect(ctx, containerID)
	if err != nil {
		if errdefs.IsNotFound(err) {
			return nil, &command.NotFoundError{Ref: containerID, Object: "container"}
		}
		return nil, fmt.Errorf("failed to inspect container %q: %w", containerID, err)
	}
	return &resp, nil
}

func (c *apiClient) ContainerLogs(ctx context.Context, containerID string, w io.Writer) error {
	console.Debugf("=== APIClient.ContainerLogs %s", containerID)

	// First inspect the container to check if it has TTY enabled
	inspect, err := c.ContainerInspect(ctx, containerID)
	if err != nil {
		return fmt.Errorf("failed to inspect container %q: %w", containerID, err)
	}

	logs, err := c.client.ContainerLogs(ctx, containerID, container.LogsOptions{
		ShowStdout: true,
		ShowStderr: true,
		Follow:     true,
	})
	if err != nil {
		if errdefs.IsNotFound(err) {
			return &command.NotFoundError{Ref: containerID, Object: "container"}
		}
		return fmt.Errorf("failed to get container logs for %q: %w", containerID, err)
	}
	defer logs.Close()

	// If TTY is enabled, we can just copy the logs directly
	if inspect.Config.Tty {
		if _, err = io.Copy(w, logs); err != nil {
			return fmt.Errorf("failed to copy logs: %w", err)
		}
		return nil
	}

	// For non-TTY containers, use StdCopy to demultiplex stdout and stderr
	if _, err = stdcopy.StdCopy(w, w, logs); err != nil {
		return fmt.Errorf("failed to copy logs: %w", err)
	}
	return nil
}

func (c *apiClient) Push(ctx context.Context, imageRef string) error {
	console.Debugf("=== APIClient.Push %s", imageRef)

	parsedName, err := name.ParseReference(imageRef)
	if err != nil {
		return fmt.Errorf("failed to parse image reference: %w", err)
	}

	console.Debugf("fully qualified image ref: %s", parsedName)

	// eagerly set auth config, or do it async
	var authConfig registry.AuthConfig
	registryHost := parsedName.Context().RegistryStr()
	if auth, ok := c.authConfig[registryHost]; ok {
		authConfig = auth
	} else {
		// Dynamically load authentication for this registry if not already loaded
		authConfigs, err := loadRegistryAuths(ctx, registryHost)
		if err == nil {
			if auth, ok := authConfigs[registryHost]; ok {
				authConfig = auth
				// Cache the auth config for future use
				c.authConfig[registryHost] = auth
			}
		}
	}

	var opts image.PushOptions
	encodedAuth, err := registry.EncodeAuthConfig(authConfig)
	if err != nil {
		return fmt.Errorf("failed to encode auth config: %w", err)
	}
	opts.RegistryAuth = encodedAuth

	output, err := c.client.ImagePush(ctx, imageRef, opts)
	if err != nil {
		return fmt.Errorf("failed to push image: %w", err)
	}
	defer output.Close()

	// output is a json stream, so we need to parse it, handle errors, and write progress to stderr
	isTTY := console.IsTTY(os.Stderr)
	if err := jsonmessage.DisplayJSONMessagesStream(output, os.Stderr, os.Stderr.Fd(), isTTY, nil); err != nil {
		var streamErr *jsonmessage.JSONError
		if errors.As(err, &streamErr) {
			if isTagNotFoundError(err) {
				return &command.NotFoundError{Ref: imageRef, Object: "tag"}
			}
			if isRepositoryNotFoundError(err) {
				return &command.NotFoundError{Ref: imageRef, Object: "repository"}
			}
			if isAuthorizationFailedError(err) {
				return command.ErrAuthorizationFailed
			}
		}
		return fmt.Errorf("error during image push: %w", err)
	}

	return nil
}

func (c *apiClient) ImageSave(ctx context.Context, imageRef string) (io.ReadCloser, error) {
	console.Debugf("=== APIClient.ImageSave %s", imageRef)
	return c.client.ImageSave(ctx, []string{imageRef})
}

// TODO[md]: this doesn't need to be on the interface, move to auth handler
func (c *apiClient) LoadUserInformation(ctx context.Context, registryHost string) (*command.UserInfo, error) {
	console.Debugf("=== APIClient.LoadUserInformation %s", registryHost)

	return loadUserInformation(ctx, registryHost)
}

func (c *apiClient) Inspect(ctx context.Context, ref string) (*image.InspectResponse, error) {
	console.Debugf("=== APIClient.Inspect %s", ref)

	// TODO[md]: platform requires engine 1.49+, and it's not widly available as of 2025-05.
	// 	platform := ocispec.Platform{OS: "linux", Architecture: "amd64"}
	//  client.ImageInspectWithPlatform(&platform),
	inspect, err := c.client.ImageInspect(ctx, ref)

	if err != nil {
		if errdefs.IsNotFound(err) {
			return nil, &command.NotFoundError{Ref: ref, Object: "image"}
		}
		return nil, fmt.Errorf("error inspecting image: %w", err)
	}

	return &inspect, nil
}

func (c *apiClient) RemoveImage(ctx context.Context, ref string) error {
	console.Debugf("=== APIClient.RemoveImage %s", ref)

	resp, err := c.client.ImageRemove(ctx, ref, image.RemoveOptions{})
	if err != nil {
		return err
	}

	if len(resp) == 0 {
		return &command.NotFoundError{Ref: ref, Object: "image"}
	}
	return nil
}

func (c *apiClient) ImageExists(ctx context.Context, ref string) (bool, error) {
	console.Debugf("=== APIClient.ImageExists %s", ref)

	_, err := c.Inspect(ctx, ref)
	if err != nil {
		if command.IsNotFoundError(err) {
			return false, nil
		}
		return false, err
	}
	return true, nil
}

func (c *apiClient) ImageBuild(ctx context.Context, options command.ImageBuildOptions) (string, error) {
	console.Debugf("=== APIClient.ImageBuild %s", options.ImageName)

	buildDir, err := os.MkdirTemp("", "cog-build")
	if err != nil {
		return "", err
	}
	defer os.RemoveAll(buildDir)

	bc, err := buildkitclient.New(ctx, "",
		// Connect to Docker Engine's embedded Buildkit.
		buildkitclient.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) {
			return c.client.DialHijack(ctx, "/grpc", "h2c", map[string][]string{})
		}),
	)
	if err != nil {
		return "", err
	}

	statusCh := make(chan *buildkitclient.SolveStatus)
	var res *buildkitclient.SolveResponse

	// Determine display mode: options.ProgressOutput > env > 'auto'
	displayMode := options.ProgressOutput
	if displayMode == "" {
		displayMode = os.Getenv("BUILDKIT_PROGRESS")
	}
	if displayMode == "" {
		displayMode = "auto"
	}

	// Build the image.
	eg, ctx := errgroup.WithContext(ctx)

	// run the build in a goroutine
	eg.Go(func() error {
		options, err := solveOptFromImageOptions(buildDir, options)
		if err != nil {
			return err
		}

		// run the display in a goroutine _after_ we've built SolveOpt
		eg.Go(newDisplay(statusCh, displayMode))

		res, err = bc.Solve(ctx, nil, options, statusCh)
		if err != nil {
			return err
		}
		return nil
	})
	err = eg.Wait()

	if err != nil {
		return "", err
	}

	imageID := res.ExporterResponse[exptypes.ExporterImageDigestKey]
	if imageID == "" {
		return "", fmt.Errorf("buildkit did not return an image digest")
	}
	console.Debugf("image digest %s", imageID)

	return imageID, nil
}

func (c *apiClient) containerRun(ctx context.Context, options command.RunOptions) (string, error) {
	console.Debugf("=== APIClient.containerRun %s", options.Image)

	var attachStdin, tty, attachStderr, attachStdout bool
	if !options.Detach {
		// Determine if we should attach stdin (file, pipe, interactive stdin, etc)
		attachStdin, tty = shouldAttachStdin(options.Stdin)
		attachStdout = options.Stdout != nil
		attachStderr = options.Stderr != nil
	}

	containerCfg := &container.Config{
		Image:        options.Image,
		Cmd:          options.Args,
		Env:          options.Env,
		AttachStdin:  attachStdin,
		AttachStdout: attachStdout,
		AttachStderr: attachStderr,
		OpenStdin:    attachStdin,
		StdinOnce:    attachStdin,
		Tty:          tty,
	}

	// Set working directory if specified
	if options.Workdir != "" {
		containerCfg.WorkingDir = options.Workdir
	}

	if len(options.Ports) > 0 {
		containerCfg.ExposedPorts = make(nat.PortSet)
		for _, port := range options.Ports {
			containerPort := nat.Port(fmt.Sprintf("%d/tcp", port.ContainerPort))
			containerCfg.ExposedPorts[containerPort] = struct{}{}
		}
	}

	hostCfg := &container.HostConfig{
		// always remove container after it exits
		AutoRemove: true,
		// https://github.com/pytorch/pytorch/issues/2244
		// https://github.com/replicate/cog/issues/1293
		ShmSize:   6 * 1024 * 1024 * 1024, // 6GB
		Resources: container.Resources{},
	}

	if options.GPUs != "" {
		deviceRequest, err := parseGPURequest(options)
		if err != nil {
			return "", err
		}
		hostCfg.DeviceRequests = []container.DeviceRequest{deviceRequest}
	}

	// Configure port bindings
	if len(options.Ports) > 0 {
		hostCfg.PortBindings = make(nat.PortMap)
		for _, port := range options.Ports {
			containerPort := nat.Port(fmt.Sprintf("%d/tcp", port.ContainerPort))
			hostCfg.PortBindings[containerPort] = []nat.PortBinding{
				{
					HostIP:   "", // use empty string to bind to all interfaces
					HostPort: strconv.Itoa(port.HostPort),
				},
			}
		}
	}

	// Configure volume bindings
	if len(options.Volumes) > 0 {
		hostCfg.Binds = make([]string, len(options.Volumes))
		for i, volume := range options.Volumes {
			hostCfg.Binds[i] = fmt.Sprintf("%s:%s", volume.Source, volume.Destination)
		}
	}

	// Configure extra hosts (e.g. host.docker.internal on Linux)
	if len(options.ExtraHosts) > 0 {
		hostCfg.ExtraHosts = options.ExtraHosts
	}

	networkingCfg := &network.NetworkingConfig{
		EndpointsConfig: map[string]*network.EndpointSettings{},
	}

	platform := &ocispec.Platform{
		// force platform to linux/amd64
		Architecture: "amd64",
		OS:           "linux",
	}

	runContainer, err := c.client.ContainerCreate(ctx,
		containerCfg,
		hostCfg,
		networkingCfg,
		platform,
		"")
	if err != nil {
		return "", fmt.Errorf("failed to create container: %w", err)
	}
	// TODO[md]: ensure the container is removed if start & auto-remove fails

	console.Debugf("container id: %s", runContainer.ID)

	// Create error group for stream copying
	var eg *errgroup.Group
	var stream types.HijackedResponse

	// Attach to container streams if we have any writers and not detached
	if attachStderr || attachStdout || attachStdin {
		attachOpts := container.AttachOptions{
			Stream: true,
			Stdin:  attachStdin,
			Stdout: attachStdout,
			Stderr: attachStderr,
		}

		var err error
		stream, err = c.client.ContainerAttach(ctx, runContainer.ID, attachOpts)
		if err != nil {
			return "", fmt.Errorf("failed to attach to container: %w", err)
		}
		defer stream.Close()

		// Start copying streams in the background
		eg, _ = errgroup.WithContext(ctx)
		if attachStdout || attachStderr {
			eg.Go(func() (err error) {
				if containerCfg.Tty {
					w := options.Stdout
					if w == nil {
						w = options.Stderr
					}
					_, err = io.Copy(w, stream.Reader)
				} else {
					_, err = stdcopy.StdCopy(options.Stdout, options.Stderr, stream.Reader)
				}
				return err
			})
		}
		if attachStdin {
			// if we're in a TTY we need to set the terminal to raw mode, and restore it when we're done
			if tty {
				// TODO[md]: handle terminal resize events, see: github.com/containerd/console
				state, err := term.SetRawTerminal(os.Stdin.Fd())
				if err != nil {
					console.Warnf("error setting raw terminal on stdin: %s", err)
				}
				defer func() {
					if err := term.RestoreTerminal(os.Stdin.Fd(), state); err != nil {
						console.Warnf("error restoring terminal on stdin: %s", err)
					}
				}()
			}

			go func() {
				_, err := io.Copy(stream.Conn, options.Stdin)
				// Close the stdin stream to signal EOF to the container
				if err := errors.Join(err, stream.CloseWrite()); err != nil {
					console.Errorf("error copying and closing stdin stream: %s", err)
				}
			}()
		}
	}

	// Start the container
	if err := c.client.ContainerStart(ctx, runContainer.ID, container.StartOptions{}); err != nil {
		if isMissingDeviceDriverError(err) {
			return "", ErrMissingDeviceDriver
		}
		return "", fmt.Errorf("failed to start container: %w", err)
	}

	// If detached, wait for container to be running before returning
	if options.Detach {
		return runContainer.ID, nil
	}

	// Wait for the container to exit
	statusCh, errCh := c.client.ContainerWait(ctx, runContainer.ID, container.WaitConditionNotRunning)
	select {
	case err := <-errCh:
		return "", fmt.Errorf("error waiting for container: %w", err)
	case status := <-statusCh:
		if status.StatusCode != 0 {
			return "", fmt.Errorf("container exited with status %d", status.StatusCode)
		}
	}

	// container is gone, close the attached streams so stdin is released, ignore the error
	_ = stream.CloseWrite()

	// Wait for stream copying to complete
	if eg != nil {
		if err := eg.Wait(); err != nil {
			return "", fmt.Errorf("error copying streams: %w", err)
		}
	}

	return runContainer.ID, nil
}

func (c *apiClient) Run(ctx context.Context, options command.RunOptions) error {
	console.Debugf("=== APIClient.Run %s", options.Image)

	if options.Stdout == nil {
		options.Stdout = os.Stdout
	}
	if options.Stderr == nil {
		options.Stderr = os.Stderr
	}

	_, err := c.containerRun(ctx, options)
	return err
}

func (c *apiClient) ContainerStart(ctx context.Context, options command.RunOptions) (string, error) {
	console.Debugf("=== APIClient.ContainerStart %s", options.Image)

	options.Detach = true
	id, err := c.containerRun(ctx, options)
	return id, err
}

// parseGPURequest converts a Docker CLI --gpus string into a DeviceRequest slice
func parseGPURequest(opts command.RunOptions) (container.DeviceRequest, error) {
	if opts.GPUs == "" {
		return container.DeviceRequest{}, nil
	}

	deviceRequest := container.DeviceRequest{
		Driver:       "nvidia",
		Capabilities: [][]string{{"gpu"}},
	}

	// Parse the GPUs string
	switch opts.GPUs {
	case "all":
		deviceRequest.Count = -1 // Use all available GPUs
	default:
		// Check if it's a number
		if count, err := strconv.Atoi(opts.GPUs); err == nil {
			deviceRequest.Count = count
		} else if after, ok := strings.CutPrefix(opts.GPUs, "device="); ok {
			// Handle device=0,1 format
			devices := after
			deviceRequest.DeviceIDs = strings.Split(devices, ",")
		} else {
			// Invalid GPU specification, return nil to indicate no GPU access
			return container.DeviceRequest{}, fmt.Errorf("invalid GPU specification: %q", opts.GPUs)
		}
	}

	return deviceRequest, nil
}

// shouldAttachStdin determines if we should attach stdin to the container
// We should attach stdin only if:
//   - stdin is not os.Stdin (explicit input like pipe/file/buffer)
//   - OR stdin is os.Stdin but it's not a TTY (piped input)
func shouldAttachStdin(stdin io.Reader) (attach bool, tty bool) {
	if stdin == nil {
		return false, false
	}

	// If it's not a file, it's probably a buffer/pipe with actual data
	f, ok := stdin.(*os.File)
	if !ok {
		return true, false
	}

	tty = isatty.IsTerminal(f.Fd())

	// If it's not os.Stdin, it's an explicit file, so attach it
	if f != os.Stdin {
		return true, tty
	}

	// If it's os.Stdin but not a TTY, it's probably piped input
	if !tty {
		return true, false
	}

	// If it's os.Stdin and a TTY, attach by default. if this becomes a problem for some
	// reason we need to add a flag to the run command similar to `docker run -i` that instructs
	// the container to attach stdin and keep open
	return true, true
}


================================================
FILE: pkg/docker/docker_client_test.go
================================================
package docker

import (
	"bytes"
	"net"
	"strings"
	"testing"

	"github.com/docker/docker/api/types/container"
	"github.com/docker/docker/api/types/registry"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"

	"github.com/testcontainers/testcontainers-go"
	"github.com/testcontainers/testcontainers-go/wait"

	"github.com/replicate/cog/pkg/docker/command"
	"github.com/replicate/cog/pkg/docker/dockertest"
	"github.com/replicate/cog/pkg/registry_testhelpers"
)

func TestDockerClient(t *testing.T) {
	if testing.Short() {
		t.Skip("skipping docker client tests in short mode")
	}

	dockerClient, err := NewClient(t.Context())
	require.NoError(t, err, "Failed to create docker client")
	dockerHelper := dockertest.NewHelperClient(t)
	testRegistry := registry_testhelpers.StartTestRegistry(t)

	dockerHelper.CleanupImages(t)

	t.Run("ImageInspect", func(t *testing.T) {
		t.Parallel()

		t.Run("ExistingLocalImage", func(t *testing.T) {
			t.Parallel()

			ref := dockertest.NewRef(t)
			dockerHelper.ImageFixture(t, "alpine", ref.String())

			expectedImage := dockerHelper.InspectImage(t, ref.String())
			resp, err := dockerClient.Inspect(t.Context(), ref.String())
			require.NoError(t, err, "Failed to inspect image %q", ref.String())
			assert.Equal(t, expectedImage.ID, resp.ID)
		})

		t.Run("MissingLocalImage", func(t *testing.T) {
			t.Parallel()

			image := "not-a-valid-image"
			_, err := dockerClient.Inspect(t.Context(), image)
			assert.ErrorIs(t, err, &command.NotFoundError{})
			assert.ErrorContains(t, err, "image not found")
		})
	})

	t.Run("Pull", func(t *testing.T) {
		t.Parallel()

		// TODO[md]: add tests for the following permutations:
		// - remote reference exists/not exists
		// - local reference exists/not exists
		// - force pull true/false

		t.Run("RemoteImageExists", func(t *testing.T) {
			t.Parallel()
			repo := testRegistry.CloneRepoForTest(t, "alpine")
			imageRef := repo + ":latest"

			assertNoImageExists(t, dockerClient, imageRef)

			resp, err := dockerClient.Pull(t.Context(), imageRef, false)
			require.NoError(t, err, "Failed to pull image %q", imageRef)
			dockerHelper.CleanupImage(t, imageRef)

			assertImageExists(t, dockerClient, imageRef)
			expectedResp := dockerHelper.InspectImage(t, imageRef)
			// TODO[md]: we should check that the responsees are actually equal beyond the IDs. but atm
			// the CLI and api are slightly different. The CLI leaves the descriptor field nil while the
			// API response is populated. These should be identical on the new client, so we can change to EqualValues
			assert.Equal(t, expectedResp.ID, resp.ID, "inspect response should match expected")
		})

		t.Run("RemoteReferenceNotFound", func(t *testing.T) {
			t.Parallel()
			imageRef := testRegistry.ImageRefForTest(t, "")

			assertNoImageExists(t, dockerClient, imageRef)

			resp, err := dockerClient.Pull(t.Context(), imageRef, false)
			// TODO[md]: this might not be the right check. we probably want to wrap the error from the registry
			// so we handle other failure cases, like failed auth, unknown tag, and unknown repo
			require.Error(t, err, "Failed to pull image %q", imageRef)
			assert.ErrorIs(t, err, &command.NotFoundError{Object: "manifest", Ref: imageRef})
			assert.Nil(t, resp, "inspect response should be nil")
		})

		t.Run("InvalidAuth", func(t *testing.T) {
			t.Skip("skip auth tests until we're using the docker engine since we can't set auth on the host without side effects")
			imageRef := testRegistry.ImageRefForTest(t, "")

			assertNoImageExists(t, dockerClient, imageRef)

			resp, err := dockerClient.Pull(t.Context(), imageRef, false)
			// TODO[md]: this might not be the right check. we probably want to wrap the error from the registry
			// so we handle other failure cases, like failed auth, unknown tag, and unknown repo
			require.Error(t, err, "Failed to pull image %q", imageRef)
			assert.ErrorContains(t, err, "failed to resolve reference")
			assert.Nil(t, resp, "inspect response should be nil")
		})
	})

	t.Run("ContainerStop", func(t *testing.T) {
		t.Parallel()

		t.Run("ContainerExistsAndIsRunning", func(t *testing.T) {
			t.Parallel()

			container, err := testcontainers.Run(
				t.Context(),
				testRegistry.ImageRef("alpine:latest"),
				testcontainers.WithCmd("sleep", "5000"),
			)
			defer dockerHelper.CleanupImages(t)
			defer testcontainers.CleanupContainer(t, container)
			require.NoError(t, err, "Failed to run container")

			err = dockerClient.ContainerStop(t.Context(), container.ID)
			require.NoError(t, err, "Failed to stop container %q", container.ID)

			state, err := container.State(t.Context())
			require.NoError(t, err, "Failed to get container state")
			assert.Equal(t, state.Running, false)
		})

		t.Run("ContainerExistsAndIsNotRunning", func(t *testing.T) {
			t.Parallel()

			container, err := testcontainers.GenericContainer(t.Context(),
				testcontainers.GenericContainerRequest{
					ContainerRequest: testcontainers.ContainerRequest{
						Image: testRegistry.ImageRef("alpine:latest"),
						Cmd:   []string{"sleep", "5000"},
					},
					Started: false,
				},
			)
			defer testcontainers.CleanupContainer(t, container)
			containerID := container.GetContainerID()
			require.NoError(t, err, "Failed to create container")

			err = dockerClient.ContainerStop(t.Context(), containerID)
			require.NoError(t, err, "Failed to stop container %q", containerID)

			state, err := container.State(t.Context())
			require.NoError(t, err, "Failed to get container state")
			assert.Equal(t, state.Running, false)
		})

		t.Run("ContainerDoesNotExist", func(t *testing.T) {
			t.Parallel()

			err := dockerClient.ContainerStop(t.Context(), "containerid-that-does-not-exist")
			require.ErrorIs(t, err, &command.NotFoundError{})
			require.ErrorContains(t, err, "container not found")
		})
	})

	t.Run("ContainerInspect", func(t *testing.T) {
		t.Parallel()

		t.Run("ContainerExists", func(t *testing.T) {
			t.Parallel()

			container, err := testcontainers.Run(
				t.Context(),
				testRegistry.ImageRef("alpine:latest"),
				testcontainers.WithCmd("sleep", "5000"),
			)
			defer testcontainers.CleanupContainer(t, container)
			require.NoError(t, err, "Failed to run container")

			expected, err := container.Inspect(t.Context())
			require.NoError(t, err, "Failed to inspect container for expected response")

			resp, err := dockerClient.ContainerInspect(t.Context(), container.ID)
			require.NoError(t, err, "Failed to inspect container")
			require.Equal(t, expected, resp)
		})

		t.Run("ContainerDoesNotExist", func(t *testing.T) {
			t.Parallel()

			_, err := dockerClient.ContainerInspect(t.Context(), "containerid-that-does-not-exist")
			require.ErrorIs(t, err, &command.NotFoundError{})
		})
	})

	t.Run("ContainerLogs", func(t *testing.T) {
		t.Parallel()

		t.Run("ContainerExistsAndIsRunning", func(t *testing.T) {
			t.Parallel()

			container, err := testcontainers.Run(
				t.Context(),
				testRegistry.ImageRef("alpine:latest"),
				// print "line $i" N times then exit, where $i is the line number
				testcontainers.WithCmd("sh", "-c", "for i in $(seq 1 5); do echo \"$i\"; sleep 1; done"),
				// testcontainers.WithConfigModifier(func(config *container.Config) {
				// 	config.Tty = true
				// }),
			)
			require.NoError(t, err, "Failed to run container")
			defer testcontainers.CleanupContainer(t, container)

			var buf bytes.Buffer
			err = dockerClient.ContainerLogs(t.Context(), container.ID, &buf)
			require.NoError(t, err, "Failed to get container logs")

			assert.Equal(t, "1\n2\n3\n4\n5\n", buf.String())
		})

		t.Run("ContainerAlreadyStopped", func(t *testing.T) {
			t.Parallel()

			container, err := testcontainers.Run(
				t.Context(),
				testRegistry.ImageRef("alpine:latest"),
				testcontainers.WithCmd("sh", "-c", "for i in $(seq 1 3); do echo \"$i\"; sleep 0.1; done"),
				testcontainers.WithWaitStrategy(wait.ForExit()),
			)
			require.NoError(t, err, "Failed to run container")
			defer testcontainers.CleanupContainer(t, container)

			state, err := container.State(t.Context())
			require.NoError(t, err, "Failed to get container state")
			assert.Equal(t, state.Running, false)

			var buf bytes.Buffer
			err = dockerClient.ContainerLogs(t.Context(), container.ID, &buf)
			require.NoError(t, err, "Failed to get container logs")

			assert.Equal(t, "1\n2\n3\n", buf.String())
		})

		t.Run("TTY and non-TTY streams match", func(t *testing.T) {
			t.Parallel()

			runContainer := func(tty bool) string {
				container, err := testcontainers.Run(
					t.Context(),
					testRegistry.ImageRef("alpine:latest"),
					// print "line $i" N times then exit, where $i is the line number
					testcontainers.WithCmd("sh", "-c", "for i in $(seq 1 5); do echo \"$i\"; sleep 0.1; done"),
					testcontainers.WithConfigModifier(func(config *container.Config) {
						config.Tty = tty
					}),
				)
				require.NoError(t, err, "Failed to run container")
				defer testcontainers.CleanupContainer(t, container)

				var buf bytes.Buffer
				err = dockerClient.ContainerLogs(t.Context(), container.ID, &buf)
				require.NoError(t, err, "Failed to get container logs")
				return buf.String()
			}

			ttyOutput := runContainer(true)
			nonTtyOutput := runContainer(false)

			// TTY uses CRLF for line endings, non-TTY uses LF. replace \r\n with \n so they match
			ttyOutput = strings.ReplaceAll(ttyOutput, "\r\n", "\n")

			assert.Equal(t, ttyOutput, nonTtyOutput, "TTY and non-TTY streams should match after normalizing line endings")
		})

		t.Run("ContainerDoesNotExist", func(t *testing.T) {
			t.Parallel()

			err := dockerClient.ContainerLogs(t.Context(), "containerid-that-does-not-exist", &bytes.Buffer{})
			require.ErrorIs(t, err, &command.NotFoundError{})
		})
	})

	t.Run("Push", func(t *testing.T) {
		t.Parallel()

		t.Run("valid image, valid registry", func(t *testing.T) {
			t.Parallel()

			ref := dockertest.NewRef(t).WithRegistry(testRegistry.RegistryHost())

			dockerHelper.ImageFixture(t, "alpine", ref.String())

			err := dockerClient.Push(t.Context(), ref.String())
			require.NoError(t, err)
			assert.NoError(t, testRegistry.ImageExists(t, ref.String()))
		})

		t.Run("non-existent registry", func(t *testing.T) {
			t.Parallel()

			// start a local tcp server that immediately closes connections
			listener, err := net.Listen("tcp", "127.0.0.1:0")
			require.NoError(t, err)
			defer listener.Close()

			go func() {
				for {
					conn, err := listener.Accept()
					if err != nil {
						return
					}
					conn.Close()
				}
			}()

			// Create a reference to the mock registry
			ref := dockertest.NewRef(t).WithRegistry(listener.Addr().String())
			dockerHelper.ImageFixture(t, "alpine", ref.String())

			// Try to push to the mock registry
			err = dockerClient.Push(t.Context(), ref.String())
			require.Error(t, err, "Push should fail with unreachable registry")

			assert.True(t, isNetworkError(err), "Error should be a network error, got: %q", err.Error())
		})

		t.Run("missing image", func(t *testing.T) {
			t.Parallel()

			ref := dockertest.NewRef(t).WithRegistry(testRegistry.RegistryHost())

			err := dockerClient.Push(t.Context(), ref.String())
			assertNotFoundError(t, err, ref.String(), "tag")
		})

		t.Run("registry with authentication", func(t *testing.T) {
			t.Parallel()

			authReg := registry_testhelpers.StartTestRegistry(t, registry_testhelpers.WithAuth("testuser", "testpass"))

			t.Run("correct credentials", func(t *testing.T) {
				t.Parallel()

				ref := dockertest.NewRef(t).WithRegistry(authReg.RegistryHost())
				dockerHelper.ImageFixture(t, "alpine", ref.String())

				// create a new client with the correct auth config
				authClient, err := NewClient(t.Context(), WithAuthConfig(registry.AuthConfig{
					Username:      "testuser",
					Password:      "testpass",
					ServerAddress: authReg.RegistryHost(),
				}))
				require.NoError(t, err)

				err = authClient.Push(t.Context(), ref.String())
				require.NoError(t, err, "Failed to push image to auth registry")
				assert.NoError(t, authReg.ImageExists(t, ref.String()))
			})

			t.Run("missing auth", func(t *testing.T) {
				t.Parallel()

				ref := dockertest.NewRef(t).WithRegistry(authReg.RegistryHost())
				dockerHelper.ImageFixture(t, "alpine", ref.String())

				// use root client which doesn't have auth setup
				err := dockerClient.Push(t.Context(), ref.String())
				require.ErrorIs(t, err, command.ErrAuthorizationFailed)
			})

			t.Run("incorrect auth", func(t *testing.T) {
				t.Parallel()

				ref := dockertest.NewRef(t).WithRegistry(authReg.RegistryHost())
				dockerHelper.ImageFixture(t, "alpine", ref.String())

				authClient, err := NewClient(t.Context(), WithAuthConfig(registry.AuthConfig{
					Username:      "testuser",
					Password:      "wrongpass",
					ServerAddress: authReg.RegistryHost(),
				}))
				require.NoError(t, err)

				err = authClient.Push(t.Context(), ref.String())
				require.ErrorIs(t, err, command.ErrAuthorizationFailed)
			})

			t.Run("correct credentials, not authorized", func(t *testing.T) {
				t.Skip("skipping until the registry supports repo authorizations")
			})
		})
	})
}

func assertImageExists(t *testing.T, dockerClient command.Command, imageRef string) {
	t.Helper()

	inspect, err := dockerClient.Inspect(t.Context(), imageRef)
	assert.NoError(t, err, "Failed to inspect image %q", imageRef)
	assert.NotNil(t, inspect, "Image should exist")
}

func assertNoImageExists(t *testing.T, dockerClient command.Command, imageRef string) {
	t.Helper()

	inspect, err := dockerClient.Inspect(t.Context(), imageRef)
	assert.ErrorIs(t, err, &command.NotFoundError{}, "Image should not exist")
	assert.Nil(t, inspect, "Image should not exist")
}

func assertNotFoundError(t *testing.T, err error, ref string, object string) {
	t.Helper()

	var notFoundErr *command.NotFoundError
	require.ErrorAs(t, err, ¬FoundErr, "should be a not found error")
	require.Equal(t, ref, notFoundErr.Ref, "ref should match")
	require.Equal(t, object, notFoundErr.Object, "object should match")
}


================================================
FILE: pkg/docker/dockertest/command_mocks.go
================================================
// Code generated by mockery; DO NOT EDIT.
// github.com/vektra/mockery
// template: testify

package dockertest

import (
	"context"
	"io"

	"github.com/docker/docker/api/types/container"
	"github.com/docker/docker/api/types/image"
	"github.com/replicate/cog/pkg/docker/command"
	mock "github.com/stretchr/testify/mock"
)

// NewMockCommand2 creates a new instance of MockCommand2. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockCommand2(t interface {
	mock.TestingT
	Cleanup(func())
}) *MockCommand2 {
	mock := &MockCommand2{}
	mock.Mock.Test(t)

	t.Cleanup(func() { mock.AssertExpectations(t) })

	return mock
}

// MockCommand2 is an autogenerated mock type for the Command type
type MockCommand2 struct {
	mock.Mock
}

type MockCommand2_Expecter struct {
	mock *mock.Mock
}

func (_m *MockCommand2) EXPECT() *MockCommand2_Expecter {
	return &MockCommand2_Expecter{mock: &_m.Mock}
}

// ContainerInspect provides a mock function for the type MockCommand2
func (_mock *MockCommand2) ContainerInspect(ctx context.Context, id string) (*container.InspectResponse, error) {
	ret := _mock.Called(ctx, id)

	if len(ret) == 0 {
		panic("no return value specified for ContainerInspect")
	}

	var r0 *container.InspectResponse
	var r1 error
	if returnFunc, ok := ret.Get(0).(func(context.Context, string) (*container.InspectResponse, error)); ok {
		return returnFunc(ctx, id)
	}
	if returnFunc, ok := ret.Get(0).(func(context.Context, string) *container.InspectResponse); ok {
		r0 = returnFunc(ctx, id)
	} else {
		if ret.Get(0) != nil {
			r0 = ret.Get(0).(*container.InspectResponse)
		}
	}
	if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok {
		r1 = returnFunc(ctx, id)
	} else {
		r1 = ret.Error(1)
	}
	return r0, r1
}

// MockCommand2_ContainerInspect_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ContainerInspect'
type MockCommand2_ContainerInspect_Call struct {
	*mock.Call
}

// ContainerInspect is a helper method to define mock.On call
//   - ctx context.Context
//   - id string
func (_e *MockCommand2_Expecter) ContainerInspect(ctx interface{}, id interface{}) *MockCommand2_ContainerInspect_Call {
	return &MockCommand2_ContainerInspect_Call{Call: _e.mock.On("ContainerInspect", ctx, id)}
}

func (_c *MockCommand2_ContainerInspect_Call) Run(run func(ctx context.Context, id string)) *MockCommand2_ContainerInspect_Call {
	_c.Call.Run(func(args mock.Arguments) {
		var arg0 context.Context
		if args[0] != nil {
			arg0 = args[0].(context.Context)
		}
		var arg1 string
		if args[1] != nil {
			arg1 = args[1].(string)
		}
		run(
			arg0,
			arg1,
		)
	})
	return _c
}

func (_c *MockCommand2_ContainerInspect_Call) Return(inspectResponse *container.InspectResponse, err error) *MockCommand2_ContainerInspect_Call {
	_c.Call.Return(inspectResponse, err)
	return _c
}

func (_c *MockCommand2_ContainerInspect_Call) RunAndReturn(run func(ctx context.Context, id string) (*container.InspectResponse, error)) *MockCommand2_ContainerInspect_Call {
	_c.Call.Return(run)
	return _c
}

// ContainerLogs provides a mock function for the type MockCommand2
func (_mock *MockCommand2) ContainerLogs(ctx context.Context, containerID string, w io.Writer) error {
	ret := _mock.Called(ctx, containerID, w)

	if len(ret) == 0 {
		panic("no return value specified for ContainerLogs")
	}

	var r0 error
	if returnFunc, ok := ret.Get(0).(func(context.Context, string, io.Writer) error); ok {
		r0 = returnFunc(ctx, containerID, w)
	} else {
		r0 = ret.Error(0)
	}
	return r0
}

// MockCommand2_ContainerLogs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ContainerLogs'
type MockCommand2_ContainerLogs_Call struct {
	*mock.Call
}

// ContainerLogs is a helper method to define mock.On call
//   - ctx context.Context
//   - containerID string
//   - w io.Writer
func (_e *MockCommand2_Expecter) ContainerLogs(ctx interface{}, containerID interface{}, w interface{}) *MockCommand2_ContainerLogs_Call {
	return &MockCommand2_ContainerLogs_Call{Call: _e.mock.On("ContainerLogs", ctx, containerID, w)}
}

func (_c *MockCommand2_ContainerLogs_Call) Run(run func(ctx context.Context, containerID string, w io.Writer)) *MockCommand2_ContainerLogs_Call {
	_c.Call.Run(func(args mock.Arguments) {
		var arg0 context.Context
		if args[0] != nil {
			arg0 = args[0].(context.Context)
		}
		var arg1 string
		if args[1] != nil {
			arg1 = args[1].(string)
		}
		var arg2 io.Writer
		if args[2] != nil {
			arg2 = args[2].(io.Writer)
		}
		run(
			arg0,
			arg1,
			arg2,
		)
	})
	return _c
}

func (_c *MockCommand2_ContainerLogs_Call) Return(err error) *MockCommand2_ContainerLogs_Call {
	_c.Call.Return(err)
	return _c
}

func (_c *MockCommand2_ContainerLogs_Call) RunAndReturn(run func(ctx context.Context, containerID string, w io.Writer) error) *MockCommand2_ContainerLogs_Call {
	_c.Call.Return(run)
	return _c
}

// ContainerStart provides a mock function for the type MockCommand2
func (_mock *MockCommand2) ContainerStart(ctx context.Context, options command.RunOptions) (string, error) {
	ret := _mock.Called(ctx, options)

	if len(ret) == 0 {
		panic("no return value specified for ContainerStart")
	}

	var r0 string
	var r1 error
	if returnFunc, ok := ret.Get(0).(func(context.Context, command.RunOptions) (string, error)); ok {
		return returnFunc(ctx, options)
	}
	if returnFunc, ok := ret.Get(0).(func(context.Context, command.RunOptions) string); ok {
		r0 = returnFunc(ctx, options)
	} else {
		r0 = ret.Get(0).(string)
	}
	if returnFunc, ok := ret.Get(1).(func(context.Context, command.RunOptions) error); ok {
		r1 = returnFunc(ctx, options)
	} else {
		r1 = ret.Error(1)
	}
	return r0, r1
}

// MockCommand2_ContainerStart_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ContainerStart'
type MockCommand2_ContainerStart_Call struct {
	*mock.Call
}

// ContainerStart is a helper method to define mock.On call
//   - ctx context.Context
//   - options command.RunOptions
func (_e *MockCommand2_Expecter) ContainerStart(ctx interface{}, options interface{}) *MockCommand2_ContainerStart_Call {
	return &MockCommand2_ContainerStart_Call{Call: _e.mock.On("ContainerStart", ctx, options)}
}

func (_c *MockCommand2_ContainerStart_Call) Run(run func(ctx context.Context, options command.RunOptions)) *MockCommand2_ContainerStart_Call {
	_c.Call.Run(func(args mock.Arguments) {
		var arg0 context.Context
		if args[0] != nil {
			arg0 = args[0].(context.Context)
		}
		var arg1 command.RunOptions
		if args[1] != nil {
			arg1 = args[1].(command.RunOptions)
		}
		run(
			arg0,
			arg1,
		)
	})
	return _c
}

func (_c *MockCommand2_ContainerStart_Call) Return(s string, err error) *MockCommand2_ContainerStart_Call {
	_c.Call.Return(s, err)
	return _c
}

func (_c *MockCommand2_ContainerStart_Call) RunAndReturn(run func(ctx context.Context, options command.RunOptions) (string, error)) *MockCommand2_ContainerStart_Call {
	_c.Call.Return(run)
	return _c
}

// ContainerStop provides a mock function for the type MockCommand2
func (_mock *MockCommand2) ContainerStop(ctx context.Context, containerID string) error {
	ret := _mock.Called(ctx, containerID)

	if len(ret) == 0 {
		panic("no return value specified for ContainerStop")
	}

	var r0 error
	if returnFunc, ok := ret.Get(0).(func(context.Context, string) error); ok {
		r0 = returnFunc(ctx, containerID)
	} else {
		r0 = ret.Error(0)
	}
	return r0
}

// MockCommand2_ContainerStop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ContainerStop'
type MockCommand2_ContainerStop_Call struct {
	*mock.Call
}

// ContainerStop is a helper method to define mock.On call
//   - ctx context.Context
//   - containerID string
func (_e *MockCommand2_Expecter) ContainerStop(ctx interface{}, containerID interface{}) *MockCommand2_ContainerStop_Call {
	return &MockCommand2_ContainerStop_Call{Call: _e.mock.On("ContainerStop", ctx, containerID)}
}

func (_c *MockCommand2_ContainerStop_Call) Run(run func(ctx context.Context, containerID string)) *MockCommand2_ContainerStop_Call {
	_c.Call.Run(func(args mock.Arguments) {
		var arg0 context.Context
		if args[0] != nil {
			arg0 = args[0].(context.Context)
		}
		var arg1 string
		if args[1] != nil {
			arg1 = args[1].(string)
		}
		run(
			arg0,
			arg1,
		)
	})
	return _c
}

func (_c *MockCommand2_ContainerStop_Call) Return(err error) *MockCommand2_ContainerStop_Call {
	_c.Call.Return(err)
	return _c
}

func (_c *MockCommand2_ContainerStop_Call) RunAndReturn(run func(ctx context.Context, containerID string) error) *MockCommand2_ContainerStop_Call {
	_c.Call.Return(run)
	return _c
}

// ImageBuild provides a mock function for the type MockCommand2
func (_mock *MockCommand2) ImageBuild(ctx context.Context, options command.ImageBuildOptions) (string, error) {
	ret := _mock.Called(ctx, options)

	if len(ret) == 0 {
		panic("no return value specified for ImageBuild")
	}

	var r0 string
	var r1 error
	if returnFunc, ok := ret.Get(0).(func(context.Context, command.ImageBuildOptions) (string, error)); ok {
		return returnFunc(ctx, options)
	}
	if returnFunc, ok := ret.Get(0).(func(context.Context, command.ImageBuildOptions) string); ok {
		r0 = returnFunc(ctx, options)
	} else {
		r0 = ret.Get(0).(string)
	}
	if returnFunc, ok := ret.Get(1).(func(context.Context, command.ImageBuildOptions) error); ok {
		r1 = returnFunc(ctx, options)
	} else {
		r1 = ret.Error(1)
	}
	return r0, r1
}

// MockCommand2_ImageBuild_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ImageBuild'
type MockCommand2_ImageBuild_Call struct {
	*mock.Call
}

// ImageBuild is a helper method to define mock.On call
//   - ctx context.Context
//   - options command.ImageBuildOptions
func (_e *MockCommand2_Expecter) ImageBuild(ctx interface{}, options interface{}) *MockCommand2_ImageBuild_Call {
	return &MockCommand2_ImageBuild_Call{Call: _e.mock.On("ImageBuild", ctx, options)}
}

func (_c *MockCommand2_ImageBuild_Call) Run(run func(ctx context.Context, options command.ImageBuildOptions)) *MockCommand2_ImageBuild_Call {
	_c.Call.Run(func(args mock.Arguments) {
		var arg0 context.Context
		if args[0] != nil {
			arg0 = args[0].(context.Context)
		}
		var arg1 command.ImageBuildOptions
		if args[1] != nil {
			arg1 = args[1].(command.ImageBuildOptions)
		}
		run(
			arg0,
			arg1,
		)
	})
	return _c
}

func (_c *MockCommand2_ImageBuild_Call) Return(imageID string, err error) *MockCommand2_ImageBuild_Call {
	_c.Call.Return(imageID, err)
	return _c
}

func (_c *MockCommand2_ImageBuild_Call) RunAndReturn(run func(ctx context.Context, options command.ImageBuildOptions) (string, error)) *MockCommand2_ImageBuild_Call {
	_c.Call.Return(run)
	return _c
}

// ImageSave provides a mock function for the type MockCommand2
func (_mock *MockCommand2) ImageSave(ctx context.Context, imageRef string) (io.ReadCloser, error) {
	ret := _mock.Called(ctx, imageRef)

	if len(ret) == 0 {
		panic("no return value specified for ImageSave")
	}

	var r0 io.ReadCloser
	var r1 error
	if returnFunc, ok := ret.Get(0).(func(context.Context, string) (io.ReadCloser, error)); ok {
		return returnFunc(ctx, imageRef)
	}
	if returnFunc, ok := ret.Get(0).(func(context.Context, string) io.ReadCloser); ok {
		r0 = returnFunc(ctx, imageRef)
	} else {
		if ret.Get(0) != nil {
			r0 = ret.Get(0).(io.ReadCloser)
		}
	}
	if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok {
		r1 = returnFunc(ctx, imageRef)
	} else {
		r1 = ret.Error(1)
	}
	return r0, r1
}

// MockCommand2_ImageSave_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ImageSave'
type MockCommand2_ImageSave_Call struct {
	*mock.Call
}

// ImageSave is a helper method to define mock.On call
//   - ctx context.Context
//   - imageRef string
func (_e *MockCommand2_Expecter) ImageSave(ctx interface{}, imageRef interface{}) *MockCommand2_ImageSave_Call {
	return &MockCommand2_ImageSave_Call{Call: _e.mock.On("ImageSave", ctx, imageRef)}
}

func (_c *MockCommand2_ImageSave_Call) Run(run func(ctx context.Context, imageRef string)) *MockCommand2_ImageSave_Call {
	_c.Call.Run(func(args mock.Arguments) {
		var arg0 context.Context
		if args[0] != nil {
			arg0 = args[0].(context.Context)
		}
		var arg1 string
		if args[1] != nil {
			arg1 = args[1].(string)
		}
		run(
			arg0,
			arg1,
		)
	})
	return _c
}

func (_c *MockCommand2_ImageSave_Call) Return(rc io.ReadCloser, err error) *MockCommand2_ImageSave_Call {
	_c.Call.Return(rc, err)
	return _c
}

func (_c *MockCommand2_ImageSave_Call) RunAndReturn(run func(ctx context.Context, imageRef string) (io.ReadCloser, error)) *MockCommand2_ImageSave_Call {
	_c.Call.Return(run)
	return _c
}

// ImageExists provides a mock function for the type MockCommand2
func (_mock *MockCommand2) ImageExists(ctx context.Context, ref string) (bool, error) {
	ret := _mock.Called(ctx, ref)

	if len(ret) == 0 {
		panic("no return value specified for ImageExists")
	}

	var r0 bool
	var r1 error
	if returnFunc, ok := ret.Get(0).(func(context.Context, string) (bool, error)); ok {
		return returnFunc(ctx, ref)
	}
	if returnFunc, ok := ret.Get(0).(func(context.Context, string) bool); ok {
		r0 = returnFunc(ctx, ref)
	} else {
		r0 = ret.Get(0).(bool)
	}
	if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok {
		r1 = returnFunc(ctx, ref)
	} else {
		r1 = ret.Error(1)
	}
	return r0, r1
}

// MockCommand2_ImageExists_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ImageExists'
type MockCommand2_ImageExists_Call struct {
	*mock.Call
}

// ImageExists is a helper method to define mock.On call
//   - ctx context.Context
//   - ref string
func (_e *MockCommand2_Expecter) ImageExists(ctx interface{}, ref interface{}) *MockCommand2_ImageExists_Call {
	return &MockCommand2_ImageExists_Call{Call: _e.mock.On("ImageExists", ctx, ref)}
}

func (_c *MockCommand2_ImageExists_Call) Run(run func(ctx context.Context, ref string)) *MockCommand2_ImageExists_Call {
	_c.Call.Run(func(args mock.Arguments) {
		var arg0 context.Context
		if args[0] != nil {
			arg0 = args[0].(context.Context)
		}
		var arg1 string
		if args[1] != nil {
			arg1 = args[1].(string)
		}
		run(
			arg0,
			arg1,
		)
	})
	return _c
}

func (_c *MockCommand2_ImageExists_Call) Return(b bool, err error) *MockCommand2_ImageExists_Call {
	_c.Call.Return(b, err)
	return _c
}

func (_c *MockCommand2_ImageExists_Call) RunAndReturn(run func(ctx context.Context, ref string) (bool, error)) *MockCommand2_ImageExists_Call {
	_c.Call.Return(run)
	return _c
}

// Inspect provides a mock function for the type MockCommand2
func (_mock *MockCommand2) Inspect(ctx context.Context, ref string) (*image.InspectResponse, error) {
	ret := _mock.Called(ctx, ref)

	if len(ret) == 0 {
		panic("no return value specified for Inspect")
	}

	var r0 *image.InspectResponse
	var r1 error
	if returnFunc, ok := ret.Get(0).(func(context.Context, string) (*image.InspectResponse, error)); ok {
		return returnFunc(ctx, ref)
	}
	if returnFunc, ok := ret.Get(0).(func(context.Context, string) *image.InspectResponse); ok {
		r0 = returnFunc(ctx, ref)
	} else {
		if ret.Get(0) != nil {
			r0 = ret.Get(0).(*image.InspectResponse)
		}
	}
	if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok {
		r1 = returnFunc(ctx, ref)
	} else {
		r1 = ret.Error(1)
	}
	return r0, r1
}

// MockCommand2_Inspect_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Inspect'
type MockCommand2_Inspect_Call struct {
	*mock.Call
}

// Inspect is a helper method to define mock.On call
//   - ctx context.Context
//   - ref string
func (_e *MockCommand2_Expecter) Inspect(ctx interface{}, ref interface{}) *MockCommand2_Inspect_Call {
	return &MockCommand2_Inspect_Call{Call: _e.mock.On("Inspect", ctx, ref)}
}

func (_c *MockCommand2_Inspect_Call) Run(run func(ctx context.Context, ref string)) *MockCommand2_Inspect_Call {
	_c.Call.Run(func(args mock.Arguments) {
		var arg0 context.Context
		if args[0] != nil {
			arg0 = args[0].(context.Context)
		}
		var arg1 string
		if args[1] != nil {
			arg1 = args[1].(string)
		}
		run(
			arg0,
			arg1,
		)
	})
	return _c
}

func (_c *MockCommand2_Inspect_Call) Return(inspectResponse *image.InspectResponse, err error) *MockCommand2_Inspect_Call {
	_c.Call.Return(inspectResponse, err)
	return _c
}

func (_c *MockCommand2_Inspect_Call) RunAndReturn(run func(ctx context.Context, ref string) (*image.InspectResponse, error)) *MockCommand2_Inspect_Call {
	_c.Call.Return(run)
	return _c
}

// LoadUserInformation provides a mock function for the type MockCommand2
func (_mock *MockCommand2) LoadUserInformation(ctx context.Context, registryHost string) (*command.UserInfo, error) {
	ret := _mock.Called(ctx, registryHost)

	if len(ret) == 0 {
		panic("no return value specified for LoadUserInformation")
	}

	var r0 *command.UserInfo
	var r1 error
	if returnFunc, ok := ret.Get(0).(func(context.Context, string) (*command.UserInfo, error)); ok {
		return returnFunc(ctx, registryHost)
	}
	if returnFunc, ok := ret.Get(0).(func(context.Context, string) *command.UserInfo); ok {
		r0 = returnFunc(ctx, registryHost)
	} else {
		if ret.Get(0) != nil {
			r0 = ret.Get(0).(*command.UserInfo)
		}
	}
	if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok {
		r1 = returnFunc(ctx, registryHost)
	} else {
		r1 = ret.Error(1)
	}
	return r0, r1
}

// MockCommand2_LoadUserInformation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadUserInformation'
type MockCommand2_LoadUserInformation_Call struct {
	*mock.Call
}

// LoadUserInformation is a helper method to define mock.On call
//   - ctx context.Context
//   - registryHost string
func (_e *MockCommand2_Expecter) LoadUserInformation(ctx interface{}, registryHost interface{}) *MockCommand2_LoadUserInformation_Call {
	return &MockCommand2_LoadUserInformation_Call{Call: _e.mock.On("LoadUserInformation", ctx, registryHost)}
}

func (_c *MockCommand2_LoadUserInformation_Call) Run(run func(ctx context.Context, registryHost string)) *MockCommand2_LoadUserInformation_Call {
	_c.Call.Run(func(args mock.Arguments) {
		var arg0 context.Context
		if args[0] != nil {
			arg0 = args[0].(context.Context)
		}
		var arg1 string
		if args[1] != nil {
			arg1 = args[1].(string)
		}
		run(
			arg0,
			arg1,
		)
	})
	return _c
}

func (_c *MockCommand2_LoadUserInformation_Call) Return(userInfo *command.UserInfo, err error) *MockCommand2_LoadUserInformation_Call {
	_c.Call.Return(userInfo, err)
	return _c
}

func (_c *MockCommand2_LoadUserInformation_Call) RunAndReturn(run func(ctx context.Context, registryHost string) (*command.UserInfo, error)) *MockCommand2_LoadUserInformation_Call {
	_c.Call.Return(run)
	return _c
}

// Pull provides a mock function for the type MockCommand2
func (_mock *MockCommand2) Pull(ctx context.Context, ref string, force bool) (*image.InspectResponse, error) {
	ret := _mock.Called(ctx, ref, force)

	if len(ret) == 0 {
		panic("no return value specified for Pull")
	}

	var r0 *image.InspectResponse
	var r1 error
	if returnFunc, ok := ret.Get(0).(func(context.Context, string, bool) (*image.InspectResponse, error)); ok {
		return returnFunc(ctx, ref, force)
	}
	if returnFunc, ok := ret.Get(0).(func(context.Context, string, bool) *image.InspectResponse); ok {
		r0 = returnFunc(ctx, ref, force)
	} else {
		if ret.Get(0) != nil {
			r0 = ret.Get(0).(*image.InspectResponse)
		}
	}
	if returnFunc, ok := ret.Get(1).(func(context.Context, string, bool) error); ok {
		r1 = returnFunc(ctx, ref, force)
	} else {
		r1 = ret.Error(1)
	}
	return r0, r1
}

// MockCommand2_Pull_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Pull'
type MockCommand2_Pull_Call struct {
	*mock.Call
}

// Pull is a helper method to define mock.On call
//   - ctx context.Context
//   - ref string
//   - force bool
func (_e *MockCommand2_Expecter) Pull(ctx interface{}, ref interface{}, force interface{}) *MockCommand2_Pull_Call {
	return &MockCommand2_Pull_Call{Call: _e.mock.On("Pull", ctx, ref, force)}
}

func (_c *MockCommand2_Pull_Call) Run(run func(ctx context.Context, ref string, force bool)) *MockCommand2_Pull_Call {
	_c.Call.Run(func(args mock.Arguments) {
		var arg0 context.Context
		if args[0] != nil {
			arg0 = args[0].(context.Context)
		}
		var arg1 string
		if args[1] != nil {
			arg1 = args[1].(string)
		}
		var arg2 bool
		if args[2] != nil {
			arg2 = args[2].(bool)
		}
		run(
			arg0,
			arg1,
			arg2,
		)
	})
	return _c
}

func (_c *MockCommand2_Pull_Call) Return(inspectResponse *image.InspectResponse, err error) *MockCommand2_Pull_Call {
	_c.Call.Return(inspectResponse, err)
	return _c
}

func (_c *MockCommand2_Pull_Call) RunAndReturn(run func(ctx context.Context, ref string, force bool) (*image.InspectResponse, error)) *MockCommand2_Pull_Call {
	_c.Call.Return(run)
	return _c
}

// Push provides a mock function for the type MockCommand2
func (_mock *MockCommand2) Push(ctx context.Context, ref string) error {
	ret := _mock.Called(ctx, ref)

	if len(ret) == 0 {
		panic("no return value specified for Push")
	}

	var r0 error
	if returnFunc, ok := ret.Get(0).(func(context.Context, string) error); ok {
		r0 = returnFunc(ctx, ref)
	} else {
		r0 = ret.Error(0)
	}
	return r0
}

// MockCommand2_Push_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Push'
type MockCommand2_Push_Call struct {
	*mock.Call
}

// Push is a helper method to define mock.On call
//   - ctx context.Context
//   - ref string
func (_e *MockCommand2_Expecter) Push(ctx interface{}, ref interface{}) *MockCommand2_Push_Call {
	return &MockCommand2_Push_Call{Call: _e.mock.On("Push", ctx, ref)}
}

func (_c *MockCommand2_Push_Call) Run(run func(ctx context.Context, ref string)) *MockCommand2_Push_Call {
	_c.Call.Run(func(args mock.Arguments) {
		var arg0 context.Context
		if args[0] != nil {
			arg0 = args[0].(context.Context)
		}
		var arg1 string
		if args[1] != nil {
			arg1 = args[1].(string)
		}
		run(
			arg0,
			arg1,
		)
	})
	return _c
}

func (_c *MockCommand2_Push_Call) Return(err error) *MockCommand2_Push_Call {
	_c.Call.Return(err)
	return _c
}

func (_c *MockCommand2_Push_Call) RunAndReturn(run func(ctx context.Context, ref string) error) *MockCommand2_Push_Call {
	_c.Call.Return(run)
	return _c
}

// RemoveImage provides a mock function for the type MockCommand2
func (_mock *MockCommand2) RemoveImage(ctx context.Context, ref string) error {
	ret := _mock.Called(ctx, ref)

	if len(ret) == 0 {
		panic("no return value specified for RemoveImage")
	}

	var r0 error
	if returnFunc, ok := ret.Get(0).(func(context.Context, string) error); ok {
		r0 = returnFunc(ctx, ref)
	} else {
		r0 = ret.Error(0)
	}
	return r0
}

// MockCommand2_RemoveImage_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveImage'
type MockCommand2_RemoveImage_Call struct {
	*mock.Call
}

// RemoveImage is a helper method to define mock.On call
//   - ctx context.Context
//   - ref string
func (_e *MockCommand2_Expecter) RemoveImage(ctx interface{}, ref interface{}) *MockCommand2_RemoveImage_Call {
	return &MockCommand2_RemoveImage_Call{Call: _e.mock.On("RemoveImage", ctx, ref)}
}

func (_c *MockCommand2_RemoveImage_Call) Run(run func(ctx context.Context, ref string)) *MockCommand2_RemoveImage_Call {
	_c.Call.Run(func(args mock.Arguments) {
		var arg0 context.Context
		if args[0] != nil {
			arg0 = args[0].(context.Context)
		}
		var arg1 string
		if args[1] != nil {
			arg1 = args[1].(string)
		}
		run(
			arg0,
			arg1,
		)
	})
	return _c
}

func (_c *MockCommand2_RemoveImage_Call) Return(err error) *MockCommand2_RemoveImage_Call {
	_c.Call.Return(err)
	return _c
}

func (_c *MockCommand2_RemoveImage_Call) RunAndReturn(run func(ctx context.Context, ref string) error) *MockCommand2_RemoveImage_Call {
	_c.Call.Return(run)
	return _c
}

// Run provides a mock function for the type MockCommand2
func (_mock *MockCommand2) Run(ctx context.Context, options command.RunOptions) error {
	ret := _mock.Called(ctx, options)

	if len(ret) == 0 {
		panic("no return value specified for Run")
	}

	var r0 error
	if returnFunc, ok := ret.Get(0).(func(context.Context, command.RunOptions) error); ok {
		r0 = returnFunc(ctx, options)
	} else {
		r0 = ret.Error(0)
	}
	return r0
}

// MockCommand2_Run_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Run'
type MockCommand2_Run_Call struct {
	*mock.Call
}

// Run is a helper method to define mock.On call
//   - ctx context.Context
//   - options command.RunOptions
func (_e *MockCommand2_Expecter) Run(ctx interface{}, options interface{}) *MockCommand2_Run_Call {
	return &MockCommand2_Run_Call{Call: _e.mock.On("Run", ctx, options)}
}

func (_c *MockCommand2_Run_Call) Run(run func(ctx context.Context, options command.RunOptions)) *MockCommand2_Run_Call {
	_c.Call.Run(func(args mock.Arguments) {
		var arg0 context.Context
		if args[0] != nil {
			arg0 = args[0].(context.Context)
		}
		var arg1 command.RunOptions
		if args[1] != nil {
			arg1 = args[1].(command.RunOptions)
		}
		run(
			arg0,
			arg1,
		)
	})
	return _c
}

func (_c *MockCommand2_Run_Call) Return(err error) *MockCommand2_Run_Call {
	_c.Call.Return(err)
	return _c
}

func (_c *MockCommand2_Run_Call) RunAndReturn(run func(ctx context.Context, options command.RunOptions) error) *MockCommand2_Run_Call {
	_c.Call.Return(run)
	return _c
}


================================================
FILE: pkg/docker/dockertest/helper_client.go
================================================
package dockertest

import (
	"context"
	"encoding/base64"
	"encoding/json"
	"fmt"
	"io"
	"os"
	"path/filepath"
	"runtime"
	"slices"
	"sync"
	"testing"

	"github.com/docker/docker/api/types/container"
	"github.com/docker/docker/api/types/image"
	"github.com/docker/docker/api/types/registry"
	"github.com/docker/docker/client"
	"github.com/stretchr/testify/require"
)

// NewHelperClient returns a Docker client for testing.
// It skips the test if Docker is not available.
func NewHelperClient(t testing.TB) *HelperClient {
	t.Helper()

	// Check if we should skip integration tests
	if os.Getenv("SKIP_INTEGRATION_TESTS") == "1" {
		t.Skip("Skipping integration tests")
	}

	// Create Docker client
	cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
	if err != nil {
		t.Fatalf("Failed to create Docker client: %v", err)
	}

	// Verify Docker daemon is running
	_, err = cli.Ping(t.Context())
	if err != nil {
		t.Skip("Docker daemon is not running")
	}

	helper := &HelperClient{
		Client:   cli,
		fixtures: make(map[string]*imageFixture),
		mu:       &sync.Mutex{},
	}

	t.Cleanup(func() {
		for _, img := range helper.fixtures {
			_, err := helper.Client.ImageRemove(context.Background(), img.imageID, image.RemoveOptions{Force: true, PruneChildren: true})
			if err != nil {
				t.Logf("Warning: Failed to remove image %q: %v", img.imageID, err)
			}
		}

		if err := cli.Close(); err != nil {
			t.Fatalf("Failed to close Docker client: %v", err)
		}
	})

	return helper
}

type HelperClient struct {
	Client *client.Client

	mu       *sync.Mutex
	fixtures map[string]*imageFixture
}

func (c *HelperClient) Close() error {
	return c.Client.Close()
}

func (c *HelperClient) PullImage(t testing.TB, ref string) error {
	t.Helper()
	out, err := c.Client.ImagePull(t.Context(), ref, image.PullOptions{})
	if err != nil {
		return err
	}
	defer out.Close()

	t.Cleanup(func() {
		t.Logf("Removing image %q", ref)

		opts := image.RemoveOptions{
			Force: true,
		}

		// use a background context because t.Context() is already closed when cleanup functions are called
		if _, err := c.Client.ImageRemove(context.Background(), ref, opts); err != nil {
			t.Logf("Warning: Failed to remove image %q: %v", ref, err)
		}
	})

	_, err = io.Copy(os.Stderr, out)
	return err
}

func (c *HelperClient) MustPullImage(t testing.TB, ref string) {
	t.Helper()
	require.NoError(t, c.PullImage(t, ref), "Failed to pull image %q", ref)
}

func (c *HelperClient) PushImage(t testing.TB, ref string) error {
	t.Helper()

	// Create auth config for the registry
	authConfig := registry.AuthConfig{
		// "username": "testuser",
		// "password": "testpassword",
	}
	authBytes, err := json.Marshal(authConfig)
	if err != nil {
		return fmt.Errorf("failed to marshal auth config: %w", err)
	}
	authStr := base64.URLEncoding.EncodeToString(authBytes)

	out, err := c.Client.ImagePush(t.Context(), ref, image.PushOptions{
		RegistryAuth: authStr,
	})
	if err != nil {
		return err
	}
	defer out.Close()

	_, err = io.Copy(os.Stdout, out)
	return err
}

func (c *HelperClient) MustPushImage(t testing.TB, ref string) {
	t.Helper()
	require.NoError(t, c.PushImage(t, ref), "Failed to push image %q", ref)
}

func (c *HelperClient) RunContainer(t testing.TB, imageName string) string {
	t.Helper()

	containerConfig := &container.Config{
		Image: imageName,
		Cmd:   []string{"sleep", "60"}, // Run a long sleep to keep container alive
	}
	hostConfig := &container.HostConfig{
		AutoRemove: true,
	}

	resp, err := c.Client.ContainerCreate(t.Context(), containerConfig, hostConfig, nil, nil, "")
	require.NoError(t, err, "Failed to create container")
	containerID := resp.ID
	t.Cleanup(func() {
		t.Logf("Removing container %q", containerID)
		_ = c.Client.ContainerRemove(context.Background(), containerID, container.RemoveOptions{
			RemoveVolumes: true,
			RemoveLinks:   false,
			Force:         true,
		})
	})

	t.Logf("Created container %q", containerID)
	if len(resp.Warnings) > 0 {
		t.Logf("Warnings: %v", resp.Warnings)
	}

	if err := c.Client.ContainerStart(t.Context(), containerID, container.StartOptions{}); err != nil {
		require.NoErrorf(t, err, "Failed to start container")
		t.Cleanup(func() {
			t.Logf("Stopping container %q", containerID)
			_ = c.Client.ContainerStop(context.Background(), containerID, container.StopOptions{
				Timeout: new(int),
			})
		})
	}

	return resp.ID
}

func (c *HelperClient) StopContainer(t testing.TB, containerID string) {
	t.Helper()

	err := c.Client.ContainerStop(t.Context(), containerID, container.StopOptions{
		// set timeout to 0 to force immediate stop
		Timeout: new(int),
	})
	require.NoErrorf(t, err, "Failed to stop container %q", containerID)
}

func (c *HelperClient) InspectImage(t testing.TB, imageRef string) *image.InspectResponse {
	t.Helper()

	img, err := c.Client.ImageInspect(t.Context(), imageRef)
	require.NoError(t, err, "Failed to inspect image %q", imageRef)

	return &img
}

func (c *HelperClient) ImageExists(t testing.TB, imageRef string) bool {
	t.Helper()

	_, err := c.Client.ImageInspect(t.Context(), imageRef)
	return err == nil
}

func (c *HelperClient) DeleteImage(t testing.TB, imageRef string) error {
	t.Helper()

	_, err := c.Client.ImageRemove(t.Context(), imageRef, image.RemoveOptions{
		Force:         true,
		PruneChildren: true,
	})
	return err
}

func (c *HelperClient) MustDeleteImage(t testing.TB, imageRef string) {
	t.Helper()

	_, err := c.Client.ImageRemove(t.Context(), imageRef, image.RemoveOptions{
		Force:         true,
		PruneChildren: true,
	})
	require.NoError(t, err, "Failed to delete image %q", imageRef)
}

func (c *HelperClient) CleanupImage(t testing.TB, imageRef string) {
	t.Helper()

	t.Cleanup(func() {
		_, err := c.Client.ImageRemove(context.Background(), imageRef, image.RemoveOptions{
			Force:         true,
			PruneChildren: true,
		})
		if err != nil {
			t.Logf("Warning: Failed to remove image %q: %v", imageRef, err)
		}
	})
}

func (c *HelperClient) CleanupImages(t testing.TB) {
	t.Helper()

	existingImages, err := c.Client.ImageList(t.Context(), image.ListOptions{})
	require.NoError(t, err, "Failed to list images")

	imageIDs := make([]string, len(existingImages))
	for i, image := range existingImages {
		imageIDs[i] = image.ID
	}

	fmt.Println("existing imageIDs", imageIDs)

	t.Cleanup(func() {
		newImages, err := c.Client.ImageList(context.Background(), image.ListOptions{})
		if err != nil {
			t.Logf("Warning: Failed to list images: %v", err)
			return
		}

		for _, image := range newImages {
			fmt.Println("new image", image.ID)
			if !slices.Contains(imageIDs, image.ID) {
				c.CleanupImage(t, image.ID)
			}
		}
	})
}

func (c *HelperClient) InspectContainer(t testing.TB, containerID string) *container.InspectResponse {
	t.Helper()

	inspect, err := c.Client.ContainerInspect(t.Context(), containerID)
	require.NoError(t, err, "Failed to inspect container %q", containerID)

	return &inspect
}

func (c *HelperClient) ImageFixture(t testing.TB, name string, tag string) {
	t.Helper()
	fixture := c.loadImageFixture(t, name)

	t.Logf("Tagging image fixture %q with %q", fixture.ref, tag)
	if err := c.Client.ImageTag(t.Context(), fixture.imageID, tag); err != nil {
		require.NoError(t, err, "Failed to tag image %q with %q: %v", fixture.ref, tag, err)
	}
	// remove the image when the test is done
	t.Cleanup(func() {
		_, _ = c.Client.ImageRemove(context.Background(), tag, image.RemoveOptions{Force: true})
	})
}

func (c *HelperClient) loadImageFixture(t testing.TB, name string) *imageFixture {
	t.Helper()

	c.mu.Lock()
	defer c.mu.Unlock()

	ref := fmt.Sprintf("cog-test-fixture:%s", name)

	if fixture, ok := c.fixtures[ref]; ok {
		return fixture
	}

	// Get the path of the current file
	_, filename, _, ok := runtime.Caller(0)
	if !ok {
		t.Fatal("Could not get current file path")
	}

	// Get the directory of the current file
	dir := filepath.Dir(filename)

	// Construct the path to the fixture
	fixturePath := filepath.Join(dir, "testdata", name+".tar")

	t.Logf("Loading image fixture %q from %s", ref, fixturePath)

	f, err := os.Open(fixturePath)
	require.NoError(t, err, "Failed to open fixture %q", name)
	defer f.Close()

	l, err := c.Client.ImageLoad(t.Context(), f)
	require.NoError(t, err, "Failed to load fixture %q", name)
	defer l.Body.Close()
	_, err = io.Copy(os.Stderr, l.Body)
	require.NoError(t, err, "Failed to copy fixture %q", name)

	inspect, err := c.Client.ImageInspect(t.Context(), ref)
	require.NoError(t, err, "Failed to inspect image %q", ref)

	fixture := &imageFixture{
		ref:     ref,
		imageID: inspect.ID,
	}

	c.fixtures[ref] = fixture

	return fixture
}

type imageFixture struct {
	imageID string
	ref     string
}


================================================
FILE: pkg/docker/dockertest/image.go
================================================
package dockertest

import (
	"fmt"
	"path"
	"strings"
	"testing"
	"time"
)

// ImageRef returns an reference based on the unique test name and label.
// If the label is empty, it will default to "test-" followed by the current unix epoch time.
func ImageRef(t *testing.T, label string) string {
	if label == "" {
		label = fmt.Sprintf("test-%d", time.Now().Unix())
	}

	return fmt.Sprintf("cog-test/%s:%s", strings.ToLower(t.Name()), label)
}

func ImageRefWithRegistry(t *testing.T, registryAddr string, label string) string {
	return path.Join(registryAddr, ImageRef(t, label))
}


================================================
FILE: pkg/docker/dockertest/mock_command.go
================================================
package dockertest

import (
	"context"
	"io"
	"os"
	"path/filepath"

	"github.com/docker/docker/api/types/container"
	"github.com/docker/docker/api/types/image"
	dockerspec "github.com/moby/docker-image-spec/specs-go/v1"
	ocispec "github.com/opencontainers/image-spec/specs-go/v1"

	"github.com/replicate/cog/pkg/docker/command"
)

var PushError error = nil
var MockCogConfig string = "{\"build\":{\"python_version\":\"3.12\",\"python_packages\":[\"torch==2.5.0\",\"beautifulsoup4==4.12.3\"],\"system_packages\":[\"git\"]},\"image\":\"test\",\"predict\":\"predict.py:Predictor\"}"
var MockOpenAPISchema string = "{}"

type MockCommand struct{}

func NewMockCommand() *MockCommand {
	return &MockCommand{}
}

func (c *MockCommand) Pull(ctx context.Context, image string, force bool) (*image.InspectResponse, error) {
	return nil, nil
}

func (c *MockCommand) Push(ctx context.Context, image string) error {
	return PushError
}

func (c *MockCommand) LoadUserInformation(ctx context.Context, registryHost string) (*command.UserInfo, error) {
	userInfo := command.UserInfo{
		Token:    "test-token",
		Username: "test-user",
	}
	return &userInfo, nil
}

func (c *MockCommand) CreateTarFile(ctx context.Context, image string, tmpDir string, tarFile string, folder string) (string, error) {
	path := filepath.Join(tmpDir, tarFile)
	d1 := []byte("hello\ngo\n")
	err := os.WriteFile(path, d1, 0o644)
	if err != nil {
		return "", err
	}
	return path, nil
}

func (c *MockCommand) CreateAptTarFile(ctx context.Context, tmpDir string, aptTarFile string, packages ...string) (string, error) {
	path := filepath.Join(tmpDir, aptTarFile)
	d1 := []byte("hello\ngo\n")
	err := os.WriteFile(path, d1, 0o644)
	if err != nil {
		return "", err
	}
	return path, nil
}

func (c *MockCommand) Inspect(ctx context.Context, ref string) (*image.InspectResponse, error) {
	resp := &image.InspectResponse{
		Config: &dockerspec.DockerOCIImageConfig{
			ImageConfig: ocispec.ImageConfig{
				Labels: map[string]string{
					command.CogConfigLabelKey:        MockCogConfig,
					command.CogOpenAPISchemaLabelKey: MockOpenAPISchema,
					command.CogVersionLabelKey:       "0.11.3",
				},
				Env: []string{
					command.R8TorchVersionEnvVarName + "=2.5.0",
					command.R8CudaVersionEnvVarName + "=2.4",
					command.R8CudnnVersionEnvVarName + "=1.0",
					command.R8PythonVersionEnvVarName + "=3.12",
				},
			},
		},
	}

	return resp, nil
}

func (c *MockCommand) ImageExists(ctx context.Context, ref string) (bool, error) {
	panic("not implemented")
}

func (c *MockCommand) ContainerLogs(ctx context.Context, containerID string, w io.Writer) error {
	panic("not implemented")
}

func (c *MockCommand) ContainerInspect(ctx context.Context, id string) (*container.InspectResponse, error) {
	panic("not implemented")
}

func (c *MockCommand) ContainerStop(ctx context.Context, containerID string) error {
	panic("not implemented")
}

func (c *MockCommand) RemoveImage(ctx context.Context, ref string) error {
	panic("not implemented")
}

func (c *MockCommand) ImageBuild(ctx context.Context, options command.ImageBuildOptions) (string, error) {
	panic("not implemented")
}

func (c *MockCommand) Run(ctx context.Context, options command.RunOptions) error {
	panic("not implemented")
}

func (c *MockCommand) ContainerStart(ctx context.Context, options command.RunOptions) (string, error) {
	panic("not implemented")
}

func (c *MockCommand) ImageSave(ctx context.Context, imageRef string) (io.ReadCloser, error) {
	panic("not implemented")
}


================================================
FILE: pkg/docker/dockertest/ref.go
================================================
package dockertest

import (
	"strings"
	"testing"

	"github.com/google/go-containerregistry/pkg/name"
	"github.com/stretchr/testify/require"
)

type Ref struct {
	t   *testing.T
	ref name.Reference
}

func NewRef(t *testing.T) Ref {
	t.Helper()

	repoName := strings.ToLower(t.Name())
	// Replace any characters that aren't valid in a docker image repo name with underscore
	// Valid characters are: a-z, 0-9, ., _, -, /
	repoName = strings.Map(func(r rune) rune {
		if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '.' || r == '_' || r == '-' || r == '/' {
			return r
		}
		return '_'
	}, repoName)

	ref, err := name.ParseReference(repoName, name.WithDefaultRegistry(""))
	require.NoError(t, err, "Failed to create reference for test")

	return Ref{t: t, ref: ref}
}

func (r Ref) WithTag(tagName string) Ref {
	tagRef := r.ref.Context().Tag(tagName)
	return Ref{t: r.t, ref: tagRef}
}

func (r Ref) WithDigest(digest string) Ref {
	digestRef := r.ref.Context().Digest(digest)
	return Ref{t: r.t, ref: digestRef}
}

func (r Ref) WithRegistry(registry string) Ref {
	reg, err := name.NewRegistry(registry)
	require.NoError(r.t, err, "Failed to create registry for test")

	repo := r.ref.Context()
	repo.Registry = reg
	var newRef name.Reference
	switch r.ref.(type) {
	case name.Tag:
		newRef = repo.Tag(r.ref.Identifier())
	case name.Digest:
		newRef = repo.Digest(r.ref.Identifier())
	default:
		require.Fail(r.t, "Unsupported reference type")
	}

	return Ref{t: r.t, ref: newRef}
}

func (r Ref) WithoutRegistry() Ref {
	repo := r.ref.Context()
	repo.Registry = name.Registry{}
	var newRef name.Reference
	switch r.ref.(type) {
	case name.Tag:
		newRef = repo.Tag(r.ref.Identifier())
	case name.Digest:
		newRef = repo.Digest(r.ref.Identifier())
	default:
		require.Fail(r.t, "Unsupported reference type")
	}

	return Ref{t: r.t, ref: newRef}
}

func (r Ref) String() string {
	return r.ref.Name()
}


================================================
FILE: pkg/docker/dockertest/ref_test.go
================================================
package dockertest

import (
	"testing"

	"github.com/stretchr/testify/assert"
)

func TestRef(t *testing.T) {
	ref := NewRef(t)
	assert.Equal(t, "testref:latest", ref.String())

	ref = ref.WithTag("v2")
	assert.Equal(t, "testref:v2", ref.String())

	ref = ref.WithRegistry("r8.im")
	assert.Equal(t, "r8.im/testref:v2", ref.String())

	ref = ref.WithoutRegistry()
	assert.Equal(t, "testref:v2", ref.String())

	ref = ref.WithDigest("sha256:71859b0c62df47efaeae4f93698b56a8dddafbf041778fd668bbd1ab45a864f8")
	assert.Equal(t, "testref@sha256:71859b0c62df47efaeae4f93698b56a8dddafbf041778fd668bbd1ab45a864f8", ref.String())
}


================================================
FILE: pkg/docker/dockertest/testdata/create-image-fixtures.sh
================================================
#!/usr/bin/env bash
set -euo pipefail

SRC="amd64/alpine:3.14"
TAG="cog-test-fixture:alpine"

echo "Creating test image fixtures"

docker pull $SRC

docker tag $SRC $TAG

docker save -o alpine.tar $TAG

docker rmi $TAG

echo "Test fixtures created"


================================================
FILE: pkg/docker/env.go
================================================
package docker

import "os"

const DockerCommandEnvVarName = "R8_DOCKER_COMMAND"

func DockerCommandFromEnvironment() string {
	command := os.Getenv(DockerCommandEnvVarName)
	if command == "" {
		command = "docker"
	}
	return command
}


================================================
FILE: pkg/docker/errors.go
================================================
package docker

import (
	"errors"
	"strings"
)

// Error messages vary between different backends (dockerd, containerd, podman, orbstack, etc) or even versions of docker.
// These helpers normalize the check so callers can handle situations without worrying about the underlying implementation.
// Yes, it's gross, but whattaya gonna do

func isTagNotFoundError(err error) bool {
	msg := err.Error()
	return strings.Contains(msg, "tag does not exist") ||
		strings.Contains(msg, "An image does not exist locally with the tag")
}

func isAuthorizationFailedError(err error) bool {
	msg := err.Error()

	// registry requires auth and none were provided
	if strings.Contains(msg, "no basic auth credentials") {
		return true
	}

	// registry rejected the provided auth
	if strings.Contains(msg, "authorization failed") ||
		strings.Contains(msg, "401 Unauthorized") ||
		strings.Contains(msg, "unauthorized: authentication required") {
		return true
	}

	return false
}

// isRepositoryNotFoundError checks if the error indicates that the repository
// doesn't exist on the registry. This typically means the model hasn't been
// created on Replicate yet.
func isRepositoryNotFoundError(err error) bool {
	msg := err.Error()
	// NAME_UNKNOWN is an OCI registry error code meaning "repository name not known to registry"
	return strings.Contains(msg, "NAME_UNKNOWN")
}

func isMissingDeviceDriverError(err error) bool {
	msg := err.Error()
	return strings.Contains(msg, "could not select device driver") ||
		strings.Contains(msg, "nvidia-container-cli: initialization error")
}

// isNetworkError checks if the error is a network error. This is janky and intended for use in tests only
func isNetworkError(err error) bool {
	// for both CLI and API clients, network errors are wrapped and lose the net.Error interface
	// CLI client: wrapped by exec.Command as exec.ExitError
	// API client: wrapped by JSON message stream processing
	// Sad as it may be, we rely on string matching for common network error messages

	msg := err.Error()
	networkErrorStrings := []string{
		"connection refused",
		"connection reset by peer",
		"dial tcp",
		"EOF",
		"no route to host",
		"network is unreachable",
		"server closed",
	}

	for _, errStr := range networkErrorStrings {
		if strings.Contains(msg, errStr) {
			return true
		}
	}

	// also check wrapped errors
	if unwrapped := errors.Unwrap(err); unwrapped != nil {
		return isNetworkError(unwrapped)
	}

	return false
}


================================================
FILE: pkg/docker/host.go
================================================
package docker

import (
	"fmt"
	"os"

	dconfig "github.com/docker/cli/cli/config"
	dctxdocker "github.com/docker/cli/cli/context/docker"
	dctxstore "github.com/docker/cli/cli/context/store"

	"github.com/replicate/cog/pkg/util/console"
)

// determineDockerHost returns the host to use for the docker client.
// It first checks the DOCKER_HOST environment variable, then the docker context, and finally the system default.
func determineDockerHost() (string, error) {
	// 1) if DOCKER_HOST is set, use it
	if host := os.Getenv("DOCKER_HOST"); host != "" {
		console.Debug("using docker host from DOCKER_HOST")

		return host, nil
	}

	// 2) try to get a host from the docker context. Use DOCKER_CONTEXT if set, otherwise check the current context
	if host, contextName, err := dockerHostFromContext(os.Getenv("DOCKER_CONTEXT")); err != nil {
		console.Debugf("could not find docker host from context %q: %v", contextName, err)

		// if DOCKER_CONTEXT was explicitly set, return an error since the user probably expects that context to be used
		if os.Getenv("DOCKER_CONTEXT") != "" {
			return "", err
		}
	} else if host != "" {
		console.Debugf("using docker host from context %q", contextName)

		return host, nil
	}

	console.Debug("using system default docker host")

	// 3) if we couldn't get a host from env or context, fallback to the system default
	return defaultDockerHost, nil
}

func dockerHostFromContext(contextName string) (string, string, error) {
	if contextName == "" {
		cf, err := dconfig.Load(dconfig.Dir())
		if err != nil {
			return "", "", fmt.Errorf("error loading docker config: %w", err)
		}
		contextName = cf.CurrentContext
	}

	typeGetter := func() any { return &dctxdocker.EndpointMeta{} }
	storeConfig := dctxstore.NewConfig(typeGetter, dctxstore.EndpointTypeGetter(dctxdocker.DockerEndpoint, typeGetter))

	store := dctxstore.New(dconfig.ContextStoreDir(), storeConfig)
	meta, err := store.GetMetadata(contextName)
	if err != nil {
		return "", contextName, fmt.Errorf("error getting metadata for context %q: %w", contextName, err)
	}

	endpoint, ok := meta.Endpoints[dctxdocker.DockerEndpoint]
	if !ok {
		return "", contextName, fmt.Errorf("no docker endpoints found for context %q", contextName)
	}

	dockerEPMeta, ok := endpoint.(dctxdocker.EndpointMeta)
	if !ok {
		return "", contextName, fmt.Errorf("invalid context config: %v", endpoint)
	}

	if dockerEPMeta.Host == "" {
		return "", contextName, fmt.Errorf("no host found for context %q", contextName)
	}

	return dockerEPMeta.Host, contextName, nil
}


================================================
FILE: pkg/docker/host_unix.go
================================================
//go:build !windows

package docker

const (
	defaultDockerHost = "unix:///var/run/docker.sock"
)


================================================
FILE: pkg/docker/host_windows.go
================================================
package docker

const (
	defaultDockerHost = "npipe:////.pipe/docker_engine"
)


================================================
FILE: pkg/docker/login.go
================================================
package docker

import (
	"context"
	"encoding/json"
	"fmt"
	"os"
	"os/exec"
	"strings"

	"github.com/docker/cli/cli/config"
	"github.com/docker/cli/cli/config/configfile"
	"github.com/docker/cli/cli/config/types"

	"github.com/replicate/cog/pkg/util/console"
)

func SaveLoginToken(ctx context.Context, registryHost string, username string, token string) error {
	conf := config.LoadDefaultConfigFile(os.Stderr)
	credsStore := conf.CredentialsStore
	if credsStore == "" {
		return saveAuthToConfig(conf, registryHost, username, token)
	}
	return saveAuthToCredentialsStore(ctx, credsStore, registryHost, username, token)
}

func saveAuthToConfig(conf *configfile.ConfigFile, registryHost string, username string, token string) error {
	// conf.Save() will base64 encode username and password
	conf.AuthConfigs[registryHost] = types.AuthConfig{
		Username: username,
		Password: token,
	}
	if err := conf.Save(); err != nil {
		return fmt.Errorf("Failed to save Docker config.json: %w", err)
	}
	return nil
}

func saveAuthToCredentialsStore(ctx context.Context, credsStore string, registryHost string, username string, token string) error {
	binary := dockerCredentialBinary(credsStore)
	input := CredentialHelperInput{
		Username:  username,
		Secret:    token,
		ServerURL: registryHost,
	}
	cmd := exec.CommandContext(ctx, binary, "store") //nolint:gosec // G702: binary is from Docker config, not user input
	cmd.Env = os.Environ()
	cmd.Stderr = os.Stderr
	stdin, err := cmd.StdinPipe()
	if err != nil {
		return fmt.Errorf("Failed to connect stdin to %s: %w", binary, err)
	}
	console.Debug("$ " + strings.Join(cmd.Args, " "))
	if err := cmd.Start(); err != nil {
		return fmt.Errorf("Failed to start %s: %w", binary, err)
	}
	if err := json.NewEncoder(stdin).Encode(input); err != nil {
		return fmt.Errorf("Failed to write to %s: %w", binary, err)
	}
	if err := stdin.Close(); err != nil {
		return fmt.Errorf("Failed to close stdin to %s: %w", binary, err)
	}
	if err := cmd.Wait(); err != nil {
		return fmt.Errorf("Failed to run %s: %w", binary, err)
	}
	return nil
}


================================================
FILE: pkg/docker/options.go
================================================
package docker

import "github.com/docker/docker/api/types/registry"

type clientOptions struct {
	authConfigs map[string]registry.AuthConfig
	host        string
}

type Option func(*clientOptions)

func WithAuthConfig(authConfig registry.AuthConfig) Option {
	return func(o *clientOptions) {
		o.authConfigs[authConfig.ServerAddress] = authConfig
	}
}

func WithHost(host string) Option {
	return func(o *clientOptions) {
		o.host = host
	}
}


================================================
FILE: pkg/docker/progress.go
================================================
package docker

import (
	"encoding/json"
	"io"
	"os"
	"sync"

	"github.com/docker/docker/pkg/jsonmessage"

	"github.com/replicate/cog/pkg/util/console"
)

// ProgressWriter adapts push progress callbacks to Docker's jsonmessage rendering.
//
// This uses the same ANSI cursor movement and progress display as `docker push`,
// which handles terminal resizing correctly: each line is erased and rewritten
// individually (ESC[2K + cursor up/down per line), rather than relying on a
// bulk cursor-up count that can desync when lines wrap after a terminal resize.
type ProgressWriter struct {
	mu   sync.Mutex
	pw   *io.PipeWriter
	done chan error
	once sync.Once
}

// NewProgressWriter creates a ProgressWriter that renders push progress to stderr
// using Docker's jsonmessage format, matching the output of `docker push`.
func NewProgressWriter() *ProgressWriter {
	pr, pw := io.Pipe()
	isTTY := console.IsTTY(os.Stderr)
	done := make(chan error, 1)

	go func() {
		done <- jsonmessage.DisplayJSONMessagesStream(pr, os.Stderr, os.Stderr.Fd(), isTTY, nil)
	}()

	return &ProgressWriter{
		pw:   pw,
		done: done,
	}
}

// Write sends a progress update for a specific layer/artifact.
// id is a unique identifier for the item (layer digest, artifact name).
// status is the current operation (e.g. "Pushing").
// current and total are the byte counts for the progress bar.
func (p *ProgressWriter) Write(id, status string, current, total int64) {
	msg := jsonmessage.JSONMessage{
		ID:     id,
		Status: status,
		Progress: &jsonmessage.JSONProgress{
			Current: current,
			Total:   total,
		},
	}
	p.writeMessage(msg)
}

// WriteStatus sends a status-only message for a specific layer/artifact
// (no progress bar), e.g. "Pushed", "FAILED", or retry messages.
func (p *ProgressWriter) WriteStatus(id, status string) {
	msg := jsonmessage.JSONMessage{
		ID:     id,
		Status: status,
	}
	p.writeMessage(msg)
}

func (p *ProgressWriter) writeMessage(msg jsonmessage.JSONMessage) {
	p.mu.Lock()
	defer p.mu.Unlock()

	if p.pw == nil {
		return
	}

	data, err := json.Marshal(msg)
	if err != nil {
		return
	}
	data = append(data, '\n')
	_, _ = p.pw.Write(data)
}

// Close shuts down the progress display. Safe to call multiple times.
func (p *ProgressWriter) Close() {
	p.once.Do(func() {
		p.mu.Lock()
		pw := p.pw
		p.pw = nil
		p.mu.Unlock()

		if pw != nil {
			_ = pw.Close()
			<-p.done
		}
	})
}


================================================
FILE: pkg/docker/push.go
================================================
package docker

import (
	"context"
	"net/http"
	"time"

	"github.com/replicate/cog/pkg/docker/command"
	"github.com/replicate/cog/pkg/util/console"
	"github.com/replicate/cog/pkg/web"
)

type BuildInfo struct {
	BuildTime time.Duration
	BuildID   string
}

func Push(ctx context.Context, image string, projectDir string, command command.Command, buildInfo BuildInfo, client *http.Client) error {
	webClient := web.NewClient(command, client)

	if err := webClient.PostPushStart(ctx, buildInfo.BuildID, buildInfo.BuildTime); err != nil {
		console.Warnf("Failed to send build timings to server: %v", err)
	}

	return StandardPush(ctx, image, command)
}


================================================
FILE: pkg/docker/run.go
================================================
package docker

import (
	"context"
	"errors"
	"fmt"
	"io"
	"os"
	"strconv"

	"github.com/docker/go-connections/nat"

	"github.com/replicate/cog/pkg/docker/command"
	"github.com/replicate/cog/pkg/util/console"
)

var ErrMissingDeviceDriver = errors.New("Docker is missing required device driver")

func Run(ctx context.Context, dockerClient command.Command, options command.RunOptions) error {
	return RunWithIO(ctx, dockerClient, options, os.Stdin, os.Stdout, os.Stderr)
}

func RunWithIO(ctx context.Context, dockerClient command.Command, options command.RunOptions, stdin io.Reader, stdout, stderr io.Writer) error {
	options.Stdin = stdin
	options.Stdout = stdout
	options.Stderr = stderr
	// TODO[md]: we're gonna stop passing the entire host env to the container by default, if users indeed rely on that behavior we can uncomment this line:
	// options.Env = append(os.Environ(), options.Env...)
	return dockerClient.Run(ctx, options)
}

func RunDaemon(ctx context.Context, dockerClient command.Command, options command.RunOptions, stderr io.Writer) (string, error) {
	options.Stderr = stderr
	return dockerClient.ContainerStart(ctx, options)
}

func GetHostPortForContainer(ctx context.Context, dockerCommand command.Command, containerID string, containerPort int) (int, error) {
	console.Debugf("=== DockerCommand.GetPort %s/%d", containerID, containerPort)

	inspect, err := dockerCommand.ContainerInspect(ctx, containerID)
	if err != nil {
		return 0, fmt.Errorf("failed to inspect container %q: %w", containerID, err)
	}

	if inspect.ContainerJSONBase == nil || inspect.State == nil || !inspect.State.Running {
		return 0, fmt.Errorf("container %s is not running", containerID)
	}

	targetPort, err := nat.NewPort("tcp", strconv.Itoa(containerPort))
	if err != nil {
		return 0, fmt.Errorf("failed to create target port: %w", err)
	}

	if inspect.NetworkSettings == nil || inspect.NetworkSettings.Ports == nil {
		return 0, fmt.Errorf("container %s does not have expected network configuration", containerID)
	}

	for _, portBinding := range inspect.NetworkSettings.Ports[targetPort] {
		// TODO[md]: this should not be hardcoded since docker may be bound to a different address
		if portBinding.HostIP != "0.0.0.0" {
			continue
		}
		hostPort, err := nat.ParsePort(portBinding.HostPort)
		if err != nil {
			return 0, fmt.Errorf("failed to parse host port: %w", err)
		}
		return hostPort, nil
	}

	return 0, fmt.Errorf("container %s does not have a port bound to 0.0.0.0", containerID)
}


================================================
FILE: pkg/docker/run_test.go
================================================
//nolint:staticcheck // container.NetworkSettingsBase deprecated but Ports field moving to NetworkSettings in docker v29
package docker

import (
	"testing"

	"github.com/docker/docker/api/types/container"
	"github.com/docker/go-connections/nat"
	"github.com/stretchr/testify/require"

	"github.com/replicate/cog/pkg/docker/dockertest"
)

func TestGetHostPortForContainer(t *testing.T) {
	t.Run("WithExposedPort", func(t *testing.T) {
		testClient := dockertest.NewMockCommand2(t)
		testClient.EXPECT().ContainerInspect(t.Context(), "container123").Return(&container.InspectResponse{
			ContainerJSONBase: &container.ContainerJSONBase{
				State: &container.State{
					Status:  "running",
					Running: true,
				},
			},
			NetworkSettings: &container.NetworkSettings{
				NetworkSettingsBase: container.NetworkSettingsBase{
					Ports: nat.PortMap{
						nat.Port("5678/tcp"): []nat.PortBinding{
							{
								HostIP:   "0.0.0.0",
								HostPort: "12345",
							},
						},
					},
				},
			},
		}, nil)

		hostPort, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678)
		require.NoError(t, err)
		require.Equal(t, 12345, hostPort)
	})

	t.Run("WithMultipleExposedPorts", func(t *testing.T) {
		testClient := dockertest.NewMockCommand2(t)
		testClient.EXPECT().ContainerInspect(t.Context(), "container123").Return(&container.InspectResponse{
			ContainerJSONBase: &container.ContainerJSONBase{
				State: &container.State{
					Status:  "running",
					Running: true,
				},
			},
			NetworkSettings: &container.NetworkSettings{
				NetworkSettingsBase: container.NetworkSettingsBase{
					Ports: nat.PortMap{
						nat.Port("5678/tcp"): []nat.PortBinding{
							{
								HostIP:   "0.0.0.0",
								HostPort: "12345",
							},
							{
								HostIP:   "0.0.0.0",
								HostPort: "54321",
							},
						},
					},
				},
			},
		}, nil)

		hostPort, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678)
		require.NoError(t, err)
		require.Equal(t, 12345, hostPort)
	})

	t.Run("WithExposedPortOnDifferentAddress", func(t *testing.T) {
		testClient := dockertest.NewMockCommand2(t)
		testClient.EXPECT().ContainerInspect(t.Context(), "container123").Return(&container.InspectResponse{
			ContainerJSONBase: &container.ContainerJSONBase{
				State: &container.State{
					Status:  "running",
					Running: true,
				},
			},
			NetworkSettings: &container.NetworkSettings{
				NetworkSettingsBase: container.NetworkSettingsBase{
					Ports: nat.PortMap{
						nat.Port("5678/tcp"): []nat.PortBinding{
							{
								HostIP:   "127.0.0.1",
								HostPort: "12345",
							},
						},
					},
				},
			},
		}, nil)

		_, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678)
		require.ErrorContains(t, err, "does not have a port bound to 0.0.0.0")
	})

	t.Run("WithDifferentPortExposed", func(t *testing.T) {
		testClient := dockertest.NewMockCommand2(t)
		testClient.EXPECT().ContainerInspect(t.Context(), "container123").Return(&container.InspectResponse{
			ContainerJSONBase: &container.ContainerJSONBase{
				State: &container.State{
					Status:  "running",
					Running: true,
				},
			},
			NetworkSettings: &container.NetworkSettings{
				NetworkSettingsBase: container.NetworkSettingsBase{
					Ports: nat.PortMap{
						nat.Port("1234/tcp"): []nat.PortBinding{
							{
								HostIP:   "0.0.0.0",
								HostPort: "12345",
							},
						},
					},
				},
			},
		}, nil)

		_, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678)
		require.ErrorContains(t, err, "does not have a port bound to 0.0.0.0")
	})

	t.Run("WithNoExposedPort", func(t *testing.T) {
		testClient := dockertest.NewMockCommand2(t)
		testClient.EXPECT().ContainerInspect(t.Context(), "container123").Return(&container.InspectResponse{
			ContainerJSONBase: &container.ContainerJSONBase{
				State: &container.State{
					Status:  "running",
					Running: true,
				},
			},
		}, nil)

		_, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678)
		require.ErrorContains(t, err, "does not have expected network configuration")
	})

	t.Run("ContainerNotRunning", func(t *testing.T) {
		testClient := dockertest.NewMockCommand2(t)
		testClient.EXPECT().ContainerInspect(t.Context(), "container123").Return(&container.InspectResponse{
			ContainerJSONBase: &container.ContainerJSONBase{
				State: &container.State{
					Status: "dead",
					Dead:   true,
				},
			},
		}, nil)

		_, err := GetHostPortForContainer(t.Context(), testClient, "container123", 5678)
		require.ErrorContains(t, err, "is not running")
	})
}


================================================
FILE: pkg/docker/standard_push.go
================================================
package docker

import (
	"context"

	"github.com/replicate/cog/pkg/docker/command"
)

func StandardPush(ctx context.Context, image string, command command.Command) error {
	return command.Push(ctx, image)
}


================================================
FILE: pkg/docker/standard_push_test.go
================================================
package docker

import (
	"testing"

	"github.com/stretchr/testify/require"

	"github.com/replicate/cog/pkg/docker/dockertest"
)

func TestStandardPush(t *testing.T) {
	command := dockertest.NewMockCommand()
	dockertest.PushError = nil
	err := StandardPush(t.Context(), "test", command)
	require.NoError(t, err)
}


================================================
FILE: pkg/dockercontext/build_tempdir.go
================================================
package dockercontext

import (
	"os"
	"path"
	"time"

	"github.com/replicate/cog/pkg/global"
)

func CogBuildArtifactsDirPath(dir string) (string, error) {
	tmpDir := path.Join(dir, global.CogBuildArtifactsFolder)
	err := os.MkdirAll(tmpDir, 0o777)
	if err != nil {
		return "", err
	}
	return tmpDir, nil
}

func CogTempDir(dir string, contextDir string) (string, error) {
	tmpDir, err := CogBuildArtifactsDirPath(dir)
	if err != nil {
		return "", err
	}
	return path.Join(tmpDir, "tmp", contextDir), nil
}

func BuildCogTempDir(dir string, subDir string) (string, error) {
	rootTmp, err := CogTempDir(dir, subDir)
	if err != nil {
		return "", err
	}
	if err := os.MkdirAll(rootTmp, 0o777); err != nil {
		return "", err
	}
	return rootTmp, nil
}

func BuildTempDir(dir string) (string, error) {
	// tmpDir ends up being something like dir/.cog/tmp/build20240620123456.000000
	now := time.Now().Format("20060102150405.000000")
	return BuildCogTempDir(dir, "build"+now)
}


================================================
FILE: pkg/dockercontext/build_tempdir_test.go
================================================
package dockercontext

import (
	"path/filepath"
	"testing"

	"github.com/stretchr/testify/require"
)

func TestBuildCogTempDir(t *testing.T) {
	tmpDir := t.TempDir()
	cogTmpDir, err := BuildCogTempDir(tmpDir, "weights")
	require.NoError(t, err)
	require.Equal(t, filepath.Join(tmpDir, ".cog/tmp/weights"), cogTmpDir)
}


================================================
FILE: pkg/dockercontext/directories.go
================================================
package dockercontext

import "path/filepath"

const StandardBuildDirectory = "."

const ContextBuildDir = "context"
const AptBuildContextName = "apt"
const RequirementsBuildContextName = "requirements"
const SrcBuildContextName = "src"

var SrcBuildDir = filepath.Join(ContextBuildDir, "src")
var AptBuildDir = filepath.Join(ContextBuildDir, "apt")
var RequirementsBuildDir = filepath.Join(ContextBuildDir, "requirements")


================================================
FILE: pkg/dockerfile/base.go
================================================
package dockerfile

import (
	"context"
	"encoding/json"
	"fmt"
	"strings"

	"github.com/replicate/cog/pkg/config"
	"github.com/replicate/cog/pkg/docker/command"
	"github.com/replicate/cog/pkg/global"
	"github.com/replicate/cog/pkg/registry"
	"github.com/replicate/cog/pkg/util/version"
)

const MinimumCUDAVersion = "11.6"
const MinimumPythonVersion = "3.10"
const MinimumTorchVersion = "1.13.1"
const CogBaseImageName = "cog-base"

var (
	baseImageSystemPackages = []string{
		"build-essential",
		"cmake",
		"curl",
		"ffmpeg",
		"findutils",
		"g++",
		"gcc",
		"git",
		"libavcodec-dev",
		"libcairo2-dev",
		"libfontconfig1",
		"libgirepository1.0-dev",
		"libgl1",
		"libglx-mesa0",
		"libglib2.0-0",
		"libopencv-dev",
		"libsm6",
		"libsndfile1",
		"libssl-dev",
		"libunistring-dev",
		"libxext6",
		"libxrender1",
		"sox",
		"unzip",
		"wget",
		"zip",
		"zstd",
	}
)

type CUDAVersion struct {
	Version string `json:"versions"`
}

type PyTorchVersion struct {
	Version string `json:"version"`
}

type PythonVersion struct {
	Version string           `json:"version"`
	PyTorch []PyTorchVersion `json:"pytorch"`
	CUDA    []CUDAVersion    `json:"cuda"`
}

type AvailableBaseImageConfigurations struct {
	PythonVersions []PythonVersion `json:"python_versions"`
}

type BaseImageConfiguration struct {
	CUDAVersion   string `json:"cuda_version" yaml:"cuda_version"`
	PythonVersion string `json:"python_version" yaml:"python_version"`
	TorchVersion  string `json:"torch_version" yaml:"torch_version"`
}

type BaseImageGenerator struct {
	cudaVersion   string
	pythonVersion string
	torchVersion  string
	command       command.Command
	client        registry.Client
}

func (b BaseImageConfiguration) MarshalJSON() ([]byte, error) {
	type Alias BaseImageConfiguration
	type BaseImageConfigWithImageName struct {
		Alias
		ImageName string `json:"image_name,omitempty" yaml:"image_name,omitempty"`
		Tag       string `json:"image_tag,omitempty" yaml:"image_tag,omitempty"`
	}

	rawName := BaseImageName(b.CUDAVersion, b.PythonVersion, b.TorchVersion)
	rawName = strings.TrimPrefix(rawName, global.ReplicateRegistryHost+"/")
	split := strings.Split(rawName, ":")
	if len(split) != 2 {
		return nil, fmt.Errorf("invalid base image name and tag: %s", rawName)
	}
	imageName, tag := split[0], split[1]

	alias := &BaseImageConfigWithImageName{
		Alias:     Alias(b),
		ImageName: imageName,
		Tag:       tag,
	}
	return json.Marshal(alias)
}

// BaseImageConfigurations returns a list of CUDA/Python/Torch versions
func BaseImageConfigurations() []BaseImageConfiguration {
	configs := []BaseImageConfiguration{}

	// Assuming that the Torch versions cover all Python and CUDA versions to avoid
	// having to hard-code a list of Python versions here.
	pythonVersionsSet := make(map[string]bool)
	cudaVersionsSet := make(map[string]bool)

	// Torch configs
	for _, compat := range config.TorchCompatibilityMatrix {
		for _, python := range compat.Pythons {
			if !version.GreaterOrEqual(python, MinimumPythonVersion) || !version.GreaterOrEqual(compat.Torch, MinimumTorchVersion) {
				continue
			}

			if compat.CUDA == nil {
				configs = append(configs, BaseImageConfiguration{
					PythonVersion: python,
					TorchVersion:  compat.Torch,
				})
			} else {
				cuda := *compat.CUDA
				torch := compat.Torch
				conf := BaseImageConfiguration{
					CUDAVersion:   cuda,
					PythonVersion: python,
					TorchVersion:  torch,
				}
				if version.GreaterOrEqual(cuda, MinimumCUDAVersion) {
					configs = append(configs, conf)
					pythonVersionsSet[python] = true
					cudaVersionsSet[cuda] = true
				}
			}
		}
	}

	// Python and CUDA-only configs
	for python := range pythonVersionsSet {
		for cuda := range cudaVersionsSet {
			configs = append(configs, BaseImageConfiguration{
				CUDAVersion:   cuda,
				PythonVersion: python,
			})
		}
	}

	// Python-only configs
	for python := range pythonVersionsSet {
		configs = append(configs, BaseImageConfiguration{
			PythonVersion: python,
		})
	}

	return configs
}

func NewBaseImageGenerator(ctx context.Context, client registry.Client, cudaVersion string, pythonVersion string, torchVersion string, command command.Command, generate bool) (*BaseImageGenerator, error) {
	valid, cudaVersion, pythonVersion, torchVersion, err := BaseImageConfigurationExists(ctx, client, cudaVersion, pythonVersion, torchVersion, generate)
	if err != nil {
		return nil, err
	}
	if valid {
		return &BaseImageGenerator{cudaVersion, pythonVersion, torchVersion, command, client}, nil
	}
	printNone := func(s string) string {
		if s == "" {
			return "(none)"
		}
		return s
	}
	return nil, fmt.Errorf("unsupported base image configuration: CUDA: %s / Python: %s / Torch: %s", printNone(cudaVersion), printNone(pythonVersion), printNone(torchVersion))
}

func (g *BaseImageGenerator) GenerateDockerfile(ctx context.Context) (string, error) {
	conf, err := g.makeConfig()
	if err != nil {
		return "", err
	}

	generator, err := NewGenerator(conf, "", "", g.command, g.client, false)
	if err != nil {
		return "", err
	}
	useCogBaseImage := false
	generator.SetUseCogBaseImagePtr(&useCogBaseImage)

	dockerfile, err := generator.GenerateInitialSteps(ctx)
	if err != nil {
		return "", err
	}

	return dockerfile, nil
}

func (g *BaseImageGenerator) makeConfig() (*config.Config, error) {
	conf := &config.Config{
		Build: &config.Build{
			GPU:            g.cudaVersion != "",
			PythonVersion:  g.pythonVersion,
			PythonPackages: g.pythonPackages(),
			Run:            g.runStatements(),
			SystemPackages: baseImageSystemPackages,
			CUDA:           g.cudaVersion,
		},
	}
	if err := conf.Complete(""); err != nil {
		return nil, err
	}
	return conf, nil
}

func (g *BaseImageGenerator) pythonPackages() []string {
	if g.torchVersion != "" {
		pkgs := []string{
			"torch==" + g.torchVersion,
			"opencv-python==4.12.0.88",
		}

		// Find torchvision compatibility.
		for _, compat := range config.TorchCompatibilityMatrix {
			if len(compat.Torchvision) == 0 {
				continue
			}
			if !version.Matches(g.torchVersion, compat.TorchVersion()) {
				continue
			}

			pkgs = append(pkgs, "torchvision=="+compat.Torchvision)
			break
		}

		// Find torchaudio compatibility.
		for _, compat := range config.TorchCompatibilityMatrix {
			if len(compat.Torchaudio) == 0 {
				continue
			}
			if !version.Matches(g.torchVersion, compat.TorchVersion()) {
				continue
			}

			pkgs = append(pkgs, "torchaudio=="+compat.Torchaudio)
			break
		}

		return pkgs
	}
	return []string{}
}

func (g *BaseImageGenerator) runStatements() []config.RunItem {
	return []config.RunItem{}
}

func baseImageComponentNormalisation(cudaVersion string, pythonVersion string, torchVersion string) (string, string, string) {
	compatibleTorchVersion := ""
	for _, conf := range BaseImageConfigurations() {
		// Check CUDA version compatibility
		if !isVersionCompatible(conf.CUDAVersion, cudaVersion) {
			continue
		}

		// Check Python version compatibility
		if !isVersionCompatible(conf.PythonVersion, pythonVersion) {
			continue
		}

		// Check Torch version compatibility
		if !isVersionCompatible(conf.TorchVersion, torchVersion) {
			continue
		}

		if compatibleTorchVersion == "" || version.Greater(conf.TorchVersion, compatibleTorchVersion) {
			compatibleTorchVersion = version.StripModifier(conf.TorchVersion)
		}
	}

	return cudaVersion, pythonVersion, compatibleTorchVersion
}

func BaseImageName(cudaVersion string, pythonVersion string, torchVersion string) string {
	cudaVersion, pythonVersion, torchVersion = baseImageComponentNormalisation(cudaVersion, pythonVersion, torchVersion)

	components := []string{}
	if cudaVersion != "" {
		components = append(components, "cuda"+version.StripPatch(cudaVersion))
	}
	if pythonVersion != "" {
		components = append(components, "python"+version.StripPatch(pythonVersion))
	}
	if torchVersion != "" {
		components = append(components, "torch"+version.StripModifier(torchVersion))
	}

	tag := strings.Join(components, "-")
	if tag == "" {
		tag = "latest"
	}

	return global.ReplicateRegistryHost + "/" + CogBaseImageName + ":" + tag
}

func BaseImageConfigurationExists(ctx context.Context, client registry.Client, cudaVersion, pythonVersion, torchVersion string, generate bool) (bool, string, string, string, error) {
	cudaVersion, pythonVersion, torchVersion = baseImageComponentNormalisation(cudaVersion, pythonVersion, torchVersion)

	valid := false
	for _, conf := range BaseImageConfigurations() {
		// Check CUDA version compatibility
		if !isVersionCompatible(conf.CUDAVersion, cudaVersion) {
			continue
		}

		// Check Python version compatibility
		if !isVersionCompatible(conf.PythonVersion, pythonVersion) {
			continue
		}

		// Check Torch version compatibility
		if !isVersionCompatible(conf.TorchVersion, torchVersion) {
			continue
		}

		valid = true
	}

	var err error
	if valid && !generate {
		valid, err = client.Exists(ctx, BaseImageName(cudaVersion, pythonVersion, torchVersion))
	}

	return valid, cudaVersion, pythonVersion, torchVersion, err
}

func isVersionCompatible(confVersion, requestedVersion string) bool {
	if confVersion == "" || requestedVersion == "" {
		return confVersion == requestedVersion
	}
	return version.Matches(requestedVersion, confVersion)
}


================================================
FILE: pkg/dockerfile/base_test.go
================================================
package dockerfile

import (
	"reflect"
	"strings"
	"testing"

	"github.com/stretchr/testify/require"

	"github.com/replicate/cog/pkg/docker/dockertest"
	"github.com/replicate/cog/pkg/registry/registrytest"
)

func TestBaseImageName(t *testing.T) {
	for _, tt := range []struct {
		cuda     string
		python   string
		torch    string
		expected string
	}{
		{"", "3.10", "",
			"r8.im/cog-base:python3.10"},
		{"", "3.10", "2.1",
			"r8.im/cog-base:python3.10-torch2.1.2"},
		{"12.1", "3.10", "",
			"r8.im/cog-base:cuda12.1-python3.10"},
		{"12.1", "3.10", "2.1",
			"r8.im/cog-base:cuda12.1-python3.10-torch2.1.2"},
		{"12.1", "3.10", "2.1",
			"r8.im/cog-base:cuda12.1-python3.10-torch2.1.2"},
	} {
		actual := BaseImageName(tt.cuda, tt.python, tt.torch)
		require.Equal(t, tt.expected, actual)
	}
}

func TestGenerateDockerfile(t *testing.T) {
	cudaVersion := "12.1"
	pythonVersion := "3.10"
	torchVersion := "2.1.0"
	client := registrytest.NewMockRegistryClient()
	client.AddMockImage(BaseImageName(cudaVersion, pythonVersion, torchVersion))
	command := dockertest.NewMockCommand()
	generator, err := NewBaseImageGenerator(
		t.Context(),
		client,
		cudaVersion,
		pythonVersion,
		torchVersion,
		command,
		false,
	)
	require.NoError(t, err)
	dockerfile, err := generator.GenerateDockerfile(t.Context())
	require.NoError(t, err)
	require.True(t, strings.Contains(dockerfile, "FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04"))
}

func TestBaseImageNameWithVersionModifier(t *testing.T) {
	actual := BaseImageName("11.8", "3.10", "2.0.1+cu118")
	require.Equal(t, "r8.im/cog-base:cuda11.8-python3.10-torch2.0.1", actual)
}

func TestBaseImageConfigurationExists(t *testing.T) {
	cudaVersion := "12.1"
	pythonVersion := "3.10"
	torchVersion := "2.3"
	client := registrytest.NewMockRegistryClient()
	client.AddMockImage(BaseImageName(cudaVersion, pythonVersion, torchVersion))
	exists, _, _, torchVersion, err := BaseImageConfigurationExists(t.Context(), client, cudaVersion, pythonVersion, torchVersion, false)
	require.NoError(t, err)
	require.True(t, exists)
	require.Equal(t, "2.3.1", torchVersion)
}

func TestBaseImageConfigurationExistsNoTorch(t *testing.T) {
	cudaVersion := ""
	pythonVersion := "3.12"
	torchVersion := ""
	client := registrytest.NewMockRegistryClient()
	client.AddMockImage(BaseImageName(cudaVersion, pythonVersion, torchVersion))
	exists, _, _, _, err := BaseImageConfigurationExists(t.Context(), client, cudaVersion, pythonVersion, torchVersion, false)
	require.NoError(t, err)
	require.True(t, exists)
}

func TestBaseImageConfigurationExistsNoCUDA(t *testing.T) {
	cudaVersion := ""
	pythonVersion := "3.10"
	torchVersion := "2.1"
	client := registrytest.NewMockRegistryClient()
	client.AddMockImage(BaseImageName(cudaVersion, pythonVersion, torchVersion))
	exists, _, _, torchVersion, err := BaseImageConfigurationExists(t.Context(), client, cudaVersion, pythonVersion, torchVersion, false)
	require.NoError(t, err)
	require.True(t, exists)
	require.Equal(t, "2.1.2", torchVersion)
}

func TestIsVersionCompatible(t *testing.T) {
	compatible := isVersionCompatible("2.3.1+cu121", "2.3")
	require.True(t, compatible)
}

func TestPythonPackages(t *testing.T) {
	cudaVersion := "12.1"
	pythonVersion := "3.10"
	torchVersion := "2.1.0"
	command := dockertest.NewMockCommand()
	client := registrytest.NewMockRegistryClient()
	client.AddMockImage(BaseImageName(cudaVersion, pythonVersion, torchVersion))
	generator, err := NewBaseImageGenerator(t.Context(), client, cudaVersion, pythonVersion, torchVersion, command, false)
	require.NoError(t, err)
	pkgs := generator.pythonPackages()
	require.Truef(t, reflect.DeepEqual(pkgs, []string{
		"torch==" + torchVersion,
		"opencv-python==4.12.0.88",
		"torchvision==0.16.0",
		"torchaudio==2.1.0",
	}), "expected %v", pkgs)
}

func TestInvalidBaseImage(t *testing.T) {
	command := dockertest.NewMockCommand()
	client := registrytest.NewMockRegistryClient()
	_, err := NewBaseImageGenerator(t.Context(), client, "12.78", "3.10", "2.1.0", command, false)
	require.Error(t, err)
}

func TestBaseImageConfigurationNoTorchPythonVersionDoesNotExist(t *testing.T) {
	client := registrytest.NewMockRegistryClient()
	exists, _, _, _, err := BaseImageConfigurationExists(t.Context(), client, "", "3.99", "", false)
	require.NoError(t, err)
	require.False(t, exists)
}


================================================
FILE: pkg/dockerfile/cacert.go
================================================
package dockerfile

import (
	"encoding/base64"
	"fmt"
	"os"
	"path/filepath"
	"strings"

	"github.com/replicate/cog/pkg/util/console"
)

const (
	// CACertEnvVar is the environment variable that specifies the CA certificate to inject
	CACertEnvVar = "COG_CA_CERT"

	// CACertFilename is the filename used for the CA cert in the build context and container
	CACertFilename = "cog-ca-cert.crt"

	// CACertContainerPath is where the cert is installed in the container
	CACertContainerPath = "/usr/local/share/ca-certificates/" + CACertFilename

	// SystemCertBundle is the path to the system certificate bundle after update-ca-certificates
	SystemCertBundle = "/etc/ssl/certs/ca-certificates.crt"
)

// ReadCACert reads the CA certificate from the COG_CA_CERT environment variable.
// It supports multiple input formats:
//   - File path: /path/to/cert.crt
//   - Directory: /path/to/certs/ (concatenates all *.crt and *.pem files)
//   - Inline PEM: -----BEGIN CERTIFICATE-----...
//   - Base64-encoded PEM: LS0tLS1CRUdJTi...
//
// Returns:
//   - (nil, nil) if COG_CA_CERT is not set (no-op case)
//   - (certBytes, nil) if a valid certificate was found
//   - (nil, error) if the input is invalid
func ReadCACert() ([]byte, error) {
	value := os.Getenv(CACertEnvVar)
	if value == "" {
		return nil, nil
	}

	value = strings.TrimSpace(value)

	// Check if it's a file path
	if info, err := os.Stat(value); err == nil { //nolint:gosec // G703: path from trusted COG_CA_CERT env var
		if info.IsDir() {
			return readCACertDirectory(value)
		}
		return readCACertFile(value)
	}

	// Check if it's inline PEM
	if strings.HasPrefix(value, "-----BEGIN") {
		return validatePEM([]byte(value))
	}

	// Try base64 decoding
	decoded, err := base64.StdEncoding.DecodeString(value)
	if err == nil && strings.HasPrefix(string(decoded), "-----BEGIN") {
		return validatePEM(decoded)
	}

	return nil, fmt.Errorf("%s: invalid value - must be a file path, directory, PEM certificate, or base64-encoded PEM", CACertEnvVar)
}

// readCACertFile reads a single certificate file
func readCACertFile(path string) ([]byte, error) {
	data, err := os.ReadFile(path) //nolint:gosec // G703: path from trusted COG_CA_CERT env var
	if err != nil {
		return nil, fmt.Errorf("%s: failed to read file %s: %w", CACertEnvVar, path, err)
	}
	return validatePEM(data)
}

// readCACertDirectory reads all .crt and .pem files from a directory and concatenates them
func readCACertDirectory(dir string) ([]byte, error) {
	var certs []byte

	entries, err := os.ReadDir(dir)
	if err != nil {
		return nil, fmt.Errorf("%s: failed to read directory %s: %w", CACertEnvVar, dir, err)
	}

	for _, entry := range entries {
		if entry.IsDir() {
			continue
		}
		ext := strings.ToLower(filepath.Ext(entry.Name()))
		if ext != ".crt" && ext != ".pem" {
			continue
		}

		path := filepath.Join(dir, entry.Name())
		data, err := os.ReadFile(path) //nolint:gosec // G703: path from trusted COG_CA_CERT env var directory
		if err != nil {
			return nil, fmt.Errorf("%s: failed to read file %s: %w", CACertEnvVar, path, err)
		}

		// Validate each cert
		if _, err := validatePEM(data); err != nil {
			return nil, fmt.Errorf("%s: invalid certificate in %s: %w", CACertEnvVar, path, err)
		}

		if len(certs) > 0 && !strings.HasSuffix(string(certs), "\n") {
			certs = append(certs, '\n')
		}
		certs = append(certs, data...)
	}

	if len(certs) == 0 {
		return nil, fmt.Errorf("%s: no .crt or .pem files found in directory %s", CACertEnvVar, dir)
	}

	return certs, nil
}

// validatePEM checks that the data looks like a valid PEM certificate
func validatePEM(data []byte) ([]byte, error) {
	content := strings.TrimSpace(string(data))
	if !strings.HasPrefix(content, "-----BEGIN CERTIFICATE-----") {
		return nil, fmt.Errorf("invalid PEM: must start with '-----BEGIN CERTIFICATE-----'")
	}
	if !strings.Contains(content, "-----END CERTIFICATE-----") {
		return nil, fmt.Errorf("invalid PEM: must contain '-----END CERTIFICATE-----'")
	}
	return []byte(content + "\n"), nil
}

// GenerateCACertInstall generates the Dockerfile lines to install a CA certificate.
// It writes the cert to the build context and returns the Dockerfile lines.
//
// The returned lines:
//  1. COPY the cert to /usr/local/share/ca-certificates/
//  2. RUN update-ca-certificates
//  3. Set SSL_CERT_FILE and REQUESTS_CA_BUNDLE env vars
//
// Parameters:
//   - certData: The PEM-encoded certificate data
//   - writeTemp: Function to write a file to the build context (returns COPY lines and container path)
//
// Returns the Dockerfile lines to add, or error
func GenerateCACertInstall(certData []byte, writeTemp func(filename string, contents []byte) ([]string, string, error)) (string, error) {
	if len(certData) == 0 {
		return "", nil
	}

	console.Infof("Injecting CA certificate from %s", CACertEnvVar)

	// Write cert to build context
	copyLines, _, err := writeTemp(CACertFilename, certData)
	if err != nil {
		return "", fmt.Errorf("failed to write CA certificate to build context: %w", err)
	}

	lines := []string{}
	lines = append(lines, copyLines...)

	// Copy to system CA directory, update the certificate store, and set env vars.
	// Also append the cert directly to the bundle file as a fallback for images
	// where update-ca-certificates may not work as expected.
	lines = append(lines,
		fmt.Sprintf("RUN cp /tmp/%s %s && update-ca-certificates && cat /tmp/%s >> %s", CACertFilename, CACertContainerPath, CACertFilename, SystemCertBundle),
		fmt.Sprintf("ENV SSL_CERT_FILE=%s", SystemCertBundle),
		fmt.Sprintf("ENV REQUESTS_CA_BUNDLE=%s", SystemCertBundle),
	)

	return strings.Join(lines, "\n"), nil
}


================================================
FILE: pkg/dockerfile/cacert_test.go
================================================
package dockerfile

import (
	"encoding/base64"
	"os"
	"path/filepath"
	"strings"
	"testing"

	"github.com/stretchr/testify/require"
)

const testCertPEM = `-----BEGIN CERTIFICATE-----
MIIBkTCB+wIJAKHBfpegPjMCMA0GCSqGSIb3DQEBCwUAMBExDzANBgNVBAMMBlRl
c3RDQTAeFw0yNDAxMDEwMDAwMDBaFw0yNTAxMDEwMDAwMDBaMBExDzANBgNVBAMM
BlRlc3RDQTBcMA0GCSqGSIb3DQEBAQUAA0sAMEgCQQC7o96WzE5gvnMXvPjNdXjH
HwjE7F5Q4X5g5W5P5s5Q5Y5V5y5v5p5o5k5f5d5c5b5a5Z5X5W5U5T5S5R5P5N5L
AgMBAAGjUzBRMB0GA1UdDgQWBBQExample1234567890ABCDEFGHIJKLMN
MB8GA1UdIwQYMBaAFBQExample1234567890ABCDEFGHIJKLMNMA8GA1Ud
EwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADQQBExample1234567890
-----END CERTIFICATE-----`

const testCertPEM2 = `-----BEGIN CERTIFICATE-----
MIIBkTCB+wIJAKHBfpegPjMDMA0GCSqGSIb3DQEBCwUAMBExDzANBgNVBAMMBlRl
c3RDQTAeFw0yNDAxMDEwMDAwMDBaFw0yNTAxMDEwMDAwMDBaMBExDzANBgNVBAMM
BlRlc3RDQTBcMA0GCSqGSIb3DQEBAQUAA0sAMEgCQQC8p97XzF6hvoPYvQkOeYkI
HxkF8G6Q5Y6h6X6Q6Z6W6z6w6q6p6l6g6e6d6c6b6a6Y6X6W6V6U6T6S6R6Q6O6M
AgMBAAGjUzBRMB0GA1UdDgQWBBQExample2222222222ABCDEFGHIJKLMN
MB8GA1UdIwQYMBaAFBQExample2222222222ABCDEFGHIJKLMNMA8GA1Ud
EwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADQQBExample2222222222
-----END CERTIFICATE-----`

func TestReadCACert_NotSet(t *testing.T) {
	os.Unsetenv(CACertEnvVar)

	cert, err := ReadCACert()
	require.NoError(t, err)
	require.Nil(t, cert)
}

func TestReadCACert_FilePath(t *testing.T) {
	tmpDir := t.TempDir()
	certPath := filepath.Join(tmpDir, "test.crt")
	require.NoError(t, os.WriteFile(certPath, []byte(testCertPEM), 0o644))

	t.Setenv(CACertEnvVar, certPath)

	cert, err := ReadCACert()
	require.NoError(t, err)
	require.NotNil(t, cert)
	require.Contains(t, string(cert), "-----BEGIN CERTIFICATE-----")
	require.Contains(t, string(cert), "-----END CERTIFICATE-----")
}

func TestReadCACert_Directory(t *testing.T) {
	tmpDir := t.TempDir()

	// Write two cert files
	require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "cert1.crt"), []byte(testCertPEM), 0o644))
	require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "cert2.pem"), []byte(testCertPEM2), 0o644))
	// Also write a non-cert file that should be ignored
	require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "readme.txt"), []byte("ignore me"), 0o644))

	t.Setenv(CACertEnvVar, tmpDir)

	cert, err := ReadCACert()
	require.NoError(t, err)
	require.NotNil(t, cert)
	// Should contain both certificates
	require.Equal(t, 2, strings.Count(string(cert), "-----BEGIN CERTIFICATE-----"))
	require.Equal(t, 2, strings.Count(string(cert), "-----END CERTIFICATE-----"))
}

func TestReadCACert_InlinePEM(t *testing.T) {
	t.Setenv(CACertEnvVar, testCertPEM)

	cert, err := ReadCACert()
	require.NoError(t, err)
	require.NotNil(t, cert)
	require.Contains(t, string(cert), "-----BEGIN CERTIFICATE-----")
}

func TestReadCACert_Base64EncodedPEM(t *testing.T) {
	encoded := base64.StdEncoding.EncodeToString([]byte(testCertPEM))
	t.Setenv(CACertEnvVar, encoded)

	cert, err := ReadCACert()
	require.NoError(t, err)
	require.NotNil(t, cert)
	require.Contains(t, string(cert), "-----BEGIN CERTIFICATE-----")
}

func TestReadCACert_InvalidPEM(t *testing.T) {
	t.Setenv(CACertEnvVar, "not a valid certificate")

	cert, err := ReadCACert()
	require.Error(t, err)
	require.Nil(t, cert)
	require.Contains(t, err.Error(), "invalid value")
}

func TestReadCACert_MissingFile(t *testing.T) {
	t.Setenv(CACertEnvVar, "/nonexistent/path/to/cert.crt")

	cert, err := ReadCACert()
	require.Error(t, err)
	require.Nil(t, cert)
	require.Contains(t, err.Error(), "invalid value")
}

func TestReadCACert_EmptyDirectory(t *testing.T) {
	tmpDir := t.TempDir()
	t.Setenv(CACertEnvVar, tmpDir)

	cert, err := ReadCACert()
	require.Error(t, err)
	require.Nil(t, cert)
	require.Contains(t, err.Error(), "no .crt or .pem files found")
}

func TestReadCACert_InvalidFileInDirectory(t *testing.T) {
	tmpDir := t.TempDir()
	// Write an invalid cert file
	require.NoError(t, os.WriteFile(filepath.Join(tmpDir, "bad.crt"), []byte("not a cert"), 0o644))

	t.Setenv(CACertEnvVar, tmpDir)

	cert, err := ReadCACert()
	require.Error(t, err)
	require.Nil(t, cert)
	require.Contains(t, err.Error(), "invalid certificate")
}

func TestReadCACert_TrimsWhitespace(t *testing.T) {
	// Test that whitespace is trimmed from the env var value
	t.Setenv(CACertEnvVar, "  "+testCertPEM+"  \n")

	cert, err := ReadCACert()
	require.NoError(t, err)
	require.NotNil(t, cert)
	require.True(t, strings.HasPrefix(string(cert), "-----BEGIN"))
}

func TestGenerateCACertInstall(t *testing.T) {
	var writtenFilename string
	var writtenContents []byte

	mockWriteTemp := func(filename string, contents []byte) ([]string, string, error) {
		writtenFilename = filename
		writtenContents = contents
		return []string{"COPY .cog/tmp/" + filename + " /tmp/" + filename}, "/tmp/" + filename, nil
	}

	result, err := GenerateCACertInstall([]byte(testCertPEM), mockWriteTemp)
	require.NoError(t, err)

	// Check that the cert was written
	require.Equal(t, CACertFilename, writtenFilename)
	require.Contains(t, string(writtenContents), "-----BEGIN CERTIFICATE-----")

	// Check the generated Dockerfile lines
	require.Contains(t, result, "COPY .cog/tmp/"+CACertFilename)
	require.Contains(t, result, "update-ca-certificates")
	require.Contains(t, result, "ENV SSL_CERT_FILE="+SystemCertBundle)
	require.Contains(t, result, "ENV REQUESTS_CA_BUNDLE="+SystemCertBundle)
}

func TestGenerateCACertInstall_EmptyData(t *testing.T) {
	mockWriteTemp := func(filename string, contents []byte) ([]string, string, error) {
		t.Fatal("writeTemp should not be called for empty data")
		return nil, "", nil
	}

	result, err := GenerateCACertInstall(nil, mockWriteTemp)
	require.NoError(t, err)
	require.Empty(t, result)

	result, err = GenerateCACertInstall([]byte{}, mockWriteTemp)
	require.NoError(t, err)
	require.Empty(t, result)
}


================================================
FILE: pkg/dockerfile/env.go
================================================
package dockerfile

import (
	"maps"
	"slices"

	"github.com/replicate/cog/pkg/config"
)

func envLineFromConfig(c *config.Config) (string, error) {
	vars := c.ParsedEnvironment()
	if len(vars) == 0 {
		return "", nil
	}

	out := "ENV"
	for _, name := range slices.Sorted(maps.Keys(vars)) {
		out = out + " " + name + "=" + vars[name]
	}
	out += "\n"

	return out, nil
}


================================================
FILE: pkg/dockerfile/generator.go
================================================
package dockerfile

import (
	"context"

	"github.com/replicate/cog/pkg/weights"
)

type Generator interface {
	GenerateInitialSteps(ctx context.Context) (string, error)
	SetUseCogBaseImage(bool)
	SetUseCogBaseImagePtr(*bool)
	GenerateModelBaseWithSeparateWeights(ctx context.Context, imageName string) (string, string, string, error)
	Cleanup() error
	SetStrip(bool)
	SetPrecompile(bool)
	SetUseCudaBaseImage(string)
	IsUsingCogBaseImage() bool
	BaseImage(ctx context.Context) (string, error)
	GenerateWeightsManifest(ctx context.Context) (*weights.Manifest, error)
	GenerateDockerfileWithoutSeparateWeights(ctx context.Context) (string, error)
	GenerateModelBase(ctx context.Context) (string, error)
	Name() string
	BuildDir() (string, error)
	BuildContexts() (map[string]string, error)
}


================================================
FILE: pkg/dockerfile/generator_factory.go
================================================
package dockerfile

import (
	"github.com/replicate/cog/pkg/config"
	"github.com/replicate/cog/pkg/docker/command"
	"github.com/replicate/cog/pkg/registry"
)

func NewGenerator(config *config.Config, dir string, configFilename string, command command.Command, client registry.Client, requiresCog bool) (Generator, error) {
	return NewStandardGenerator(config, dir, configFilename, command, client, requiresCog)
}


================================================
FILE: pkg/dockerfile/generator_factory_test.go
================================================
package dockerfile

import (
	"testing"

	"github.com/stretchr/testify/require"

	"github.com/replicate/cog/pkg/config"
	"github.com/replicate/cog/pkg/docker/dockertest"
	"github.com/replicate/cog/pkg/registry/registrytest"
)

func TestGeneratorFactoryStandardGenerator(t *testing.T) {
	dir := t.TempDir()
	build := config.Build{
		PythonPackages: []string{"torch==2.5.1"},
	}
	cfg := config.Config{
		Build: &build,
	}
	command := dockertest.NewMockCommand()
	client := registrytest.NewMockRegistryClient()
	generator, err := NewGenerator(&cfg, dir, "", command, client, true)
	require.NoError(t, err)
	require.Equal(t, generator.Name(), STANDARD_GENERATOR_NAME)
}


================================================
FILE: pkg/dockerfile/standard_generator.go
================================================
package dockerfile

import (
	"context"
	"fmt"
	"os"
	"path"
	"path/filepath"
	"slices"
	"strings"

	"github.com/replicate/cog/pkg/config"
	"github.com/replicate/cog/pkg/docker/command"
	"github.com/replicate/cog/pkg/dockercontext"
	"github.com/replicate/cog/pkg/registry"
	"github.com/replicate/cog/pkg/requirements"
	"github.com/replicate/cog/pkg/util/console"
	"github.com/replicate/cog/pkg/util/version"
	"github.com/replicate/cog/pkg/weights"
	"github.com/replicate/cog/pkg/wheels"
)

const DockerignoreHeader = `# generated by replicate/cog
__pycache__
*.pyc
*.pyo
*.pyd
.Python
env
pip-log.txt
pip-delete-this-directory.txt
.tox
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.log
.git
.mypy_cache
.pytest_cache
.hypothesis
`
const LDConfigCacheBuildCommand = "RUN find / -type f -name \"*python*.so\" -printf \"%h\\n\" | sort -u > /etc/ld.so.conf.d/cog.conf && ldconfig"
const StripDebugSymbolsCommand = "find / -type f -name \"*python*.so\" -not -name \"*cpython*.so\" -exec strip -S {} \\;"
const CFlags = "ENV CFLAGS=\"-O3 -funroll-loops -fno-strict-aliasing -flto -S\""
const UVVersion = "0.9.26"
const uvCacheMount = "--mount=type=cache,target=/root/.cache/uv"
const uvPip = "uv pip"
const PrecompilePythonCommand = "RUN find / -type f -name \"*.py[co]\" -delete && find / -type f -name \"*.py\" -exec touch -t 197001010000 {} \\; && find / -type f -name \"*.py\" -printf \"%h\\n\" | sort -u | /usr/bin/python3 -m compileall --invalidation-mode timestamp -o 2 -j 0"
const STANDARD_GENERATOR_NAME = "STANDARD_GENERATOR"

type StandardGenerator struct {
	Config         *config.Config
	Dir            string
	ConfigFilename string // Base filename like "cog.yaml" or "my-config.yaml"

	// these are here to make this type testable
	GOOS   string
	GOARCH string

	useCudaBaseImage bool
	useCogBaseImage  *bool
	strip            bool
	precompile       bool

	// absolute path to tmpDir, a directory that will be cleaned up
	tmpDir string
	// tmpDir relative to Dir
	relativeTmpDir string

	fileWalker weights.FileWalker

	modelDirs  []string
	modelFiles []string

	pythonRequirementsContents string
	command                    command.Command
	client                     registry.Client
	requiresCog                bool

	// Optional overrides for wheel configs (used by tests for deterministic output).
	// When nil, auto-detection is used (env var → dist/ → PyPI).
	cogWheelConfig    *wheels.WheelConfig
	cogletWheelConfig *wheels.WheelConfig

	// Resolved wheel configs — set once by resolveCogWheelConfigs() and shared
	// between filterManagedPackages() (for warnings) and installCog() (for install).
	resolvedCogConfig    *wheels.WheelConfig
	resolvedCogletConfig *wheels.WheelConfig
}

func NewStandardGenerator(config *config.Config, dir string, configFilename string, command command.Command, client registry.Client, requiresCog bool) (*StandardGenerator, error) {
	tmpDir, err := dockercontext.BuildTempDir(dir)
	if err != nil {
		return nil, err
	}
	// tmpDir, but without dir prefix. This is the path used in the Dockerfile.
	relativeTmpDir, err := filepath.Rel(dir, tmpDir)
	if err != nil {
		return nil, err
	}

	// Default to "cog.yaml" if not specified
	if configFilename == "" {
		configFilename = "cog.yaml"
	}

	return &StandardGenerator{
		Config:         config,
		Dir:            dir,
		ConfigFilename: configFilename,
		// Docker build target is always linux/amd64 (see pkg/docker/buildkit.go).
		// These must match the container platform, not the host.
		GOOS:             "linux",
		GOARCH:           "amd64",
		tmpDir:           tmpDir,
		relativeTmpDir:   relativeTmpDir,
		fileWalker:       filepath.Walk,
		useCudaBaseImage: true,
		useCogBaseImage:  nil,
		strip:            false,
		precompile:       false,
		command:          command,
		client:           client,
		requiresCog:      requiresCog,
	}, nil
}

func (g *StandardGenerator) SetUseCudaBaseImage(argumentValue string) {
	// "false" -> false, "true" -> true, "auto" -> true, "asdf" -> true
	g.useCudaBaseImage = argumentValue != "false"
}

func (g *StandardGenerator) SetUseCogBaseImage(useCogBaseImage bool) {
	g.useCogBaseImage = new(bool)
	*g.useCogBaseImage = useCogBaseImage
}

func (g *StandardGenerator) SetUseCogBaseImagePtr(useCogBaseImage *bool) {
	g.useCogBaseImage = useCogBaseImage
}

func (g *StandardGenerator) IsUsingCogBaseImage() bool {
	useCogBaseImage := g.useCogBaseImage
	if useCogBaseImage != nil {
		return *useCogBaseImage
	}
	return true
}

func (g *StandardGenerator) SetStrip(strip bool) {
	g.strip = strip
}

func (g *StandardGenerator) SetPrecompile(precompile bool) {
	g.precompile = precompile
}

func (g *StandardGenerator) GenerateInitialSteps(ctx context.Context) (string, error) {
	baseImage, err := g.BaseImage(ctx)
	if err != nil {
		return "", err
	}
	installPython, err := g.installPython()
	if err != nil {
		return "", err
	}
	aptInstalls, err := g.aptInstalls()
	if err != nil {
		return "", err
	}
	envs, err := g.envVars()
	if err != nil {
		return "", err
	}
	runCommands, err := g.runCommands()
	if err != nil {
		return "", err
	}
	pipInstalls, err := g.pipInstalls()
	if err != nil {
		return "", err
	}
	installCog, err := g.installCog()
	if err != nil {
		return "", err
	}
	installCACert, err := g.installCACert()
	if err != nil {
		return "", err
	}

	if g.IsUsingCogBaseImage() {
		steps := []string{
			"#syntax=docker/dockerfile:1.4",
			"FROM " + baseImage,
			installCACert, // First! Before any network requests (apt, pip, etc.)
			envs,
			aptInstalls,
			g.installUV(),
		}
		if installCog != "" {
			steps = append(steps, installCog)
		}
		steps = append(steps, pipInstalls)
		if g.precompile {
			steps = append(steps, PrecompilePythonCommand)
		}
		steps = append(steps, runCommands)

		return joinStringsWithoutLineSpace(steps), nil
	}

	// For the CUDA path, uv is installed inside installPython (after the apt step).
	// For all other paths (python:X-slim), install uv after apt.
	uvInstall := ""
	if installPython == "" {
		uvInstall = g.installUV()
	}
	steps := []string{
		"#syntax=docker/dockerfile:1.4",
		"FROM " + baseImage,
		g.preamble(),
		installCACert, // Early! Before tini (uses curl), apt, pip, etc.
		g.installTini(),
		envs,
		aptInstalls,
		uvInstall,
		installPython,
		pipInstalls,
		installCog,
	}
	if g.precompile {
		steps = append(steps, PrecompilePythonCommand)
	}
	steps = append(steps, LDConfigCacheBuildCommand, runCommands)

	return joinStringsWithoutLineSpace(steps), nil
}

func (g *StandardGenerator) GenerateModelBase(ctx context.Context) (string, error) {
	initialSteps, err := g.GenerateInitialSteps(ctx)
	if err != nil {
		return "", err
	}
	steps := []string{
		initialSteps,
		`WORKDIR /src`,
		`EXPOSE 5000`,
	}
	steps = append(steps, g.cogEnvVars()...)
	steps = append(steps, `CMD ["python", "-m", "cog.server.http"]`)
	return strings.Join(steps, "\n"), nil
}

// GenerateDockerfileWithoutSeparateWeights generates a Dockerfile that doesn't write model weights to a separate layer.
func (g *StandardGenerator) GenerateDockerfileWithoutSeparateWeights(ctx context.Context) (string, error) {
	base, err := g.GenerateModelBase(ctx)
	if err != nil {
		return "", err
	}
	bases := []string{
		base,
		`COPY . /src`,
	}
	if m := g.cpCogYaml(); m != "" {
		bases = append(bases, m)
	}
	return joinStringsWithoutLineSpace(bases), nil
}

// GenerateModelBaseWithSeparateWeights creates the Dockerfile and .dockerignore file contents for model weights
// It returns four values:
// - weightsBase: The base image used for Dockerfile generation for model weights.
// - dockerfile: A string that represents the Dockerfile content generated by the function.
// - dockerignoreContents: A string that represents the .dockerignore content.
// - err: An error object if an error occurred during Dockerfile generation; otherwise nil.
func (g *StandardGenerator) GenerateModelBaseWithSeparateWeights(ctx context.Context, imageName string) (weightsBase string, dockerfile string, dockerignoreContents string, err error) {
	weightsBase, g.modelDirs, g.modelFiles, err = g.generateForWeights()
	if err != nil {
		return "", "", "", fmt.Errorf("Failed to generate Dockerfile for model weights files: %w", err)
	}
	initialSteps, err := g.GenerateInitialSteps(ctx)
	if err != nil {
		return "", "", "", err
	}

	// Inject weights base image into initial steps so we can COPY from it
	base := []string{}
	initialStepsLines := strings.Split(initialSteps, "\n")
	for i, line := range initialStepsLines {
		if strings.HasPrefix(line, "FROM ") {
			base = append(base, fmt.Sprintf("FROM %s AS %s", imageName+"-weights", "weights"))
			base = append(base, initialStepsLines[i:]...)
			break
		} else {
			base = append(base, line)
		}
	}

	for _, p := range append(g.modelDirs, g.modelFiles...) {
		base = append(base, "COPY --from=weights --link "+path.Join("/src", p)+" "+path.Join("/src", p))
	}

	base = append(base,
		`WORKDIR /src`,
		`EXPOSE 5000`,
	)
	base = append(base, g.cogEnvVars()...)
	base = append(base,
		`CMD ["python", "-m", "cog.server.http"]`,
		`COPY . /src`,
	)
	if m := g.cpCogYaml(); m != "" {
		base = append(base, m)
	}

	dockerignoreContents = makeDockerignoreForWeights(g.modelDirs, g.modelFiles)
	return weightsBase, joinStringsWithoutLineSpace(base), dockerignoreContents, nil
}

// cogEnvVars returns ENV lines that pass cog.yaml config to the runtime
// so the container doesn't need to parse cog.yaml at startup.
func (g *StandardGenerator) cogEnvVars() []string {
	var envs []string
	if g.Config.Predict != "" {
		envs = append(envs, fmt.Sprintf(`ENV COG_PREDICT_TYPE_STUB="%s"`, g.Config.Predict))
	}
	if g.Config.Train != "" {
		envs = append(envs, fmt.Sprintf(`ENV COG_TRAIN_TYPE_STUB="%s"`, g.Config.Train))
	}
	if g.Config.Concurrency != nil && g.Config.Concurrency.Max > 0 {
		envs = append(envs, fmt.Sprintf(`ENV COG_MAX_CONCURRENCY=%d`, g.Config.Concurrency.Max))
	}
	return envs
}

func (g *StandardGenerator) cpCogYaml() string {
	if g.ConfigFilename == "" || g.ConfigFilename == "cog.yaml" {
		return ""
	}
	// Absolute filename doesn't work anyway, so it's always relative
	return fmt.Sprintf("RUN cp %s /src/cog.yaml", filepath.Join("/src", g.ConfigFilename))
}

func (g *StandardGenerator) generateForWeights() (string, []string, []string, error) {
	modelDirs, modelFiles, err := weights.FindWeights(g.fileWalker)
	if err != nil {
		return "", nil, nil, err
	}
	// generate dockerfile to store these model weights files
	var dockerfileContents strings.Builder
	dockerfileContents.WriteString(`#syntax=docker/dockerfile:1.4
FROM scratch
`)
	for _, p := range append(modelDirs, modelFiles...) {
		fmt.Fprintf(&dockerfileContents, "\nCOPY %s %s", p, path.Join("/src", p))
	}

	return dockerfileContents.String(), modelDirs, modelFiles, nil
}

func makeDockerignoreForWeights(dirs, files []string) string {
	var contents strings.Builder
	for _, p := range dirs {
		fmt.Fprintf(&contents, "%[1]s\n%[1]s/**/*\n", p)
	}
	for _, p := range files {
		fmt.Fprintf(&contents, "%[1]s\n", p)
	}
	return DockerignoreHeader + contents.String()
}

func (g *StandardGenerator) Cleanup() error {
	if err := os.RemoveAll(g.tmpDir); err != nil {
		return fmt.Errorf("Failed to clean up %s: %w", g.tmpDir, err)
	}
	return nil
}

func (g *StandardGenerator) BaseImage(ctx context.Context) (string, error) {
	if g.IsUsingCogBaseImage() {
		baseImage, err := g.determineBaseImageName(ctx)
		if err == nil || g.useCogBaseImage != nil {
			return baseImage, err
		}
		console.Warnf("Could not find a suitable base image, continuing without base image support (%v).", err)
		if g.useCogBaseImage == nil {
			g.useCogBaseImage = new(bool)
			*g.useCogBaseImage = false
		}
	}

	if g.Config.Build.GPU && g.useCudaBaseImage {
		return g.Config.CUDABaseImageTag()
	}
	return "python:" + g.Config.Build.PythonVersion + "-slim", nil
}

func (g *StandardGenerator) Name() string {
	return STANDARD_GENERATOR_NAME
}

func (g *StandardGenerator) BuildDir() (string, error) {
	return dockercontext.StandardBuildDirectory, nil
}

func (g *StandardGenerator) BuildContexts() (map[string]string, error) {
	return map[string]string{}, nil
}

func (g *StandardGenerator) preamble() string {
	return `ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/x86_64-linux-gnu:/usr/local/nvidia/lib64:/usr/local/nvidia/bin
ENV NVIDIA_DRIVER_CAPABILITIES=all`
}

func (g *StandardGenerator) installTini() string {
	// Install tini as the image entrypoint to provide signal handling and process
	// reaping appropriate for PID 1.
	//
	// N.B. If you remove/change this, consider removing/changing the `has_init`
	// image label applied in image/build.go.
	lines := []string{
		`RUN --mount=type=cache,target=/var/cache/apt,sharing=locked set -eux; \
apt-get update -qq && \
apt-get install -qqy --no-install-recommends curl; \
rm -rf /var/lib/apt/lists/*; \
TINI_VERSION=v0.19.0; \
TINI_ARCH="$(dpkg --print-architecture)"; \
curl -sSL -o /sbin/tini "https://github.com/krallin/tini/releases/download/${TINI_VERSION}/tini-${TINI_ARCH}"; \
chmod +x /sbin/tini`,
		`ENTRYPOINT ["/sbin/tini", "--"]`,
	}
	return strings.Join(lines, "\n")
}

func (g *StandardGenerator) aptInstalls() (string, error) {
	packages := g.Config.Build.SystemPackages
	if len(packages) == 0 {
		return "", nil
	}

	if g.IsUsingCogBaseImage() {
		// Filter out packages that are already in the base image
		packages = slices.DeleteFunc(slices.Clone(packages), func(pkg string) bool {
			return slices.Contains(baseImageSystemPackages, pkg)
		})
	}

	return "RUN --mount=type=cache,target=/var/cache/apt,sharing=locked apt-get update -qq && apt-get install -qqy " +
		strings.Join(packages, " ") +
		" && rm -rf /var/lib/apt/lists/*", nil
}

func (g *StandardGenerator) installPython() (string, error) {
	if g.Config.Build.GPU && g.useCudaBaseImage && !g.IsUsingCogBaseImage() {
		return g.installPythonCUDA()
	}
	return "", nil
}

func (g *StandardGenerator) installUV() string {
	return `COPY --from=ghcr.io/astral-sh/uv:` + UVVersion + ` /uv /uvx /usr/local/bin/
ENV UV_SYSTEM_PYTHON=true`
}

func (g *StandardGenerator) installPythonCUDA() (string, error) {
	// TODO: check that python version is valid

	py := g.Config.Build.PythonVersion
	return `RUN --mount=type=cache,target=/var/cache/apt,sharing=locked apt-get update -qq && apt-get install -qqy --no-install-recommends \
	wget \
	curl \
	xz-utils \
	git \
	ca-certificates \
	&& rm -rf /var/lib/apt/lists/*
` + g.installUV() + "\n" + fmt.Sprintf(`RUN uv python install %s && \
	ln -sf $(uv python find %s) /usr/bin/python3
ENV UV_PYTHON=%s
ENV PATH="/usr/local/bin:$PATH"`, py, py, py), nil
}

// resolveCogWheelConfigs resolves and caches the cog and coglet wheel configs.
// It is idempotent — subsequent calls are no-ops. Must be called before
// filterManagedPackages() and installCog().
//
// Precedence for cog SDK:
//  1. Test override (cogWheelConfig field)
//  2. COG_SDK_WHEEL env var
//  3. build.sdk_version in cog.yaml
//  4. Auto-detect dist/ (dev builds only)
//  5. Latest PyPI
//
// Precedence for coglet:
//  1. Test override (cogletWheelConfig field)
//  2. COGLET_WHEEL env var
//  3. Auto-detect dist/ (dev builds only)
//  4. Latest PyPI
func (g *StandardGenerator) resolveCogWheelConfigs() error {
	if g.resolvedCogConfig != nil {
		return nil // already resolved
	}

	var err error

	// Resolve cog SDK
	if g.cogWheelConfig != nil {
		g.resolvedCogConfig = g.cogWheelConfig
	} else if envVal := os.Getenv(wheels.CogSDKWheelEnvVar); envVal != "" {
		g.resolvedCogConfig, err = wheels.GetCogWheelConfig()
		if err != nil {
			return err
		}
	} else if g.Config.Build != nil && g.Config.Build.SDKVersion != "" {
		g.resolvedCogConfig = &wheels.WheelConfig{
			Source:  wheels.WheelSourcePyPI,
			Version: g.Config.Build.SDKVersion,
		}
	} else {
		g.resolvedCogConfig, err = wheels.GetCogWheelConfig()
		if err != nil {
			return err
		}
	}

	// Validate: refuse versions older than the minimum supported SDK
	if err := wheels.ValidateSDKVersion(g.resolvedCogConfig, "cog"); err != nil {
		return err
	}

	// Resolve coglet
	if g.cogletWheelConfig != nil {
		g.resolvedCogletConfig = g.cogletWheelConfig
	} else {
		g.resolvedCogletConfig, err = wheels.GetCogletWheelConfig(g.GOARCH)
		if err != nil {
			return err
		}
	}

	return nil
}

// cogletMinSDKVersion is the minimum SDK version that supports coglet.
// Older SDKs use the built-in Python HTTP server and are incompatible with coglet.
const cogletMinSDKVersion = "0.17.0"

// isLegacySDKVersion returns true if the resolved cog SDK version is explicitly
// pinned below the minimum version that supports coglet. Returns false for
// unpinned, non-PyPI, or unparseable versions (assume modern).
func (g *StandardGenerator) isLegacySDKVersion() bool {
	cfg := g.resolvedCogConfig
	if cfg == nil || cfg.Source != wheels.WheelSourcePyPI || cfg.Version == "" {
		return false
	}
	base := cfg.Version
	if m := wheels.BaseVersionRe.FindString(base); m != "" {
		base = m
	}
	ver, err := version.NewVersion(base)
	if err != nil {
		return false
	}
	return !ver.GreaterOrEqual(version.MustVersion(cogletMinSDKVersion))
}

func (g *StandardGenerator) installCog() (string, error) {
	// Do not install Cog in base images
	if !g.requiresCog {
		return "", nil
	}

	if err := g.resolveCogWheelConfigs(); err != nil {
		return "", err
	}
	wheelConfig := g.resolvedCogConfig

	// Determine if we need --pre flag (pre-release SDK implies pre-release coglet too)
	sdkIsPreRelease := wheelConfig.Source == wheels.WheelSourcePyPI && wheels.IsPreRelease(wheelConfig.Version)

	// Only install coglet explicitly when there's a specific source:
	//   - COGLET_WHEEL env var (explicit override)
	//   - Local wheel from dist/ (dev/CI auto-detect)
	//   - PyPI with pinned version (e.g. COGLET_WHEEL=pypi:0.17.0)
	// Otherwise, let the SDK's own dependency handle it — cog >= 0.17.0 declares
	// coglet as a hard dependency, older versions don't install it.
	//
	// Never install coglet when the SDK is explicitly pinned to < 0.17.0 — those
	// versions use the built-in Python HTTP server and are incompatible with coglet.
	var installLines string
	cogletConfig := g.resolvedCogletConfig
	explicitCoglet := cogletConfig != nil &&
		(cogletConfig.Source == wheels.WheelSourceFile ||
			cogletConfig.Source == wheels.WheelSourceURL ||
			(cogletConfig.Source == wheels.WheelSourcePyPI && cogletConfig.Version != ""))
	if explicitCoglet && g.isLegacySDKVersion() {
		console.Info("Skipping coglet install for legacy SDK")
		explicitCoglet = false
	}
	if explicitCoglet {
		switch cogletConfig.Source {
		case wheels.WheelSourcePyPI:
			console.Infof("Using coglet from PyPI: %s", cogletConfig.PyPIPackageURL("coglet"))
		case wheels.WheelSourceURL:
			console.Infof("Using coglet wheel from URL: %s", cogletConfig.URL)
		case wheels.WheelSourceFile:
			console.Debugf("Using local coglet wheel: %s", cogletConfig.Path)
		}

		cogletIsPreRelease := sdkIsPreRelease ||
			(cogletConfig.Source == wheels.WheelSourcePyPI && wheels.IsPreRelease(cogletConfig.Version))

		cogletInstall, err := g.installCogletWheel(cogletConfig, cogletIsPreRelease)
		if err != nil {
			return "", fmt.Errorf("failed to install coglet wheel: %w", err)
		}
		if cogletInstall != "" {
			installLines = cogletInstall
		}
	}

	// Install cog SDK
	var cogInstall string
	var err error
	switch wheelConfig.Source {
	case wheels.WheelSourcePyPI:
		cogInstall, err = g.installCogFromPyPI(wheelConfig, sdkIsPreRelease)
	case wheels.WheelSourceURL:
		console.Infof("Using cog wheel from URL: %s", wheelConfig.URL)
		cogInstall, err = g.installWheelFromURL(wheelConfig.URL)
	case wheels.WheelSourceFile:
		console.Debugf("Using local cog wheel: %s", wheelConfig.Path)
		cogInstall, err = g.installWheelFromFile(wheelConfig.Path)
	default:
		return "", fmt.Errorf("unknown wheel source: %v", wheelConfig.Source)
	}
	if err != nil {
		return "", err
	}
	if cogInstall != "" {
		if installLines != "" {
			installLines += "\n"
		}
		installLines += cogInstall
	}

	return installLines, nil
}

// installCogFromPyPI installs the cog SDK from PyPI.
// preRelease adds --pre to allow pip to resolve pre-release packages.
func (g *StandardGenerator) installCogFromPyPI(config *wheels.WheelConfig, preRelease bool) (string, error) {
	packageSpec := config.PyPIPackageURL("cog")
	flags := "--no-cache"
	if preRelease {
		flags = "--pre " + flags
	}
	pipInstallLine := "RUN " + uvCacheMount + " " + uvPip + " install " + flags + " " + packageSpec
	if g.strip {
		pipInstallLine += " && " + StripDebugSymbolsCommand
	}
	lines := []string{CFlags, pipInstallLine, "ENV CFLAGS="}
	return strings.Join(lines, "\n"), nil
}

// installWheelFromURL installs a wheel from a URL (when COG_SDK_WHEEL=https://...)
func (g *StandardGenerator) installWheelFromURL(url string) (string, error) {
	// Set coglet env vars if this looks like a coglet wheel
	var envLines []string
	if strings.Contains(url, "coglet") {
		if !CheckMajorMinorOnly(g.Config.Build.PythonVersion) {
			return "", fmt.Errorf("Python version must be . for coglet")
		}
		envLines = []string{
			"ENV R8_COG_VERSION=coglet",
			"ENV R8_PYTHON_VERSION=" + g.Config.Build.PythonVersion,
		}
	}

	// For coglet URLs, uninstall cog first to avoid conflicts with coglet's cog shim package.
	// Some base images (e.g. r8.im/cog-base) have cog pre-installed, which conflicts
	// with coglet's cog compatibility shim that provides the same module paths.
	var pipPrefix string
	if strings.Contains(url, "coglet") {
		pipPrefix = uvPip + " uninstall cog 2>/dev/null || true && "
	}
	pipInstallLine := "RUN " + uvCacheMount + " " + pipPrefix + uvPip + " install --no-cache " + url
	if g.strip {
		pipInstallLine += " && " + StripDebugSymbolsCommand
	}

	envLines = append(envLines, CFlags, pipInstallLine, "ENV CFLAGS=")
	return strings.Join(envLines, "\n"), nil
}

// installWheelFromFile installs a wheel from a local file (when COG_SDK_WHEEL=/path/to/file.whl)
func (g *StandardGenerator) installWheelFromFile(path string) (string, error) {
	// Read the local wheel file
	data, err := os.ReadFile(path)
	if err != nil {
		return "", fmt.Errorf("failed to read wheel file %s: %w", path, err)
	}

	filename := filepath.Base(path)
	lines, containerPath, err := g.writeTemp(filename, data)
	if err != nil {
		return "", err
	}

	// Set coglet env vars if this looks like a coglet wheel
	var pipPrefix string
	if strings.Contains(filename, "coglet") {
		if !CheckMajorMinorOnly(g.Config.Build.PythonVersion) {
			return "", fmt.Errorf("Python version must be . for coglet")
		}
		lines = append(lines,
			"ENV R8_COG_VERSION=coglet",
			"ENV R8_PYTHON_VERSION="+g.Config.Build.PythonVersion,
		)
		// Uninstall cog first to avoid conflicts with coglet's cog shim package.
		// Some base images (e.g. r8.im/cog-base) have cog pre-installed, which conflicts
		// with coglet's cog compatibility shim that provides the same module paths.
		pipPrefix = uvPip + " uninstall cog 2>/dev/null || true && "
	}

	pipInstallLine := "RUN " + uvCacheMount + " " + pipPrefix + uvPip + " install --no-cache " + containerPath
	if g.strip {
		pipInstallLine += " && " + StripDebugSymbolsCommand
	}
	lines = append(lines, CFlags, pipInstallLine, "ENV CFLAGS=")
	return strings.Join(lines, "\n"), nil
}

// installCogletWheel installs the coglet wheel based on the provided config.
// preRelease adds --pre to allow pip to resolve pre-release packages.
func (g *StandardGenerator) installCogletWheel(config *wheels.WheelConfig, preRelease bool) (string, error) {
	switch config.Source {
	case wheels.WheelSourcePyPI:
		return g.installCogletFromPyPI(config, preRelease)
	case wheels.WheelSourceURL:
		return g.installCogletFromURL(config.URL)
	case wheels.WheelSourceFile:
		return g.installCogletFromFile(config.Path)
	default:
		return "", fmt.Errorf("unknown coglet wheel source: %v", config.Source)
	}
}

// installCogletFromPyPI installs coglet from PyPI.
// preRelease adds --pre to allow pip to resolve pre-release packages.
func (g *StandardGenerator) installCogletFromPyPI(config *wheels.WheelConfig, preRelease bool) (string, error) {
	packageSpec := config.PyPIPackageURL("coglet")
	flags := "--no-cache"
	if preRelease {
		flags = "--pre " + flags
	}
	pipInstallLine := "RUN " + uvCacheMount + " " + uvPip + " install " + flags + " " + packageSpec
	if g.strip {
		pipInstallLine += " && " + StripDebugSymbolsCommand
	}
	lines := []string{CFlags, pipInstallLine, "ENV CFLAGS="}
	return strings.Join(lines, "\n"), nil
}

// installCogletFromURL installs coglet from a URL
func (g *StandardGenerator) installCogletFromURL(url string) (string, error) {
	pipInstallLine := "RUN " + uvCacheMount + " " + uvPip + " install --no-cache " + url
	if g.strip {
		pipInstallLine += " && " + StripDebugSymbolsCommand
	}
	lines := []string{CFlags, pipInstallLine, "ENV CFLAGS="}
	return strings.Join(lines, "\n"), nil
}

// installCogletFromFile installs coglet from a local wheel file
func (g *StandardGenerator) installCogletFromFile(path string) (string, error) {
	data, err := os.ReadFile(path)
	if err != nil {
		return "", fmt.Errorf("failed to read coglet wheel %s: %w", path, err)
	}

	filename := filepath.Base(path)
	lines, containerPath, err := g.writeTemp(filename, data)
	if err != nil {
		return "", err
	}

	pipInstallLine := "RUN " + uvCacheMount + " " + uvPip + " install --no-cache " + containerPath
	if g.strip {
		pipInstallLine += " && " + StripDebugSymbolsCommand
	}

	lines = append(lines, CFlags, pipInstallLine, "ENV CFLAGS=")
	return strings.Join(lines, "\n"), nil
}

// filterManagedPackages strips cog and coglet from user requirements content,
// warning loudly for each occurrence. requirements.txt is not the intended
// mechanism for controlling cog/coglet versions — use build.sdk_version in
// cog.yaml or the COG_SDK_WHEEL / COGLET_WHEEL environment variables instead.
func (g *StandardGenerator) filterManagedPackages(reqContents string) string {
	// Build a human-readable description of what the build system will install
	// for each managed package, so the warning is actionable.
	override := func(pkg string) string {
		var cfg *wheels.WheelConfig
		if pkg == "cog" {
			cfg = g.resolvedCogConfig
		} else {
			cfg = g.resolvedCogletConfig
		}
		if cfg == nil {
			return "latest from PyPI"
		}
		switch cfg.Source {
		case wheels.WheelSourcePyPI:
			if cfg.Version != "" {
				return fmt.Sprintf("%s==%s from PyPI", pkg, cfg.Version)
			}
			return "latest " + pkg + " from PyPI"
		case wheels.WheelSourceURL:
			return cfg.URL
		case wheels.WheelSourceFile:
			return cfg.Path
		default:
			return "unknown source"
		}
	}

	managed := map[string]bool{"cog": true, "coglet": true}
	var filtered []string
	for line := range strings.SplitSeq(reqContents, "\n") {
		trimmed := strings.TrimSpace(line)
		if trimmed == "" || strings.HasPrefix(trimmed, "#") || strings.HasPrefix(trimmed, "-") {
			filtered = append(filtered, line)
			continue
		}
		pkgName := requirements.PackageName(trimmed)
		baseName := strings.ToLower(strings.Split(pkgName, "[")[0])
		if managed[baseName] {
			console.Warnf(
				"'%s' found in requirements — overriding with %s. "+
					"Remove it from requirements and use build.sdk_version in cog.yaml or %s to control the version.",
				trimmed,
				override(baseName),
				map[string]string{"cog": "COG_SDK_WHEEL", "coglet": "COGLET_WHEEL"}[baseName],
			)
			continue
		}
		filtered = append(filtered, line)
	}
	return strings.Join(filtered, "\n")
}

func (g *StandardGenerator) pipInstalls() (string, error) {
	// Resolve wheel configs early so filterManagedPackages can emit precise warnings.
	if g.requiresCog {
		if err := g.resolveCogWheelConfigs(); err != nil {
			return "", err
		}
	}

	var err error
	includePackages := []string{}
	if torchVersion, ok := g.Config.TorchVersion(); ok {
		includePackages = []string{"torch==" + torchVersion}
	}
	if torchvisionVersion, ok := g.Config.TorchvisionVersion(); ok {
		includePackages = append(includePackages, "torchvision=="+torchvisionVersion)
	}
	if torchaudioVersion, ok := g.Config.TorchaudioVersion(); ok {
		includePackages = append(includePackages, "torchaudio=="+torchaudioVersion)
	}
	if tensorflowVersion, ok := g.Config.TensorFlowVersion(); ok {
		includePackages = append(includePackages, "tensorflow=="+tensorflowVersion)
	}
	g.pythonRequirementsContents, err = g.Config.PythonRequirementsForArch(g.GOOS, g.GOARCH, includePackages)
	if err != nil {
		return "", err
	}

	// Strip cog/coglet from user requirements — we always install them ourselves
	// via installCog(). Leaving them in would cause pip to overwrite our version.
	g.pythonRequirementsContents = g.filterManagedPackages(g.pythonRequirementsContents)

	if strings.Trim(g.pythonRequirementsContents, "") == "" {
		return "", nil
	}

	console.Debugf("Generated requirements.txt:\n%s", g.pythonRequirementsContents)
	copyLine, containerPath, err := g.writeTemp("requirements.txt", []byte(g.pythonRequirementsContents))
	if err != nil {
		return "", err
	}

	pipInstallLine := "RUN --mount=type=cache,target=/root/.cache/pip uv run pip install --cache-dir /root/.cache/pip -r " + containerPath
	if g.strip {
		pipInstallLine += " && " + StripDebugSymbolsCommand
	}
	return strings.Join([]string{
		copyLine[0],
		CFlags,
		pipInstallLine,
		"ENV CFLAGS=",
	}, "\n"), nil
}

func (g *StandardGenerator) runCommands() (string, error) {
	runCommands := g.Config.Build.Run

	// For backwards compatibility
	for _, command := range g.Config.Build.PreInstall {
		runCommands = append(runCommands, config.RunItem{Command: command})
	}

	lines := []string{}
	for _, run := range runCommands {
		command := strings.TrimSpace(run.Command)
		if strings.Contains(command, "\n") {
			return "", fmt.Errorf(`One of the commands in 'run' contains a new line, which won't work. You need to create a new list item in YAML prefixed with '-' for each command.

This is the offending line: %s`, command)
		}

		if len(run.Mounts) > 0 {
			mounts := []string{}
			for _, mount := range run.Mounts {
				if mount.Type == "secret" {
					secretMount := fmt.Sprintf("--mount=type=secret,id=%s,target=%s", mount.ID, mount.Target)
					mounts = append(mounts, secretMount)
				}
			}
			lines = append(lines, fmt.Sprintf("RUN %s %s", strings.Join(mounts, " "), command))
		} else {
			lines = append(lines, "RUN "+command)
		}
	}
	return strings.Join(lines, "\n"), nil
}

func (g *StandardGenerator) envVars() (string, error) {
	return envLineFromConfig(g.Config)
}

// installCACert generates Dockerfile lines to install a custom CA certificate.
// If COG_CA_CERT is not set, returns empty string (no-op).
func (g *StandardGenerator) installCACert() (string, error) {
	certData, err := ReadCACert()
	if err != nil {
		return "", err
	}
	if certData == nil {
		return "", nil
	}
	return GenerateCACertInstall(certData, g.writeTemp)
}

// writeTemp writes a temporary file that can be used as part of the build process
// It returns the lines to add to Dockerfile to make it available and the filename it ends up as inside the container
func (g *StandardGenerator) writeTemp(filename string, contents []byte) ([]string, string, error) {
	path := filepath.Join(g.tmpDir, filename)
	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
		return []string{}, "", fmt.Errorf("Failed to write %s: %w", filename, err)
	}
	if err := os.WriteFile(path, contents, 0o644); err != nil {
		return []string{}, "", fmt.Errorf("Failed to write %s: %w", filename, err)
	}
	return []string{fmt.Sprintf("COPY %s /tmp/%s", filepath.Join(g.relativeTmpDir, filename), filename)}, "/tmp/" + filename, nil
}

func joinStringsWithoutLineSpace(chunks []string) string {
	lines := []string{}
	for _, chunk := range chunks {
		chunkLines := strings.Split(chunk, "\n")
		lines = append(lines, chunkLines...)
	}
	return strings.Join(filterEmpty(lines), "\n")
}

func filterEmpty(list []string) []string {
	filtered := []string{}
	for _, s := range list {
		if s != "" {
			filtered = append(filtered, s)
		}
	}
	return filtered
}

func (g *StandardGenerator) GenerateWeightsManifest(ctx context.Context) (*weights.Manifest, error) {
	m := weights.NewManifest()

	for _, dir := range g.modelDirs {
		err := g.fileWalker(dir, func(path string, info os.FileInfo, err error) error {
			if err != nil {
				return err
			}
			if info.IsDir() {
				return nil
			}

			return m.AddFile(path)
		})
		if err != nil {
			return nil, err
		}
	}

	for _, path := range g.modelFiles {
		err := m.AddFile(path)
		if err != nil {
			return nil, err
		}
	}

	return m, nil
}

func (g *StandardGenerator) determineBaseImageName(ctx context.Context) (string, error) {
	var changed bool
	var err error

	cudaVersion := g.Config.Build.CUDA

	pythonVersion := g.Config.Build.PythonVersion
	pythonVersion, changed, err = stripPatchVersion(pythonVersion)
	if err != nil {
		return "", err
	}
	if changed {
		console.Warnf("Stripping patch version from Python version %s to %s", g.Config.Build.PythonVersion, pythonVersion)
	}

	torchVersion, _ := g.Config.TorchVersion()

	// validate that the base image configuration exists
	imageGenerator, err := NewBaseImageGenerator(ctx, g.client, cudaVersion, pythonVersion, torchVersion, g.command, false)
	if err != nil {
		return "", err
	}
	baseImage := BaseImageName(imageGenerator.cudaVersion, imageGenerator.pythonVersion, imageGenerator.torchVersion)
	return baseImage, nil
}

func stripPatchVersion(versionString string) (string, bool, error) {
	if versionString == "" {
		return "", false, nil
	}

	v, err := version.NewVersion(versionString)
	if err != nil {
		return "", false, fmt.Errorf("Invalid version: %s", versionString)
	}

	strippedVersion := fmt.Sprintf("%d.%d", v.Major, v.Minor)
	changed := strippedVersion != versionString

	return strippedVersion, changed, nil
}


================================================
FILE: pkg/dockerfile/standard_generator_test.go
================================================
package dockerfile

import (
	"fmt"
	"os"
	"path"
	"path/filepath"
	"testing"
	"time"

	"github.com/stretchr/testify/require"

	"github.com/replicate/cog/pkg/config"
	"github.com/replicate/cog/pkg/docker/dockertest"
	"github.com/replicate/cog/pkg/registry/registrytest"
	"github.com/replicate/cog/pkg/wheels"
)

func testTini() string {
	return `RUN --mount=type=cache,target=/var/cache/apt,sharing=locked set -eux; \
apt-get update -qq && \
apt-get install -qqy --no-install-recommends curl; \
rm -rf /var/lib/apt/lists/*; \
TINI_VERSION=v0.19.0; \
TINI_ARCH="$(dpkg --print-architecture)"; \
curl -sSL -o /sbin/tini "https://github.com/krallin/tini/releases/download/${TINI_VERSION}/tini-${TINI_ARCH}"; \
chmod +x /sbin/tini
ENTRYPOINT ["/sbin/tini", "--"]
`
}

var testInstallUVLine = "COPY --from=ghcr.io/astral-sh/uv:" + UVVersion + " /uv /uvx /usr/local/bin/\nENV UV_SYSTEM_PYTHON=true"

func testInstallCog(stripped bool) string {
	strippedCall := ""
	if stripped {
		strippedCall += " && find / -type f -name \"*python*.so\" -not -name \"*cpython*.so\" -exec strip -S {} \\;"
	}
	// When coglet has no explicit version pin (empty version via pypiWheels()),
	// the SDK's own dependency handles coglet installation — no explicit coglet line.
	return fmt.Sprintf(`ENV CFLAGS="-O3 -funroll-loops -fno-strict-aliasing -flto -S"
RUN --mount=type=cache,target=/root/.cache/uv uv pip install --no-cache cog%s
ENV CFLAGS=`, strippedCall)
}

// pypiWheels sets the generator to use unpinned PyPI for both cog and coglet,
// giving deterministic Dockerfile output regardless of local dist/ contents.
func pypiWheels(gen *StandardGenerator) {
	gen.cogWheelConfig = &wheels.WheelConfig{Source: wheels.WheelSourcePyPI}
	gen.cogletWheelConfig = &wheels.WheelConfig{Source: wheels.WheelSourcePyPI}
}

func testInstallPython(version string) string {
	return fmt.Sprintf(`RUN --mount=type=cache,target=/var/cache/apt,sharing=locked apt-get update -qq && apt-get install -qqy --no-install-recommends \
	wget \
	curl \
	xz-utils \
	git \
	ca-certificates \
	&& rm -rf /var/lib/apt/lists/*
COPY --from=ghcr.io/astral-sh/uv:`+UVVersion+` /uv /uvx /usr/local/bin/
ENV UV_SYSTEM_PYTHON=true
RUN uv python install %s && \
	ln -sf $(uv python find %s) /usr/bin/python3
ENV UV_PYTHON=%s
ENV PATH="/usr/local/bin:$PATH"
`, version, version, version)
}

func TestGenerateEmptyCPU(t *testing.T) {
	tmpDir := t.TempDir()

	conf, err := config.FromYAML([]byte(`
build:
  gpu: false
  python_version: "3.12"
predict: predict.py:Predictor
`))
	require.NoError(t, err)
	require.NoError(t, conf.Complete(""))
	command := dockertest.NewMockCommand()
	client := registrytest.NewMockRegistryClient()
	gen, err := NewStandardGenerator(conf, tmpDir, "", command, client, true)
	require.NoError(t, err)
	gen.SetUseCogBaseImage(false)
	pypiWheels(gen)
	_, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test")
	require.NoError(t, err)

	expected := `#syntax=docker/dockerfile:1.4
FROM r8.im/replicate/cog-test-weights AS weights
FROM python:3.12-slim
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/x86_64-linux-gnu:/usr/local/nvidia/lib64:/usr/local/nvidia/bin
ENV NVIDIA_DRIVER_CAPABILITIES=all
` + testTini() + testInstallUVLine + "\n" + testInstallCog(gen.strip) + `
RUN find / -type f -name "*python*.so" -printf "%h\n" | sort -u > /etc/ld.so.conf.d/cog.conf && ldconfig
WORKDIR /src
EXPOSE 5000
ENV COG_PREDICT_TYPE_STUB="predict.py:Predictor"
CMD ["python", "-m", "cog.server.http"]
COPY . /src`

	require.Equal(t, expected, actual)
}

func TestGenerateEmptyGPU(t *testing.T) {
	tmpDir := t.TempDir()

	conf, err := config.FromYAML([]byte(`
build:
  gpu: true
  python_version: "3.12"
predict: predict.py:Predictor
`))
	require.NoError(t, err)
	require.NoError(t, conf.Complete(""))
	command := dockertest.NewMockCommand()
	client := registrytest.NewMockRegistryClient()
	gen, err := NewStandardGenerator(conf, tmpDir, "", command, client, true)
	require.NoError(t, err)
	gen.SetUseCogBaseImage(false)
	pypiWheels(gen)
	_, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test")
	require.NoError(t, err)

	expected := `#syntax=docker/dockerfile:1.4
FROM r8.im/replicate/cog-test-weights AS weights
FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/x86_64-linux-gnu:/usr/local/nvidia/lib64:/usr/local/nvidia/bin
ENV NVIDIA_DRIVER_CAPABILITIES=all
` + testTini() + testInstallPython("3.12") + testInstallCog(gen.strip) + `
RUN find / -type f -name "*python*.so" -printf "%h\n" | sort -u > /etc/ld.so.conf.d/cog.conf && ldconfig
WORKDIR /src
EXPOSE 5000
ENV COG_PREDICT_TYPE_STUB="predict.py:Predictor"
CMD ["python", "-m", "cog.server.http"]
COPY . /src`

	require.Equal(t, expected, actual)
}

func TestGenerateFullCPU(t *testing.T) {
	tmpDir := t.TempDir()

	conf, err := config.FromYAML([]byte(`
build:
  gpu: false
  python_version: "3.12"
  system_packages:
    - ffmpeg
    - cowsay
  python_packages:
    - torch==2.3.0
    - pandas==1.2.0.12
  run:
    - "cowsay moo"
predict: predict.py:Predictor
`))
	require.NoError(t, err)
	require.NoError(t, conf.Complete(""))
	command := dockertest.NewMockCommand()
	client := registrytest.NewMockRegistryClient()
	gen, err := NewStandardGenerator(conf, tmpDir, "", command, client, true)
	require.NoError(t, err)
	gen.SetUseCogBaseImage(false)
	pypiWheels(gen)
	_, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test")
	require.NoError(t, err)

	expected := `#syntax=docker/dockerfile:1.4
FROM r8.im/replicate/cog-test-weights AS weights
FROM python:3.12-slim
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/x86_64-linux-gnu:/usr/local/nvidia/lib64:/usr/local/nvidia/bin
ENV NVIDIA_DRIVER_CAPABILITIES=all
` + testTini() + `RUN --mount=type=cache,target=/var/cache/apt,sharing=locked apt-get update -qq && apt-get install -qqy ffmpeg cowsay && rm -rf /var/lib/apt/lists/*
` + testInstallUVLine + `
COPY ` + gen.relativeTmpDir + `/requirements.txt /tmp/requirements.txt
ENV CFLAGS="-O3 -funroll-loops -fno-strict-aliasing -flto -S"
RUN --mount=type=cache,target=/root/.cache/pip uv run pip install --cache-dir /root/.cache/pip -r /tmp/requirements.txt
ENV CFLAGS=
` + testInstallCog(gen.strip) + `
RUN find / -type f -name "*python*.so" -printf "%h\n" | sort -u > /etc/ld.so.conf.d/cog.conf && ldconfig
RUN cowsay moo
WORKDIR /src
EXPOSE 5000
ENV COG_PREDICT_TYPE_STUB="predict.py:Predictor"
CMD ["python", "-m", "cog.server.http"]
COPY . /src`
	require.Equal(t, expected, actual)

	requirements, err := os.ReadFile(path.Join(gen.tmpDir, "requirements.txt"))
	require.NoError(t, err)

	require.Equal(t, `--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.3.0
pandas==1.2.0.12`, string(requirements))
}

func TestGenerateFullGPU(t *testing.T) {
	tmpDir := t.TempDir()

	conf, err := config.FromYAML([]byte(`
build:
  gpu: true
  python_version: "3.12"
  system_packages:
    - ffmpeg
    - cowsay
  python_packages:
    - torch==2.0.1
    - pandas==2.0.3
  run:
    - "cowsay moo"
predict: predict.py:Predictor
`))
	require.NoError(t, err)
	require.NoError(t, conf.Complete(""))
	command := dockertest.NewMockCommand()
	client := registrytest.NewMockRegistryClient()
	gen, err := NewStandardGenerator(conf, tmpDir, "", command, client, true)
	require.NoError(t, err)
	gen.SetUseCogBaseImage(false)
	pypiWheels(gen)
	_, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test")
	require.NoError(t, err)

	expected := `#syntax=docker/dockerfile:1.4
FROM r8.im/replicate/cog-test-weights AS weights
FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/x86_64-linux-gnu:/usr/local/nvidia/lib64:/usr/local/nvidia/bin
ENV NVIDIA_DRIVER_CAPABILITIES=all
` + testTini() + `RUN --mount=type=cache,target=/var/cache/apt,sharing=locked apt-get update -qq && apt-get install -qqy ffmpeg cowsay && rm -rf /var/lib/apt/lists/*
` + testInstallPython("3.12") + `COPY ` + gen.relativeTmpDir + `/requirements.txt /tmp/requirements.txt
ENV CFLAGS="-O3 -funroll-loops -fno-strict-aliasing -flto -S"
RUN --mount=type=cache,target=/root/.cache/pip uv run pip install --cache-dir /root/.cache/pip -r /tmp/requirements.txt
ENV CFLAGS=
` + testInstallCog(gen.strip) + `
RUN find / -type f -name "*python*.so" -printf "%h\n" | sort -u > /etc/ld.so.conf.d/cog.conf && ldconfig
RUN cowsay moo
WORKDIR /src
EXPOSE 5000
ENV COG_PREDICT_TYPE_STUB="predict.py:Predictor"
CMD ["python", "-m", "cog.server.http"]
COPY . /src`

	require.Equal(t, expected, actual)

	requirements, err := os.ReadFile(path.Join(gen.tmpDir, "requirements.txt"))
	require.NoError(t, err)
	require.Equal(t, `--extra-index-url https://download.pytorch.org/whl/cu118
torch==2.0.1
pandas==2.0.3`, string(requirements))
}

// pre_install is deprecated but supported for backwards compatibility
func TestPreInstall(t *testing.T) {
	tmpDir := t.TempDir()

	conf, err := config.FromYAML([]byte(`
build:
  python_version: "3.12"
  system_packages:
    - cowsay
  pre_install:
    - "cowsay moo"
`))
	require.NoError(t, err)
	require.NoError(t, conf.Complete(""))
	command := dockertest.NewMockCommand()
	client := registrytest.NewMockRegistryClient()
	gen, err := NewStandardGenerator(conf, tmpDir, "", command, client, true)
	require.NoError(t, err)
	gen.SetUseCogBaseImage(false)
	pypiWheels(gen)
	_, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test")
	require.NoError(t, err)

	expected := `#syntax=docker/dockerfile:1.4
FROM r8.im/replicate/cog-test-weights AS weights
FROM python:3.12-slim
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/x86_64-linux-gnu:/usr/local/nvidia/lib64:/usr/local/nvidia/bin
ENV NVIDIA_DRIVER_CAPABILITIES=all
` + testTini() + `RUN --mount=type=cache,target=/var/cache/apt,sharing=locked apt-get update -qq && apt-get install -qqy cowsay && rm -rf /var/lib/apt/lists/*
` + testInstallUVLine + `
` + testInstallCog(gen.strip) + `
RUN find / -type f -name "*python*.so" -printf "%h\n" | sort -u > /etc/ld.so.conf.d/cog.conf && ldconfig
RUN cowsay moo
WORKDIR /src
EXPOSE 5000
CMD ["python", "-m", "cog.server.http"]
COPY . /src`
	require.Equal(t, expected, actual)

}

func TestPythonRequirements(t *testing.T) {
	tmpDir := t.TempDir()
	err := os.WriteFile(path.Join(tmpDir, "my-requirements.txt"), []byte("torch==1.0.0"), 0o644)
	require.NoError(t, err)
	conf, err := config.FromYAML([]byte(`
build:
  python_version: "3.12"
  python_requirements: "my-requirements.txt"
`))
	require.NoError(t, err)
	require.NoError(t, conf.Complete(tmpDir))
	command := dockertest.NewMockCommand()
	client := registrytest.NewMockRegistryClient()
	gen, err := NewStandardGenerator(conf, tmpDir, "", command, client, true)
	require.NoError(t, err)
	gen.SetUseCogBaseImage(false)
	pypiWheels(gen)
	_, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test")
	require.NoError(t, err)
	fmt.Println(actual)
	require.Contains(t, actual, `uv run pip install --cache-dir /root/.cache/pip -r /tmp/requirements.txt`)
}

// mockFileInfo is a test type to mock os.FileInfo
type mockFileInfo struct {
	size int64
}

func (mfi mockFileInfo) Size() int64 {
	return mfi.size
}
func (mfi mockFileInfo) Name() string {
	return ""
}
func (mfi mockFileInfo) Mode() os.FileMode {
	return 0
}
func (mfi mockFileInfo) ModTime() time.Time {
	return time.Time{}
}
func (mfi mockFileInfo) IsDir() bool {
	return false
}
func (mfi mockFileInfo) Sys() any {
	return nil
}

const sizeThreshold = 10 * 1024 * 1024

func TestGenerateWithLargeModels(t *testing.T) {
	tmpDir := t.TempDir()

	conf, err := config.FromYAML([]byte(`
build:
  gpu: true
  python_version: "3.12"
  system_packages:
    - ffmpeg
    - cowsay
  python_packages:
    - torch==2.0.1
    - pandas==2.0.3
  run:
    - "cowsay moo"
predict: predict.py:Predictor
`))
	require.NoError(t, err)
	require.NoError(t, conf.Complete(""))
	command := dockertest.NewMockCommand()
	client := registrytest.NewMockRegistryClient()
	gen, err := NewStandardGenerator(conf, tmpDir, "", command, client, true)
	require.NoError(t, err)
	gen.SetUseCogBaseImage(false)

	gen.fileWalker = func(root string, walkFn filepath.WalkFunc) error {
		for _, path := range []string{"checkpoints/large-a", "models/large-b", "root-large"} {
			walkFn(path, mockFileInfo{size: sizeThreshold}, nil)
		}
		return nil
	}
	pypiWheels(gen)

	modelDockerfile, runnerDockerfile, dockerignore, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test")
	require.NoError(t, err)

	expected := `#syntax=docker/dockerfile:1.4
FROM scratch

COPY checkpoints /src/checkpoints
COPY models /src/models
COPY root-large /src/root-large`

	require.Equal(t, expected, modelDockerfile)

	// model copy should be run before dependency install and code copy
	expected = `#syntax=docker/dockerfile:1.4
FROM r8.im/replicate/cog-test-weights AS weights
FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/x86_64-linux-gnu:/usr/local/nvidia/lib64:/usr/local/nvidia/bin
ENV NVIDIA_DRIVER_CAPABILITIES=all
` + testTini() + `RUN --mount=type=cache,target=/var/cache/apt,sharing=locked apt-get update -qq && apt-get install -qqy ffmpeg cowsay && rm -rf /var/lib/apt/lists/*` + `
` + testInstallPython("3.12") + `COPY ` + gen.relativeTmpDir + `/requirements.txt /tmp/requirements.txt
ENV CFLAGS="-O3 -funroll-loops -fno-strict-aliasing -flto -S"
RUN --mount=type=cache,target=/root/.cache/pip uv run pip install --cache-dir /root/.cache/pip -r /tmp/requirements.txt
ENV CFLAGS=
` + testInstallCog(gen.strip) + `
RUN find / -type f -name "*python*.so" -printf "%h\n" | sort -u > /etc/ld.so.conf.d/cog.conf && ldconfig
RUN cowsay moo
COPY --from=weights --link /src/checkpoints /src/checkpoints
COPY --from=weights --link /src/models /src/models
COPY --from=weights --link /src/root-large /src/root-large
WORKDIR /src
EXPOSE 5000
ENV COG_PREDICT_TYPE_STUB="predict.py:Predictor"
CMD ["python", "-m", "cog.server.http"]
COPY . /src`

	require.Equal(t, expected, runnerDockerfile)

	requirements, err := os.ReadFile(path.Join(gen.tmpDir, "requirements.txt"))
	require.NoError(t, err)
	require.Equal(t, `--extra-index-url https://download.pytorch.org/whl/cu118
torch==2.0.1
pandas==2.0.3`, string(requirements))

	expected = `# generated by replicate/cog
__pycache__
*.pyc
*.pyo
*.pyd
.Python
env
pip-log.txt
pip-delete-this-directory.txt
.tox
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.log
.git
.mypy_cache
.pytest_cache
.hypothesis
checkpoints
checkpoints/**/*
models
models/**/*
root-large
`
	require.Equal(t, expected, dockerignore)
}

func TestGenerateDockerfileWithoutSeparateWeights(t *testing.T) {
	tmpDir := t.TempDir()

	conf, err := config.FromYAML([]byte(`
build:
  gpu: false
  python_version: "3.12"
predict: predict.py:Predictor
`))
	require.NoError(t, err)
	require.NoError(t, conf.Complete(""))
	command := dockertest.NewMockCommand()
	client := registrytest.NewMockRegistryClient()
	gen, err := NewStandardGenerator(conf, tmpDir, "", command, client, true)
	require.NoError(t, err)
	gen.SetUseCogBaseImage(false)
	pypiWheels(gen)
	actual, err := gen.GenerateDockerfileWithoutSeparateWeights(t.Context())
	require.NoError(t, err)

	expected := `#syntax=docker/dockerfile:1.4
FROM python:3.12-slim
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/x86_64-linux-gnu:/usr/local/nvidia/lib64:/usr/local/nvidia/bin
ENV NVIDIA_DRIVER_CAPABILITIES=all
` + testTini() + testInstallUVLine + "\n" + testInstallCog(gen.strip) + `
RUN find / -type f -name "*python*.so" -printf "%h\n" | sort -u > /etc/ld.so.conf.d/cog.conf && ldconfig
WORKDIR /src
EXPOSE 5000
ENV COG_PREDICT_TYPE_STUB="predict.py:Predictor"
CMD ["python", "-m", "cog.server.http"]
COPY . /src`

	require.Equal(t, expected, actual)
}

func TestGenerateEmptyCPUWithCogBaseImage(t *testing.T) {
	tmpDir := t.TempDir()
	conf, err := config.FromYAML([]byte(`
build:
  gpu: false
  python_version: "3.12"
predict: predict.py:Predictor
`))
	require.NoError(t, err)
	require.NoError(t, conf.Complete(""))
	command := dockertest.NewMockCommand()
	client := registrytest.NewMockRegistryClient()
	client.AddMockImage(BaseImageName("", "3.12", ""))
	gen, err := NewStandardGenerator(conf, tmpDir, "", command, client, true)
	require.NoError(t, err)
	gen.SetUseCogBaseImage(true)
	pypiWheels(gen)
	_, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test")
	require.NoError(t, err)

	expected := `#syntax=docker/dockerfile:1.4
FROM r8.im/replicate/cog-test-weights AS weights
FROM r8.im/cog-base:python3.12
` + testInstallUVLine + `
` + testInstallCog(gen.strip) + `
WORKDIR /src
EXPOSE 5000
ENV COG_PREDICT_TYPE_STUB="predict.py:Predictor"
CMD ["python", "-m", "cog.server.http"]
COPY . /src`

	require.Equal(t, expected, actual)
}

func TestGeneratePythonCPUWithCogBaseImage(t *testing.T) {
	tmpDir := t.TempDir()

	conf, err := config.FromYAML([]byte(`
build:
  gpu: false
  python_version: "3.12"
  system_packages:
    - ffmpeg
    - cowsay
  python_packages:
    - pandas==1.2.0.12
  run:
    - "cowsay moo"
predict: predict.py:Predictor
`))
	require.NoError(t, err)
	require.NoError(t, conf.Complete(""))
	command := dockertest.NewMockCommand()
	client := registrytest.NewMockRegistryClient()
	client.AddMockImage(BaseImageName("", "3.12", ""))
	gen, err := NewStandardGenerator(conf, tmpDir, "", command, client, true)
	require.NoError(t, err)
	gen.SetUseCogBaseImage(true)
	pypiWheels(gen)
	_, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test")
	require.NoError(t, err)

	expected := `#syntax=docker/dockerfile:1.4
FROM r8.im/replicate/cog-test-weights AS weights
FROM r8.im/cog-base:python3.12
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked apt-get update -qq && apt-get install -qqy cowsay && rm -rf /var/lib/apt/lists/*
` + testInstallUVLine + `
` + testInstallCog(gen.strip) + `
COPY ` + gen.relativeTmpDir + `/requirements.txt /tmp/requirements.txt
ENV CFLAGS="-O3 -funroll-loops -fno-strict-aliasing -flto -S"
RUN --mount=type=cache,target=/root/.cache/pip uv run pip install --cache-dir /root/.cache/pip -r /tmp/requirements.txt
ENV CFLAGS=
RUN cowsay moo
WORKDIR /src
EXPOSE 5000
ENV COG_PREDICT_TYPE_STUB="predict.py:Predictor"
CMD ["python", "-m", "cog.server.http"]
COPY . /src`
	require.Equal(t, expected, actual)

	requirements, err := os.ReadFile(path.Join(gen.tmpDir, "requirements.txt"))
	require.NoError(t, err)
	require.Equal(t, `pandas==1.2.0.12`, string(requirements))
}

func TestGenerateFullGPUWithCogBaseImage(t *testing.T) {
	tmpDir := t.TempDir()
	client := registrytest.NewMockRegistryClient()
	command := dockertest.NewMockCommand()
	torchVersions := []string{"2.3", "2.3.0", "2.3.1"}
	for _, torchVersion := range torchVersions {
		yaml := fmt.Sprintf(`
build:
  gpu: true
  cuda: "11.8"
  python_version: "3.11"
  system_packages:
    - ffmpeg
    - cowsay
  python_packages:
    - torch==%s
    - pandas==2.0.3
  run:
    - "cowsay moo"
predict: predict.py:Predictor
`, torchVersion)
		conf, err := config.FromYAML([]byte(yaml))
		require.NoError(t, err)
		require.NoError(t, conf.Complete(""))
		client.AddMockImage(BaseImageName("11.8", "3.11", torchVersion))
		gen, err := NewStandardGenerator(conf, tmpDir, "", command, client, true)
		require.NoError(t, err)
		gen.SetUseCogBaseImage(true)
		pypiWheels(gen)
		_, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test")
		require.NoError(t, err)

		// We add the patch version to the expected torch version
		expectedTorchVersion := torchVersion
		if torchVersion == "2.3" {
			expectedTorchVersion = "2.3.1"
		}
		expected := fmt.Sprintf(`#syntax=docker/dockerfile:1.4
FROM r8.im/replicate/cog-test-weights AS weights
FROM r8.im/cog-base:cuda11.8-python3.11-torch%s
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked apt-get update -qq && apt-get install -qqy cowsay && rm -rf /var/lib/apt/lists/*
`+testInstallUVLine+`
`+testInstallCog(gen.strip)+`
COPY `+gen.relativeTmpDir+`/requirements.txt /tmp/requirements.txt
ENV CFLAGS="-O3 -funroll-loops -fno-strict-aliasing -flto -S"
RUN --mount=type=cache,target=/root/.cache/pip uv run pip install --cache-dir /root/.cache/pip -r /tmp/requirements.txt
ENV CFLAGS=
RUN cowsay moo
WORKDIR /src
EXPOSE 5000
ENV COG_PREDICT_TYPE_STUB="predict.py:Predictor"
CMD ["python", "-m", "cog.server.http"]
COPY . /src`, expectedTorchVersion)

		require.Equal(t, expected, actual)

		requirements, err := os.ReadFile(path.Join(gen.tmpDir, "requirements.txt"))
		require.NoError(t, err)
		expected = fmt.Sprintf(`--extra-index-url https://download.pytorch.org/whl/cu118
torch==%s
pandas==2.0.3`, expectedTorchVersion)
		require.Equal(t, expected, string(requirements))
	}
}

func TestGenerateTorchWithStrippedModifiedVersion(t *testing.T) {
	tmpDir := t.TempDir()

	yaml := `
build:
  gpu: true
  cuda: "11.8"
  python_version: "3.12"
  system_packages:
    - ffmpeg
    - cowsay
  python_packages:
    - torch==2.3.1+cu118
    - pandas==2.0.3
  run:
    - "cowsay moo"
predict: predict.py:Predictor
`
	conf, err := config.FromYAML([]byte(yaml))
	require.NoError(t, err)
	require.NoError(t, conf.Complete(""))
	command := dockertest.NewMockCommand()
	client := registrytest.NewMockRegistryClient()
	client.AddMockImage(BaseImageName("11.8", "3.12", "2.3.1"))
	gen, err := NewStandardGenerator(conf, tmpDir, "", command, client, true)
	require.NoError(t, err)
	gen.SetUseCogBaseImage(true)
	pypiWheels(gen)
	_, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test")
	require.NoError(t, err)

	expected := `#syntax=docker/dockerfile:1.4
FROM r8.im/replicate/cog-test-weights AS weights
FROM r8.im/cog-base:cuda11.8-python3.12-torch2.3.1
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked apt-get update -qq && apt-get install -qqy cowsay && rm -rf /var/lib/apt/lists/*
` + testInstallUVLine + `
` + testInstallCog(gen.strip) + `
COPY ` + gen.relativeTmpDir + `/requirements.txt /tmp/requirements.txt
ENV CFLAGS="-O3 -funroll-loops -fno-strict-aliasing -flto -S"
RUN --mount=type=cache,target=/root/.cache/pip uv run pip install --cache-dir /root/.cache/pip -r /tmp/requirements.txt
ENV CFLAGS=
RUN cowsay moo
WORKDIR /src
EXPOSE 5000
ENV COG_PREDICT_TYPE_STUB="predict.py:Predictor"
CMD ["python", "-m", "cog.server.http"]
COPY . /src`

	require.Equal(t, expected, actual)

	requirements, err := os.ReadFile(path.Join(gen.tmpDir, "requirements.txt"))
	require.NoError(t, err)
	require.Equal(t, `--extra-index-url https://download.pytorch.org/whl/cu118
torch==2.3.1
pandas==2.0.3`, string(requirements))
}

func TestGenerateWithStrip(t *testing.T) {
	tmpDir := t.TempDir()

	yaml := `
build:
  gpu: true
  cuda: "11.8"
  python_version: "3.12"
  system_packages:
    - ffmpeg
    - cowsay
  python_packages:
    - torch==2.3.1
    - pandas==2.0.3
  run:
    - "cowsay moo"
predict: predict.py:Predictor
`
	conf, err := config.FromYAML([]byte(yaml))
	require.NoError(t, err)
	require.NoError(t, conf.Complete(""))
	command := dockertest.NewMockCommand()
	client := registrytest.NewMockRegistryClient()
	client.AddMockImage(BaseImageName("11.8", "3.12", "2.3.1"))
	gen, err := NewStandardGenerator(conf, tmpDir, "", command, client, true)
	require.NoError(t, err)
	gen.SetUseCogBaseImage(true)
	gen.SetStrip(true)
	pypiWheels(gen)
	_, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test")
	require.NoError(t, err)

	expected := `#syntax=docker/dockerfile:1.4
FROM r8.im/replicate/cog-test-weights AS weights
FROM r8.im/cog-base:cuda11.8-python3.12-torch2.3.1
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked apt-get update -qq && apt-get install -qqy cowsay && rm -rf /var/lib/apt/lists/*
` + testInstallUVLine + `
` + testInstallCog(gen.strip) + `
COPY ` + gen.relativeTmpDir + `/requirements.txt /tmp/requirements.txt
ENV CFLAGS="-O3 -funroll-loops -fno-strict-aliasing -flto -S"
RUN --mount=type=cache,target=/root/.cache/pip uv run pip install --cache-dir /root/.cache/pip -r /tmp/requirements.txt && find / -type f -name "*python*.so" -not -name "*cpython*.so" -exec strip -S {} \;
ENV CFLAGS=
RUN cowsay moo
WORKDIR /src
EXPOSE 5000
ENV COG_PREDICT_TYPE_STUB="predict.py:Predictor"
CMD ["python", "-m", "cog.server.http"]
COPY . /src`

	require.Equal(t, expected, actual)

	requirements, err := os.ReadFile(path.Join(gen.tmpDir, "requirements.txt"))
	require.NoError(t, err)
	require.Equal(t, `--extra-index-url https://download.pytorch.org/whl/cu118
torch==2.3.1
pandas==2.0.3`, string(requirements))
}

func TestGenerateDoesNotContainDangerousCFlags(t *testing.T) {
	tmpDir := t.TempDir()

	yaml := `
build:
  gpu: true
  cuda: "11.8"
  python_version: "3.12"
  system_packages:
    - ffmpeg
    - cowsay
  python_packages:
    - torch==2.3.1
    - pandas==2.0.3
  run:
    - "cowsay moo"
predict: predict.py:Predictor
`
	conf, err := config.FromYAML([]byte(yaml))
	require.NoError(t, err)
	require.NoError(t, conf.Complete(""))
	command := dockertest.NewMockCommand()
	client := registrytest.NewMockRegistryClient()
	client.AddMockImage(BaseImageName("11.8", "3.12", "2.3.1"))
	gen, err := NewStandardGenerator(conf, tmpDir, "", command, client, true)
	require.NoError(t, err)
	gen.SetUseCogBaseImage(true)
	pypiWheels(gen)
	_, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test")
	require.NoError(t, err)

	require.NotContains(t, actual, "-march=native")
	require.NotContains(t, actual, "-mtune=native")
}

func TestGenerateWithPrecompile(t *testing.T) {
	tmpDir := t.TempDir()

	yaml := `
build:
  gpu: true
  cuda: "11.8"
  python_version: "3.12"
  system_packages:
    - ffmpeg
    - cowsay
  python_packages:
    - torch==2.3.1
    - pandas==2.0.3
  run:
    - "cowsay moo"
predict: predict.py:Predictor
`
	conf, err := config.FromYAML([]byte(yaml))
	require.NoError(t, err)
	require.NoError(t, conf.Complete(""))
	command := dockertest.NewMockCommand()
	client := registrytest.NewMockRegistryClient()
	client.AddMockImage(BaseImageName("11.8", "3.12", "2.3.1"))
	gen, err := NewStandardGenerator(conf, tmpDir, "", command, client, true)
	require.NoError(t, err)
	gen.SetUseCogBaseImage(true)
	gen.SetStrip(true)
	gen.SetPrecompile(true)
	pypiWheels(gen)
	_, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test")
	require.NoError(t, err)

	expected := `#syntax=docker/dockerfile:1.4
FROM r8.im/replicate/cog-test-weights AS weights
FROM r8.im/cog-base:cuda11.8-python3.12-torch2.3.1
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked apt-get update -qq && apt-get install -qqy cowsay && rm -rf /var/lib/apt/lists/*
` + testInstallUVLine + `
` + testInstallCog(gen.strip) + `
COPY ` + gen.relativeTmpDir + `/requirements.txt /tmp/requirements.txt
ENV CFLAGS="-O3 -funroll-loops -fno-strict-aliasing -flto -S"
RUN --mount=type=cache,target=/root/.cache/pip uv run pip install --cache-dir /root/.cache/pip -r /tmp/requirements.txt && find / -type f -name "*python*.so" -not -name "*cpython*.so" -exec strip -S {} \;
ENV CFLAGS=
RUN find / -type f -name "*.py[co]" -delete && find / -type f -name "*.py" -exec touch -t 197001010000 {} \; && find / -type f -name "*.py" -printf "%h\n" | sort -u | /usr/bin/python3 -m compileall --invalidation-mode timestamp -o 2 -j 0
RUN cowsay moo
WORKDIR /src
EXPOSE 5000
ENV COG_PREDICT_TYPE_STUB="predict.py:Predictor"
CMD ["python", "-m", "cog.server.http"]
COPY . /src`

	require.Equal(t, expected, actual)

	requirements, err := os.ReadFile(path.Join(gen.tmpDir, "requirements.txt"))
	require.NoError(t, err)
	require.Equal(t, `--extra-index-url https://download.pytorch.org/whl/cu118
torch==2.3.1
pandas==2.0.3`, string(requirements))
}

func TestGenerateWithCoglet(t *testing.T) {
	tmpDir := t.TempDir()

	yaml := `
build:
  gpu: true
  cuda: "11.8"
  python_version: "3.12"
  system_packages:
    - ffmpeg
    - cowsay
  python_packages:
    - torch==2.3.1
    - pandas==2.0.3
    - coglet @ https://github.com/replicate/cog-runtime/releases/download/v0.1.0-alpha31/coglet-0.1.0a31-py3-none-any.whl
  run:
    - "cowsay moo"
predict: predict.py:Predictor
`
	conf, err := config.FromYAML([]byte(yaml))
	require.NoError(t, err)
	require.NoError(t, conf.Complete(""))
	command := dockertest.NewMockCommand()
	client := registrytest.NewMockRegistryClient()
	client.AddMockImage(BaseImageName("11.8", "3.12", "2.3.1"))
	gen, err := NewStandardGenerator(conf, tmpDir, "", command, client, true)
	require.NoError(t, err)
	gen.SetUseCogBaseImage(true)
	gen.SetStrip(true)
	gen.SetPrecompile(true)
	_, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test")
	require.NoError(t, err)

	// coglet in python_packages is stripped — the build system always installs coglet
	// via installCog(), which runs before pip requirements.
	expected := `#syntax=docker/dockerfile:1.4
FROM r8.im/replicate/cog-test-weights AS weights
FROM r8.im/cog-base:cuda11.8-python3.12-torch2.3.1
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked apt-get update -qq && apt-get install -qqy cowsay && rm -rf /var/lib/apt/lists/*
` + testInstallUVLine + `
` + testInstallCog(true) + `
COPY ` + gen.relativeTmpDir + `/requirements.txt /tmp/requirements.txt
ENV CFLAGS="-O3 -funroll-loops -fno-strict-aliasing -flto -S"
RUN --mount=type=cache,target=/root/.cache/pip uv run pip install --cache-dir /root/.cache/pip -r /tmp/requirements.txt && find / -type f -name "*python*.so" -not -name "*cpython*.so" -exec strip -S {} \;
ENV CFLAGS=
RUN find / -type f -name "*.py[co]" -delete && find / -type f -name "*.py" -exec touch -t 197001010000 {} \; && find / -type f -name "*.py" -printf "%h\n" | sort -u | /usr/bin/python3 -m compileall --invalidation-mode timestamp -o 2 -j 0
RUN cowsay moo
WORKDIR /src
EXPOSE 5000
ENV COG_PREDICT_TYPE_STUB="predict.py:Predictor"
CMD ["python", "-m", "cog.server.http"]
COPY . /src`

	require.Equal(t, expected, actual)

	// coglet URL is stripped from requirements — build system installs coglet itself
	requirements, err := os.ReadFile(path.Join(gen.tmpDir, "requirements.txt"))
	require.NoError(t, err)
	require.Equal(t, `--extra-index-url https://download.pytorch.org/whl/cu118
torch==2.3.1
pandas==2.0.3`, string(requirements))
}

func TestCOGWheelDefault(t *testing.T) {
	// Default behavior should install cog from PyPI
	tmpDir := t.TempDir()

	yaml := `
build:
  gpu: false
  python_version: "3.11"
predict: predict.py:Predictor
`
	conf, err := config.FromYAML([]byte(yaml))
	require.NoError(t, err)
	require.NoError(t, conf.Complete(""))

	command := dockertest.NewMockCommand()
	client := registrytest.NewMockRegistryClient()
	gen, err := NewStandardGenerator(conf, tmpDir, "", command, client, true)
	require.NoError(t, err)
	pypiWheels(gen)

	_, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test")
	require.NoError(t, err)

	// Should contain uv pip install cog from PyPI.
	// Coglet is not explicitly installed when unpinned — the SDK dependency handles it.
	require.Contains(t, actual, "uv pip install --no-cache cog")
}

func TestCOGWheelEnvPyPI(t *testing.T) {
	// COG_SDK_WHEEL=pypi should install from PyPI
	t.Setenv("COG_SDK_WHEEL", "pypi")

	tmpDir := t.TempDir()

	yaml := `
build:
  gpu: false
  python_version: "3.11"
predict: predict.py:Predictor
`
	conf, err := config.FromYAML([]byte(yaml))
	require.NoError(t, err)
	require.NoError(t, conf.Complete(""))

	command := dockertest.NewMockCommand()
	client := registrytest.NewMockRegistryClient()
	gen, err := NewStandardGenerator(conf, tmpDir, "", command, client, true)
	require.NoError(t, err)

	_, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test")
	require.NoError(t, err)

	// Should contain uv pip install cog from PyPI
	require.Contains(t, actual, "uv pip install --no-cache cog")
}

func TestCOGWheelEnvPyPIWithVersion(t *testing.T) {
	// COG_SDK_WHEEL=pypi:0.17.0 should install specific version from PyPI
	t.Setenv("COG_SDK_WHEEL", "pypi:0.17.0")

	tmpDir := t.TempDir()

	yaml := `
build:
  gpu: false
  python_version: "3.11"
predict: predict.py:Predictor
`
	conf, err := config.FromYAML([]byte(yaml))
	require.NoError(t, err)
	require.NoError(t, conf.Complete(""))

	command := dockertest.NewMockCommand()
	client := registrytest.NewMockRegistryClient()
	gen, err := NewStandardGenerator(conf, tmpDir, "", command, client, true)
	require.NoError(t, err)

	_, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test")
	require.NoError(t, err)

	// Should contain uv pip install cog==0.17.0 from PyPI
	require.Contains(t, actual, "uv pip install --no-cache cog==0.17.0")
}

func TestCOGWheelEnvURL(t *testing.T) {
	// COG_SDK_WHEEL=https://... should install from URL
	customURL := "https://example.com/custom-wheel-0.1.0.whl"
	t.Setenv("COG_SDK_WHEEL", customURL)

	tmpDir := t.TempDir()

	yaml := `
build:
  gpu: false
  python_version: "3.11"
predict: predict.py:Predictor
`
	conf, err := config.FromYAML([]byte(yaml))
	require.NoError(t, err)
	require.NoError(t, conf.Complete(""))

	command := dockertest.NewMockCommand()
	client := registrytest.NewMockRegistryClient()
	gen, err := NewStandardGenerator(conf, tmpDir, "", command, client, true)
	require.NoError(t, err)

	_, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test")
	require.NoError(t, err)

	// Should contain uv pip install from custom URL
	require.Contains(t, actual, "uv pip install --no-cache "+customURL)
}

func TestCOGWheelEnvFile(t *testing.T) {
	// COG_SDK_WHEEL=/path/to/file.whl should install from local file
	tmpDir := t.TempDir()

	// Create a mock wheel file
	wheelPath := filepath.Join(tmpDir, "test-cog-0.1.0-py3-none-any.whl")
	err := os.WriteFile(wheelPath, []byte("mock wheel content"), 0o644)
	require.NoError(t, err)

	t.Setenv("COG_SDK_WHEEL", wheelPath)

	yaml := `
build:
  gpu: false
  python_version: "3.11"
predict: predict.py:Predictor
`
	conf, err := config.FromYAML([]byte(yaml))
	require.NoError(t, err)
	require.NoError(t, conf.Complete(""))

	command := dockertest.NewMockCommand()
	client := registrytest.NewMockRegistryClient()
	gen, err := NewStandardGenerator(conf, tmpDir, "", command, client, true)
	require.NoError(t, err)

	_, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test")
	require.NoError(t, err)

	// Should contain uv pip install from temp path (copied into container)
	require.Contains(t, actual, "uv pip install --no-cache /tmp/test-cog-0.1.0-py3-none-any.whl")
}

func TestCogletStrippedFromRequirements(t *testing.T) {
	tmpDir := t.TempDir()
	conf, err := config.FromYAML([]byte(`
build:
  python_version: "3.12"
  python_packages:
    - "coglet==0.1.0"
predict: predict.py:Predictor
`))
	require.NoError(t, err)
	require.NoError(t, conf.Complete(""))
	command := dockertest.NewMockCommand()
	client := registrytest.NewMockRegistryClient()
	gen, err := NewStandardGenerator(conf, tmpDir, "", command, client, true)
	require.NoError(t, err)
	gen.SetUseCogBaseImage(false)
	pypiWheels(gen)
	dockerfile, err := gen.GenerateInitialSteps(t.Context())
	require.NoError(t, err)
	// coglet is NOT explicitly installed — SDK dependency handles it.
	// But the user-supplied coglet==0.1.0 must be stripped from requirements
	// to avoid conflicting with whatever version the SDK pulls in.
	require.NotContains(t, dockerfile, "coglet==0.1.0")
}

func TestInstallCogWithSDKVersion(t *testing.T) {
	// build.sdk_version pins the cog SDK version installed from PyPI
	tmpDir := t.TempDir()
	conf, err := config.FromYAML([]byte(`
build:
  python_version: "3.12"
  sdk_version: "0.18.0"
predict: predict.py:Predictor
`))
	require.NoError(t, err)
	require.NoError(t, conf.Complete(""))
	command := dockertest.NewMockCommand()
	client := registrytest.NewMockRegistryClient()
	gen, err := NewStandardGenerator(conf, tmpDir, "", command, client, true)
	require.NoError(t, err)
	// Only pin coglet to PyPI; leave cog to come from config sdk_version
	gen.cogletWheelConfig = &wheels.WheelConfig{Source: wheels.WheelSourcePyPI}
	gen.SetUseCogBaseImage(false)

	dockerfile, err := gen.GenerateInitialSteps(t.Context())
	require.NoError(t, err)
	require.Contains(t, dockerfile, "uv pip install --no-cache cog==0.18.0")
	// No --pre flag for stable release
	require.NotContains(t, dockerfile, "--pre")
}

func TestInstallCogWithPreReleaseSDKVersion(t *testing.T) {
	// build.sdk_version with a pre-release version adds --pre to cog install
	tmpDir := t.TempDir()
	conf, err := config.FromYAML([]byte(`
build:
  python_version: "3.12"
  sdk_version: "0.18.0a1"
predict: predict.py:Predictor
`))
	require.NoError(t, err)
	require.NoError(t, conf.Complete(""))
	command := dockertest.NewMockCommand()
	client := registrytest.NewMockRegistryClient()
	gen, err := NewStandardGenerator(conf, tmpDir, "", command, client, true)
	require.NoError(t, err)
	gen.cogletWheelConfig = &wheels.WheelConfig{Source: wheels.WheelSourcePyPI}
	gen.SetUseCogBaseImage(false)

	dockerfile, err := gen.GenerateInitialSteps(t.Context())
	require.NoError(t, err)
	// cog install should have --pre and pinned version
	require.Contains(t, dockerfile, "uv pip install --pre --no-cache cog==0.18.0a1")
	// coglet is NOT explicitly installed — SDK dependency pulls it in.
	// No separate coglet install line expected.
}

func TestInstallCogSDKVersionBelowMinimum(t *testing.T) {
	// build.sdk_version below MinimumSDKVersion should return an error
	tmpDir := t.TempDir()
	conf, err := config.FromYAML([]byte(`
build:
  python_version: "3.12"
  sdk_version: "0.15.0"
predict: predict.py:Predictor
`))
	require.NoError(t, err)
	require.NoError(t, conf.Complete(""))
	command := dockertest.NewMockCommand()
	client := registrytest.NewMockRegistryClient()
	gen, err := NewStandardGenerator(conf, tmpDir, "", command, client, true)
	require.NoError(t, err)
	gen.cogletWheelConfig = &wheels.WheelConfig{Source: wheels.WheelSourcePyPI}
	gen.SetUseCogBaseImage(false)

	_, err = gen.GenerateInitialSteps(t.Context())
	require.Error(t, err)
	require.Contains(t, err.Error(), "0.15.0")
	require.Contains(t, err.Error(), "minimum required version")
}

func TestCOGSDKWheelEnvVarOverridesSDKVersion(t *testing.T) {
	// COG_SDK_WHEEL env var overrides build.sdk_version
	t.Setenv("COG_SDK_WHEEL", "pypi:0.17.0")
	tmpDir := t.TempDir()
	conf, err := config.FromYAML([]byte(`
build:
  python_version: "3.12"
  sdk_version: "0.18.0"
predict: predict.py:Predictor
`))
	require.NoError(t, err)
	require.NoError(t, conf.Complete(""))
	command := dockertest.NewMockCommand()
	client := registrytest.NewMockRegistryClient()
	gen, err := NewStandardGenerator(conf, tmpDir, "", command, client, true)
	require.NoError(t, err)
	gen.cogletWheelConfig = &wheels.WheelConfig{Source: wheels.WheelSourcePyPI}
	gen.SetUseCogBaseImage(false)

	dockerfile, err := gen.GenerateInitialSteps(t.Context())
	require.NoError(t, err)
	// env var wins: should install 0.17.0, not 0.18.0
	require.Contains(t, dockerfile, "uv pip install --no-cache cog==0.17.0")
	require.NotContains(t, dockerfile, "cog==0.18.0")
}


================================================
FILE: pkg/dockerfile/version_check.go
================================================
package dockerfile

import (
	"regexp"
)

// Version string in the form x.y.z (e.g., Python version, CUDA version)
// We do not support suffixes like -alpha1 or +cu124
var versionRegex = regexp.MustCompile(`^(?P\d+)(\.(?P\d+)(\.(?P\d+))?)?$`)

func parse(s string) (string, string, string) {
	m := versionRegex.FindStringSubmatch(s)
	if m == nil {
		return "", "", ""
	}
	major := m[versionRegex.SubexpIndex("major")]
	minor := m[versionRegex.SubexpIndex("minor")]
	patch := m[versionRegex.SubexpIndex("patch")]
	return major, minor, patch

}

func CheckMajorOnly(s string) bool {
	major, minor, patch := parse(s)
	return major != "" && minor == "" && patch == ""
}

func CheckMajorMinorOnly(s string) bool {
	major, minor, patch := parse(s)
	return major != "" && minor != "" && patch == ""
}

func CheckMajorMinorPatch(s string) bool {
	major, minor, patch := parse(s)
	return major != "" && minor != "" && patch != ""
}


================================================
FILE: pkg/dockerignore/dockerignore.go
================================================
package dockerignore

import (
	"bufio"
	"os"
	"path/filepath"

	ignore "github.com/sabhiram/go-gitignore"

	"github.com/replicate/cog/pkg/util/files"
)

const DockerIgnoreFilename = ".dockerignore"

func CreateMatcher(dir string) (*ignore.GitIgnore, error) {
	dockerIgnorePath := filepath.Join(dir, DockerIgnoreFilename)
	dockerIgnoreExists, err := files.Exists(dockerIgnorePath)
	if err != nil {
		return nil, err
	}
	if !dockerIgnoreExists {
		return nil, nil
	}

	patterns, err := readDockerIgnore(dockerIgnorePath)
	if err != nil {
		return nil, err
	}
	return ignore.CompileIgnoreLines(patterns...), nil
}

func Walk(root string, ignoreMatcher *ignore.GitIgnore, fn filepath.WalkFunc) error {
	return filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
		if err != nil {
			return err
		}

		// We ignore files ignored by .dockerignore
		if ignoreMatcher != nil && ignoreMatcher.MatchesPath(path) {
			if info.IsDir() {
				return filepath.SkipDir
			}
			return nil
		}

		if info.IsDir() && info.Name() == ".cog" {
			return filepath.SkipDir
		}

		if info.Name() == DockerIgnoreFilename {
			return nil
		}

		return fn(path, info, err)
	})
}

func readDockerIgnore(dockerIgnorePath string) ([]string, error) {
	var patterns []string
	file, err := os.Open(dockerIgnorePath)
	if err != nil {
		return patterns, err
	}
	defer file.Close()

	scanner := bufio.NewScanner(file)
	for scanner.Scan() {
		line := scanner.Text()
		patterns = append(patterns, line)
	}
	return patterns, scanner.Err()
}


================================================
FILE: pkg/dockerignore/dockerignore_test.go
================================================
package dockerignore

import (
	"os"
	"path/filepath"
	"testing"

	"github.com/stretchr/testify/require"
)

func TestWalk(t *testing.T) {
	dir := t.TempDir()

	predictOtherPyFilename := "predict_other.py"
	predictOtherPyFilepath := filepath.Join(dir, predictOtherPyFilename)
	predictOtherPyHandle, err := os.Create(predictOtherPyFilepath)
	require.NoError(t, err)
	predictOtherPyHandle.WriteString("import cog")

	dockerIgnorePath := filepath.Join(dir, ".dockerignore")
	dockerIgnoreHandle, err := os.Create(dockerIgnorePath)
	require.NoError(t, err)
	dockerIgnoreHandle.WriteString(predictOtherPyFilename)

	predictPyFilename := "predict.py"
	predictPyFilepath := filepath.Join(dir, predictPyFilename)
	predictPyHandle, err := os.Create(predictPyFilepath)
	require.NoError(t, err)
	predictPyHandle.WriteString("import cog")

	matcher, err := CreateMatcher(dir)
	require.NoError(t, err)

	foundFiles := []string{}
	err = Walk(dir, matcher, func(path string, info os.FileInfo, err error) error {
		if err != nil {
			return err
		}

		if info.IsDir() {
			return nil
		}

		relPath, err := filepath.Rel(dir, path)
		if err != nil {
			return err
		}

		foundFiles = append(foundFiles, relPath)

		return nil
	})
	require.NoError(t, err)

	require.Equal(t, []string{predictPyFilename}, foundFiles)
}


================================================
FILE: pkg/env/env.go
================================================
package env

import "os"

const SchemeEnvVarName = "R8_SCHEME"
const WebHostEnvVarName = "R8_WEB_HOST"
const APIHostEnvVarName = "R8_API_HOST"
const PytorchHostEnvVarName = "R8_PYTORCH_HOST"

func SchemeFromEnvironment() string {
	scheme := os.Getenv(SchemeEnvVarName)
	if scheme == "" {
		scheme = "https"
	}
	return scheme
}

func WebHostFromEnvironment() string {
	host := os.Getenv(WebHostEnvVarName)
	if host == "" {
		host = "cog.replicate.com"
	}
	return host
}

func APIHostFromEnvironment() string {
	host := os.Getenv(APIHostEnvVarName)
	if host == "" {
		host = "api.replicate.com"
	}
	return host
}

func PytorchHostFromEnvironment() string {
	host := os.Getenv(PytorchHostEnvVarName)
	if host == "" {
		host = "download.pytorch.org"
	}
	return host
}


================================================
FILE: pkg/env/env_test.go
================================================
package env

import (
	"testing"

	"github.com/stretchr/testify/require"
)

func TestSchemeFromEnvironment(t *testing.T) {
	const testScheme = "myscheme"
	t.Setenv(SchemeEnvVarName, "myscheme")
	require.Equal(t, SchemeFromEnvironment(), testScheme)
}

func TestWebHostFromEnvironment(t *testing.T) {
	const testHost = "web"
	t.Setenv(WebHostEnvVarName, testHost)
	require.Equal(t, WebHostFromEnvironment(), testHost)
}


================================================
FILE: pkg/errors/common.go
================================================
package errors

import (
	"errors"

	"github.com/replicate/cog/pkg/global"
)

var (
	ErrorBadRegistryURL = errors.New("The image URL must have 3 components in the format of " + global.ReplicateRegistryHost + "/your-username/your-model")
)


================================================
FILE: pkg/errors/errors.go
================================================
package errors

const (
	CodeConfigNotFound = "CONFIG_NOT_FOUND"
)

// Types ////////////////////////////////////////

type CodedError interface {
	Code() string
}

type codedError struct {
	code string
	msg  string
}

func (e *codedError) Error() string {
	return e.msg
}

func (e *codedError) Code() string {
	return e.code
}

// Error Creators ///////////////////////////////

// The Cog config was not found
func ConfigNotFound(msg string) error {
	return &codedError{
		code: CodeConfigNotFound,
		msg:  msg + ``, // TODO: populate this
	}
}

// Helpers //////////////////////////////////////

func IsConfigNotFound(err error) bool {
	return Code(err) == CodeConfigNotFound
}

// Return the error code, or the empty string
func Code(err error) string {
	if cerr, ok := err.(CodedError); ok {
		return cerr.Code()
	}

	return ""
}


================================================
FILE: pkg/global/global.go
================================================
package global

import "os"

const (
	DefaultReplicateRegistryHost = "r8.im"
	ReplicateWebsiteHost         = "replicate.com"
)

var (
	Version               = "dev"
	Commit                = ""
	BuildTime             = "none"
	Debug                 = false
	NoColor               = false
	ProfilingEnabled      = false
	ReplicateRegistryHost = getDefaultRegistryHost()

	LabelNamespace          = "run.cog."
	CogBuildArtifactsFolder = ".cog"
)

func getDefaultRegistryHost() string {
	// Priority: flag will override at runtime, but env var provides default
	if host := os.Getenv("COG_REGISTRY_HOST"); host != "" {
		return host
	}
	return DefaultReplicateRegistryHost
}


================================================
FILE: pkg/http/client.go
================================================
package http

import (
	"context"
	"net/http"

	"github.com/replicate/cog/pkg/docker/command"
	"github.com/replicate/cog/pkg/env"
	"github.com/replicate/cog/pkg/global"
)

const UserAgentHeader = "User-Agent"
const BearerHeaderPrefix = "Bearer "

func ProvideHTTPClient(ctx context.Context, dockerCommand command.Command) (*http.Client, error) {
	userInfo, err := dockerCommand.LoadUserInformation(ctx, global.ReplicateRegistryHost)
	if err != nil {
		return nil, err
	}

	client := http.Client{
		Transport: &Transport{
			headers: map[string]string{
				UserAgentHeader: UserAgent(),
				"Content-Type":  "application/json",
			},
			authentication: map[string]string{
				env.WebHostFromEnvironment(): BearerHeaderPrefix + userInfo.Token,
			},
		},
	}

	return &client, nil
}


================================================
FILE: pkg/http/client_test.go
================================================
package http

import (
	"net/http"
	"net/http/httptest"
	"testing"

	"github.com/stretchr/testify/require"

	"github.com/replicate/cog/pkg/docker/dockertest"
)

func TestClientDecoratesUserAgent(t *testing.T) {
	// Setup mock http server
	seenUserAgent := false
	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		require.Equal(t, r.Header.Get(UserAgentHeader), UserAgent())
		seenUserAgent = true
	}))
	defer server.Close()

	command := dockertest.NewMockCommand()
	client, err := ProvideHTTPClient(t.Context(), command)
	require.NoError(t, err)

	_, err = client.Get(server.URL)
	require.NoError(t, err)

	require.True(t, seenUserAgent)
}


================================================
FILE: pkg/http/transport.go
================================================
package http

import (
	"errors"
	"net/http"
)

const AuthorizationHeader = "Authorization"

type Transport struct {
	headers        map[string]string
	authentication map[string]string
	base           http.RoundTripper
}

func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
	// Write standard headers
	for k, v := range t.headers {
		if req.Header.Get(k) == "" {
			req.Header.Set(k, v)
		}
	}

	// Write authentication
	if req.Header.Get(AuthorizationHeader) == "" {
		authorisation, ok := t.authentication[req.URL.Host]
		if ok {
			if authorisation == BearerHeaderPrefix {
				return nil, errors.New("No token supplied for HTTP authorization. Have you run 'cog login'?")
			}
			req.Header.Set(AuthorizationHeader, authorisation)
		}
	}

	base := t.base
	if base == nil {
		base = http.DefaultTransport
	}
	return base.RoundTrip(req)
}


================================================
FILE: pkg/http/transport_test.go
================================================
package http

import (
	"net/http"
	"net/http/httptest"
	"net/url"
	"testing"

	"github.com/stretchr/testify/require"
)

func TestTransportAddsHeaders(t *testing.T) {
	// Setup mock http server
	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusOK)
	}))
	defer server.Close()

	const testHeader = "X-Test-Header"
	const testValue = "TestValue"
	transport := Transport{
		headers: map[string]string{
			testHeader: testValue,
		},
	}
	req, err := http.NewRequest("GET", server.URL, nil)
	require.NoError(t, err)
	resp, err := transport.RoundTrip(req)
	require.NoError(t, err)
	require.Equal(t, resp.Request.Header.Get(testHeader), testValue)
}

func TestTransportOnlyAddsHeaderIfMissing(t *testing.T) {
	// Setup mock http server
	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusOK)
	}))
	defer server.Close()

	const testHeader = "X-Test-Header"
	const testValue = "TestValue"
	transport := Transport{
		headers: map[string]string{
			testHeader: testValue,
		},
	}
	const expectedValue = "ExpectedValue"
	req, err := http.NewRequest("GET", server.URL, nil)
	req.Header.Set(testHeader, expectedValue)
	require.NoError(t, err)
	resp, err := transport.RoundTrip(req)
	require.NoError(t, err)
	require.Equal(t, resp.Request.Header.Get(testHeader), expectedValue)
}

func TestTransportSendsErrorWithMissingToken(t *testing.T) {
	// Setup mock http server
	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusOK)
	}))
	defer server.Close()
	u, err := url.Parse(server.URL)
	require.NoError(t, err)

	transport := Transport{
		authentication: map[string]string{
			u.Host: BearerHeaderPrefix + "",
		},
	}
	req, err := http.NewRequest("GET", server.URL, nil)
	require.NoError(t, err)
	resp, err := transport.RoundTrip(req)
	require.Error(t, err)
	require.Nil(t, resp)
}


================================================
FILE: pkg/http/user_agent.go
================================================
package http

import (
	"fmt"
	"runtime"

	"github.com/replicate/cog/pkg/global"
)

func UserAgent() string {
	var platform string
	switch runtime.GOOS {
	case "linux":
		platform = "Linux"
	case "windows":
		platform = "Windows"
	case "darwin":
		platform = "macOS"
	default:
		platform = runtime.GOOS
	}

	return fmt.Sprintf("Cog/%s (%s)", global.Version, platform)
}


================================================
FILE: pkg/http/user_agent_test.go
================================================
package http

import (
	"strings"
	"testing"

	"github.com/stretchr/testify/require"
)

func TestUserAgent(t *testing.T) {
	require.True(t, strings.HasPrefix(UserAgent(), "Cog/"))
}


================================================
FILE: pkg/image/build.go
================================================
package image

import (
	"bytes"
	"context"
	"crypto/sha256"
	"encoding/hex"
	"encoding/json"
	"errors"
	"fmt"
	"maps"
	"os"
	"os/exec"
	"path/filepath"
	"strings"
	"time"

	"github.com/getkin/kin-openapi/openapi3"
	"github.com/google/go-containerregistry/pkg/name"
	"github.com/google/go-containerregistry/pkg/v1/remote"

	"github.com/replicate/cog/pkg/config"
	"github.com/replicate/cog/pkg/docker/command"
	"github.com/replicate/cog/pkg/dockercontext"
	"github.com/replicate/cog/pkg/dockerfile"
	"github.com/replicate/cog/pkg/dockerignore"
	"github.com/replicate/cog/pkg/global"
	"github.com/replicate/cog/pkg/registry"
	"github.com/replicate/cog/pkg/schema"
	"github.com/replicate/cog/pkg/schema/python"
	"github.com/replicate/cog/pkg/util/console"
	cogversion "github.com/replicate/cog/pkg/util/version"
	"github.com/replicate/cog/pkg/weights"
	"github.com/replicate/cog/pkg/wheels"
)

const dockerignoreBackupPath = ".dockerignore.cog.bak"
const weightsManifestPath = ".cog/cache/weights_manifest.json"
const bundledSchemaFile = ".cog/openapi_schema.json"

var errGit = errors.New("git error")

// Build a Cog model from a config and returns the image ID (sha256:...) on success.
//
// This is separated out from docker.Build(), so that can be as close as possible to the behavior of 'docker build'.
func Build(
	ctx context.Context,
	cfg *config.Config,
	dir,
	imageName string,
	configFilename string,
	secrets []string,
	noCache,
	separateWeights bool,
	useCudaBaseImage string,
	progressOutput string,
	schemaFile string,
	dockerfileFile string,
	useCogBaseImage *bool,
	strip bool,
	precompile bool,
	excludeSource bool,
	skipSchemaValidation bool,
	skipLabels bool,
	annotations map[string]string,
	dockerCommand command.Command,
	client registry.Client) (string, error) {
	// remove bundled schema files that may be left from previous builds
	_ = os.Remove(bundledSchemaFile)

	if err := checkCompatibleDockerIgnore(dir); err != nil {
		return "", err
	}

	// Determine whether to use the static schema generator (Go tree-sitter) or
	// the legacy runtime path (boot container + python introspection).
	//
	// Static generation is opt-in via COG_STATIC_SCHEMA=1 for all commands.
	// The legacy runtime path (boot container + python -m cog.command.openapi_schema)
	// remains the default for `cog build`. For `cog train`, `cog predict`, and
	// `cog serve` (skipLabels=true), no schema is generated unless
	// COG_STATIC_SCHEMA=1 is set, since these paths return before the post-build
	// legacy schema generation step.
	//
	// The SDK version must be >= 0.17.0 (or unpinned/latest/dev) since older
	// SDKs use pydantic-based schemas that cannot be statically analyzed.
	needsSchema := !skipSchemaValidation && schemaFile == ""
	useStatic := needsSchema && canUseStaticSchemaGen(cfg)

	// --- Pre-build static schema generation ---
	// When using the static path, generate schema BEFORE the Docker build so we
	// fail fast on schema errors and the schema file is in the build context.
	var schemaJSON []byte
	switch {
	case useStatic:
		console.Debug("Generating model schema (static)...")
		data, err := generateStaticSchema(cfg, dir)
		if err == nil {
			schemaJSON = data
			break
		}

		// For `cog build` only: fall back to the post-build legacy runtime
		// schema generation which can handle types that require Python import
		// (e.g. package __init__.py modules, pydantic v2 BaseModel subclasses).
		var se *schema.SchemaError
		if !skipLabels && errors.As(err, &se) && se.Kind == schema.ErrUnresolvableType {
			console.Warnf("Static schema generation failed: %s", err)
			console.Warn("Falling back to legacy runtime schema generation...")
			// leave schemaJSON nil — the post-build legacy path will handle it
			break
		}

		return "", fmt.Errorf("image build failed: %w", err)
	case !skipSchemaValidation && schemaFile != "":
		console.Infof("Validating model schema from %s...", schemaFile)
		data, err := os.ReadFile(schemaFile)
		if err != nil {
			return "", fmt.Errorf("Failed to read schema file: %w", err)
		}
		schemaJSON = data
	case skipSchemaValidation:
		console.Debug("Skipping model schema validation")
	}

	// Write and validate pre-build schema (static or from file).
	if len(schemaJSON) > 0 {
		if err := writeAndValidateSchema(schemaJSON); err != nil {
			return "", err
		}
	}

	// --- Docker build ---
	var cogBaseImageName string

	tmpImageId := imageName
	isR8imImage := strings.HasPrefix(imageName, "r8.im")
	if isR8imImage {
		hash := sha256.New()
		_, err := hash.Write([]byte(imageName))
		if err != nil {
			return "", err
		}
		tmpImageId = fmt.Sprintf("cog-tmp:%s", hex.EncodeToString(hash.Sum(nil)))
	}

	if dockerfileFile != "" {
		dockerfileContents, err := os.ReadFile(dockerfileFile)
		if err != nil {
			return "", fmt.Errorf("Failed to read Dockerfile at %s: %w", dockerfileFile, err)
		}

		buildOpts := command.ImageBuildOptions{
			WorkingDir:         dir,
			DockerfileContents: string(dockerfileContents),
			ImageName:          tmpImageId,
			Secrets:            secrets,
			NoCache:            noCache,
			ProgressOutput:     progressOutput,
			Epoch:              &config.BuildSourceEpochTimestamp,
			ContextDir:         dockercontext.StandardBuildDirectory,
		}
		if _, err := dockerCommand.ImageBuild(ctx, buildOpts); err != nil {
			return "", fmt.Errorf("Failed to build Docker image: %w", err)
		}
	} else {
		generator, err := dockerfile.NewGenerator(cfg, dir, configFilename, dockerCommand, client, true)
		if err != nil {
			return "", fmt.Errorf("Error creating Dockerfile generator: %w", err)
		}
		contextDir, err := generator.BuildDir()
		if err != nil {
			return "", err
		}
		buildContexts, err := generator.BuildContexts()
		if err != nil {
			return "", err
		}
		defer func() {
			if err := generator.Cleanup(); err != nil {
				console.Warnf("Error cleaning up Dockerfile generator: %s", err)
			}
		}()
		generator.SetStrip(strip)
		generator.SetPrecompile(precompile)
		generator.SetUseCudaBaseImage(useCudaBaseImage)
		if useCogBaseImage != nil {
			generator.SetUseCogBaseImage(*useCogBaseImage)
		}

		if generator.IsUsingCogBaseImage() {
			cogBaseImageName, err = generator.BaseImage(ctx)
			if err != nil {
				return "", fmt.Errorf("Failed to get cog base image name: %s", err)
			}
		}

		if separateWeights {
			weightsDockerfile, runnerDockerfile, dockerignore, err := generator.GenerateModelBaseWithSeparateWeights(ctx, imageName)
			if err != nil {
				return "", fmt.Errorf("Failed to generate Dockerfile: %w", err)
			}

			if err := backupDockerignore(); err != nil {
				return "", fmt.Errorf("Failed to backup .dockerignore file: %w", err)
			}

			weightsManifest, err := generator.GenerateWeightsManifest(ctx)
			if err != nil {
				return "", fmt.Errorf("Failed to generate weights manifest: %w", err)
			}
			cachedManifest, _ := weights.LoadManifest(weightsManifestPath)
			changed := cachedManifest == nil || !weightsManifest.Equal(cachedManifest)
			if changed {
				if err := buildWeightsImage(ctx, dockerCommand, dir, weightsDockerfile, imageName+"-weights", secrets, noCache, progressOutput, contextDir, buildContexts); err != nil {
					return "", fmt.Errorf("Failed to build model weights Docker image: %w", err)
				}
				err := weightsManifest.Save(weightsManifestPath)
				if err != nil {
					return "", fmt.Errorf("Failed to save weights hash: %w", err)
				}
			} else {
				console.Info("Weights unchanged, skip rebuilding and use cached image...")
			}

			if err := buildRunnerImage(ctx, dockerCommand, dir, runnerDockerfile, dockerignore, imageName, secrets, noCache, progressOutput, contextDir, buildContexts); err != nil {
				return "", fmt.Errorf("Failed to build runner Docker image: %w", err)
			}
		} else {
			var dockerfileContents string
			if excludeSource {
				// Dev mode (cog serve): same layers as cog build but without
				// COPY . /src — source is volume-mounted at runtime instead.
				// This shares Docker layer cache with full builds.
				dockerfileContents, err = generator.GenerateModelBase(ctx)
			} else {
				dockerfileContents, err = generator.GenerateDockerfileWithoutSeparateWeights(ctx)
			}
			if err != nil {
				return "", fmt.Errorf("Failed to generate Dockerfile: %w", err)
			}

			buildOpts := command.ImageBuildOptions{
				WorkingDir:         dir,
				DockerfileContents: dockerfileContents,
				ImageName:          tmpImageId,
				Secrets:            secrets,
				NoCache:            noCache,
				ProgressOutput:     progressOutput,
				Epoch:              &config.BuildSourceEpochTimestamp,
				ContextDir:         contextDir,
				BuildContexts:      buildContexts,
			}

			if _, err := dockerCommand.ImageBuild(ctx, buildOpts); err != nil {
				return "", fmt.Errorf("Failed to build Docker image: %w", err)
			}
		}
	}

	// --- Post-build legacy schema generation ---
	// For SDK < 0.17.0 (or when static gen was not used), generate the schema
	// by running the built image with python -m cog.command.openapi_schema.
	// This must run before the skipLabels early return so that cog train/predict/serve
	// have a schema available for input validation and -i flag parsing.
	if len(schemaJSON) == 0 && !skipSchemaValidation {
		console.Info("Validating model schema...")
		enableGPU := cfg.Build != nil && cfg.Build.GPU
		// When excludeSource is true (cog serve/predict/train), /src was not
		// COPYed into the image, so volume-mount the project directory.
		sourceDir := ""
		if excludeSource {
			sourceDir = dir
		}
		legacySchema, err := GenerateOpenAPISchema(ctx, dockerCommand, tmpImageId, enableGPU, sourceDir)
		if err != nil {
			return "", fmt.Errorf("Failed to get type signature: %w", err)
		}
		data, err := json.Marshal(legacySchema)
		if err != nil {
			return "", fmt.Errorf("Failed to convert type signature to JSON: %w", err)
		}
		schemaJSON = data

		if err := writeAndValidateSchema(schemaJSON); err != nil {
			return "", err
		}
	}

	// When skipLabels is true (cog run/predict/serve/train), skip the expensive
	// label-adding phase. This image is for local use only and won't be distributed,
	// so we don't need metadata labels, pip freeze, or git info.
	// We still need the schema bundled, so do a minimal second build to add it.
	if skipLabels {
		if len(schemaJSON) > 0 {
			// Use trailing "/" on the destination so Docker creates the .cog/
			// directory even in ExcludeSource images where COPY . /src was
			// skipped and .cog/ does not yet exist.
			schemaDockerfile := fmt.Sprintf("FROM %s\nCOPY %s .cog/\n", tmpImageId, bundledSchemaFile)
			buildOpts := command.ImageBuildOptions{
				DockerfileContents: schemaDockerfile,
				ImageName:          tmpImageId,
				ProgressOutput:     progressOutput,
			}
			if _, err := dockerCommand.ImageBuild(ctx, buildOpts); err != nil {
				return "", fmt.Errorf("Failed to bundle schema into image: %w", err)
			}
		}
		return tmpImageId, nil
	}

	console.Info("Adding labels to image...")
	console.Info("")

	// We used to set the cog_version and config labels in Dockerfile, because we didn't require running the
	// built image to get those. But, the escaping of JSON inside a label inside a Dockerfile was gnarly, and
	// doesn't seem to be a problem here, so do it here instead.
	configJSON, err := json.Marshal(cfg)
	if err != nil {
		return "", fmt.Errorf("Failed to convert config to JSON: %w", err)
	}

	pipFreeze, err := GeneratePipFreeze(ctx, dockerCommand, tmpImageId)
	if err != nil {
		return "", fmt.Errorf("Failed to generate pip freeze from image: %w", err)
	}

	labels := map[string]string{
		command.CogVersionLabelKey:           global.Version,
		command.CogConfigLabelKey:            string(bytes.TrimSpace(configJSON)),
		command.CogOpenAPISchemaLabelKey:     string(schemaJSON),
		global.LabelNamespace + "pip_freeze": pipFreeze,
		// Mark the image as having an appropriate init entrypoint. We can use this
		// to decide how/if to shim the image.
		global.LabelNamespace + "has_init": "true",
	}

	if cogBaseImageName != "" {
		labels[global.LabelNamespace+"cog-base-image-name"] = cogBaseImageName

		// name.Insecure allows HTTP fallback for local/test registries,
		// consistent with ParseReference calls in pkg/registry/.
		ref, err := name.ParseReference(cogBaseImageName, name.Insecure)
		if err != nil {
			return "", fmt.Errorf("Failed to parse cog base image reference: %w", err)
		}

		img, err := remote.Image(ref)
		if err != nil {
			return "", fmt.Errorf("Failed to fetch cog base image: %w", err)
		}

		layers, err := img.Layers()
		if err != nil {
			return "", fmt.Errorf("Failed to get layers for cog base image: %w", err)
		}

		if len(layers) == 0 {
			return "", fmt.Errorf("Cog base image has no layers: %s", cogBaseImageName)
		}

		lastLayerIndex := len(layers) - 1
		layerLayerDigest, err := layers[lastLayerIndex].DiffID()
		if err != nil {
			return "", fmt.Errorf("Failed to get last layer digest for cog base image: %w", err)
		}

		lastLayer := layerLayerDigest.String()
		console.Debugf("Last layer of the cog base image: %s", lastLayer)

		labels[global.LabelNamespace+"cog-base-image-last-layer-sha"] = lastLayer
		labels[global.LabelNamespace+"cog-base-image-last-layer-idx"] = fmt.Sprintf("%d", lastLayerIndex)
	}

	if commit, err := gitHead(ctx, dir); commit != "" && err == nil {
		labels["org.opencontainers.image.revision"] = commit
	} else {
		console.Debug("Unable to determine Git commit")
	}

	if tag, err := gitTag(ctx, dir); tag != "" && err == nil {
		labels["org.opencontainers.image.version"] = tag
	} else {
		console.Debug("Unable to determine Git tag")
	}

	maps.Copy(labels, annotations)

	// The final image ID comes from the label-adding step.
	// When schema validation is skipped (cog run), there is no schema file to bundle.
	schemaFileToBundle := bundledSchemaFile
	if skipSchemaValidation {
		schemaFileToBundle = ""
	}
	imageID, err := BuildAddLabelsAndSchemaToImage(ctx, dockerCommand, tmpImageId, imageName, labels, schemaFileToBundle, progressOutput)
	if err != nil {
		return "", fmt.Errorf("Failed to add labels to image: %w", err)
	}

	// We created a temp image, so delete it. Don't "-f" so it doesn't blow anything up
	if isR8imImage {
		if err = dockerCommand.RemoveImage(ctx, tmpImageId); err != nil {
			return "", err
		}
	}

	return imageID, nil
}

// BuildAddLabelsAndSchemaToImage builds a cog model with labels and schema.
// Returns the image ID (sha256:...) of the final image.
//
// The new image is based on the provided image with the labels and schema file appended to it.
// tmpName is the source image to build from, image is the final image name/tag.
func BuildAddLabelsAndSchemaToImage(ctx context.Context, dockerClient command.Command, tmpName, image string, labels map[string]string, bundledSchemaFile string, progressOutput string) (string, error) {
	var dockerfile string
	if bundledSchemaFile != "" {
		dockerfile = fmt.Sprintf("FROM %s\nCOPY %s .cog\n", tmpName, bundledSchemaFile)
	} else {
		dockerfile = fmt.Sprintf("FROM %s\n", tmpName)
	}

	buildOpts := command.ImageBuildOptions{
		DockerfileContents: dockerfile,
		ImageName:          image,
		Labels:             labels,
		ProgressOutput:     progressOutput,
	}

	imageID, err := dockerClient.ImageBuild(ctx, buildOpts)
	if err != nil {
		return "", fmt.Errorf("Failed to add labels and schema to image: %w", err)
	}
	return imageID, nil
}

// staticSchemaGenMinSDKVersion is the minimum SDK version that supports
// static schema generation. Older SDK versions use pydantic-based runtime
// introspection and must fall back to the legacy Docker-based path.
const staticSchemaGenMinSDKVersion = "0.17.0"

// canUseStaticSchemaGen returns true if the user has opted in to static schema
// generation via COG_STATIC_SCHEMA=1 (or "true").
//
// Even when opted in, returns false when the SDK version is explicitly
// pinned < 0.17.0, since older SDKs use pydantic-based schemas that the
// static parser cannot analyze.
func canUseStaticSchemaGen(cfg *config.Config) bool {
	env := strings.ToLower(os.Getenv("COG_STATIC_SCHEMA"))
	if env != "1" && env != "true" {
		return false
	}

	sdkVersion := resolveSDKVersion(cfg)
	if sdkVersion != "" {
		base := sdkVersion
		if m := wheels.BaseVersionRe.FindString(base); m != "" {
			base = m
		}
		if ver, err := cogversion.NewVersion(base); err == nil {
			minVer := cogversion.MustVersion(staticSchemaGenMinSDKVersion)
			if !ver.GreaterOrEqual(minVer) {
				console.Infof("SDK version %s < %s, using legacy runtime schema generation", sdkVersion, staticSchemaGenMinSDKVersion)
				return false
			}
		}
	}
	return true
}

// resolveSDKVersion determines the SDK version that will be installed in the
// container, using the same precedence as the Dockerfile generator:
//  1. COG_SDK_WHEEL env var (parse version from "pypi:X.Y.Z")
//  2. build.sdk_version in cog.yaml
//  3. Auto-detect from dist/ wheel filename
//  4. Empty string (latest/unpinned)
func resolveSDKVersion(cfg *config.Config) string {
	if envVal := os.Getenv(wheels.CogSDKWheelEnvVar); envVal != "" {
		wc := wheels.ParseWheelValue(envVal)
		if wc != nil && wc.Source == wheels.WheelSourcePyPI && wc.Version != "" {
			return wc.Version
		}
		return ""
	}
	if cfg.Build != nil && cfg.Build.SDKVersion != "" {
		return cfg.Build.SDKVersion
	}
	if v := wheels.DetectLocalSDKVersion(); v != "" {
		return v
	}
	return ""
}

// generateStaticSchema runs the Go tree-sitter parser to produce the OpenAPI schema.
// When both predict and train are configured, it generates both and merges them.
func generateStaticSchema(cfg *config.Config, dir string) ([]byte, error) {
	if cfg.Predict == "" && cfg.Train == "" {
		return nil, fmt.Errorf("no predict or train reference found in cog.yaml")
	}
	return schema.GenerateCombined(dir, cfg.Predict, cfg.Train, python.ParsePredictor)

}

// writeAndValidateSchema writes the schema JSON to the bundled schema file and
// validates it as a well-formed OpenAPI 3.0 specification.
func writeAndValidateSchema(schemaJSON []byte) error {
	if err := os.MkdirAll(filepath.Dir(bundledSchemaFile), 0o755); err != nil {
		return fmt.Errorf("failed to create directory for %s: %w", bundledSchemaFile, err)
	}
	if err := os.WriteFile(bundledSchemaFile, schemaJSON, 0o644); err != nil {
		return fmt.Errorf("failed to store bundled schema file %s: %w", bundledSchemaFile, err)
	}

	loader := openapi3.NewLoader()
	loader.IsExternalRefsAllowed = true
	doc, err := loader.LoadFromData(schemaJSON)
	if err != nil {
		return fmt.Errorf("Failed to load model schema JSON: %w", err)
	}
	if err := doc.Validate(loader.Context); err != nil {
		return fmt.Errorf("Model schema is invalid: %w\n\n%s", err, string(schemaJSON))
	}
	return nil
}

func isGitWorkTree(ctx context.Context, dir string) bool {
	ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
	defer cancel()

	out, err := exec.CommandContext(ctx, "git", "-C", dir, "rev-parse", "--is-inside-work-tree").Output()
	if err != nil {
		return false
	}

	return strings.TrimSpace(string(out)) == "true"
}

func gitHead(ctx context.Context, dir string) (string, error) {
	if v, ok := os.LookupEnv("GITHUB_SHA"); ok && v != "" {
		return v, nil
	}

	if isGitWorkTree(ctx, dir) {
		ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
		defer cancel()

		out, err := exec.CommandContext(ctx, "git", "-C", dir, "rev-parse", "HEAD").Output()
		if err != nil {
			return "", err
		}

		return string(bytes.TrimSpace(out)), nil
	}

	return "", fmt.Errorf("Failed to find HEAD commit: %w", errGit)
}

func gitTag(ctx context.Context, dir string) (string, error) {
	if v, ok := os.LookupEnv("GITHUB_REF_NAME"); ok && v != "" {
		return v, nil
	}

	if isGitWorkTree(ctx, dir) {
		ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
		defer cancel()

		out, err := exec.CommandContext(ctx, "git", "-C", dir, "describe", "--tags", "--dirty").Output()
		if err != nil {
			return "", err
		}

		return string(bytes.TrimSpace(out)), nil
	}

	return "", fmt.Errorf("Failed to find ref name: %w", errGit)
}

func buildWeightsImage(ctx context.Context, dockerClient command.Command, dir, dockerfileContents, imageName string, secrets []string, noCache bool, progressOutput string, contextDir string, buildContexts map[string]string) error {
	if err := makeDockerignoreForWeightsImage(); err != nil {
		return fmt.Errorf("Failed to create .dockerignore file: %w", err)
	}
	buildOpts := command.ImageBuildOptions{
		WorkingDir:         dir,
		DockerfileContents: dockerfileContents,
		ImageName:          imageName,
		Secrets:            secrets,
		NoCache:            noCache,
		ProgressOutput:     progressOutput,
		Epoch:              &config.BuildSourceEpochTimestamp,
		ContextDir:         contextDir,
		BuildContexts:      buildContexts,
	}
	if _, err := dockerClient.ImageBuild(ctx, buildOpts); err != nil {
		return fmt.Errorf("Failed to build Docker image for model weights: %w", err)
	}
	return nil
}

func buildRunnerImage(ctx context.Context, dockerClient command.Command, dir, dockerfileContents, dockerignoreContents, imageName string, secrets []string, noCache bool, progressOutput string, contextDir string, buildContexts map[string]string) error {
	if err := writeDockerignore(dockerignoreContents); err != nil {
		return fmt.Errorf("Failed to write .dockerignore file with weights included: %w", err)
	}
	buildOpts := command.ImageBuildOptions{
		WorkingDir:         dir,
		DockerfileContents: dockerfileContents,
		ImageName:          imageName,
		Secrets:            secrets,
		NoCache:            noCache,
		ProgressOutput:     progressOutput,
		Epoch:              &config.BuildSourceEpochTimestamp,
		ContextDir:         contextDir,
		BuildContexts:      buildContexts,
	}
	if _, err := dockerClient.ImageBuild(ctx, buildOpts); err != nil {
		return fmt.Errorf("Failed to build Docker image: %w", err)
	}
	if err := restoreDockerignore(); err != nil {
		return fmt.Errorf("Failed to restore backup .dockerignore file: %w", err)
	}
	return nil
}

func makeDockerignoreForWeightsImage() error {
	if err := backupDockerignore(); err != nil {
		return fmt.Errorf("Failed to backup .dockerignore file: %w", err)
	}

	if err := writeDockerignore(dockerfile.DockerignoreHeader); err != nil {
		return fmt.Errorf("Failed to write .dockerignore file: %w", err)
	}
	return nil
}

func writeDockerignore(contents string) error {
	// read existing file contents from .dockerignore.cog.bak if it exists, and append to the new contents
	if _, err := os.Stat(dockerignoreBackupPath); err == nil {
		existingContents, err := os.ReadFile(dockerignoreBackupPath)
		if err != nil {
			return err
		}
		contents = string(existingContents) + "\n" + contents
	}

	return os.WriteFile(".dockerignore", []byte(contents), 0o644)
}

func backupDockerignore() error {
	if _, err := os.Stat(".dockerignore"); err != nil {
		if os.IsNotExist(err) {
			// .dockerignore file does not exist, nothing to backup
			return nil
		}
		return err
	}

	// rename the .dockerignore file to a new name
	return os.Rename(".dockerignore", dockerignoreBackupPath)
}

func restoreDockerignore() error {
	if err := os.Remove(".dockerignore"); err != nil {
		return err
	}

	if _, err := os.Stat(dockerignoreBackupPath); err != nil {
		if os.IsNotExist(err) {
			// .dockerignore backup file does not exist, nothing to restore
			return nil
		}
		return err
	}

	return os.Rename(dockerignoreBackupPath, ".dockerignore")
}

func checkCompatibleDockerIgnore(dir string) error {
	matcher, err := dockerignore.CreateMatcher(dir)
	if err != nil {
		return err
	}
	// If the matcher is nil and we don't have an error, we don't have a .dockerignore to scan.
	if matcher == nil {
		return nil
	}
	if matcher.MatchesPath(".cog") {
		return errors.New("The .cog tmp path cannot be ignored by docker in .dockerignore")
	}
	return nil
}


================================================
FILE: pkg/image/build_test.go
================================================
package image

import (
	"context"
	"os"
	"os/exec"
	"path/filepath"
	"testing"
	"time"

	"github.com/stretchr/testify/require"

	"github.com/replicate/cog/pkg/config"
)

var hasGit = (func() bool {
	_, err := exec.LookPath("git")
	return err == nil
})()

func gitRun(ctx context.Context, argv []string, t *testing.T) {
	ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
	t.Cleanup(cancel)

	out, err := exec.CommandContext(ctx, "git", argv...).CombinedOutput()
	t.Logf("git output:\n%s", string(out))

	require.NoError(t, err)
}

func setupGitWorkTree(t *testing.T) string {
	ctx := t.Context()
	if !hasGit {
		t.Skip("no git executable available")
		return ""
	}

	r := require.New(t)

	tmp := filepath.Join(t.TempDir(), "wd")
	r.NoError(os.MkdirAll(tmp, 0o755))

	gitRun(ctx, []string{"init", tmp}, t)
	gitRun(ctx, []string{"-C", tmp, "config", "user.email", "cog@localhost"}, t)
	gitRun(ctx, []string{"-C", tmp, "config", "user.name", "Cog Tests"}, t)
	gitRun(ctx, []string{"-C", tmp, "commit", "--allow-empty", "-m", "walrus"}, t)
	gitRun(ctx, []string{"-C", tmp, "tag", "-a", "v0.0.1+walrus", "-m", "walrus time"}, t)

	return tmp
}

func TestIsGitWorkTree(t *testing.T) {
	ctx := t.Context()
	r := require.New(t)

	r.False(isGitWorkTree(ctx, "/dev/null"))
	r.True(isGitWorkTree(ctx, setupGitWorkTree(t)))
}

func TestGitHead(t *testing.T) {
	t.Run("via github env", func(t *testing.T) {
		t.Setenv("GITHUB_SHA", "fafafaf")

		head, err := gitHead(t.Context(), "/dev/null")

		require.NoError(t, err)
		require.Equal(t, "fafafaf", head)
	})

	t.Run("via git", func(t *testing.T) {
		tmp := setupGitWorkTree(t)
		if tmp == "" {
			return
		}

		t.Setenv("GITHUB_SHA", "")

		head, err := gitHead(t.Context(), tmp)
		require.NoError(t, err)
		require.NotEqual(t, "", head)
	})

	t.Run("unavailable", func(t *testing.T) {
		t.Setenv("GITHUB_SHA", "")

		head, err := gitHead(t.Context(), "/dev/null")
		require.Error(t, err)
		require.Equal(t, "", head)
	})
}

func TestGitTag(t *testing.T) {
	t.Run("via github env", func(t *testing.T) {
		t.Setenv("GITHUB_REF_NAME", "v0.0.1+manatee")

		tag, err := gitTag(t.Context(), "/dev/null")
		require.NoError(t, err)
		require.Equal(t, "v0.0.1+manatee", tag)
	})

	t.Run("via git", func(t *testing.T) {
		tmp := setupGitWorkTree(t)
		if tmp == "" {
			return
		}

		t.Setenv("GITHUB_REF_NAME", "")

		tag, err := gitTag(t.Context(), tmp)
		require.NoError(t, err)
		require.Equal(t, "v0.0.1+walrus", tag)
	})

	t.Run("unavailable", func(t *testing.T) {
		t.Setenv("GITHUB_REF_NAME", "")

		tag, err := gitTag(t.Context(), "/dev/null")
		require.Error(t, err)
		require.Equal(t, "", tag)
	})
}

func TestCanUseStaticSchemaGen(t *testing.T) {
	// Helper to build a config with a specific SDK version.
	cfgWithSDK := func(version string) *config.Config {
		return &config.Config{
			Build: &config.Build{SDKVersion: version},
		}
	}
	noBuild := &config.Config{}

	tests := []struct {
		name     string
		cfg      *config.Config
		envVar   string // COG_STATIC_SCHEMA value
		sdkWheel string // COG_SDK_WHEEL value
		want     bool
	}{
		{
			name: "disabled by default (env not set)",
			cfg:  cfgWithSDK("0.18.0"),
			want: false,
		},
		{
			name:   "disabled when env is empty string",
			cfg:    cfgWithSDK("0.18.0"),
			envVar: "",
			want:   false,
		},
		{
			name:   "disabled when env is 0",
			cfg:    cfgWithSDK("0.18.0"),
			envVar: "0",
			want:   false,
		},
		{
			name:   "enabled when env is 1",
			cfg:    cfgWithSDK("0.18.0"),
			envVar: "1",
			want:   true,
		},
		{
			name:   "enabled when env is true",
			cfg:    cfgWithSDK("0.18.0"),
			envVar: "true",
			want:   true,
		},
		{
			name:   "enabled when env is True (mixed case)",
			cfg:    cfgWithSDK("0.18.0"),
			envVar: "True",
			want:   true,
		},
		{
			name:   "enabled when env is TRUE (upper case)",
			cfg:    cfgWithSDK("0.18.0"),
			envVar: "TRUE",
			want:   true,
		},
		{
			name:   "disabled for old SDK even when opted in",
			cfg:    cfgWithSDK("0.16.12"),
			envVar: "1",
			want:   false,
		},
		{
			name:   "disabled for pre-release old SDK",
			cfg:    cfgWithSDK("0.16.0a1"),
			envVar: "1",
			want:   false,
		},
		{
			name:   "enabled for SDK 0.17.0 when opted in",
			cfg:    cfgWithSDK("0.17.0"),
			envVar: "1",
			want:   true,
		},
		{
			name:   "enabled for SDK 0.18.0 when opted in",
			cfg:    cfgWithSDK("0.18.0"),
			envVar: "1",
			want:   true,
		},
		{
			name:   "enabled for unpinned SDK when opted in",
			cfg:    noBuild,
			envVar: "1",
			want:   true,
		},
		{
			name:     "disabled for old SDK via COG_SDK_WHEEL even when opted in",
			cfg:      noBuild,
			envVar:   "1",
			sdkWheel: "pypi:0.16.12",
			want:     false,
		},
		{
			name:     "enabled for new SDK via COG_SDK_WHEEL when opted in",
			cfg:      noBuild,
			envVar:   "1",
			sdkWheel: "pypi:0.18.0",
			want:     true,
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			t.Setenv("COG_STATIC_SCHEMA", tt.envVar)
			t.Setenv("COG_SDK_WHEEL", tt.sdkWheel)

			got := canUseStaticSchemaGen(tt.cfg)
			require.Equal(t, tt.want, got)
		})
	}
}


================================================
FILE: pkg/image/config.go
================================================
package image

import (
	"context"
	"encoding/json"
	"fmt"

	"github.com/docker/docker/api/types/image"

	"github.com/replicate/cog/pkg/config"
	"github.com/replicate/cog/pkg/docker/command"
)

func CogConfigFromManifest(ctx context.Context, manifest *image.InspectResponse) (*config.Config, error) {
	configString := manifest.Config.Labels[command.CogConfigLabelKey]
	if configString == "" {
		// Deprecated. Remove for 1.0.
		configString = manifest.Config.Labels["org.cogmodel.config"]
	}
	if configString == "" {
		// TODO[md]: find the tag/ref and return that in the error instead of the ID
		return nil, fmt.Errorf("Image %s does not appear to be a Cog model", friendlyName(manifest))
	}
	conf := new(config.Config)
	if err := json.Unmarshal([]byte(configString), conf); err != nil {
		// TODO[md]: find the tag/ref and return that in the error instead of the ID
		return nil, fmt.Errorf("Failed to parse config from %s: %w", friendlyName(manifest), err)
	}
	return conf, nil
}

func friendlyName(manifest *image.InspectResponse) string {
	// this appears to get the base image name, which we don't really want
	// name := manifest.Config.Labels["org.opencontainers.image.title"]
	// if name != "" {
	// 	return name
	// }

	if len(manifest.RepoTags) > 0 {
		return manifest.RepoTags[0]
	}

	return manifest.ID
}


================================================
FILE: pkg/image/openapi_schema.go
================================================
package image

import (
	"bytes"
	"context"
	"encoding/json"

	"github.com/replicate/cog/pkg/docker"
	"github.com/replicate/cog/pkg/docker/command"
	"github.com/replicate/cog/pkg/util/console"
)

// GenerateOpenAPISchema generates the OpenAPI schema by running the built Docker
// image with `python -m cog.command.openapi_schema`. This is the legacy path used
// for SDK versions < 0.17.0 where the schema must be generated at runtime via
// pydantic introspection rather than static analysis.
//
// sourceDir, when non-empty, is volume-mounted as /src. This is needed for
// ExcludeSource builds (cog serve/predict/train) where COPY . /src was skipped.
func GenerateOpenAPISchema(ctx context.Context, dockerClient command.Command, imageName string, enableGPU bool, sourceDir string) (map[string]any, error) {
	console.Debugf("=== image.GenerateOpenAPISchema %s", imageName)
	var stdout bytes.Buffer
	var stderr bytes.Buffer

	gpus := ""
	if enableGPU {
		gpus = "all"
	}

	runOpts := command.RunOptions{
		Image: imageName,
		Args: []string{
			"python", "-m", "cog.command.openapi_schema",
		},
		GPUs: gpus,
	}
	if sourceDir != "" {
		runOpts.Volumes = []command.Volume{{Source: sourceDir, Destination: "/src"}}
	}

	err := docker.RunWithIO(ctx, dockerClient, runOpts, nil, &stdout, &stderr)

	if enableGPU && err == docker.ErrMissingDeviceDriver {
		console.Debug(stdout.String())
		console.Debug(stderr.String())
		console.Debug("Missing device driver, re-trying without GPU")
		return GenerateOpenAPISchema(ctx, dockerClient, imageName, false, sourceDir)
	}

	if err != nil {
		console.Info(stdout.String())
		console.Info(stderr.String())
		return nil, err
	}

	var schema map[string]any
	if err := json.Unmarshal(stdout.Bytes(), &schema); err != nil {
		console.Info(stdout.String())
		console.Info(stderr.String())
		return nil, err
	}

	return schema, nil
}


================================================
FILE: pkg/image/pip_freeze.go
================================================
package image

import (
	"bytes"
	"context"

	"github.com/replicate/cog/pkg/docker"
	"github.com/replicate/cog/pkg/docker/command"
	"github.com/replicate/cog/pkg/util/console"
)

// GeneratePipFreeze by running a pip freeze on the image.
// This will be run as part of the build process then added as a label to the image.
func GeneratePipFreeze(ctx context.Context, dockerClient command.Command, imageName string) (string, error) {
	var stdout bytes.Buffer
	var stderr bytes.Buffer

	args := []string{"python", "-m", "pip", "freeze"}
	err := docker.RunWithIO(ctx, dockerClient, command.RunOptions{
		Image: imageName,
		Args:  args,
	}, nil, &stdout, &stderr)

	if err != nil {
		console.Info(stdout.String())
		console.Info(stderr.String())
		return "", err
	}

	return stdout.String(), nil
}


================================================
FILE: pkg/model/artifact.go
================================================
package model

import v1 "github.com/google/go-containerregistry/pkg/v1"

// ArtifactType identifies the kind of artifact.
type ArtifactType int

const (
	// ArtifactTypeImage is a container image artifact.
	ArtifactTypeImage ArtifactType = iota + 1
	// ArtifactTypeWeight is a model weight artifact.
	ArtifactTypeWeight
)

// String returns the human-readable name of the artifact type.
func (t ArtifactType) String() string {
	switch t {
	case ArtifactTypeImage:
		return "image"
	case ArtifactTypeWeight:
		return "weight"
	default:
		return "unknown"
	}
}

// ArtifactSpec declares what artifact will be produced.
// It contains all inputs needed to build that artifact.
// Specs are derived from analyzing the Source (cog.yaml + project directory).
type ArtifactSpec interface {
	Type() ArtifactType
	Name() string
}

// Artifact is the immutable result of building a spec.
// It contains the OCI descriptor and enough information for a pusher to upload it.
type Artifact interface {
	Type() ArtifactType
	Name() string
	Descriptor() v1.Descriptor
}


================================================
FILE: pkg/model/artifact_image.go
================================================
package model

import (
	"encoding/json"

	"github.com/getkin/kin-openapi/openapi3"
	v1 "github.com/google/go-containerregistry/pkg/v1"

	"github.com/replicate/cog/pkg/config"
	"github.com/replicate/cog/pkg/global"
)

// ImageSource indicates where an image was loaded from.
type ImageSource string

const (
	ImageSourceLocal  ImageSource = "local"  // Docker daemon
	ImageSourceRemote ImageSource = "remote" // Registry
	ImageSourceBuild  ImageSource = "build"  // Just built
)

// Platform describes the OS and architecture of an image.
type Platform struct {
	OS           string
	Architecture string
	Variant      string
}

// Label keys for Cog-specific metadata stored in image labels.
var (
	LabelConfig          = global.LabelNamespace + "config"
	LabelVersion         = global.LabelNamespace + "version"
	LabelOpenAPISchema   = global.LabelNamespace + "openapi_schema"
	LabelWeightsManifest = global.LabelNamespace + "r8_weights_manifest"
)

// =============================================================================
// ImageSpec
// =============================================================================

// ImageSpecOption configures optional fields on ImageSpec.
type ImageSpecOption func(*ImageSpec)

// WithImageSecrets sets build-time secrets for the image build.
func WithImageSecrets(secrets []string) ImageSpecOption {
	return func(s *ImageSpec) {
		s.Secrets = secrets
	}
}

// WithImageNoCache disables build cache for the image build.
func WithImageNoCache(noCache bool) ImageSpecOption {
	return func(s *ImageSpec) {
		s.NoCache = noCache
	}
}

// ImageSpec declares an image to be built.
// It implements ArtifactSpec.
//
// TODO: ImageBuilder currently reads build options from BuildOptions (passed at
// construction) rather than from ImageSpec fields. When the build pipeline fully
// migrates to specs, ImageName/Secrets/NoCache should be the source of truth.
type ImageSpec struct {
	name      string
	ImageName string
	Secrets   []string
	NoCache   bool
}

// NewImageSpec creates an ImageSpec with the given name and image name.
// Optional configuration can be provided via ImageSpecOption functions.
func NewImageSpec(name, imageName string, opts ...ImageSpecOption) *ImageSpec {
	s := &ImageSpec{
		name:      name,
		ImageName: imageName,
	}
	for _, opt := range opts {
		opt(s)
	}
	return s
}

// Type returns ArtifactTypeImage.
func (s *ImageSpec) Type() ArtifactType { return ArtifactTypeImage }

// Name returns the spec's logical name.
func (s *ImageSpec) Name() string { return s.name }

// =============================================================================
// ImageArtifact
// =============================================================================

// ImageArtifact represents an OCI container image.
// It serves as both the build artifact (in Model.Artifacts) and the general-purpose
// image metadata type throughout the codebase.
// It implements the Artifact interface.
type ImageArtifact struct {
	// Artifact fields (set when used as a build artifact)
	name       string
	descriptor v1.Descriptor

	// Image metadata
	Reference string            // Full image reference (e.g., "r8.im/user/model:latest")
	Digest    string            // Content-addressable digest (sha256:...)
	Labels    map[string]string // Docker/OCI image labels
	Platform  *Platform         // OS/architecture
	Source    ImageSource       // Where loaded from (local/remote/build)
}

// NewImageArtifact creates an ImageArtifact from a build result.
func NewImageArtifact(name string, desc v1.Descriptor, reference string) *ImageArtifact {
	return &ImageArtifact{
		name:       name,
		descriptor: desc,
		Reference:  reference,
	}
}

// Type returns ArtifactTypeImage.
func (a *ImageArtifact) Type() ArtifactType { return ArtifactTypeImage }

// Name returns the artifact's logical name.
func (a *ImageArtifact) Name() string { return a.name }

// Descriptor returns the OCI descriptor for this image.
func (a *ImageArtifact) Descriptor() v1.Descriptor { return a.descriptor }

// =============================================================================
// Image metadata methods (formerly on *Image)
// =============================================================================

// IsCogModel returns true if this image has Cog labels indicating it's a Cog model.
func (a *ImageArtifact) IsCogModel() bool {
	if a.Labels == nil {
		return false
	}
	_, ok := a.Labels[LabelConfig]
	return ok
}

// CogVersion returns the Cog version that built this image, or empty string if not set.
func (a *ImageArtifact) CogVersion() string {
	if a.Labels == nil {
		return ""
	}
	return a.Labels[LabelVersion]
}

// Config returns the raw cog.yaml config stored in image labels, or empty string if not set.
func (a *ImageArtifact) Config() string {
	if a.Labels == nil {
		return ""
	}
	return a.Labels[LabelConfig]
}

// OpenAPISchema returns the OpenAPI schema stored in image labels, or empty string if not set.
func (a *ImageArtifact) OpenAPISchema() string {
	if a.Labels == nil {
		return ""
	}
	return a.Labels[LabelOpenAPISchema]
}

// ParsedConfig returns the parsed cog.yaml config from image labels.
// Returns nil without error if no config label is present.
// Returns error if the label contains invalid JSON.
func (a *ImageArtifact) ParsedConfig() (*config.Config, error) {
	raw := a.Config()
	if raw == "" {
		return nil, nil
	}

	cfg := new(config.Config)
	if err := json.Unmarshal([]byte(raw), cfg); err != nil {
		return nil, err
	}
	return cfg, nil
}

// ParsedOpenAPISchema returns the parsed OpenAPI schema from image labels.
// Returns nil without error if no schema label is present.
// Returns error if the label contains invalid JSON.
func (a *ImageArtifact) ParsedOpenAPISchema() (*openapi3.T, error) {
	raw := a.OpenAPISchema()
	if raw == "" {
		return nil, nil
	}

	loader := openapi3.NewLoader()
	schema, err := loader.LoadFromData([]byte(raw))
	if err != nil {
		return nil, err
	}
	return schema, nil
}

// ToModel converts the ImageArtifact to a Model by parsing its labels.
// Returns error if the image is not a valid Cog model or if labels contain invalid JSON.
func (a *ImageArtifact) ToModel() (*Model, error) {
	if !a.IsCogModel() {
		return nil, ErrNotCogModel
	}

	cfg, err := a.ParsedConfig()
	if err != nil {
		return nil, err
	}

	schema, err := a.ParsedOpenAPISchema()
	if err != nil {
		return nil, err
	}

	return &Model{
		Image:      a,
		Config:     cfg,
		Schema:     schema,
		CogVersion: a.CogVersion(),
	}, nil
}


================================================
FILE: pkg/model/artifact_image_test.go
================================================
package model

import (
	"testing"

	v1 "github.com/google/go-containerregistry/pkg/v1"
	"github.com/stretchr/testify/require"
)

func TestImageSpec_ImplementsArtifactSpec(t *testing.T) {
	spec := NewImageSpec("model", "r8.im/user/model:latest")

	var _ ArtifactSpec = spec // compile-time interface check

	require.Equal(t, ArtifactTypeImage, spec.Type())
	require.Equal(t, "model", spec.Name())
}

func TestImageSpec_Fields(t *testing.T) {
	spec := NewImageSpec("model", "r8.im/user/model:latest",
		WithImageSecrets([]string{"secret1", "secret2"}),
		WithImageNoCache(true),
	)

	require.Equal(t, "r8.im/user/model:latest", spec.ImageName)
	require.Equal(t, []string{"secret1", "secret2"}, spec.Secrets)
	require.True(t, spec.NoCache)
}

func TestImageSpec_DefaultFields(t *testing.T) {
	spec := NewImageSpec("model", "myimage:latest")

	require.Equal(t, "myimage:latest", spec.ImageName)
	require.Nil(t, spec.Secrets)
	require.False(t, spec.NoCache)
}

func TestImageArtifact_ImplementsArtifact(t *testing.T) {
	desc := v1.Descriptor{
		Digest: v1.Hash{Algorithm: "sha256", Hex: "abc123"},
		Size:   1024,
	}
	artifact := NewImageArtifact("model", desc, "r8.im/user/model@sha256:abc123")

	var _ Artifact = artifact // compile-time interface check

	require.Equal(t, ArtifactTypeImage, artifact.Type())
	require.Equal(t, "model", artifact.Name())
	require.Equal(t, desc, artifact.Descriptor())
}

func TestImageArtifact_Reference(t *testing.T) {
	desc := v1.Descriptor{
		Digest: v1.Hash{Algorithm: "sha256", Hex: "abc123"},
		Size:   1024,
	}
	artifact := NewImageArtifact("model", desc, "r8.im/user/model@sha256:abc123")

	require.Equal(t, "r8.im/user/model@sha256:abc123", artifact.Reference)
}


================================================
FILE: pkg/model/artifact_test.go
================================================
package model

import (
	"testing"

	"github.com/stretchr/testify/require"
)

func TestArtifactType_String(t *testing.T) {
	tests := []struct {
		name   string
		at     ArtifactType
		expect string
	}{
		{name: "image type", at: ArtifactTypeImage, expect: "image"},
		{name: "weight type", at: ArtifactTypeWeight, expect: "weight"},
		{name: "zero value", at: ArtifactType(0), expect: "unknown"},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			require.Equal(t, tt.expect, tt.at.String())
		})
	}
}

func TestArtifactType_Values(t *testing.T) {
	// Ensure types are distinct
	require.NotEqual(t, ArtifactTypeImage, ArtifactTypeWeight)
}


================================================
FILE: pkg/model/artifact_weight.go
================================================
package model

import (
	"time"

	v1 "github.com/google/go-containerregistry/pkg/v1"
)

// Media types for weight artifacts (OCI 1.1 conventions).
const (
	// MediaTypeWeightArtifact is the artifactType for weight manifests.
	MediaTypeWeightArtifact = "application/vnd.cog.weight.v1"
	// MediaTypeWeightConfig is the media type for weight config blobs.
	MediaTypeWeightConfig = "application/vnd.cog.weight.config.v1+json"
	// MediaTypeWeightLayer is the media type for uncompressed weight layers.
	MediaTypeWeightLayer = "application/vnd.cog.weight.layer.v1"
	// MediaTypeWeightLayerGzip is the media type for gzip-compressed weight layers.
	MediaTypeWeightLayerGzip = "application/vnd.cog.weight.layer.v1+gzip"
	// MediaTypeWeightLayerZstd is the media type for zstd-compressed weight layers (future).
	MediaTypeWeightLayerZstd = "application/vnd.cog.weight.layer.v1+zstd"
)

// Annotation keys for weight file layers in OCI manifests.
const (
	AnnotationWeightName             = "vnd.cog.weight.name"
	AnnotationWeightDest             = "vnd.cog.weight.dest"
	AnnotationWeightDigestOriginal   = "vnd.cog.weight.digest.original"
	AnnotationWeightSizeUncompressed = "vnd.cog.weight.size.uncompressed"
)

// WeightSpec declares a weight artifact to be built.
// It implements ArtifactSpec.
type WeightSpec struct {
	name string
	// Source is the local file path to the weight file.
	Source string
	// Target is the container mount path for this weight.
	Target string
}

// NewWeightSpec creates a WeightSpec with the given name, source path, and target mount path.
func NewWeightSpec(name, source, target string) *WeightSpec {
	return &WeightSpec{
		name:   name,
		Source: source,
		Target: target,
	}
}

// Type returns ArtifactTypeWeight.
func (s *WeightSpec) Type() ArtifactType { return ArtifactTypeWeight }

// Name returns the spec's logical name.
func (s *WeightSpec) Name() string { return s.name }

// WeightArtifact is a built weight artifact ready to push as an OCI artifact.
// It implements Artifact.
type WeightArtifact struct {
	name       string
	descriptor v1.Descriptor

	// FilePath is the local file path to the weight data (for pushing layers).
	FilePath string
	// Target is the container mount path for this weight.
	Target string
	// Config is the weight metadata for the config blob.
	Config WeightConfig
}

// NewWeightArtifact creates a WeightArtifact from a build result.
func NewWeightArtifact(name string, desc v1.Descriptor, filePath, target string, cfg WeightConfig) *WeightArtifact {
	return &WeightArtifact{
		name:       name,
		descriptor: desc,
		FilePath:   filePath,
		Target:     target,
		Config:     cfg,
	}
}

// Type returns ArtifactTypeWeight.
func (a *WeightArtifact) Type() ArtifactType { return ArtifactTypeWeight }

// Name returns the artifact's logical name.
func (a *WeightArtifact) Name() string { return a.name }

// Descriptor returns the OCI descriptor for this weight artifact.
func (a *WeightArtifact) Descriptor() v1.Descriptor { return a.descriptor }

// WeightConfig contains metadata about a weight artifact.
// This is serialized as the config blob in the OCI manifest.
// The schema is versioned via SchemaVersion to allow evolution.
type WeightConfig struct {
	SchemaVersion string    `json:"schemaVersion"`
	CogVersion    string    `json:"cogVersion"`
	Name          string    `json:"name"`
	Target        string    `json:"target"`
	Created       time.Time `json:"created"` // RFC 3339 format when serialized to JSON
}


================================================
FILE: pkg/model/artifact_weight_test.go
================================================
package model

import (
	"encoding/json"
	"testing"
	"time"

	v1 "github.com/google/go-containerregistry/pkg/v1"
	"github.com/stretchr/testify/require"
)

func TestWeightSpec_ImplementsArtifactSpec(t *testing.T) {
	spec := NewWeightSpec("my-model-weights", "/data/weights.bin", "/weights/model.bin")

	var _ ArtifactSpec = spec // compile-time interface check

	require.Equal(t, ArtifactTypeWeight, spec.Type())
	require.Equal(t, "my-model-weights", spec.Name())
}

func TestWeightSpec_Fields(t *testing.T) {
	spec := NewWeightSpec("llama-7b", "/data/llama-7b.safetensors", "/weights/llama-7b.safetensors")

	require.Equal(t, "/data/llama-7b.safetensors", spec.Source)
	require.Equal(t, "/weights/llama-7b.safetensors", spec.Target)
}

func TestWeightArtifact_ImplementsArtifact(t *testing.T) {
	desc := v1.Descriptor{
		Digest: v1.Hash{Algorithm: "sha256", Hex: "def456"},
		Size:   4096,
	}
	cfg := WeightConfig{
		SchemaVersion: "1.0",
		CogVersion:    "0.15.0",
		Name:          "my-weights",
		Target:        "/weights/model.bin",
		Created:       time.Date(2026, 2, 5, 12, 0, 0, 0, time.UTC),
	}
	artifact := NewWeightArtifact("my-weights", desc, "/data/weights.bin", "/weights/model.bin", cfg)

	var _ Artifact = artifact // compile-time interface check

	require.Equal(t, ArtifactTypeWeight, artifact.Type())
	require.Equal(t, "my-weights", artifact.Name())
	require.Equal(t, desc, artifact.Descriptor())
}

func TestWeightArtifact_Fields(t *testing.T) {
	desc := v1.Descriptor{
		Digest: v1.Hash{Algorithm: "sha256", Hex: "def456"},
		Size:   4096,
	}
	cfg := WeightConfig{
		SchemaVersion: "1.0",
		CogVersion:    "0.15.0",
		Name:          "my-weights",
		Target:        "/weights/model.bin",
		Created:       time.Date(2026, 2, 5, 12, 0, 0, 0, time.UTC),
	}
	artifact := NewWeightArtifact("my-weights", desc, "/data/weights.bin", "/weights/model.bin", cfg)

	require.Equal(t, "/data/weights.bin", artifact.FilePath)
	require.Equal(t, "/weights/model.bin", artifact.Target)
	require.Equal(t, cfg, artifact.Config)
}

func TestWeightConfig_JSONRoundTrip(t *testing.T) {
	original := WeightConfig{
		SchemaVersion: "1.0",
		CogVersion:    "0.15.0",
		Name:          "llama-7b",
		Target:        "/weights/llama-7b",
		Created:       time.Date(2026, 2, 5, 12, 0, 0, 0, time.UTC),
	}

	data, err := json.Marshal(original)
	require.NoError(t, err)

	// Verify JSON structure
	var raw map[string]any
	err = json.Unmarshal(data, &raw)
	require.NoError(t, err)
	require.Equal(t, "1.0", raw["schemaVersion"])
	require.Equal(t, "0.15.0", raw["cogVersion"])
	require.Equal(t, "llama-7b", raw["name"])
	require.Equal(t, "/weights/llama-7b", raw["target"])

	// Round-trip
	var decoded WeightConfig
	err = json.Unmarshal(data, &decoded)
	require.NoError(t, err)
	require.Equal(t, original.SchemaVersion, decoded.SchemaVersion)
	require.Equal(t, original.CogVersion, decoded.CogVersion)
	require.Equal(t, original.Name, decoded.Name)
	require.Equal(t, original.Target, decoded.Target)
	require.True(t, original.Created.Equal(decoded.Created))
}

func TestWeightMediaTypeConstants(t *testing.T) {
	// Verify media type constants have expected values
	require.Equal(t, "application/vnd.cog.weight.v1", MediaTypeWeightArtifact)
	require.Equal(t, "application/vnd.cog.weight.config.v1+json", MediaTypeWeightConfig)
	require.Equal(t, "application/vnd.cog.weight.layer.v1", MediaTypeWeightLayer)
	require.Equal(t, "application/vnd.cog.weight.layer.v1+gzip", MediaTypeWeightLayerGzip)
}


================================================
FILE: pkg/model/builder.go
================================================
package model

import "context"

// Builder builds an artifact from a spec.
// Each builder handles one artifact type (image, weight, etc.).
type Builder interface {
	Build(ctx context.Context, spec ArtifactSpec) (Artifact, error)
}


================================================
FILE: pkg/model/builder_test.go
================================================
package model

import (
	"context"
	"testing"

	v1 "github.com/google/go-containerregistry/pkg/v1"
	"github.com/stretchr/testify/require"
)

// mockBuilder is a test double that implements the Builder interface.
type mockBuilder struct {
	buildFn func(ctx context.Context, spec ArtifactSpec) (Artifact, error)
}

func (m *mockBuilder) Build(ctx context.Context, spec ArtifactSpec) (Artifact, error) {
	return m.buildFn(ctx, spec)
}

func TestBuilderInterface_Satisfiable(t *testing.T) {
	// Compile-time check: mockBuilder satisfies Builder.
	var _ Builder = &mockBuilder{}

	// Runtime check: a mock builder can be called and returns an artifact.
	mb := &mockBuilder{
		buildFn: func(_ context.Context, spec ArtifactSpec) (Artifact, error) {
			return NewImageArtifact(spec.Name(), v1.Descriptor{}, "test-ref"), nil
		},
	}

	artifact, err := mb.Build(context.Background(), NewImageSpec("test", "test-image"))
	require.NoError(t, err)
	require.Equal(t, "test", artifact.Name())
}


================================================
FILE: pkg/model/errors.go
================================================
package model

import "errors"

// Sentinel errors for Resolver operations.
var (
	// ErrNotCogModel indicates the image exists but is not a valid Cog model.
	// This occurs when the image lacks the required run.cog.config label.
	ErrNotCogModel = errors.New("image is not a Cog model")

	// ErrNotFound indicates the image was not found in the requested location(s).
	ErrNotFound = errors.New("image not found")
)


================================================
FILE: pkg/model/errors_test.go
================================================
package model

import (
	"errors"
	"fmt"
	"testing"

	"github.com/stretchr/testify/require"
)

func TestSentinelErrors(t *testing.T) {
	// Test that sentinel errors can be wrapped and unwrapped
	t.Run("ErrNotCogModel can be wrapped and detected", func(t *testing.T) {
		wrapped := fmt.Errorf("failed to inspect image: %w", ErrNotCogModel)
		require.True(t, errors.Is(wrapped, ErrNotCogModel))
	})

	t.Run("ErrNotFound can be wrapped and detected", func(t *testing.T) {
		wrapped := fmt.Errorf("image my-image:latest: %w", ErrNotFound)
		require.True(t, errors.Is(wrapped, ErrNotFound))
	})

	t.Run("errors are distinct", func(t *testing.T) {
		require.False(t, errors.Is(ErrNotCogModel, ErrNotFound))
		require.False(t, errors.Is(ErrNotFound, ErrNotCogModel))
	})
}


================================================
FILE: pkg/model/factory.go
================================================
package model

import (
	"context"

	"github.com/replicate/cog/pkg/docker/command"
	"github.com/replicate/cog/pkg/image"
	"github.com/replicate/cog/pkg/registry"
)

// Factory is the build backend interface.
// Different implementations handle different build strategies.
type Factory interface {
	// Build creates a Docker image from source and returns ImageArtifact metadata.
	// For dev mode (cog serve), set ExcludeSource=true in BuildOptions to skip
	// COPY . /src — the source directory is volume-mounted at runtime instead.
	Build(ctx context.Context, src *Source, opts BuildOptions) (*ImageArtifact, error)

	// Name returns the factory name for logging/debugging.
	Name() string
}

// DockerfileFactory wraps existing Dockerfile-based build.
type DockerfileFactory struct {
	docker   command.Command
	registry registry.Client
}

// NewDockerfileFactory creates a Factory that uses the existing Dockerfile-based build.
func NewDockerfileFactory(docker command.Command, registry registry.Client) Factory {
	return &DockerfileFactory{docker: docker, registry: registry}
}

// Name returns the factory name.
func (f *DockerfileFactory) Name() string {
	return "dockerfile"
}

// Build delegates to the existing image.Build() function.
func (f *DockerfileFactory) Build(ctx context.Context, src *Source, opts BuildOptions) (*ImageArtifact, error) {
	imageID, err := image.Build(
		ctx,
		src.Config,
		src.ProjectDir,
		opts.ImageName,
		src.ConfigFilename,
		opts.Secrets,
		opts.NoCache,
		opts.SeparateWeights,
		opts.UseCudaBaseImage,
		opts.ProgressOutput,
		opts.SchemaFile,
		opts.DockerfileFile,
		opts.UseCogBaseImage,
		opts.Strip,
		opts.Precompile,
		opts.ExcludeSource,
		opts.SkipSchemaValidation,
		opts.SkipLabels,
		opts.Annotations,
		f.docker,
		f.registry,
	)
	if err != nil {
		return nil, err
	}

	return &ImageArtifact{
		Reference: opts.ImageName,
		Digest:    imageID,
		Source:    ImageSourceBuild,
	}, nil
}

// defaultFactory returns a Factory based on environment variables.
// It checks COG_BUILDER and COGPACK to select the appropriate backend.
//
// TODO: When FrontendFactory is implemented, check COG_BUILDER env var.
// TODO: When CogpacksFactory is implemented, check COGPACK env var.
func defaultFactory(docker command.Command, registry registry.Client) Factory {
	return NewDockerfileFactory(docker, registry)
}


================================================
FILE: pkg/model/factory_test.go
================================================
package model

import (
	"testing"

	"github.com/stretchr/testify/require"

	"github.com/replicate/cog/pkg/docker/dockertest"
	"github.com/replicate/cog/pkg/registry/registrytest"
)

func TestDockerfileFactory_Name(t *testing.T) {
	docker := dockertest.NewMockCommand()
	registry := registrytest.NewMockRegistryClient()

	factory := NewDockerfileFactory(docker, registry)

	require.Equal(t, "dockerfile", factory.Name())
}

func TestDockerfileFactory_ImplementsInterface(t *testing.T) {
	docker := dockertest.NewMockCommand()
	registry := registrytest.NewMockRegistryClient()

	// Verify that DockerfileFactory implements the Factory interface
	var _ = NewDockerfileFactory(docker, registry)
}

func TestDefaultFactory_ReturnsDockerfileFactory(t *testing.T) {
	docker := dockertest.NewMockCommand()
	registry := registrytest.NewMockRegistryClient()

	factory := defaultFactory(docker, registry)

	require.Equal(t, "dockerfile", factory.Name())
}


================================================
FILE: pkg/model/format.go
================================================
package model

import "os"

// TODO(md): OCIIndexEnabled is a temporary gate for the OCI Image Index push path.
// When COG_OCI_INDEX=1, builds produce weight artifacts and pushes create an OCI
// Image Index instead of a single image manifest. Remove this gate (and always use
// the index path) once we've validated index compatibility with all registries.
func OCIIndexEnabled() bool {
	return os.Getenv("COG_OCI_INDEX") == "1"
}


================================================
FILE: pkg/model/format_test.go
================================================
package model

import (
	"testing"

	"github.com/stretchr/testify/require"
)

func TestOCIIndexEnabled_Default(t *testing.T) {
	t.Setenv("COG_OCI_INDEX", "")
	require.False(t, OCIIndexEnabled())
}

func TestOCIIndexEnabled_Enabled(t *testing.T) {
	t.Setenv("COG_OCI_INDEX", "1")
	require.True(t, OCIIndexEnabled())
}

func TestOCIIndexEnabled_OtherValue(t *testing.T) {
	t.Setenv("COG_OCI_INDEX", "0")
	require.False(t, OCIIndexEnabled())
}


================================================
FILE: pkg/model/hash.go
================================================
package model

import (
	"crypto/sha256"
	"encoding/hex"
	"io"
	"os"
)

// hashFile computes SHA256 digest and size of a file by streaming.
func hashFile(path string) (digest string, size int64, err error) {
	f, err := os.Open(path)
	if err != nil {
		return "", 0, err
	}
	defer f.Close()

	h := sha256.New()
	size, err = io.Copy(h, f)
	if err != nil {
		return "", 0, err
	}

	digest = "sha256:" + hex.EncodeToString(h.Sum(nil))
	return digest, size, nil
}


================================================
FILE: pkg/model/image_builder.go
================================================
package model

import (
	"context"
	"fmt"

	v1 "github.com/google/go-containerregistry/pkg/v1"

	"github.com/replicate/cog/pkg/docker/command"
)

// ImageBuilder builds an ImageArtifact from an ImageSpec.
// It delegates to a Factory for the docker build, inspects the result
// to populate labels and the canonical digest, and returns a fully
// populated ImageArtifact.
type ImageBuilder struct {
	factory Factory
	docker  command.Command
	source  *Source
	opts    BuildOptions
}

// NewImageBuilder creates an ImageBuilder.
func NewImageBuilder(factory Factory, docker command.Command, source *Source, opts BuildOptions) *ImageBuilder {
	return &ImageBuilder{
		factory: factory,
		docker:  docker,
		source:  source,
		opts:    opts,
	}
}

// Build builds an ImageArtifact from an ImageSpec.
// It delegates to the Factory for the docker build, inspects the result
// to populate labels and the canonical digest, and returns a fully
// populated ImageArtifact.
func (b *ImageBuilder) Build(ctx context.Context, spec ArtifactSpec) (Artifact, error) {
	is, ok := spec.(*ImageSpec)
	if !ok {
		return nil, fmt.Errorf("image builder: expected *ImageSpec, got %T", spec)
	}

	// Build the image via the factory (returns partially populated ImageArtifact)
	img, err := b.factory.Build(ctx, b.source, b.opts)
	if err != nil {
		return nil, fmt.Errorf("image build failed: %w", err)
	}

	// Inspect the built image to get labels and canonical digest.
	// Prefer digest (ID) for stable lookups, fall back to reference.
	inspectRef := img.Digest
	if inspectRef == "" {
		inspectRef = img.Reference
	}

	resp, err := b.docker.Inspect(ctx, inspectRef)
	if err != nil {
		return nil, fmt.Errorf("inspect built image: %w", err)
	}

	// Populate the artifact with inspect results
	img.name = is.Name()
	img.Labels = resp.Config.Labels
	img.Digest = resp.ID
	img.Source = ImageSourceBuild

	digest, err := v1.NewHash(resp.ID)
	if err != nil {
		return nil, fmt.Errorf("parse image digest %q: %w", resp.ID, err)
	}
	img.descriptor = v1.Descriptor{Digest: digest}

	return img, nil
}


================================================
FILE: pkg/model/image_builder_test.go
================================================
package model

import (
	"context"
	"errors"
	"testing"

	"github.com/docker/docker/api/types/image"
	dockerspec "github.com/moby/docker-image-spec/specs-go/v1"
	ocispec "github.com/opencontainers/image-spec/specs-go/v1"
	"github.com/stretchr/testify/require"

	"github.com/replicate/cog/pkg/config"
)

func TestImageBuilder_HappyPath(t *testing.T) {
	// Setup mock factory that returns a built image
	factory := &mockFactory{
		buildFunc: func(_ context.Context, _ *Source, opts BuildOptions) (*ImageArtifact, error) {
			return &ImageArtifact{
				Reference: opts.ImageName,
				Digest:    "sha256:a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2",
				Source:    ImageSourceBuild,
			}, nil
		},
	}

	// Setup mock docker that returns inspect results with labels
	docker := &mockDocker{
		inspectFunc: func(_ context.Context, ref string) (*image.InspectResponse, error) {
			return &image.InspectResponse{
				ID: "sha256:a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2",
				Config: &dockerspec.DockerOCIImageConfig{
					ImageConfig: ocispec.ImageConfig{
						Labels: map[string]string{

							"org.cogmodel.cog_version": "0.15.0",
						},
					},
				},
			}, nil
		},
	}

	src := NewSourceFromConfig(&config.Config{
		Image: "my-model:latest",
	}, "/project")

	ib := NewImageBuilder(factory, docker, src, BuildOptions{
		ImageName: "my-model:latest",
	})

	spec := NewImageSpec("model", "my-model:latest")
	artifact, err := ib.Build(context.Background(), spec)
	require.NoError(t, err)
	require.NotNil(t, artifact)

	// Type assertion
	ia, ok := artifact.(*ImageArtifact)
	require.True(t, ok, "expected *ImageArtifact, got %T", artifact)

	// Check artifact interface
	require.Equal(t, ArtifactTypeImage, ia.Type())
	require.Equal(t, "model", ia.Name())

	// Check descriptor has the digest
	desc := ia.Descriptor()
	require.Equal(t, "sha256:a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", desc.Digest.String())

	// Check image-specific fields
	require.Equal(t, "my-model:latest", ia.Reference)
}

func TestImageBuilder_ErrorWrongSpecType(t *testing.T) {
	src := NewSourceFromConfig(&config.Config{}, "/project")
	ib := NewImageBuilder(&mockFactory{}, &mockDocker{}, src, BuildOptions{})

	// Pass a WeightSpec instead of ImageSpec
	weightSpec := NewWeightSpec("model", "model.bin", "/weights/model.bin")
	_, err := ib.Build(context.Background(), weightSpec)
	require.Error(t, err)
	require.Contains(t, err.Error(), "expected *ImageSpec")
}

func TestImageBuilder_ErrorFactoryBuildFails(t *testing.T) {
	factory := &mockFactory{
		buildFunc: func(_ context.Context, _ *Source, _ BuildOptions) (*ImageArtifact, error) {
			return nil, errors.New("docker build failed: out of disk")
		},
	}

	src := NewSourceFromConfig(&config.Config{}, "/project")
	ib := NewImageBuilder(factory, &mockDocker{}, src, BuildOptions{})

	spec := NewImageSpec("model", "test-image")
	_, err := ib.Build(context.Background(), spec)
	require.Error(t, err)
	require.Contains(t, err.Error(), "image build failed")
	require.Contains(t, err.Error(), "out of disk")
}

func TestImageBuilder_ErrorInspectFails(t *testing.T) {
	factory := &mockFactory{
		buildFunc: func(_ context.Context, _ *Source, opts BuildOptions) (*ImageArtifact, error) {
			return &ImageArtifact{
				Reference: opts.ImageName,
				Digest:    "sha256:a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2",
				Source:    ImageSourceBuild,
			}, nil
		},
	}

	docker := &mockDocker{
		inspectFunc: func(_ context.Context, _ string) (*image.InspectResponse, error) {
			return nil, errors.New("image not found")
		},
	}

	src := NewSourceFromConfig(&config.Config{}, "/project")
	ib := NewImageBuilder(factory, docker, src, BuildOptions{})

	spec := NewImageSpec("model", "test-image")
	_, err := ib.Build(context.Background(), spec)
	require.Error(t, err)
	require.Contains(t, err.Error(), "inspect built image")
}

func TestImageBuilder_ImplementsBuilderInterface(t *testing.T) {
	src := NewSourceFromConfig(&config.Config{}, "/project")
	// Compile-time check
	var _ Builder = NewImageBuilder(&mockFactory{}, &mockDocker{}, src, BuildOptions{})
}


================================================
FILE: pkg/model/image_pusher.go
================================================
package model

import (
	"bytes"
	"context"
	"errors"
	"fmt"
	"io"
	"net/http"
	"os"

	"github.com/google/go-containerregistry/pkg/name"
	v1 "github.com/google/go-containerregistry/pkg/v1"
	"github.com/google/go-containerregistry/pkg/v1/remote/transport"
	"github.com/google/go-containerregistry/pkg/v1/tarball"
	"github.com/google/go-containerregistry/pkg/v1/types"
	"golang.org/x/sync/errgroup"

	"github.com/replicate/cog/pkg/docker/command"
	"github.com/replicate/cog/pkg/registry"
	"github.com/replicate/cog/pkg/util/console"
)

// ImagePusher pushes container images to a registry.
//
// It first attempts an OCI chunked push (export from Docker -> tarball ->
// push layers via registry client), then falls back to Docker's native push
// on any non-fatal error. This bypasses size limits on Docker's monolithic
// push path while maintaining backwards compatibility.
type ImagePusher struct {
	docker   command.Command
	registry registry.Client
}

// newImagePusher creates a new ImagePusher.
func newImagePusher(docker command.Command, reg registry.Client) *ImagePusher {
	return &ImagePusher{
		docker:   docker,
		registry: reg,
	}
}

// imagePushOptions holds the resolved configuration for an image push.
type imagePushOptions struct {
	progressFn func(PushProgress)
	onFallback func()
}

// ImagePushOption is a functional option for configuring ImagePusher.Push.
type ImagePushOption func(*imagePushOptions)

// WithProgressFn sets a callback for reporting per-layer upload progress.
func WithProgressFn(fn func(PushProgress)) ImagePushOption {
	return func(o *imagePushOptions) {
		o.progressFn = fn
	}
}

// WithOnFallback sets a callback invoked when OCI push fails and the push is
// about to fall back to Docker push. This allows the caller to clean up any
// OCI-specific progress display before Docker push starts its own output.
func WithOnFallback(fn func()) ImagePushOption {
	return func(o *imagePushOptions) {
		o.onFallback = fn
	}
}

// Push pushes a container image to the registry.
//
// Tries the OCI chunked push path first (if enabled and registry client is
// available), then falls back to Docker push on any non-fatal error.
// The artifact must have a valid Reference.
func (p *ImagePusher) Push(ctx context.Context, artifact *ImageArtifact, opts ...ImagePushOption) error {
	if artifact == nil {
		return fmt.Errorf("image artifact is nil")
	}
	if artifact.Reference == "" {
		return fmt.Errorf("image artifact has no reference")
	}

	var opt imagePushOptions
	for _, apply := range opts {
		apply(&opt)
	}

	imageRef := artifact.Reference

	if p.canOCIPush() {
		err := p.ociPush(ctx, imageRef, opt)
		if err == nil {
			return nil
		}
		if !shouldFallbackToDocker(err) {
			return fmt.Errorf("OCI chunked push: %w", err)
		}
		if opt.onFallback != nil {
			opt.onFallback()
		}
		console.Warnf("OCI chunked push failed, falling back to Docker push: %v", sanitizeError(err))
	}

	return p.docker.Push(ctx, imageRef)
}

// canOCIPush returns true if OCI chunked push is enabled.
func (p *ImagePusher) canOCIPush() bool {
	return os.Getenv("COG_PUSH_OCI") == "1"
}

// ociPush exports the image from Docker daemon as a tar, then pushes all layers,
// config, and manifest to the registry using chunked uploads.
func (p *ImagePusher) ociPush(ctx context.Context, imageRef string, opt imagePushOptions) error {
	console.Debugf("Exporting image %s from Docker daemon...", imageRef)

	ref, err := name.ParseReference(imageRef, name.Insecure)
	if err != nil {
		return fmt.Errorf("parse image reference %q: %w", imageRef, err)
	}

	// Get the Docker tar stream directly from the docker command
	rc, err := p.docker.ImageSave(ctx, imageRef)
	if err != nil {
		return fmt.Errorf("export image from daemon: %w", err)
	}
	defer rc.Close() //nolint:errcheck

	// Write the tar to a temp file so we can seek on it
	tmpTar, err := os.CreateTemp("", "cog-image-*.tar")
	if err != nil {
		return fmt.Errorf("create temp tar file: %w", err)
	}
	defer func() { _ = os.Remove(tmpTar.Name()) }() //nolint:gosec // G703: path from os.CreateTemp, not user input
	defer tmpTar.Close()                            //nolint:errcheck

	if _, err := io.Copy(tmpTar, rc); err != nil {
		return fmt.Errorf("write image tar: %w", err)
	}
	_ = rc.Close()

	// Load image from Docker tar using go-containerregistry.
	// tarball.ImageFromPath returns a lazy image that reads layers on-demand
	// from the tar file rather than loading them all into memory at once.
	tag, ok := ref.(name.Tag)
	if !ok {
		// If reference is a digest, use tag "latest" as a fallback
		tag = ref.Context().Tag("latest")
	}

	img, err := tarball.ImageFromPath(tmpTar.Name(), &tag)
	if err != nil {
		return fmt.Errorf("load image from tar: %w", err)
	}

	return p.pushImage(ctx, imageRef, img, opt)
}

// pushImage pushes a v1.Image (layers, config, manifest) to the registry.
func (p *ImagePusher) pushImage(ctx context.Context, imageRef string, img v1.Image, opt imagePushOptions) error {
	repo := repoFromReference(imageRef)

	if err := p.pushLayers(ctx, repo, img, opt); err != nil {
		return fmt.Errorf("push layers: %w", err)
	}

	if err := p.pushConfig(ctx, repo, img); err != nil {
		return fmt.Errorf("push config: %w", err)
	}

	console.Debugf("Pushing image manifest for %s", imageRef)
	if err := p.registry.PushImage(ctx, imageRef, img); err != nil {
		return fmt.Errorf("push manifest: %w", err)
	}

	return nil
}

// pushLayers pushes all image layers concurrently using the registry client's
// WriteLayer method, which handles chunked uploads, retry, and progress reporting.
func (p *ImagePusher) pushLayers(ctx context.Context, repo string, img v1.Image, opt imagePushOptions) error {
	layers, err := img.Layers()
	if err != nil {
		return fmt.Errorf("get image layers: %w", err)
	}

	if len(layers) == 0 {
		return nil
	}

	concurrency := GetPushConcurrency()
	console.Debugf("Pushing %d layers with concurrency %d", len(layers), concurrency)

	g, ctx := errgroup.WithContext(ctx)
	g.SetLimit(concurrency)

	for _, layer := range layers {
		g.Go(func() error {
			return p.pushLayer(ctx, repo, layer, opt)
		})
	}

	return g.Wait()
}

// pushLayer pushes a single layer with progress reporting.
func (p *ImagePusher) pushLayer(ctx context.Context, repo string, layer v1.Layer, opt imagePushOptions) error {
	digest, err := layer.Digest()
	if err != nil {
		return fmt.Errorf("get layer digest: %w", err)
	}

	size, err := layer.Size()
	if err != nil {
		return fmt.Errorf("get layer size: %w", err)
	}

	console.Debugf("Pushing layer %s (%d bytes)", digest, size)

	var onProgress func(v1.Update)
	if opt.progressFn != nil {
		digestStr := digest.String()
		onProgress = func(update v1.Update) {
			opt.progressFn(PushProgress{
				LayerDigest: digestStr,
				Complete:    update.Complete,
				Total:       update.Total,
			})
		}
	}

	writeErr := writeLayerWithProgress(ctx, p.registry, registry.WriteLayerOptions{
		Repo:  repo,
		Layer: layer,
	}, onProgress)

	if writeErr != nil {
		return fmt.Errorf("push layer %s: %w", digest, writeErr)
	}

	return nil
}

// pushConfig pushes the image config blob to the registry.
func (p *ImagePusher) pushConfig(ctx context.Context, repo string, img v1.Image) error {
	cfgBlob, err := img.RawConfigFile()
	if err != nil {
		return fmt.Errorf("get config: %w", err)
	}

	cfgName, err := img.ConfigName()
	if err != nil {
		return fmt.Errorf("get config digest: %w", err)
	}

	console.Debugf("Pushing config blob %s (%d bytes)", cfgName, len(cfgBlob))

	configLayer := &configBlobLayer{
		data:   cfgBlob,
		digest: cfgName,
	}

	return p.registry.WriteLayer(ctx, registry.WriteLayerOptions{
		Repo:  repo,
		Layer: configLayer,
	})
}

// shouldFallbackToDocker returns true if the error is safe to fall back from.
// We do NOT fall back on context errors (cancellation/timeout) or authentication
// errors (401/403), since Docker push would fail with the same credentials.
func shouldFallbackToDocker(err error) bool {
	if err == nil {
		return false
	}
	if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
		return false
	}
	var transportErr *transport.Error
	if errors.As(err, &transportErr) {
		switch transportErr.StatusCode {
		case http.StatusUnauthorized, http.StatusForbidden:
			return false
		}
	}
	return true
}

// sanitizeError returns a clean, user-friendly error message.
//
// Registry errors from go-containerregistry's transport.Error can contain the
// entire HTTP response body which produces unreadable terminal output. This function extracts
// just the HTTP status code and status text for those cases.
func sanitizeError(err error) error {
	var transportErr *transport.Error
	if errors.As(err, &transportErr) {
		return fmt.Errorf("HTTP %d %s", transportErr.StatusCode, http.StatusText(transportErr.StatusCode))
	}
	return err
}

// configBlobLayer wraps a config blob to satisfy the v1.Layer interface
// required by WriteLayerOptions.
type configBlobLayer struct {
	data   []byte
	digest v1.Hash
}

func (c *configBlobLayer) Digest() (v1.Hash, error) {
	return c.digest, nil
}

// DiffID returns the same hash as Digest. For uncompressed config blobs,
// the compressed and uncompressed representations are identical, so DiffID
// (hash of uncompressed content) equals Digest (hash of compressed content).
func (c *configBlobLayer) DiffID() (v1.Hash, error) {
	return c.digest, nil
}

func (c *configBlobLayer) Compressed() (io.ReadCloser, error) {
	return io.NopCloser(bytes.NewReader(c.data)), nil
}

func (c *configBlobLayer) Uncompressed() (io.ReadCloser, error) {
	return io.NopCloser(bytes.NewReader(c.data)), nil
}

func (c *configBlobLayer) Size() (int64, error) {
	return int64(len(c.data)), nil
}

func (c *configBlobLayer) MediaType() (types.MediaType, error) {
	return types.OCIConfigJSON, nil
}


================================================
FILE: pkg/model/image_pusher_test.go
================================================
package model

import (
	"bytes"
	"context"
	"errors"
	"fmt"
	"io"
	"sync"
	"testing"

	"github.com/google/go-containerregistry/pkg/name"
	v1 "github.com/google/go-containerregistry/pkg/v1"
	"github.com/google/go-containerregistry/pkg/v1/empty"
	"github.com/google/go-containerregistry/pkg/v1/mutate"
	"github.com/google/go-containerregistry/pkg/v1/random"
	"github.com/google/go-containerregistry/pkg/v1/remote/transport"
	"github.com/google/go-containerregistry/pkg/v1/tarball"
	"github.com/google/go-containerregistry/pkg/v1/types"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"

	"github.com/replicate/cog/pkg/registry"
)

// ociMockClient implements registry.Client for testing ImagePusher.
type ociMockClient struct {
	mu              sync.Mutex
	writtenLayers   []v1.Hash
	pushedImages    []string
	writeLayerErr   error
	pushImageErr    error
	writeLayerCount int
}

func (m *ociMockClient) WriteLayer(_ context.Context, opts registry.WriteLayerOptions) error {
	m.mu.Lock()
	defer m.mu.Unlock()
	m.writeLayerCount++
	if m.writeLayerErr != nil {
		return m.writeLayerErr
	}
	digest, err := opts.Layer.Digest()
	if err != nil {
		return err
	}
	m.writtenLayers = append(m.writtenLayers, digest)

	// Send progress if channel is provided
	if opts.ProgressCh != nil {
		size, _ := opts.Layer.Size()
		opts.ProgressCh <- v1.Update{Complete: size, Total: size}
	}
	return nil
}

func (m *ociMockClient) PushImage(_ context.Context, ref string, _ v1.Image) error {
	m.mu.Lock()
	defer m.mu.Unlock()
	if m.pushImageErr != nil {
		return m.pushImageErr
	}
	m.pushedImages = append(m.pushedImages, ref)
	return nil
}

func (m *ociMockClient) Inspect(context.Context, string, *registry.Platform) (*registry.ManifestResult, error) {
	return nil, nil
}
func (m *ociMockClient) GetImage(context.Context, string, *registry.Platform) (v1.Image, error) {
	return nil, nil
}
func (m *ociMockClient) Exists(context.Context, string) (bool, error) { return false, nil }
func (m *ociMockClient) GetDescriptor(context.Context, string) (v1.Descriptor, error) {
	return v1.Descriptor{}, nil
}
func (m *ociMockClient) PushIndex(context.Context, string, v1.ImageIndex) error { return nil }

// testArtifact creates an *ImageArtifact for testing with the given reference string.
func testArtifact(ref string) *ImageArtifact {
	return &ImageArtifact{Reference: ref}
}

// fakeImageSaveFunc creates a fake ImageSave function that produces a Docker-format tar
// from the given v1.Image. This simulates Docker's ImageSave API.
func fakeImageSaveFunc(img v1.Image, tagStr string) func(context.Context, string) (io.ReadCloser, error) {
	return func(_ context.Context, _ string) (io.ReadCloser, error) {
		tag, err := name.NewTag(tagStr, name.Insecure)
		if err != nil {
			return nil, fmt.Errorf("parse tag: %w", err)
		}
		var buf bytes.Buffer
		refToImage := map[name.Tag]v1.Image{tag: img}
		if err := tarball.MultiWrite(refToImage, &buf); err != nil {
			return nil, fmt.Errorf("create test tar: %w", err)
		}
		return io.NopCloser(bytes.NewReader(buf.Bytes())), nil
	}
}

// =============================================================================
// ImagePusher.Push — OCI chunked push tests
// =============================================================================

func TestImagePusher_Push(t *testing.T) {
	t.Setenv("COG_PUSH_OCI", "1")

	t.Run("pushes layers config and manifest via OCI path", func(t *testing.T) {
		img, err := random.Image(1024, 2) // 2 layers of 1KB
		require.NoError(t, err)

		mock := &ociMockClient{}
		tag := "example.com/test/repo:v1"
		docker := &mockDocker{imageSaveFunc: fakeImageSaveFunc(img, tag)}
		pusher := newImagePusher(docker, mock)

		err = pusher.Push(context.Background(), testArtifact(tag))
		require.NoError(t, err)

		// Should have pushed 2 layers + 1 config blob = 3 WriteLayer calls
		assert.Equal(t, 3, mock.writeLayerCount)

		// Should have pushed the manifest
		require.Len(t, mock.pushedImages, 1)
		assert.Equal(t, tag, mock.pushedImages[0])
	})

	t.Run("reports progress via callback", func(t *testing.T) {
		img, err := random.Image(1024, 1)
		require.NoError(t, err)

		mock := &ociMockClient{}
		tag := "example.com/test/repo:v1"
		docker := &mockDocker{imageSaveFunc: fakeImageSaveFunc(img, tag)}
		pusher := newImagePusher(docker, mock)

		var mu sync.Mutex
		var progressUpdates []PushProgress

		err = pusher.Push(context.Background(), testArtifact(tag), WithProgressFn(func(p PushProgress) {
			mu.Lock()
			defer mu.Unlock()
			progressUpdates = append(progressUpdates, p)
		}))
		require.NoError(t, err)

		mu.Lock()
		defer mu.Unlock()
		assert.NotEmpty(t, progressUpdates)
		for _, p := range progressUpdates {
			assert.NotEmpty(t, p.LayerDigest)
			assert.True(t, p.Complete > 0)
			assert.True(t, p.Total > 0)
		}
	})

	t.Run("falls back to docker when WriteLayer fails", func(t *testing.T) {
		img, err := random.Image(1024, 1)
		require.NoError(t, err)

		var dockerPushed bool
		mock := &ociMockClient{writeLayerErr: errors.New("upload failed")}
		tag := "example.com/test/repo:v1"
		docker := &mockDocker{
			imageSaveFunc: fakeImageSaveFunc(img, tag),
			pushFunc: func(_ context.Context, _ string) error {
				dockerPushed = true
				return nil
			},
		}
		pusher := newImagePusher(docker, mock)

		err = pusher.Push(context.Background(), testArtifact(tag))
		require.NoError(t, err)
		assert.True(t, dockerPushed)
	})

	t.Run("falls back to docker when PushImage fails", func(t *testing.T) {
		img, err := random.Image(1024, 1)
		require.NoError(t, err)

		var dockerPushed bool
		mock := &ociMockClient{pushImageErr: errors.New("manifest push failed")}
		tag := "example.com/test/repo:v1"
		docker := &mockDocker{
			imageSaveFunc: fakeImageSaveFunc(img, tag),
			pushFunc: func(_ context.Context, _ string) error {
				dockerPushed = true
				return nil
			},
		}
		pusher := newImagePusher(docker, mock)

		err = pusher.Push(context.Background(), testArtifact(tag))
		require.NoError(t, err)
		assert.True(t, dockerPushed)
	})

	t.Run("falls back to docker when ImageSave fails", func(t *testing.T) {
		mock := &ociMockClient{}

		var dockerPushed bool
		docker := &mockDocker{
			imageSaveFunc: func(_ context.Context, _ string) (io.ReadCloser, error) {
				return nil, errors.New("docker daemon unavailable")
			},
			pushFunc: func(_ context.Context, _ string) error {
				dockerPushed = true
				return nil
			},
		}
		pusher := newImagePusher(docker, mock)

		err := pusher.Push(context.Background(), testArtifact("example.com/test/repo:v1"))
		require.NoError(t, err)
		assert.True(t, dockerPushed)
	})

	t.Run("handles empty image with no layers", func(t *testing.T) {
		img := empty.Image
		img, err := mutate.Config(img, v1.Config{})
		require.NoError(t, err)

		mock := &ociMockClient{}
		tag := "example.com/test/repo:empty"
		docker := &mockDocker{imageSaveFunc: fakeImageSaveFunc(img, tag)}
		pusher := newImagePusher(docker, mock)

		err = pusher.Push(context.Background(), testArtifact(tag))
		require.NoError(t, err)

		// Only config blob should be written (no layers)
		assert.Equal(t, 1, mock.writeLayerCount)
		require.Len(t, mock.pushedImages, 1)
	})
}

// =============================================================================
// ImagePusher.Push with artifact tests
// =============================================================================

func TestImagePusher_PushArtifact(t *testing.T) {
	t.Run("pushes artifact by reference", func(t *testing.T) {
		var dockerPushed string
		docker := &mockDocker{
			pushFunc: func(_ context.Context, ref string) error {
				dockerPushed = ref
				return nil
			},
		}

		// No registry — will use Docker push directly
		pusher := newImagePusher(docker, nil)
		artifact := &ImageArtifact{Reference: "r8.im/user/model:latest"}

		err := pusher.Push(context.Background(), artifact)

		require.NoError(t, err)
		require.Equal(t, "r8.im/user/model:latest", dockerPushed)
	})

	t.Run("returns error for nil artifact", func(t *testing.T) {
		pusher := newImagePusher(&mockDocker{}, nil)

		err := pusher.Push(context.Background(), nil)

		require.Error(t, err)
		require.Contains(t, err.Error(), "nil")
	})

	t.Run("returns error for empty reference", func(t *testing.T) {
		pusher := newImagePusher(&mockDocker{}, nil)

		err := pusher.Push(context.Background(), &ImageArtifact{Reference: ""})

		require.Error(t, err)
		require.Contains(t, err.Error(), "no reference")
	})

	t.Run("propagates docker push error", func(t *testing.T) {
		docker := &mockDocker{
			pushFunc: func(_ context.Context, _ string) error {
				return errors.New("unauthorized: authentication required")
			},
		}

		pusher := newImagePusher(docker, nil)
		artifact := &ImageArtifact{Reference: "r8.im/user/model:latest"}

		err := pusher.Push(context.Background(), artifact)

		require.Error(t, err)
		require.Contains(t, err.Error(), "unauthorized")
	})
}

// =============================================================================
// Docker fallback behavior tests
// =============================================================================

func TestImagePusher_Fallback(t *testing.T) {
	t.Setenv("COG_PUSH_OCI", "1")

	t.Run("uses OCI push when it succeeds", func(t *testing.T) {
		img, err := random.Image(512, 1)
		require.NoError(t, err)

		mock := &ociMockClient{}
		tag := "example.com/test/repo:v1"
		docker := &mockDocker{
			imageSaveFunc: fakeImageSaveFunc(img, tag),
			pushFunc: func(_ context.Context, _ string) error {
				t.Fatal("docker push should not be called when OCI succeeds")
				return nil
			},
		}

		pusher := newImagePusher(docker, mock)

		err = pusher.Push(context.Background(), testArtifact(tag))
		require.NoError(t, err)
	})

	t.Run("falls back to docker on OCI error", func(t *testing.T) {
		var dockerPushed bool
		mock := &ociMockClient{writeLayerErr: errors.New("connection reset")}
		tag := "example.com/test/repo:v1"

		img, err := random.Image(512, 1)
		require.NoError(t, err)

		docker := &mockDocker{
			imageSaveFunc: fakeImageSaveFunc(img, tag),
			pushFunc: func(_ context.Context, _ string) error {
				dockerPushed = true
				return nil
			},
		}

		pusher := newImagePusher(docker, mock)

		err = pusher.Push(context.Background(), testArtifact(tag))
		require.NoError(t, err)
		assert.True(t, dockerPushed)
	})

	t.Run("does not fall back on context cancellation", func(t *testing.T) {
		ctx, cancel := context.WithCancel(context.Background())
		cancel() // cancel immediately

		mock := &ociMockClient{}
		tag := "example.com/test/repo:v1"
		docker := &mockDocker{
			imageSaveFunc: func(ctx context.Context, _ string) (io.ReadCloser, error) {
				return nil, ctx.Err()
			},
			pushFunc: func(_ context.Context, _ string) error {
				t.Fatal("docker push should not be called on context cancellation")
				return nil
			},
		}

		pusher := newImagePusher(docker, mock)

		err := pusher.Push(ctx, testArtifact(tag))
		require.Error(t, err)
	})

	t.Run("does not fall back on 401 unauthorized", func(t *testing.T) {
		mock := &ociMockClient{writeLayerErr: &transport.Error{StatusCode: 401}}
		tag := "example.com/test/repo:v1"

		img, err := random.Image(512, 1)
		require.NoError(t, err)

		docker := &mockDocker{
			imageSaveFunc: fakeImageSaveFunc(img, tag),
			pushFunc: func(_ context.Context, _ string) error {
				t.Fatal("docker push should not be called on auth error")
				return nil
			},
		}

		pusher := newImagePusher(docker, mock)

		err = pusher.Push(context.Background(), testArtifact(tag))
		require.Error(t, err)
		assert.Contains(t, err.Error(), "OCI chunked push")
	})

	t.Run("does not fall back on 403 forbidden", func(t *testing.T) {
		mock := &ociMockClient{writeLayerErr: &transport.Error{StatusCode: 403}}
		tag := "example.com/test/repo:v1"

		img, err := random.Image(512, 1)
		require.NoError(t, err)

		docker := &mockDocker{
			imageSaveFunc: fakeImageSaveFunc(img, tag),
			pushFunc: func(_ context.Context, _ string) error {
				t.Fatal("docker push should not be called on auth error")
				return nil
			},
		}

		pusher := newImagePusher(docker, mock)

		err = pusher.Push(context.Background(), testArtifact(tag))
		require.Error(t, err)
		assert.Contains(t, err.Error(), "OCI chunked push")
	})

	t.Run("uses docker directly when registry is nil", func(t *testing.T) {
		var dockerPushed bool
		docker := &mockDocker{
			pushFunc: func(_ context.Context, _ string) error {
				dockerPushed = true
				return nil
			},
		}

		pusher := newImagePusher(docker, nil)

		err := pusher.Push(context.Background(), testArtifact("example.com/test/repo:v1"))
		require.NoError(t, err)
		assert.True(t, dockerPushed)
	})
}

// =============================================================================
// shouldFallbackToDocker tests
// =============================================================================

func TestShouldFallbackToDocker(t *testing.T) {
	tests := []struct {
		name     string
		err      error
		expected bool
	}{
		{"nil error", nil, false},
		{"context canceled", context.Canceled, false},
		{"context deadline", context.DeadlineExceeded, false},
		{"401 unauthorized", &transport.Error{StatusCode: 401}, false},
		{"403 forbidden", &transport.Error{StatusCode: 403}, false},
		{"wrapped 401", fmt.Errorf("push failed: %w", &transport.Error{StatusCode: 401}), false},
		{"500 server error", &transport.Error{StatusCode: 500}, true},
		{"network error", errors.New("connection refused"), true},
		{"unknown error", errors.New("something unexpected"), true},
		{"export error", errors.New("export OCI layout: daemon error"), true},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			assert.Equal(t, tt.expected, shouldFallbackToDocker(tt.err))
		})
	}
}

// =============================================================================
// sanitizeError tests
// =============================================================================

func TestSanitizeError(t *testing.T) {
	t.Run("strips HTML body from transport error", func(t *testing.T) {
		htmlBody := `413 Request Entity Too Large

413 Request Entity Too Large


cloudflare
` transportErr := &transport.Error{ StatusCode: 413, Errors: nil, } // The rawBody field is unexported, so we test via the wrapped error behavior. // A transport.Error with no Errors slice and status 413 produces a message // that includes the raw body — sanitizeError should replace it. _ = htmlBody // illustrates the problem scenario result := sanitizeError(transportErr) assert.Equal(t, "HTTP 413 Request Entity Too Large", result.Error()) }) t.Run("strips body from 502 transport error", func(t *testing.T) { transportErr := &transport.Error{ StatusCode: 502, } result := sanitizeError(transportErr) assert.Equal(t, "HTTP 502 Bad Gateway", result.Error()) }) t.Run("passes through non-transport errors unchanged", func(t *testing.T) { err := errors.New("connection refused") result := sanitizeError(err) assert.Equal(t, "connection refused", result.Error()) }) t.Run("passes through wrapped transport errors", func(t *testing.T) { transportErr := &transport.Error{ StatusCode: 413, } wrapped := fmt.Errorf("pushing layer: %w", transportErr) result := sanitizeError(wrapped) assert.Equal(t, "HTTP 413 Request Entity Too Large", result.Error()) }) } // ============================================================================= // OnFallback callback tests // ============================================================================= func TestImagePusher_OnFallback(t *testing.T) { t.Setenv("COG_PUSH_OCI", "1") t.Run("calls OnFallback before docker push on OCI failure", func(t *testing.T) { var callOrder []string mock := &ociMockClient{writeLayerErr: errors.New("connection reset")} tag := "example.com/test/repo:v1" img, err := random.Image(512, 1) require.NoError(t, err) docker := &mockDocker{ imageSaveFunc: fakeImageSaveFunc(img, tag), pushFunc: func(_ context.Context, _ string) error { callOrder = append(callOrder, "docker-push") return nil }, } pusher := newImagePusher(docker, mock) err = pusher.Push(context.Background(), testArtifact(tag), WithOnFallback(func() { callOrder = append(callOrder, "on-fallback") })) require.NoError(t, err) assert.Equal(t, []string{"on-fallback", "docker-push"}, callOrder) }) t.Run("does not call OnFallback when OCI push succeeds", func(t *testing.T) { mock := &ociMockClient{} tag := "example.com/test/repo:v1" img, err := random.Image(512, 1) require.NoError(t, err) docker := &mockDocker{ imageSaveFunc: fakeImageSaveFunc(img, tag), } pusher := newImagePusher(docker, mock) var fallbackCalled bool err = pusher.Push(context.Background(), testArtifact(tag), WithOnFallback(func() { fallbackCalled = true })) require.NoError(t, err) assert.False(t, fallbackCalled) }) } // ============================================================================= // configBlobLayer tests // ============================================================================= func TestConfigBlobLayer(t *testing.T) { data := []byte(`{"architecture":"amd64","os":"linux"}`) digest := v1.Hash{Algorithm: "sha256", Hex: "abc123"} layer := &configBlobLayer{data: data, digest: digest} t.Run("Digest", func(t *testing.T) { d, err := layer.Digest() require.NoError(t, err) assert.Equal(t, digest, d) }) t.Run("DiffID equals Digest for uncompressed config", func(t *testing.T) { d, err := layer.DiffID() require.NoError(t, err) assert.Equal(t, digest, d) }) t.Run("Size", func(t *testing.T) { size, err := layer.Size() require.NoError(t, err) assert.Equal(t, int64(len(data)), size) }) t.Run("MediaType", func(t *testing.T) { mt, err := layer.MediaType() require.NoError(t, err) assert.Equal(t, types.OCIConfigJSON, mt) }) t.Run("Compressed returns data", func(t *testing.T) { rc, err := layer.Compressed() require.NoError(t, err) defer rc.Close() got, err := io.ReadAll(rc) require.NoError(t, err) assert.Equal(t, data, got) }) t.Run("Uncompressed returns data", func(t *testing.T) { rc, err := layer.Uncompressed() require.NoError(t, err) defer rc.Close() got, err := io.ReadAll(rc) require.NoError(t, err) assert.Equal(t, data, got) }) } // ============================================================================= // GetPushConcurrency tests // ============================================================================= func TestGetPushConcurrency(t *testing.T) { t.Run("returns default when env not set", func(t *testing.T) { t.Setenv("COG_PUSH_CONCURRENCY", "") assert.Equal(t, DefaultPushConcurrency, GetPushConcurrency()) }) t.Run("returns env var value", func(t *testing.T) { t.Setenv("COG_PUSH_CONCURRENCY", "8") assert.Equal(t, 8, GetPushConcurrency()) }) t.Run("returns default for invalid value", func(t *testing.T) { t.Setenv("COG_PUSH_CONCURRENCY", "not-a-number") assert.Equal(t, DefaultPushConcurrency, GetPushConcurrency()) }) t.Run("returns default for zero", func(t *testing.T) { t.Setenv("COG_PUSH_CONCURRENCY", "0") assert.Equal(t, DefaultPushConcurrency, GetPushConcurrency()) }) t.Run("returns default for negative", func(t *testing.T) { t.Setenv("COG_PUSH_CONCURRENCY", "-1") assert.Equal(t, DefaultPushConcurrency, GetPushConcurrency()) }) } ================================================ FILE: pkg/model/image_test.go ================================================ package model import ( "testing" "github.com/getkin/kin-openapi/openapi3" "github.com/stretchr/testify/require" "github.com/replicate/cog/pkg/config" ) func TestImage_IsCogModel(t *testing.T) { tests := []struct { name string image *ImageArtifact expect bool }{ { name: "nil labels", image: &ImageArtifact{Labels: nil}, expect: false, }, { name: "empty labels", image: &ImageArtifact{Labels: map[string]string{}}, expect: false, }, { name: "has cog config label", image: &ImageArtifact{ Labels: map[string]string{ LabelConfig: `{"build": {}}`, }, }, expect: true, }, { name: "has other labels but not cog config", image: &ImageArtifact{ Labels: map[string]string{ "some.other.label": "value", }, }, expect: false, }, { name: "has cog version but not config", image: &ImageArtifact{ Labels: map[string]string{ LabelVersion: "0.10.0", }, }, expect: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := tt.image.IsCogModel() require.Equal(t, tt.expect, result) }) } } func TestImage_CogVersion(t *testing.T) { tests := []struct { name string image *ImageArtifact expect string }{ { name: "nil labels", image: &ImageArtifact{Labels: nil}, expect: "", }, { name: "empty labels", image: &ImageArtifact{Labels: map[string]string{}}, expect: "", }, { name: "has version label", image: &ImageArtifact{ Labels: map[string]string{ LabelVersion: "0.10.0", }, }, expect: "0.10.0", }, { name: "has other labels but not version", image: &ImageArtifact{ Labels: map[string]string{ LabelConfig: `{"build": {}}`, }, }, expect: "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := tt.image.CogVersion() require.Equal(t, tt.expect, result) }) } } func TestImage_Config(t *testing.T) { tests := []struct { name string image *ImageArtifact expect string }{ { name: "nil labels", image: &ImageArtifact{Labels: nil}, expect: "", }, { name: "has config label", image: &ImageArtifact{ Labels: map[string]string{ LabelConfig: `{"build": {"python_version": "3.11"}}`, }, }, expect: `{"build": {"python_version": "3.11"}}`, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := tt.image.Config() require.Equal(t, tt.expect, result) }) } } func TestImage_OpenAPISchema(t *testing.T) { tests := []struct { name string image *ImageArtifact expect string }{ { name: "nil labels", image: &ImageArtifact{Labels: nil}, expect: "", }, { name: "has openapi schema label", image: &ImageArtifact{ Labels: map[string]string{ LabelOpenAPISchema: `{"openapi": "3.0.0"}`, }, }, expect: `{"openapi": "3.0.0"}`, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := tt.image.OpenAPISchema() require.Equal(t, tt.expect, result) }) } } func TestImageSource_Values(t *testing.T) { // Verify the constants have expected string values require.Equal(t, ImageSource("local"), ImageSourceLocal) require.Equal(t, ImageSource("remote"), ImageSourceRemote) require.Equal(t, ImageSource("build"), ImageSourceBuild) } func TestLabelKeys(t *testing.T) { // Verify label keys have expected prefixes require.Equal(t, "run.cog.config", LabelConfig) require.Equal(t, "run.cog.version", LabelVersion) require.Equal(t, "run.cog.openapi_schema", LabelOpenAPISchema) require.Equal(t, "run.cog.r8_weights_manifest", LabelWeightsManifest) } // ============================================================================= // Parsed accessor tests // ============================================================================= func TestImage_ParsedConfig(t *testing.T) { tests := []struct { name string image *ImageArtifact expectNil bool expectErr bool checkConfig func(t *testing.T, cfg *config.Config) }{ { name: "nil labels returns nil without error", image: &ImageArtifact{Labels: nil}, expectNil: true, expectErr: false, }, { name: "empty labels returns nil without error", image: &ImageArtifact{Labels: map[string]string{}}, expectNil: true, expectErr: false, }, { name: "missing config label returns nil without error", image: &ImageArtifact{ Labels: map[string]string{ LabelVersion: "0.10.0", }, }, expectNil: true, expectErr: false, }, { name: "valid config JSON parses correctly", image: &ImageArtifact{ Labels: map[string]string{ LabelConfig: `{"build":{"python_version":"3.12","gpu":true},"predict":"predict.py:Predictor"}`, }, }, expectNil: false, expectErr: false, checkConfig: func(t *testing.T, cfg *config.Config) { require.Equal(t, "3.12", cfg.Build.PythonVersion) require.True(t, cfg.Build.GPU) require.Equal(t, "predict.py:Predictor", cfg.Predict) }, }, { name: "invalid JSON returns error", image: &ImageArtifact{ Labels: map[string]string{ LabelConfig: `{invalid json`, }, }, expectNil: true, expectErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg, err := tt.image.ParsedConfig() if tt.expectErr { require.Error(t, err) } else { require.NoError(t, err) } if tt.expectNil { require.Nil(t, cfg) } else { require.NotNil(t, cfg) if tt.checkConfig != nil { tt.checkConfig(t, cfg) } } }) } } func TestImage_ToModel(t *testing.T) { tests := []struct { name string image *ImageArtifact expectErr error checkModel func(t *testing.T, m *Model) }{ { name: "not a cog model returns ErrNotCogModel", image: &ImageArtifact{Labels: map[string]string{}}, expectErr: ErrNotCogModel, }, { name: "valid cog model with config and schema", image: &ImageArtifact{ Reference: "my-image:latest", Digest: "sha256:abc123", Labels: map[string]string{ LabelConfig: `{"build":{"python_version":"3.12"},"predict":"predict.py:Predictor"}`, LabelVersion: "0.10.0", LabelOpenAPISchema: `{"openapi":"3.0.2","info":{"title":"Cog","version":"0.1.0"},"paths":{}}`, }, Source: ImageSourceLocal, }, checkModel: func(t *testing.T, m *Model) { require.NotNil(t, m.Image) require.Equal(t, "my-image:latest", m.Image.Reference) require.Equal(t, "0.10.0", m.CogVersion) require.NotNil(t, m.Config) require.Equal(t, "3.12", m.Config.Build.PythonVersion) require.NotNil(t, m.Schema) require.Equal(t, "Cog", m.Schema.Info.Title) }, }, { name: "valid cog model without schema", image: &ImageArtifact{ Labels: map[string]string{ LabelConfig: `{"build":{}}`, LabelVersion: "0.10.0", }, }, checkModel: func(t *testing.T, m *Model) { require.NotNil(t, m.Config) require.Nil(t, m.Schema) }, }, { name: "invalid config JSON returns error", image: &ImageArtifact{ Labels: map[string]string{ LabelConfig: `{invalid`, }, }, expectErr: nil, // Will have an error, just not ErrNotCogModel }, { name: "invalid schema JSON returns error", image: &ImageArtifact{ Labels: map[string]string{ LabelConfig: `{"build":{}}`, LabelOpenAPISchema: `{invalid schema`, }, }, expectErr: nil, // Will have an error, just not ErrNotCogModel }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { m, err := tt.image.ToModel() if tt.expectErr != nil { require.ErrorIs(t, err, tt.expectErr) return } if tt.name == "invalid config JSON returns error" || tt.name == "invalid schema JSON returns error" { require.Error(t, err) return } require.NoError(t, err) if tt.checkModel != nil { tt.checkModel(t, m) } }) } } func TestImage_ParsedOpenAPISchema(t *testing.T) { tests := []struct { name string image *ImageArtifact expectNil bool expectErr bool checkSchema func(t *testing.T, schema *openapi3.T) }{ { name: "nil labels returns nil without error", image: &ImageArtifact{Labels: nil}, expectNil: true, expectErr: false, }, { name: "empty labels returns nil without error", image: &ImageArtifact{Labels: map[string]string{}}, expectNil: true, expectErr: false, }, { name: "missing schema label returns nil without error", image: &ImageArtifact{ Labels: map[string]string{ LabelConfig: `{"build":{}}`, }, }, expectNil: true, expectErr: false, }, { name: "valid OpenAPI JSON parses correctly", image: &ImageArtifact{ Labels: map[string]string{ LabelOpenAPISchema: `{"openapi":"3.0.2","info":{"title":"Cog","version":"0.1.0"},"paths":{}}`, }, }, expectNil: false, expectErr: false, checkSchema: func(t *testing.T, schema *openapi3.T) { require.Equal(t, "3.0.2", schema.OpenAPI) require.Equal(t, "Cog", schema.Info.Title) require.Equal(t, "0.1.0", schema.Info.Version) }, }, { name: "invalid JSON returns error", image: &ImageArtifact{ Labels: map[string]string{ LabelOpenAPISchema: `{invalid json`, }, }, expectNil: true, expectErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { schema, err := tt.image.ParsedOpenAPISchema() if tt.expectErr { require.Error(t, err) } else { require.NoError(t, err) } if tt.expectNil { require.Nil(t, schema) } else { require.NotNil(t, schema) if tt.checkSchema != nil { tt.checkSchema(t, schema) } } }) } } ================================================ FILE: pkg/model/index.go ================================================ // pkg/model/index.go package model // Index represents an OCI Image Index containing multiple manifests. type Index struct { // Digest is the index digest (sha256:...). Digest string // Reference is the full image reference. Reference string // MediaType is typically application/vnd.oci.image.index.v1+json. MediaType string // Manifests are the child manifests in this index. Manifests []IndexManifest } // IndexManifest represents a single manifest within an index. type IndexManifest struct { // Digest is the manifest digest. Digest string // MediaType is the manifest media type. MediaType string // Size is the manifest size in bytes. Size int64 // Platform is the target platform (nil for artifacts). Platform *Platform // Annotations are OCI annotations on this manifest. Annotations map[string]string // Type is derived from platform/annotations (image or weights). Type ManifestType } // ManifestType identifies the type of manifest in an index. type ManifestType string const ( // ManifestTypeImage is a runnable container image. ManifestTypeImage ManifestType = "image" // ManifestTypeWeights is a weights artifact. ManifestTypeWeights ManifestType = "weights" ) // Annotation keys for weights manifests. const ( AnnotationReferenceType = "vnd.cog.reference.type" AnnotationReferenceDigest = "vnd.cog.reference.digest" ) // Annotation values. const ( // AnnotationValueWeights is the value for AnnotationReferenceType indicating a weights manifest. AnnotationValueWeights = "weights" ) // Platform values for non-platform-specific artifacts. const ( // PlatformUnknown is used for artifacts that are not platform-specific (e.g., weights). PlatformUnknown = "unknown" ) ================================================ FILE: pkg/model/index_factory.go ================================================ package model import ( "fmt" v1 "github.com/google/go-containerregistry/pkg/v1" "github.com/google/go-containerregistry/pkg/v1/empty" "github.com/google/go-containerregistry/pkg/v1/mutate" "github.com/google/go-containerregistry/pkg/v1/types" ) // IndexBuilder builds an OCI Image Index from pre-pushed manifest descriptors. type IndexBuilder struct { imageDescriptor *v1.Descriptor imagePlatform *v1.Platform weightDescriptors []weightDescEntry } // weightDescEntry pairs a weight descriptor with the image digest it references. type weightDescEntry struct { descriptor v1.Descriptor imageDigest string name string target string } // NewIndexBuilder creates a new index builder. func NewIndexBuilder() *IndexBuilder { return &IndexBuilder{} } // SetImageDescriptor sets the image manifest descriptor. func (b *IndexBuilder) SetImageDescriptor(desc v1.Descriptor, platform *v1.Platform) { b.imageDescriptor = &desc b.imagePlatform = platform } // AddWeightDescriptor adds a weight manifest descriptor. // imageDigest is the digest of the model image, used in the reference annotation. // name and target are optional weight metadata for index annotations. func (b *IndexBuilder) AddWeightDescriptor(desc v1.Descriptor, imageDigest, name, target string) { b.weightDescriptors = append(b.weightDescriptors, weightDescEntry{ descriptor: desc, imageDigest: imageDigest, name: name, target: target, }) } // BuildFromDescriptors creates an OCI Image Index from pre-pushed manifest descriptors. // This works with bare descriptors returned from push operations, avoiding the need // to fetch images back from the registry. func (b *IndexBuilder) BuildFromDescriptors() (v1.ImageIndex, error) { if b.imageDescriptor == nil { return nil, fmt.Errorf("image descriptor not set") } idx := mutate.IndexMediaType(empty.Index, types.OCIImageIndex) // Add image manifest idx = mutate.AppendManifests(idx, mutate.IndexAddendum{ Add: &descriptorAppendable{desc: *b.imageDescriptor}, Descriptor: v1.Descriptor{ MediaType: b.imageDescriptor.MediaType, Size: b.imageDescriptor.Size, Digest: b.imageDescriptor.Digest, Platform: b.imagePlatform, }, }) // Add weight manifest(s) for _, entry := range b.weightDescriptors { annotations := map[string]string{ AnnotationReferenceType: AnnotationValueWeights, } if entry.imageDigest != "" { annotations[AnnotationReferenceDigest] = entry.imageDigest } if entry.name != "" { annotations[AnnotationWeightName] = entry.name } if entry.target != "" { annotations[AnnotationWeightDest] = entry.target } idx = mutate.AppendManifests(idx, mutate.IndexAddendum{ Add: &descriptorAppendable{desc: entry.descriptor}, Descriptor: v1.Descriptor{ MediaType: entry.descriptor.MediaType, Size: entry.descriptor.Size, Digest: entry.descriptor.Digest, Platform: &v1.Platform{ OS: PlatformUnknown, Architecture: PlatformUnknown, }, Annotations: annotations, }, }) } return idx, nil } // descriptorAppendable wraps a v1.Descriptor to implement mutate.Appendable. // This allows building an OCI index from descriptors without needing full v1.Image objects. type descriptorAppendable struct { desc v1.Descriptor } func (d *descriptorAppendable) MediaType() (types.MediaType, error) { return d.desc.MediaType, nil } func (d *descriptorAppendable) Digest() (v1.Hash, error) { return d.desc.Digest, nil } func (d *descriptorAppendable) Size() (int64, error) { return d.desc.Size, nil } ================================================ FILE: pkg/model/index_factory_test.go ================================================ package model import ( "testing" v1 "github.com/google/go-containerregistry/pkg/v1" "github.com/google/go-containerregistry/pkg/v1/types" "github.com/stretchr/testify/require" ) func TestIndexBuilder_BuildFromDescriptors(t *testing.T) { t.Run("builds index from image and weight descriptors", func(t *testing.T) { imgDesc := v1.Descriptor{ MediaType: types.OCIManifestSchema1, Size: 1234, Digest: v1.Hash{ Algorithm: "sha256", Hex: "aaaa", }, } weightDesc := v1.Descriptor{ MediaType: types.OCIManifestSchema1, Size: 5678, Digest: v1.Hash{ Algorithm: "sha256", Hex: "bbbb", }, } builder := NewIndexBuilder() builder.SetImageDescriptor(imgDesc, &v1.Platform{OS: "linux", Architecture: "amd64"}) builder.AddWeightDescriptor(weightDesc, imgDesc.Digest.String(), "model-v1", "/cache/model.safetensors") idx, err := builder.BuildFromDescriptors() require.NoError(t, err) idxManifest, err := idx.IndexManifest() require.NoError(t, err) require.Len(t, idxManifest.Manifests, 2) // First entry: image with platform require.Equal(t, imgDesc.Digest, idxManifest.Manifests[0].Digest) require.Equal(t, imgDesc.Size, idxManifest.Manifests[0].Size) require.Equal(t, "linux", idxManifest.Manifests[0].Platform.OS) require.Equal(t, "amd64", idxManifest.Manifests[0].Platform.Architecture) // Second entry: weight artifact with unknown platform and annotations require.Equal(t, weightDesc.Digest, idxManifest.Manifests[1].Digest) require.Equal(t, weightDesc.Size, idxManifest.Manifests[1].Size) require.Equal(t, PlatformUnknown, idxManifest.Manifests[1].Platform.OS) require.Equal(t, PlatformUnknown, idxManifest.Manifests[1].Platform.Architecture) require.Equal(t, AnnotationValueWeights, idxManifest.Manifests[1].Annotations[AnnotationReferenceType]) require.Equal(t, imgDesc.Digest.String(), idxManifest.Manifests[1].Annotations[AnnotationReferenceDigest]) require.Equal(t, "model-v1", idxManifest.Manifests[1].Annotations[AnnotationWeightName]) require.Equal(t, "/cache/model.safetensors", idxManifest.Manifests[1].Annotations[AnnotationWeightDest]) }) t.Run("builds index with multiple weight descriptors", func(t *testing.T) { imgDesc := v1.Descriptor{ MediaType: types.OCIManifestSchema1, Size: 1000, Digest: v1.Hash{Algorithm: "sha256", Hex: "img111"}, } weight1 := v1.Descriptor{ MediaType: types.OCIManifestSchema1, Size: 2000, Digest: v1.Hash{Algorithm: "sha256", Hex: "w1111"}, } weight2 := v1.Descriptor{ MediaType: types.OCIManifestSchema1, Size: 3000, Digest: v1.Hash{Algorithm: "sha256", Hex: "w2222"}, } builder := NewIndexBuilder() builder.SetImageDescriptor(imgDesc, &v1.Platform{OS: "linux", Architecture: "amd64"}) builder.AddWeightDescriptor(weight1, imgDesc.Digest.String(), "weight-1", "/weights/w1.bin") builder.AddWeightDescriptor(weight2, imgDesc.Digest.String(), "weight-2", "/weights/w2.bin") idx, err := builder.BuildFromDescriptors() require.NoError(t, err) idxManifest, err := idx.IndexManifest() require.NoError(t, err) require.Len(t, idxManifest.Manifests, 3) // 1 image + 2 weights }) t.Run("requires image descriptor", func(t *testing.T) { builder := NewIndexBuilder() _, err := builder.BuildFromDescriptors() require.Error(t, err) require.Contains(t, err.Error(), "image descriptor not set") }) t.Run("builds index without weight descriptors", func(t *testing.T) { imgDesc := v1.Descriptor{ MediaType: types.OCIManifestSchema1, Size: 1234, Digest: v1.Hash{Algorithm: "sha256", Hex: "aaaa"}, } builder := NewIndexBuilder() builder.SetImageDescriptor(imgDesc, &v1.Platform{OS: "linux", Architecture: "amd64"}) idx, err := builder.BuildFromDescriptors() require.NoError(t, err) idxManifest, err := idx.IndexManifest() require.NoError(t, err) require.Len(t, idxManifest.Manifests, 1) }) t.Run("index has OCI media type", func(t *testing.T) { imgDesc := v1.Descriptor{ MediaType: types.OCIManifestSchema1, Size: 1234, Digest: v1.Hash{Algorithm: "sha256", Hex: "aaaa"}, } builder := NewIndexBuilder() builder.SetImageDescriptor(imgDesc, &v1.Platform{OS: "linux", Architecture: "amd64"}) idx, err := builder.BuildFromDescriptors() require.NoError(t, err) mt, err := idx.MediaType() require.NoError(t, err) require.Equal(t, types.OCIImageIndex, mt) }) } ================================================ FILE: pkg/model/index_test.go ================================================ // pkg/model/index_test.go package model import ( "testing" v1 "github.com/google/go-containerregistry/pkg/v1" "github.com/stretchr/testify/require" ) func TestModel_IsBundle(t *testing.T) { t.Run("returns false with no artifacts", func(t *testing.T) { m := &Model{} require.False(t, m.IsBundle()) }) t.Run("returns false with only image artifact", func(t *testing.T) { m := &Model{ Artifacts: []Artifact{ &ImageArtifact{name: "model", Reference: "r8.im/user/model:latest"}, }, } require.False(t, m.IsBundle()) }) t.Run("returns true with weight artifacts", func(t *testing.T) { m := &Model{ Artifacts: []Artifact{ &ImageArtifact{name: "model", Reference: "r8.im/user/model:latest"}, NewWeightArtifact("w1", v1.Descriptor{}, "/tmp/w1", "/weights/w1", WeightConfig{}), }, } require.True(t, m.IsBundle()) }) } func TestManifestType(t *testing.T) { require.Equal(t, ManifestType("image"), ManifestTypeImage) require.Equal(t, ManifestType("weights"), ManifestTypeWeights) } ================================================ FILE: pkg/model/model.go ================================================ package model import ( "encoding/json" "github.com/getkin/kin-openapi/openapi3" "github.com/replicate/cog/pkg/config" ) // Model represents a Cog model extracted from an image. type Model struct { Image *ImageArtifact // Underlying OCI image Config *config.Config // Parsed cog.yaml Schema *openapi3.T // OpenAPI schema CogVersion string // Version of cog used to build // Index is the OCI Image Index (populated when inspecting a pushed model). Index *Index // TODO(md): OCIIndex is a temporary gate. When true, Push() creates an OCI // Image Index with weight artifacts. When false, Push() does a plain docker push. // Remove this field once index pushes are validated with all registries. OCIIndex bool // Artifacts is the collection of all artifacts produced by building this model. // Populated by Resolver.Build(). Contains ImageArtifact and WeightArtifact instances. Artifacts []Artifact } // HasGPU returns true if the model requires GPU. func (m *Model) HasGPU() bool { return m.Config != nil && m.Config.Build != nil && m.Config.Build.GPU } // SchemaJSON returns the OpenAPI schema as JSON bytes, or nil if no schema. func (m *Model) SchemaJSON() ([]byte, error) { if m.Schema == nil { return nil, nil } return json.Marshal(m.Schema) } // ImageRef returns the image reference string, or empty string if no image. func (m *Model) ImageRef() string { if m.Image == nil { return "" } return m.Image.Reference } // IsBundle returns true if this model has weight artifacts. func (m *Model) IsBundle() bool { return len(m.WeightArtifacts()) > 0 } // GetImageArtifact returns the first ImageArtifact from the artifacts collection, // or nil if none exists. func (m *Model) GetImageArtifact() *ImageArtifact { for _, a := range m.Artifacts { if ia, ok := a.(*ImageArtifact); ok { return ia } } return nil } // WeightArtifacts returns all WeightArtifact instances from the artifacts collection. func (m *Model) WeightArtifacts() []*WeightArtifact { var weights []*WeightArtifact for _, a := range m.Artifacts { if wa, ok := a.(*WeightArtifact); ok { weights = append(weights, wa) } } return weights } // ArtifactsByType returns all artifacts matching the given type. func (m *Model) ArtifactsByType(t ArtifactType) []Artifact { var result []Artifact for _, a := range m.Artifacts { if a.Type() == t { result = append(result, a) } } return result } ================================================ FILE: pkg/model/model_test.go ================================================ package model import ( "testing" "time" "github.com/getkin/kin-openapi/openapi3" v1 "github.com/google/go-containerregistry/pkg/v1" "github.com/stretchr/testify/require" "github.com/replicate/cog/pkg/config" ) func TestModel_HasGPU(t *testing.T) { tests := []struct { name string model *Model expect bool }{ { name: "nil config", model: &Model{Config: nil}, expect: false, }, { name: "nil build", model: &Model{Config: &config.Config{Build: nil}}, expect: false, }, { name: "GPU false", model: &Model{Config: &config.Config{Build: &config.Build{GPU: false}}}, expect: false, }, { name: "GPU true", model: &Model{Config: &config.Config{Build: &config.Build{GPU: true}}}, expect: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := tt.model.HasGPU() require.Equal(t, tt.expect, result) }) } } func TestModel_SchemaJSON(t *testing.T) { tests := []struct { name string model *Model expectNil bool expectJSON string }{ { name: "nil schema", model: &Model{Schema: nil}, expectNil: true, }, { name: "schema with openapi version", model: &Model{ Schema: &openapi3.T{ OpenAPI: "3.0.0", }, }, expectNil: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result, err := tt.model.SchemaJSON() require.NoError(t, err) if tt.expectNil { require.Nil(t, result) } else { require.NotNil(t, result) // Verify it's valid JSON containing expected field require.Contains(t, string(result), `"openapi"`) } }) } } func TestModel_ImageRef(t *testing.T) { tests := []struct { name string model *Model expect string }{ { name: "nil image", model: &Model{Image: nil}, expect: "", }, { name: "with image reference", model: &Model{ Image: &ImageArtifact{Reference: "r8.im/user/model@sha256:abc123"}, }, expect: "r8.im/user/model@sha256:abc123", }, { name: "with empty reference", model: &Model{ Image: &ImageArtifact{Reference: ""}, }, expect: "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := tt.model.ImageRef() require.Equal(t, tt.expect, result) }) } } func TestModel_GetImageArtifact(t *testing.T) { imgArtifact := NewImageArtifact("model", v1.Descriptor{Digest: v1.Hash{Algorithm: "sha256", Hex: "abc123"}, Size: 1024}, "r8.im/user/model@sha256:abc123", ) weightArtifact := NewWeightArtifact("weights", v1.Descriptor{Digest: v1.Hash{Algorithm: "sha256", Hex: "def456"}, Size: 4096}, "/data/weights.bin", "/weights/model.bin", WeightConfig{SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "weights", Target: "/weights/model.bin", Created: time.Now()}, ) tests := []struct { name string model *Model expectNil bool }{ { name: "no artifacts", model: &Model{}, expectNil: true, }, { name: "nil artifacts", model: &Model{Artifacts: nil}, expectNil: true, }, { name: "only weight artifacts", model: &Model{Artifacts: []Artifact{weightArtifact}}, expectNil: true, }, { name: "has image artifact", model: &Model{Artifacts: []Artifact{imgArtifact, weightArtifact}}, expectNil: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := tt.model.GetImageArtifact() if tt.expectNil { require.Nil(t, result) } else { require.NotNil(t, result) require.Equal(t, ArtifactTypeImage, result.Type()) require.Equal(t, "model", result.Name()) } }) } } func TestModel_WeightArtifacts(t *testing.T) { imgArtifact := NewImageArtifact("model", v1.Descriptor{Digest: v1.Hash{Algorithm: "sha256", Hex: "abc123"}, Size: 1024}, "r8.im/user/model@sha256:abc123", ) w1 := NewWeightArtifact("llama", v1.Descriptor{Digest: v1.Hash{Algorithm: "sha256", Hex: "w1"}, Size: 4096}, "/data/llama.bin", "/weights/llama.bin", WeightConfig{SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "llama", Target: "/weights/llama.bin", Created: time.Now()}, ) w2 := NewWeightArtifact("embeddings", v1.Descriptor{Digest: v1.Hash{Algorithm: "sha256", Hex: "w2"}, Size: 2048}, "/data/embed.bin", "/weights/embed.bin", WeightConfig{SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "embeddings", Target: "/weights/embed.bin", Created: time.Now()}, ) tests := []struct { name string model *Model expect int }{ {name: "no artifacts", model: &Model{}, expect: 0}, {name: "only image", model: &Model{Artifacts: []Artifact{imgArtifact}}, expect: 0}, {name: "one weight", model: &Model{Artifacts: []Artifact{imgArtifact, w1}}, expect: 1}, {name: "two weights", model: &Model{Artifacts: []Artifact{imgArtifact, w1, w2}}, expect: 2}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := tt.model.WeightArtifacts() require.Len(t, result, tt.expect) }) } } func TestModel_ArtifactsByType(t *testing.T) { imgArtifact := NewImageArtifact("model", v1.Descriptor{Digest: v1.Hash{Algorithm: "sha256", Hex: "abc123"}, Size: 1024}, "r8.im/user/model@sha256:abc123", ) w1 := NewWeightArtifact("llama", v1.Descriptor{Digest: v1.Hash{Algorithm: "sha256", Hex: "w1"}, Size: 4096}, "/data/llama.bin", "/weights/llama.bin", WeightConfig{SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "llama", Target: "/weights/llama.bin", Created: time.Now()}, ) m := &Model{Artifacts: []Artifact{imgArtifact, w1}} images := m.ArtifactsByType(ArtifactTypeImage) require.Len(t, images, 1) require.Equal(t, "model", images[0].Name()) weights := m.ArtifactsByType(ArtifactTypeWeight) require.Len(t, weights, 1) require.Equal(t, "llama", weights[0].Name()) } ================================================ FILE: pkg/model/options.go ================================================ package model import "github.com/replicate/cog/pkg/config" // BuildOptions contains all settings for building a Cog image. // This consolidates the many parameters previously passed to image.Build(). type BuildOptions struct { // ImageName is the output image name (required). ImageName string // NoCache disables build cache. NoCache bool // SeparateWeights builds weights as a separate layer. SeparateWeights bool // Strip removes debug symbols from binaries. Strip bool // Precompile precompiles Python bytecode. Precompile bool // UseCudaBaseImage controls CUDA base image usage: "auto", "true", or "false". UseCudaBaseImage string // UseCogBaseImage controls cog base image usage. nil means auto-detect. UseCogBaseImage *bool // Secrets are build-time secrets to pass to the build. Secrets []string // ProgressOutput controls build output format: "auto", "plain", or "tty". ProgressOutput string // Annotations are extra labels to add to the image. Annotations map[string]string // SchemaFile is a custom OpenAPI schema file path. SchemaFile string // DockerfileFile is a custom Dockerfile path. DockerfileFile string // WeightsLockPath is the path to weights.lock file. // Default: weights.lock in project directory. WeightsLockPath string // TODO(md): OCIIndex is a temporary gate. When true, builds produce weight // artifacts and pushes create an OCI Image Index. Set via COG_OCI_INDEX=1. // Remove this field once index pushes are validated with all registries. OCIIndex bool // ExcludeSource skips the COPY . /src step in the generated Dockerfile. // Used by `cog serve` to produce an image identical to `cog build` minus // the source copy — the source directory is volume-mounted at runtime. // All other layers (wheel installs, apt, etc.) are shared with `cog build` // via Docker layer caching. ExcludeSource bool // SkipSchemaValidation skips OpenAPI schema generation and validation. // Used by `cog run` which executes arbitrary commands and may not have // a predictor or trainer defined in cog.yaml. SkipSchemaValidation bool // SkipLabels skips adding Cog metadata labels to the built image. // Used by `cog run`, `cog predict`, `cog serve`, and `cog train` where // the image is for local use only and not being distributed. SkipLabels bool } // WithDefaults returns a copy of BuildOptions with defaults applied from Source. // This fills in sensible defaults for any unset fields. func (o BuildOptions) WithDefaults(src *Source) BuildOptions { // Default image name from project directory if o.ImageName == "" { o.ImageName = config.DockerImageName(src.ProjectDir) } // Default progress output if o.ProgressOutput == "" { o.ProgressOutput = "auto" } return o } ================================================ FILE: pkg/model/options_test.go ================================================ package model import ( "testing" "github.com/stretchr/testify/require" "github.com/replicate/cog/pkg/config" ) func TestBuildOptions_WithDefaults_ImageName(t *testing.T) { src := &Source{ Config: &config.Config{Build: &config.Build{}}, ProjectDir: "/path/to/my-project", } opts := BuildOptions{} opts = opts.WithDefaults(src) // config.DockerImageName normalizes the name require.Equal(t, "cog-my-project", opts.ImageName) } func TestBuildOptions_WithDefaults_PreservesExplicitImageName(t *testing.T) { src := &Source{ Config: &config.Config{Build: &config.Build{}}, ProjectDir: "/path/to/my-project", } opts := BuildOptions{ImageName: "my-custom-image"} opts = opts.WithDefaults(src) require.Equal(t, "my-custom-image", opts.ImageName) } func TestBuildOptions_WithDefaults_ProgressOutput(t *testing.T) { src := &Source{ Config: &config.Config{Build: &config.Build{}}, ProjectDir: "/path/to/project", } opts := BuildOptions{} opts = opts.WithDefaults(src) require.Equal(t, "auto", opts.ProgressOutput) } func TestBuildOptions_WithDefaults_PreservesExplicitProgressOutput(t *testing.T) { src := &Source{ Config: &config.Config{Build: &config.Build{}}, ProjectDir: "/path/to/project", } opts := BuildOptions{ProgressOutput: "plain"} opts = opts.WithDefaults(src) require.Equal(t, "plain", opts.ProgressOutput) } func TestBuildOptions_WithDefaults_NilBuildConfig(t *testing.T) { src := &Source{ Config: &config.Config{Build: nil}, ProjectDir: "/path/to/project", } opts := BuildOptions{} opts = opts.WithDefaults(src) // Should not panic and should apply other defaults require.Equal(t, "cog-project", opts.ImageName) require.Equal(t, "auto", opts.ProgressOutput) } func TestBuildOptions_WithDefaults_NilConfig(t *testing.T) { src := &Source{ Config: nil, ProjectDir: "/path/to/project", } opts := BuildOptions{} opts = opts.WithDefaults(src) // Should not panic and should apply other defaults require.Equal(t, "cog-project", opts.ImageName) require.Equal(t, "auto", opts.ProgressOutput) } func TestBuildOptions_AllFieldsPreserved(t *testing.T) { src := &Source{ Config: &config.Config{Build: &config.Build{}}, ProjectDir: "/path/to/project", } useCogBase := true opts := BuildOptions{ ImageName: "my-image", NoCache: true, SeparateWeights: true, Strip: true, Precompile: true, UseCudaBaseImage: "true", UseCogBaseImage: &useCogBase, Secrets: []string{"secret1", "secret2"}, ProgressOutput: "tty", Annotations: map[string]string{"key": "value"}, SchemaFile: "/path/to/schema.json", DockerfileFile: "/path/to/Dockerfile", WeightsLockPath: "/path/to/weights.lock", } result := opts.WithDefaults(src) require.Equal(t, "my-image", result.ImageName) require.True(t, result.NoCache) require.True(t, result.SeparateWeights) require.True(t, result.Strip) require.True(t, result.Precompile) require.Equal(t, "true", result.UseCudaBaseImage) require.NotNil(t, result.UseCogBaseImage) require.True(t, *result.UseCogBaseImage) require.Equal(t, []string{"secret1", "secret2"}, result.Secrets) require.Equal(t, "tty", result.ProgressOutput) require.Equal(t, map[string]string{"key": "value"}, result.Annotations) require.Equal(t, "/path/to/schema.json", result.SchemaFile) require.Equal(t, "/path/to/Dockerfile", result.DockerfileFile) require.Equal(t, "/path/to/weights.lock", result.WeightsLockPath) } func TestBuildOptions_WeightsLockPath(t *testing.T) { opts := BuildOptions{ WeightsLockPath: "/custom/weights.lock", } require.Equal(t, "/custom/weights.lock", opts.WeightsLockPath) } ================================================ FILE: pkg/model/push_helpers.go ================================================ package model import ( "context" "os" "strconv" v1 "github.com/google/go-containerregistry/pkg/v1" "github.com/replicate/cog/pkg/registry" ) const ( // DefaultPushConcurrency is the default number of concurrent uploads // for both image layers and weight artifacts. // This matches Docker's default concurrency for layer uploads, which is a reasonable baseline for OCI pushes as well. DefaultPushConcurrency = 5 // envPushConcurrency is the environment variable that overrides DefaultPushConcurrency. envPushConcurrency = "COG_PUSH_CONCURRENCY" ) // GetPushConcurrency returns the push concurrency, checking the COG_PUSH_CONCURRENCY // environment variable first, then falling back to DefaultPushConcurrency. func GetPushConcurrency() int { if v := os.Getenv(envPushConcurrency); v != "" { if n, err := strconv.Atoi(v); err == nil && n > 0 { return n } } return DefaultPushConcurrency } // PushProgress reports progress for a layer or blob upload. // Used by both ImagePusher (container image layers) and WeightPusher (weight blobs). type PushProgress struct { // LayerDigest identifies which layer this progress is for. // Empty for single-layer pushes (e.g., weight uploads). LayerDigest string // Complete is the number of bytes uploaded so far. Complete int64 // Total is the total number of bytes to upload. Total int64 } // writeLayerWithProgress pushes a layer via registry.WriteLayer, managing the // progress channel lifecycle (create, drain, close) on behalf of the caller. // // onProgress is called for each v1.Update from the registry. If nil, no progress // channel is created and no goroutine is spawned. func writeLayerWithProgress(ctx context.Context, reg registry.Client, opts registry.WriteLayerOptions, onProgress func(v1.Update)) error { var progressCh chan v1.Update var progressDone chan struct{} if onProgress != nil { progressCh = make(chan v1.Update, 100) progressDone = make(chan struct{}) go func() { defer close(progressDone) for update := range progressCh { onProgress(update) } }() opts.ProgressCh = progressCh } writeErr := reg.WriteLayer(ctx, opts) // Close the progress channel ourselves — WriteLayer sends to it but does not close it. if progressCh != nil { close(progressCh) } if progressDone != nil { <-progressDone } return writeErr } ================================================ FILE: pkg/model/pusher.go ================================================ // pkg/model/pusher.go package model import ( "context" "fmt" "github.com/google/go-containerregistry/pkg/name" v1 "github.com/google/go-containerregistry/pkg/v1" "golang.org/x/sync/errgroup" "github.com/replicate/cog/pkg/docker/command" "github.com/replicate/cog/pkg/registry" ) // PushOptions configures push behavior. type PushOptions struct { // ProjectDir is the base directory for resolving weight file paths. // // Deprecated: Artifacts carry their own file paths. ProjectDir string // FilePaths maps weight name identifiers to their file paths. // // Deprecated: Use Model.Artifacts instead — WeightArtifact carries FilePath. FilePaths map[string]string // Platform specifies the target platform for bundle indexes. // Default: linux/amd64 Platform *Platform // ImageProgressFn is an optional callback for reporting per-layer upload progress // during OCI chunked image push. Each call includes the layer digest, bytes // completed, and total bytes. ImageProgressFn func(PushProgress) // OnFallback is called when OCI push fails and the push is about to fall // back to Docker push. This allows the caller to clean up any OCI-specific // progress display before Docker push starts its own output. OnFallback func() } // ============================================================================= // BundlePusher - pushes OCI Index with image + weights // ============================================================================= // BundlePusher pushes bundles (OCI Index with image + weight artifacts). // It orchestrates ImagePusher and WeightPusher, then assembles the OCI index // from the pushed manifest descriptors. type BundlePusher struct { imagePusher *ImagePusher weightPusher *WeightPusher registry registry.Client } // NewBundlePusher creates a new BundlePusher from docker and registry clients. // Both sub-pushers (image and weight) are created internally to keep // construction unified — callers don't need to know about ImagePusher or // WeightPusher directly. func NewBundlePusher(docker command.Command, reg registry.Client) *BundlePusher { return &BundlePusher{ imagePusher: newImagePusher(docker, reg), weightPusher: NewWeightPusher(reg), registry: reg, } } // Push pushes the model as an OCI Index with weight artifacts. // It reads Model.Artifacts to find the image and weight artifacts to push. func (p *BundlePusher) Push(ctx context.Context, m *Model, opts PushOptions) error { // Extract artifacts from model imgArtifact := m.GetImageArtifact() if imgArtifact == nil { return fmt.Errorf("no image artifact in model") } weightArtifacts := m.WeightArtifacts() // Derive repo from image reference (strip tag/digest for weight pushes) repo := repoFromReference(imgArtifact.Reference) // 1. Push image via OCI chunked push (falls back to Docker push on error) var imagePushOpts []ImagePushOption if opts.ImageProgressFn != nil { imagePushOpts = append(imagePushOpts, WithProgressFn(opts.ImageProgressFn)) } if opts.OnFallback != nil { imagePushOpts = append(imagePushOpts, WithOnFallback(opts.OnFallback)) } if err := p.imagePusher.Push(ctx, imgArtifact, imagePushOpts...); err != nil { return fmt.Errorf("push image %q: %w", imgArtifact.Reference, err) } // 2. Get image manifest descriptor (lightweight HEAD request) imgDesc, err := p.registry.GetDescriptor(ctx, imgArtifact.Reference) if err != nil { return fmt.Errorf("get image descriptor: %w", err) } // 3. Push weight artifacts concurrently (if any) var weightResults []*WeightPushResult if len(weightArtifacts) > 0 { weightResults, err = p.pushWeights(ctx, repo, weightArtifacts) if err != nil { return err } } // 4. Build OCI index from pushed descriptors platform := opts.Platform if platform == nil { platform = &Platform{OS: "linux", Architecture: "amd64"} } builder := NewIndexBuilder() builder.SetImageDescriptor(imgDesc, &v1.Platform{ OS: platform.OS, Architecture: platform.Architecture, Variant: platform.Variant, }) for i, wr := range weightResults { builder.AddWeightDescriptor(wr.Descriptor, imgDesc.Digest.String(), weightArtifacts[i].Name(), weightArtifacts[i].Target) } idx, err := builder.BuildFromDescriptors() if err != nil { return fmt.Errorf("build OCI index: %w", err) } // 5. Push OCI index (overwrites the tag with the index) if err := p.registry.PushIndex(ctx, imgArtifact.Reference, idx); err != nil { return fmt.Errorf("push OCI index: %w", err) } return nil } // pushWeights pushes all weight artifacts concurrently (bounded by GetPushConcurrency) // and returns their results in the same order as the input slice. // If any weight push fails, remaining pushes are canceled and the first error is returned. func (p *BundlePusher) pushWeights(ctx context.Context, repo string, weights []*WeightArtifact) ([]*WeightPushResult, error) { ordered := make([]*WeightPushResult, len(weights)) g, ctx := errgroup.WithContext(ctx) g.SetLimit(GetPushConcurrency()) for i, wa := range weights { g.Go(func() error { result, err := p.weightPusher.Push(ctx, repo, wa) if err != nil { return fmt.Errorf("push weight %q: %w", wa.Name(), err) } ordered[i] = result return nil }) } if err := g.Wait(); err != nil { return nil, err } return ordered, nil } // repoFromReference extracts the repository (without tag or digest) from an image reference. // "r8.im/user/model:latest" -> "r8.im/user/model" // "r8.im/user/model@sha256:abc" -> "r8.im/user/model" // "localhost:5000/model:latest" -> "localhost:5000/model" func repoFromReference(ref string) string { parsed, err := name.ParseReference(ref, name.Insecure) if err != nil { return ref // fallback: return as-is if unparseable } return parsed.Context().String() } ================================================ FILE: pkg/model/pusher_test.go ================================================ package model import ( "context" "errors" "os" "path/filepath" "sync" "sync/atomic" "testing" "time" v1 "github.com/google/go-containerregistry/pkg/v1" "github.com/google/go-containerregistry/pkg/v1/types" "github.com/stretchr/testify/require" ) // ============================================================================= // BundlePusher tests // ============================================================================= func TestBundlePusher_Push(t *testing.T) { t.Run("returns error when no image artifact in model", func(t *testing.T) { docker := &mockDocker{} reg := &mockRegistry{} pusher := NewBundlePusher(docker, reg) m := &Model{ Image: nil, Artifacts: []Artifact{}, // no image artifact } err := pusher.Push(context.Background(), m, PushOptions{}) require.Error(t, err) require.Contains(t, err.Error(), "no image artifact") }) t.Run("pushes image-only model as single-entry index", func(t *testing.T) { docker := &mockDocker{ pushFunc: func(ctx context.Context, ref string) error { return nil }, } imgDesc := v1.Descriptor{ MediaType: types.OCIManifestSchema1, Size: 1234, Digest: v1.Hash{Algorithm: "sha256", Hex: "imgonly"}, } reg := &mockRegistry{ getDescriptorFunc: func(ctx context.Context, ref string) (v1.Descriptor, error) { return imgDesc, nil }, pushIndexFunc: func(ctx context.Context, ref string, idx v1.ImageIndex) error { // Verify index has exactly 1 entry (image only, no weights) idxManifest, err := idx.IndexManifest() require.NoError(t, err) require.Len(t, idxManifest.Manifests, 1) require.Equal(t, imgDesc.Digest, idxManifest.Manifests[0].Digest) require.Equal(t, "linux", idxManifest.Manifests[0].Platform.OS) return nil }, } pusher := NewBundlePusher(docker, reg) m := &Model{ Image: &ImageArtifact{Reference: "r8.im/user/model:latest"}, Artifacts: []Artifact{ &ImageArtifact{name: "model", Reference: "r8.im/user/model:latest"}, // no weight artifacts — image-only model }, } err := pusher.Push(context.Background(), m, PushOptions{}) require.NoError(t, err) }) t.Run("full push flow succeeds with single weight", func(t *testing.T) { // Create temp weight file dir := t.TempDir() weightPath := filepath.Join(dir, "model.safetensors") require.NoError(t, os.WriteFile(weightPath, []byte("fake weight data"), 0o644)) // Track call sequence (mutex-protected for goroutine safety) var mu sync.Mutex var callOrder []string track := func(entry string) { mu.Lock() callOrder = append(callOrder, entry) mu.Unlock() } docker := &mockDocker{ pushFunc: func(ctx context.Context, ref string) error { track("docker:push:" + ref) return nil }, } imgDesc := v1.Descriptor{ MediaType: types.OCIManifestSchema1, Size: 1234, Digest: v1.Hash{Algorithm: "sha256", Hex: "imgdigest"}, } reg := &mockRegistry{ getDescriptorFunc: func(ctx context.Context, ref string) (v1.Descriptor, error) { track("registry:getDescriptor:" + ref) return imgDesc, nil }, pushImageFunc: func(ctx context.Context, ref string, img v1.Image) error { track("registry:pushImage:" + ref) return nil }, pushIndexFunc: func(ctx context.Context, ref string, idx v1.ImageIndex) error { track("registry:pushIndex:" + ref) // Verify the index structure idxManifest, err := idx.IndexManifest() require.NoError(t, err) require.Len(t, idxManifest.Manifests, 2) // image + 1 weight // First manifest: image with platform require.Equal(t, imgDesc.Digest, idxManifest.Manifests[0].Digest) require.Equal(t, "linux", idxManifest.Manifests[0].Platform.OS) require.Equal(t, "amd64", idxManifest.Manifests[0].Platform.Architecture) // Second manifest: weight with annotations require.Equal(t, PlatformUnknown, idxManifest.Manifests[1].Platform.OS) require.Equal(t, AnnotationValueWeights, idxManifest.Manifests[1].Annotations[AnnotationReferenceType]) require.Equal(t, imgDesc.Digest.String(), idxManifest.Manifests[1].Annotations[AnnotationReferenceDigest]) return nil }, } pusher := NewBundlePusher(docker, reg) m := &Model{ Image: &ImageArtifact{Reference: "r8.im/user/model:latest"}, Artifacts: []Artifact{ &ImageArtifact{name: "model", Reference: "r8.im/user/model:latest"}, NewWeightArtifact("model-v1", v1.Descriptor{ Digest: v1.Hash{Algorithm: "sha256", Hex: "aabbccddee112233445566778899aabb"}, }, weightPath, "/weights/model.safetensors", WeightConfig{ SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "model-v1", Target: "/weights/model.safetensors", Created: time.Now().UTC(), }), }, } err := pusher.Push(context.Background(), m, PushOptions{ Platform: &Platform{OS: "linux", Architecture: "amd64"}, }) require.NoError(t, err) // Verify the call sequence: // 1. Push image via docker // 2. Get image descriptor from registry (lightweight HEAD) // 3. Push weight via registry (single combined tag) // 4. Push OCI index to registry require.Len(t, callOrder, 4) require.Equal(t, "docker:push:r8.im/user/model:latest", callOrder[0]) require.Equal(t, "registry:getDescriptor:r8.im/user/model:latest", callOrder[1]) require.Equal(t, "registry:pushImage:r8.im/user/model:weights-model-v1-aabbccddee11", callOrder[2]) require.Equal(t, "registry:pushIndex:r8.im/user/model:latest", callOrder[3]) }) t.Run("uses default platform when not specified", func(t *testing.T) { dir := t.TempDir() weightPath := filepath.Join(dir, "model.bin") require.NoError(t, os.WriteFile(weightPath, []byte("test"), 0o644)) docker := &mockDocker{ pushFunc: func(ctx context.Context, ref string) error { return nil }, } reg := &mockRegistry{ getDescriptorFunc: func(ctx context.Context, ref string) (v1.Descriptor, error) { return v1.Descriptor{ MediaType: types.OCIManifestSchema1, Size: 100, Digest: v1.Hash{Algorithm: "sha256", Hex: "abc"}, }, nil }, pushImageFunc: func(ctx context.Context, ref string, img v1.Image) error { return nil }, pushIndexFunc: func(ctx context.Context, ref string, idx v1.ImageIndex) error { idxManifest, _ := idx.IndexManifest() // Default platform should be linux/amd64 require.Equal(t, "linux", idxManifest.Manifests[0].Platform.OS) require.Equal(t, "amd64", idxManifest.Manifests[0].Platform.Architecture) return nil }, } pusher := NewBundlePusher(docker, reg) m := &Model{ Image: &ImageArtifact{Reference: "r8.im/user/model:latest"}, Artifacts: []Artifact{ &ImageArtifact{name: "model", Reference: "r8.im/user/model:latest"}, NewWeightArtifact("w1", v1.Descriptor{}, weightPath, "/weights/model.bin", WeightConfig{ SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "w1", Target: "/weights/model.bin", Created: time.Now().UTC(), }), }, } err := pusher.Push(context.Background(), m, PushOptions{}) require.NoError(t, err) }) t.Run("returns error when image push fails", func(t *testing.T) { dir := t.TempDir() weightPath := filepath.Join(dir, "model.bin") require.NoError(t, os.WriteFile(weightPath, []byte("test"), 0o644)) docker := &mockDocker{ pushFunc: func(ctx context.Context, ref string) error { return errors.New("unauthorized: authentication required") }, } reg := &mockRegistry{} pusher := NewBundlePusher(docker, reg) m := &Model{ Image: &ImageArtifact{Reference: "r8.im/user/model:latest"}, Artifacts: []Artifact{ &ImageArtifact{name: "model", Reference: "r8.im/user/model:latest"}, NewWeightArtifact("w1", v1.Descriptor{}, weightPath, "/weights/model.bin", WeightConfig{ SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "w1", Target: "/weights/model.bin", Created: time.Now().UTC(), }), }, } err := pusher.Push(context.Background(), m, PushOptions{}) require.Error(t, err) require.Contains(t, err.Error(), "push image") require.Contains(t, err.Error(), "unauthorized") }) t.Run("returns error when get descriptor fails", func(t *testing.T) { dir := t.TempDir() weightPath := filepath.Join(dir, "model.bin") require.NoError(t, os.WriteFile(weightPath, []byte("test"), 0o644)) docker := &mockDocker{ pushFunc: func(ctx context.Context, ref string) error { return nil }, } reg := &mockRegistry{ getDescriptorFunc: func(ctx context.Context, ref string) (v1.Descriptor, error) { return v1.Descriptor{}, errors.New("manifest not found") }, } pusher := NewBundlePusher(docker, reg) m := &Model{ Image: &ImageArtifact{Reference: "r8.im/user/model:latest"}, Artifacts: []Artifact{ &ImageArtifact{name: "model", Reference: "r8.im/user/model:latest"}, NewWeightArtifact("w1", v1.Descriptor{}, weightPath, "/weights/model.bin", WeightConfig{ SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "w1", Target: "/weights/model.bin", Created: time.Now().UTC(), }), }, } err := pusher.Push(context.Background(), m, PushOptions{}) require.Error(t, err) require.Contains(t, err.Error(), "get image descriptor") }) t.Run("returns error when weight push fails", func(t *testing.T) { dir := t.TempDir() weightPath := filepath.Join(dir, "model.bin") require.NoError(t, os.WriteFile(weightPath, []byte("test"), 0o644)) docker := &mockDocker{ pushFunc: func(ctx context.Context, ref string) error { return nil }, } reg := &mockRegistry{ getDescriptorFunc: func(ctx context.Context, ref string) (v1.Descriptor, error) { return v1.Descriptor{ MediaType: types.OCIManifestSchema1, Size: 100, Digest: v1.Hash{Algorithm: "sha256", Hex: "abc"}, }, nil }, pushImageFunc: func(ctx context.Context, ref string, img v1.Image) error { return errors.New("weight push failed: quota exceeded") }, } pusher := NewBundlePusher(docker, reg) m := &Model{ Image: &ImageArtifact{Reference: "r8.im/user/model:latest"}, Artifacts: []Artifact{ &ImageArtifact{name: "model", Reference: "r8.im/user/model:latest"}, NewWeightArtifact("w1", v1.Descriptor{}, weightPath, "/weights/model.bin", WeightConfig{ SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "w1", Target: "/weights/model.bin", Created: time.Now().UTC(), }), }, } err := pusher.Push(context.Background(), m, PushOptions{}) require.Error(t, err) require.Contains(t, err.Error(), "push weight") require.Contains(t, err.Error(), "w1") }) t.Run("returns error when index push fails", func(t *testing.T) { dir := t.TempDir() weightPath := filepath.Join(dir, "model.bin") require.NoError(t, os.WriteFile(weightPath, []byte("test"), 0o644)) docker := &mockDocker{ pushFunc: func(ctx context.Context, ref string) error { return nil }, } reg := &mockRegistry{ getDescriptorFunc: func(ctx context.Context, ref string) (v1.Descriptor, error) { return v1.Descriptor{ MediaType: types.OCIManifestSchema1, Size: 100, Digest: v1.Hash{Algorithm: "sha256", Hex: "abc"}, }, nil }, pushImageFunc: func(ctx context.Context, ref string, img v1.Image) error { return nil }, pushIndexFunc: func(ctx context.Context, ref string, idx v1.ImageIndex) error { return errors.New("index push failed") }, } pusher := NewBundlePusher(docker, reg) m := &Model{ Image: &ImageArtifact{Reference: "r8.im/user/model:latest"}, Artifacts: []Artifact{ &ImageArtifact{name: "model", Reference: "r8.im/user/model:latest"}, NewWeightArtifact("w1", v1.Descriptor{}, weightPath, "/weights/model.bin", WeightConfig{ SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "w1", Target: "/weights/model.bin", Created: time.Now().UTC(), }), }, } err := pusher.Push(context.Background(), m, PushOptions{}) require.Error(t, err) require.Contains(t, err.Error(), "push OCI index") }) t.Run("pushes multiple weights concurrently", func(t *testing.T) { dir := t.TempDir() weight1Path := filepath.Join(dir, "model1.bin") weight2Path := filepath.Join(dir, "model2.bin") require.NoError(t, os.WriteFile(weight1Path, []byte("weight 1 data"), 0o644)) require.NoError(t, os.WriteFile(weight2Path, []byte("weight 2 data"), 0o644)) docker := &mockDocker{ pushFunc: func(ctx context.Context, ref string) error { return nil }, } // Use atomic counter — safe for concurrent access from goroutines var pushedWeightCount atomic.Int32 reg := &mockRegistry{ getDescriptorFunc: func(ctx context.Context, ref string) (v1.Descriptor, error) { return v1.Descriptor{ MediaType: types.OCIManifestSchema1, Size: 100, Digest: v1.Hash{Algorithm: "sha256", Hex: "abc"}, }, nil }, pushImageFunc: func(ctx context.Context, ref string, img v1.Image) error { pushedWeightCount.Add(1) return nil }, pushIndexFunc: func(ctx context.Context, ref string, idx v1.ImageIndex) error { idxManifest, _ := idx.IndexManifest() require.Len(t, idxManifest.Manifests, 3) // 1 image + 2 weights return nil }, } pusher := NewBundlePusher(docker, reg) m := &Model{ Image: &ImageArtifact{Reference: "r8.im/user/model:latest"}, Artifacts: []Artifact{ &ImageArtifact{name: "model", Reference: "r8.im/user/model:latest"}, NewWeightArtifact("w1", v1.Descriptor{ Digest: v1.Hash{Algorithm: "sha256", Hex: "aaaa111122223333444455556666777788889999aaaabbbbccccddddeeee0000"}, }, weight1Path, "/weights/model1.bin", WeightConfig{ SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "w1", Target: "/weights/model1.bin", Created: time.Now().UTC(), }), NewWeightArtifact("w2", v1.Descriptor{ Digest: v1.Hash{Algorithm: "sha256", Hex: "bbbb111122223333444455556666777788889999aaaabbbbccccddddeeee0000"}, }, weight2Path, "/weights/model2.bin", WeightConfig{ SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "w2", Target: "/weights/model2.bin", Created: time.Now().UTC(), }), }, } err := pusher.Push(context.Background(), m, PushOptions{}) require.NoError(t, err) require.Equal(t, int32(2), pushedWeightCount.Load()) // both weights pushed (1 tag each) }) } // ============================================================================= // Resolver.Push tests // ============================================================================= func TestResolver_Push(t *testing.T) { t.Run("default uses docker push", func(t *testing.T) { var dockerPushed bool docker := &mockDocker{ pushFunc: func(ctx context.Context, ref string) error { dockerPushed = true return nil }, } reg := &mockRegistry{} resolver := NewResolver(docker, reg) m := &Model{ Image: &ImageArtifact{Reference: "r8.im/user/model:latest"}, Artifacts: []Artifact{ &ImageArtifact{name: "model", Reference: "r8.im/user/model:latest"}, }, } err := resolver.Push(context.Background(), m, PushOptions{}) require.NoError(t, err) require.True(t, dockerPushed, "standalone should use docker push") }) t.Run("OCIIndex false uses docker push", func(t *testing.T) { var dockerPushed bool docker := &mockDocker{ pushFunc: func(ctx context.Context, ref string) error { dockerPushed = true return nil }, } reg := &mockRegistry{} resolver := NewResolver(docker, reg) m := &Model{ // OCIIndex not set (false by default) Image: &ImageArtifact{Reference: "r8.im/user/model:latest"}, Artifacts: []Artifact{ &ImageArtifact{name: "model", Reference: "r8.im/user/model:latest"}, }, } err := resolver.Push(context.Background(), m, PushOptions{}) require.NoError(t, err) require.True(t, dockerPushed, "default format should use docker push") }) t.Run("OCIIndex true produces an OCI index", func(t *testing.T) { var indexPushed bool docker := &mockDocker{ pushFunc: func(ctx context.Context, ref string) error { return nil }, } reg := &mockRegistry{ getDescriptorFunc: func(ctx context.Context, ref string) (v1.Descriptor, error) { return v1.Descriptor{ MediaType: types.OCIManifestSchema1, Size: 100, Digest: v1.Hash{Algorithm: "sha256", Hex: "abc"}, }, nil }, pushIndexFunc: func(ctx context.Context, ref string, idx v1.ImageIndex) error { indexPushed = true return nil }, } resolver := NewResolver(docker, reg) m := &Model{ OCIIndex: true, Image: &ImageArtifact{Reference: "r8.im/user/model:latest"}, Artifacts: []Artifact{ &ImageArtifact{name: "model", Reference: "r8.im/user/model:latest"}, }, } err := resolver.Push(context.Background(), m, PushOptions{}) require.NoError(t, err) require.True(t, indexPushed, "OCIIndex=true should push an OCI index") }) t.Run("standalone returns error when image nil", func(t *testing.T) { docker := &mockDocker{} reg := &mockRegistry{} resolver := NewResolver(docker, reg) m := &Model{ Image: nil, Artifacts: []Artifact{}, } err := resolver.Push(context.Background(), m, PushOptions{}) require.Error(t, err) require.Contains(t, err.Error(), "no image artifact") }) t.Run("OCIIndex true returns error when no image artifact", func(t *testing.T) { docker := &mockDocker{} reg := &mockRegistry{} resolver := NewResolver(docker, reg) m := &Model{ OCIIndex: true, Image: nil, Artifacts: []Artifact{}, } err := resolver.Push(context.Background(), m, PushOptions{}) require.Error(t, err) require.Contains(t, err.Error(), "no image artifact") }) } // ============================================================================= // PushOptions tests // ============================================================================= func TestPushOptions(t *testing.T) { t.Run("ProjectDir field", func(t *testing.T) { opts := PushOptions{ ProjectDir: "/path/to/project", } require.Equal(t, "/path/to/project", opts.ProjectDir) }) t.Run("Platform field", func(t *testing.T) { opts := PushOptions{ Platform: &Platform{OS: "linux", Architecture: "arm64"}, } require.Equal(t, "linux", opts.Platform.OS) require.Equal(t, "arm64", opts.Platform.Architecture) }) } // ============================================================================= // repoFromReference tests // ============================================================================= func TestRepoFromReference(t *testing.T) { tests := []struct { input string want string }{ {"r8.im/user/model:latest", "r8.im/user/model"}, {"r8.im/user/model:v1.0", "r8.im/user/model"}, {"r8.im/user/model@sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", "r8.im/user/model"}, {"r8.im/user/model", "r8.im/user/model"}, {"registry.example.com/org/model:tag", "registry.example.com/org/model"}, {"localhost:5000/model:latest", "localhost:5000/model"}, } for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { got := repoFromReference(tt.input) require.Equal(t, tt.want, got) }) } } ================================================ FILE: pkg/model/ref.go ================================================ package model import ( "fmt" "github.com/google/go-containerregistry/pkg/name" "github.com/replicate/cog/pkg/global" ) // ParseOption configures how image references are parsed. type ParseOption func(*parseOptions) type parseOptions struct { nameOpts []name.Option } // Insecure allows parsing references to registries that use HTTP // or have invalid/self-signed certificates. // Use this for local development registries like localhost:5000. func Insecure() ParseOption { return func(o *parseOptions) { o.nameOpts = append(o.nameOpts, name.Insecure) } } // WithDefaultRegistry sets the registry to use when the reference // doesn't include one. By default, Docker Hub (index.docker.io) is used. func WithDefaultRegistry(registry string) ParseOption { return func(o *parseOptions) { o.nameOpts = append(o.nameOpts, name.WithDefaultRegistry(registry)) } } // WithDefaultTag sets the tag to use when the reference doesn't // include one. By default, "latest" is used. func WithDefaultTag(tag string) ParseOption { return func(o *parseOptions) { o.nameOpts = append(o.nameOpts, name.WithDefaultTag(tag)) } } // ParsedRef wraps a validated and parsed image reference. type ParsedRef struct { // Original is the input string before parsing. Original string // Ref is the underlying parsed reference from go-containerregistry. Ref name.Reference } // ParseRef validates and parses an image reference. func ParseRef(ref string, opts ...ParseOption) (*ParsedRef, error) { var po parseOptions for _, opt := range opts { opt(&po) } parsed, err := name.ParseReference(ref, po.nameOpts...) if err != nil { return nil, fmt.Errorf("invalid image reference %q: %w", ref, err) } return &ParsedRef{ Original: ref, Ref: parsed, }, nil } // String returns the fully-qualified canonical reference string. func (p *ParsedRef) String() string { return p.Ref.Name() } // Registry returns the registry host (e.g., "r8.im", "index.docker.io"). func (p *ParsedRef) Registry() string { return p.Ref.Context().RegistryStr() } // Repository returns the repository path (e.g., "user/model", "library/nginx"). func (p *ParsedRef) Repository() string { return p.Ref.Context().RepositoryStr() } // Tag returns the tag (e.g., "v1", "latest") or empty string if this is a digest reference. func (p *ParsedRef) Tag() string { if t, ok := p.Ref.(name.Tag); ok { return t.TagStr() } return "" } // Digest returns the digest (e.g., "sha256:...") or empty string if this is a tag reference. func (p *ParsedRef) Digest() string { if d, ok := p.Ref.(name.Digest); ok { return d.DigestStr() } return "" } // IsTag returns true if the reference includes a tag. func (p *ParsedRef) IsTag() bool { _, ok := p.Ref.(name.Tag) return ok } // IsDigest returns true if the reference includes a digest. func (p *ParsedRef) IsDigest() bool { _, ok := p.Ref.(name.Digest) return ok } // IsReplicate returns true if the registry is the Replicate registry (r8.im). func (p *ParsedRef) IsReplicate() bool { return p.Registry() == global.ReplicateRegistryHost } ================================================ FILE: pkg/model/ref_test.go ================================================ package model import ( "testing" "github.com/stretchr/testify/require" ) func TestParseRef(t *testing.T) { tests := []struct { name string ref string opts []ParseOption wantRegistry string wantRepo string wantTag string wantDigest string wantReplicate bool wantErr bool errContains string }{ { name: "basic tag", ref: "nginx:latest", wantRegistry: "index.docker.io", wantRepo: "library/nginx", wantTag: "latest", wantReplicate: false, }, { name: "implicit latest tag", ref: "nginx", wantRegistry: "index.docker.io", wantRepo: "library/nginx", wantTag: "latest", wantReplicate: false, }, { name: "replicate registry", ref: "r8.im/user/model:v1", wantRegistry: "r8.im", wantRepo: "user/model", wantTag: "v1", wantReplicate: true, }, { name: "replicate registry implicit latest", ref: "r8.im/user/model", wantRegistry: "r8.im", wantRepo: "user/model", wantTag: "latest", wantReplicate: true, }, { name: "non-replicate registry", ref: "ghcr.io/foo/bar:v1", wantRegistry: "ghcr.io", wantRepo: "foo/bar", wantTag: "v1", wantReplicate: false, }, { name: "digest reference", ref: "nginx@sha256:1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", wantRegistry: "index.docker.io", wantRepo: "library/nginx", wantTag: "", wantDigest: "sha256:1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", wantReplicate: false, }, { name: "replicate with digest", ref: "r8.im/user/model@sha256:1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", wantRegistry: "r8.im", wantRepo: "user/model", wantTag: "", wantDigest: "sha256:1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", wantReplicate: true, }, { name: "custom registry with port", ref: "localhost:5000/myimage:test", opts: []ParseOption{Insecure()}, wantRegistry: "localhost:5000", wantRepo: "myimage", wantTag: "test", wantReplicate: false, }, { name: "with default registry option", ref: "user/model:v1", opts: []ParseOption{WithDefaultRegistry("r8.im")}, wantRegistry: "r8.im", wantRepo: "user/model", wantTag: "v1", wantReplicate: true, }, { name: "with default tag option", ref: "nginx", opts: []ParseOption{WithDefaultTag("stable")}, wantRegistry: "index.docker.io", wantRepo: "library/nginx", wantTag: "stable", wantReplicate: false, }, { name: "invalid reference", ref: ":::invalid", wantErr: true, errContains: `invalid image reference ":::invalid"`, }, { name: "empty reference", ref: "", wantErr: true, errContains: `invalid image reference ""`, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { parsed, err := ParseRef(tt.ref, tt.opts...) if tt.wantErr { require.Error(t, err) if tt.errContains != "" { require.Contains(t, err.Error(), tt.errContains) } return } require.NoError(t, err) require.NotNil(t, parsed) require.Equal(t, tt.ref, parsed.Original, "Original should preserve input") require.Equal(t, tt.wantRegistry, parsed.Registry(), "Registry mismatch") require.Equal(t, tt.wantRepo, parsed.Repository(), "Repository mismatch") require.Equal(t, tt.wantTag, parsed.Tag(), "Tag mismatch") require.Equal(t, tt.wantDigest, parsed.Digest(), "Digest mismatch") require.Equal(t, tt.wantReplicate, parsed.IsReplicate(), "IsReplicate mismatch") }) } } func TestParsedRef_IsTag(t *testing.T) { tagRef, err := ParseRef("nginx:latest") require.NoError(t, err) require.True(t, tagRef.IsTag()) require.False(t, tagRef.IsDigest()) digestRef, err := ParseRef("nginx@sha256:1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef") require.NoError(t, err) require.False(t, digestRef.IsTag()) require.True(t, digestRef.IsDigest()) } func TestParsedRef_String(t *testing.T) { tests := []struct { name string ref string wantStr string }{ { name: "bare image gets fully qualified", ref: "nginx", wantStr: "index.docker.io/library/nginx:latest", }, { name: "replicate ref", ref: "r8.im/user/model:v1", wantStr: "r8.im/user/model:v1", }, { name: "digest ref", ref: "nginx@sha256:1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", wantStr: "index.docker.io/library/nginx@sha256:1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { parsed, err := ParseRef(tt.ref) require.NoError(t, err) require.Equal(t, tt.wantStr, parsed.String()) }) } } func TestParseOptions(t *testing.T) { t.Run("multiple options can be combined", func(t *testing.T) { parsed, err := ParseRef("myimage", WithDefaultRegistry("r8.im"), WithDefaultTag("v2"), ) require.NoError(t, err) require.Equal(t, "r8.im", parsed.Registry()) require.Equal(t, "v2", parsed.Tag()) require.True(t, parsed.IsReplicate()) }) t.Run("insecure allows localhost registries", func(t *testing.T) { parsed, err := ParseRef("localhost:5000/test:v1", Insecure()) require.NoError(t, err) require.Equal(t, "localhost:5000", parsed.Registry()) require.Equal(t, "test", parsed.Repository()) require.Equal(t, "v1", parsed.Tag()) }) } ================================================ FILE: pkg/model/ref_types.go ================================================ package model import "context" // Ref represents something that can be resolved to a Model. // This interface enables declarative model resolution - callers specify // "what they have" (a tag, local image, or source to build) and the // Resolver figures out how to produce a Model. type Ref interface { // resolve is unexported to keep the interface internal. // External code uses Resolver.Resolve() instead of calling this directly. resolve(ctx context.Context, r *Resolver) (*Model, error) } // Resolve resolves any Ref to a Model. // This is the main entry point for declarative model resolution. func (r *Resolver) Resolve(ctx context.Context, ref Ref) (*Model, error) { return ref.resolve(ctx, r) } // ============================================================================= // TagRef - resolves an image by tag/digest, trying local then remote // ============================================================================= // TagRef resolves an image by tag or digest reference. // It uses the default Load behavior: try remote registry first, // then fall back to local docker daemon if not found remotely. type TagRef struct { Parsed *ParsedRef } // FromTag parses and validates a tag reference. // Returns an error immediately if the reference is invalid. func FromTag(ref string) (*TagRef, error) { parsed, err := ParseRef(ref) if err != nil { return nil, err } return &TagRef{Parsed: parsed}, nil } func (t *TagRef) resolve(ctx context.Context, r *Resolver) (*Model, error) { // Use default Inspect behavior (PreferRemote) return r.Inspect(ctx, t.Parsed) } // ============================================================================= // LocalRef - explicitly loads from docker daemon only // ============================================================================= // LocalRef resolves an image from the local docker daemon only. // It will not fall back to remote registry if the image is not found locally. type LocalRef struct { Parsed *ParsedRef } // FromLocal parses and validates a reference for local resolution. // Returns an error immediately if the reference is invalid. func FromLocal(ref string) (*LocalRef, error) { parsed, err := ParseRef(ref) if err != nil { return nil, err } return &LocalRef{Parsed: parsed}, nil } func (l *LocalRef) resolve(ctx context.Context, r *Resolver) (*Model, error) { return r.Inspect(ctx, l.Parsed, LocalOnly()) } // ============================================================================= // RemoteRef - explicitly loads from registry only // ============================================================================= // RemoteRef resolves an image from a remote registry only. // It will not check the local docker daemon. type RemoteRef struct { Parsed *ParsedRef } // FromRemote parses and validates a reference for remote resolution. // Returns an error immediately if the reference is invalid. func FromRemote(ref string) (*RemoteRef, error) { parsed, err := ParseRef(ref) if err != nil { return nil, err } return &RemoteRef{Parsed: parsed}, nil } func (rr *RemoteRef) resolve(ctx context.Context, r *Resolver) (*Model, error) { return r.Inspect(ctx, rr.Parsed, RemoteOnly()) } // ============================================================================= // BuildRef - creates a Model by building from source // ============================================================================= // BuildRef resolves to a Model by building from source. // This wraps a Source and BuildOptions for deferred building. type BuildRef struct { Source *Source Options BuildOptions } // FromBuild creates a BuildRef from source and options. // Unlike the other From* functions, this does not validate eagerly - // validation happens at build time. func FromBuild(src *Source, opts BuildOptions) *BuildRef { return &BuildRef{Source: src, Options: opts} } func (b *BuildRef) resolve(ctx context.Context, r *Resolver) (*Model, error) { return r.Build(ctx, b.Source, b.Options) } ================================================ FILE: pkg/model/ref_types_test.go ================================================ package model import ( "context" "errors" "testing" "github.com/docker/docker/api/types/image" dockerspec "github.com/moby/docker-image-spec/specs-go/v1" ocispec "github.com/opencontainers/image-spec/specs-go/v1" "github.com/stretchr/testify/require" "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/registry" ) // ============================================================================= // FromTag tests // ============================================================================= func TestFromTag_ValidRef(t *testing.T) { ref, err := FromTag("my-image:latest") require.NoError(t, err) require.NotNil(t, ref) require.NotNil(t, ref.Parsed) require.Equal(t, "my-image:latest", ref.Parsed.Original) } func TestFromTag_ValidRefWithRegistry(t *testing.T) { ref, err := FromTag("r8.im/user/model:v1") require.NoError(t, err) require.NotNil(t, ref) require.Equal(t, "r8.im", ref.Parsed.Registry()) require.Equal(t, "v1", ref.Parsed.Tag()) } func TestFromTag_InvalidRef(t *testing.T) { ref, err := FromTag("INVALID::REF") require.Error(t, err) require.Nil(t, ref) require.Contains(t, err.Error(), "invalid image reference") } // ============================================================================= // FromLocal tests // ============================================================================= func TestFromLocal_ValidRef(t *testing.T) { ref, err := FromLocal("my-image:latest") require.NoError(t, err) require.NotNil(t, ref) require.NotNil(t, ref.Parsed) require.Equal(t, "my-image:latest", ref.Parsed.Original) } func TestFromLocal_InvalidRef(t *testing.T) { ref, err := FromLocal("INVALID::REF") require.Error(t, err) require.Nil(t, ref) require.Contains(t, err.Error(), "invalid image reference") } // ============================================================================= // FromRemote tests // ============================================================================= func TestFromRemote_ValidRef(t *testing.T) { ref, err := FromRemote("r8.im/user/model") require.NoError(t, err) require.NotNil(t, ref) require.NotNil(t, ref.Parsed) require.Equal(t, "r8.im", ref.Parsed.Registry()) } func TestFromRemote_InvalidRef(t *testing.T) { ref, err := FromRemote("INVALID::REF") require.Error(t, err) require.Nil(t, ref) require.Contains(t, err.Error(), "invalid image reference") } // ============================================================================= // FromBuild tests // ============================================================================= func TestFromBuild(t *testing.T) { src := &Source{ Config: &config.Config{Predict: "predict.py:Predictor"}, ProjectDir: "/path/to/project", } opts := BuildOptions{ ImageName: "my-built-image:latest", NoCache: true, } ref := FromBuild(src, opts) require.NotNil(t, ref) require.Same(t, src, ref.Source) require.Equal(t, "my-built-image:latest", ref.Options.ImageName) require.True(t, ref.Options.NoCache) } func TestFromBuild_NilSource(t *testing.T) { // FromBuild should accept nil source - validation happens at resolve time ref := FromBuild(nil, BuildOptions{ImageName: "test"}) require.NotNil(t, ref) require.Nil(t, ref.Source) } // ============================================================================= // TagRef.resolve tests // ============================================================================= func TestTagRef_Resolve_Success(t *testing.T) { docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { return &image.InspectResponse{ ID: "sha256:abc123", Config: &dockerspec.DockerOCIImageConfig{ ImageConfig: ocispec.ImageConfig{ Labels: map[string]string{ LabelConfig: `{"build":{"python_version":"3.11"}}`, LabelVersion: "0.10.0", }, }, }, }, nil }, } resolver := NewResolver(docker, &mockRegistry{}) ref, err := FromTag("my-image:latest") require.NoError(t, err) model, err := resolver.Resolve(context.Background(), ref) require.NoError(t, err) require.NotNil(t, model) require.Equal(t, "0.10.0", model.CogVersion) } func TestTagRef_Resolve_FallsBackToLocal(t *testing.T) { localCalled := false remoteCalled := false docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { localCalled = true return &image.InspectResponse{ ID: "sha256:local123", Config: &dockerspec.DockerOCIImageConfig{ ImageConfig: ocispec.ImageConfig{ Labels: map[string]string{ LabelConfig: `{"build":{"python_version":"3.11"}}`, LabelVersion: "0.10.0", }, }, }, }, nil }, } reg := &mockRegistry{ inspectFunc: func(ctx context.Context, ref string, platform *registry.Platform) (*registry.ManifestResult, error) { remoteCalled = true return nil, registry.NotFoundError }, } resolver := NewResolver(docker, reg) ref, err := FromTag("my-image:latest") require.NoError(t, err) model, err := resolver.Resolve(context.Background(), ref) require.NoError(t, err) require.NotNil(t, model) require.True(t, remoteCalled, "TagRef should try remote first") require.True(t, localCalled, "TagRef should fall back to local") require.Equal(t, ImageSourceLocal, model.Image.Source) } // ============================================================================= // LocalRef.resolve tests // ============================================================================= func TestLocalRef_Resolve_Success(t *testing.T) { docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { return &image.InspectResponse{ ID: "sha256:local123", Config: &dockerspec.DockerOCIImageConfig{ ImageConfig: ocispec.ImageConfig{ Labels: map[string]string{ LabelConfig: `{"build":{"python_version":"3.11"}}`, LabelVersion: "0.9.0", }, }, }, }, nil }, } resolver := NewResolver(docker, &mockRegistry{}) ref, err := FromLocal("my-image:latest") require.NoError(t, err) model, err := resolver.Resolve(context.Background(), ref) require.NoError(t, err) require.NotNil(t, model) require.Equal(t, ImageSourceLocal, model.Image.Source) require.Equal(t, "0.9.0", model.CogVersion) } func TestLocalRef_Resolve_NotFound(t *testing.T) { docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { return nil, errors.New("No such image: my-image:latest") }, } reg := &mockRegistry{ inspectFunc: func(ctx context.Context, ref string, platform *registry.Platform) (*registry.ManifestResult, error) { t.Fatal("LocalRef should not fall back to remote") return nil, nil }, } resolver := NewResolver(docker, reg) ref, err := FromLocal("my-image:latest") require.NoError(t, err) _, err = resolver.Resolve(context.Background(), ref) require.Error(t, err) require.Contains(t, err.Error(), "not found locally") } // ============================================================================= // RemoteRef.resolve tests // ============================================================================= func TestRemoteRef_Resolve_Success(t *testing.T) { docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { t.Fatal("RemoteRef should not check local docker") return nil, nil }, } reg := &mockRegistry{ inspectFunc: func(ctx context.Context, ref string, platform *registry.Platform) (*registry.ManifestResult, error) { return ®istry.ManifestResult{ SchemaVersion: 2, Labels: map[string]string{ LabelConfig: `{"build":{"python_version":"3.11"}}`, LabelVersion: "0.10.0", }, }, nil }, } resolver := NewResolver(docker, reg) ref, err := FromRemote("r8.im/user/model") require.NoError(t, err) model, err := resolver.Resolve(context.Background(), ref) require.NoError(t, err) require.NotNil(t, model) require.Equal(t, ImageSourceRemote, model.Image.Source) require.Equal(t, "0.10.0", model.CogVersion) } func TestRemoteRef_Resolve_NotFound(t *testing.T) { docker := &mockDocker{} reg := &mockRegistry{ inspectFunc: func(ctx context.Context, ref string, platform *registry.Platform) (*registry.ManifestResult, error) { return nil, registry.NotFoundError }, } resolver := NewResolver(docker, reg) ref, err := FromRemote("r8.im/user/model") require.NoError(t, err) _, err = resolver.Resolve(context.Background(), ref) require.Error(t, err) require.Contains(t, err.Error(), "not found in registry") } // ============================================================================= // BuildRef.resolve tests // ============================================================================= func TestBuildRef_Resolve_Success(t *testing.T) { buildCalled := false docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { return &image.InspectResponse{ ID: "sha256:a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", Config: &dockerspec.DockerOCIImageConfig{ ImageConfig: ocispec.ImageConfig{ Labels: map[string]string{ LabelVersion: "0.11.0", LabelConfig: `{"build":{"gpu":true}}`, }, }, }, }, nil }, } factory := &mockFactory{ name: "test", buildFunc: func(ctx context.Context, src *Source, opts BuildOptions) (*ImageArtifact, error) { buildCalled = true require.Equal(t, "my-built-image", opts.ImageName) return &ImageArtifact{Reference: opts.ImageName, Source: ImageSourceBuild}, nil }, } resolver := NewResolver(docker, &mockRegistry{}).WithFactory(factory) src := &Source{ Config: &config.Config{Predict: "predict.py:Predictor"}, ProjectDir: "/tmp/test", } ref := FromBuild(src, BuildOptions{ImageName: "my-built-image"}) model, err := resolver.Resolve(context.Background(), ref) require.NoError(t, err) require.NotNil(t, model) require.True(t, buildCalled, "BuildRef should call factory.Build") require.Equal(t, "0.11.0", model.CogVersion) } func TestBuildRef_Resolve_BuildError(t *testing.T) { factory := &mockFactory{ name: "test", buildFunc: func(ctx context.Context, src *Source, opts BuildOptions) (*ImageArtifact, error) { return nil, errors.New("build failed: missing dependencies") }, } resolver := NewResolver(&mockDocker{}, &mockRegistry{}).WithFactory(factory) src := &Source{ Config: &config.Config{}, ProjectDir: "/tmp/test", } ref := FromBuild(src, BuildOptions{ImageName: "my-image"}) _, err := resolver.Resolve(context.Background(), ref) require.Error(t, err) require.Contains(t, err.Error(), "build failed") } // ============================================================================= // Resolver.Resolve dispatch tests // ============================================================================= func TestResolver_Resolve_DispatchesCorrectly(t *testing.T) { // This test verifies that Resolver.Resolve correctly dispatches to each Ref type tests := []struct { name string ref Ref expectLocal bool }{ { name: "TagRef dispatches to Load (default behavior)", ref: func() Ref { r, _ := FromTag("my-image:latest") return r }(), expectLocal: true, // TagRef tries remote first, falls back to local }, { name: "LocalRef dispatches to local only", ref: func() Ref { r, _ := FromLocal("my-image:latest") return r }(), expectLocal: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { localCalled := false docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { localCalled = true return &image.InspectResponse{ ID: "sha256:test", Config: &dockerspec.DockerOCIImageConfig{ ImageConfig: ocispec.ImageConfig{ Labels: map[string]string{ LabelConfig: `{"build":{"python_version":"3.11"}}`, LabelVersion: "0.10.0", }, }, }, }, nil }, } resolver := NewResolver(docker, &mockRegistry{}) _, err := resolver.Resolve(context.Background(), tt.ref) require.NoError(t, err) require.Equal(t, tt.expectLocal, localCalled) }) } } // ============================================================================= // Ref interface compile-time checks // ============================================================================= // Compile-time check that all types implement Ref interface var ( _ Ref = (*TagRef)(nil) _ Ref = (*LocalRef)(nil) _ Ref = (*RemoteRef)(nil) _ Ref = (*BuildRef)(nil) ) ================================================ FILE: pkg/model/resolver.go ================================================ package model import ( "context" "errors" "fmt" "path/filepath" "strings" "github.com/docker/docker/api/types/image" "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/docker/command" "github.com/replicate/cog/pkg/registry" ) // Option configures how Resolver methods behave. type Option func(*options) type options struct { localOnly bool remoteOnly bool preferLocal bool // default: true platform *registry.Platform } func defaultOptions() *options { return &options{} // Default: preferRemote (try registry first, fall back to local) } // LocalOnly loads only from the local docker daemon. // Returns an error if the image is not found locally. func LocalOnly() Option { return func(o *options) { o.localOnly = true o.remoteOnly = false o.preferLocal = false } } // RemoteOnly loads only from the remote registry. // Does not check the local docker daemon. func RemoteOnly() Option { return func(o *options) { o.remoteOnly = true o.localOnly = false o.preferLocal = false } } // PreferLocal tries local docker daemon first, falls back to remote on not-found. func PreferLocal() Option { return func(o *options) { o.preferLocal = true o.localOnly = false o.remoteOnly = false } } // PreferRemote tries remote registry first, falls back to local on not-found. // This is the default behavior. func PreferRemote() Option { return func(o *options) { o.preferLocal = false o.localOnly = false o.remoteOnly = false } } // WithPlatform sets the platform for remote registry queries. func WithPlatform(p *registry.Platform) Option { return func(o *options) { o.platform = p } } // Resolver orchestrates building and loading Models. type Resolver struct { docker command.Command registry registry.Client factory Factory imagePusher *ImagePusher } // NewResolver creates a Resolver with the default factory. func NewResolver(docker command.Command, reg registry.Client) *Resolver { return &Resolver{ docker: docker, registry: reg, factory: defaultFactory(docker, reg), imagePusher: newImagePusher(docker, reg), } } // WithFactory sets a custom factory and returns the Resolver for chaining. func (r *Resolver) WithFactory(f Factory) *Resolver { r.factory = f return r } // Inspect returns Model metadata for a parsed ref without pulling. // By default (PreferLocal), tries local docker daemon first, then remote registry. // Only falls back on "not found" errors; real errors (docker down, auth) are surfaced. // Returns ErrNotCogModel if the image is not a valid Cog model. func (r *Resolver) Inspect(ctx context.Context, ref *ParsedRef, opts ...Option) (*Model, error) { o := defaultOptions() for _, opt := range opts { opt(o) } switch { case o.localOnly: return r.loadLocal(ctx, ref) case o.remoteOnly: return r.loadRemote(ctx, ref, o.platform) case o.preferLocal: model, localErr := r.loadLocal(ctx, ref) if localErr == nil { return model, nil } // Check the underlying error before the wrapper adds "not found" text if !isNotFoundError(errors.Unwrap(localErr)) { return nil, localErr // Real error, don't mask it } return r.loadRemote(ctx, ref, o.platform) default: // PreferRemote model, remoteErr := r.loadRemote(ctx, ref, o.platform) if remoteErr == nil { return model, nil } // Check the underlying error before the wrapper adds "not found" text if !isNotFoundError(errors.Unwrap(remoteErr)) { return nil, remoteErr } return r.loadLocal(ctx, ref) } } // InspectByID returns Model metadata from the local docker daemon by image ID. // This supports both full IDs (sha256:...) and short IDs (e.g., "9056219a5fb2"). // Use this when you have an image ID rather than a tagged reference. // Returns ErrNotCogModel if the image is not a valid Cog model. func (r *Resolver) InspectByID(ctx context.Context, id string) (*Model, error) { resp, err := r.docker.Inspect(ctx, id) if err != nil { return nil, fmt.Errorf("failed to load image by ID %s: %w", id, err) } // Use the canonical ID from the response as the reference img := &ImageArtifact{ Reference: resp.ID, Digest: resp.ID, Labels: resp.Config.Labels, Source: ImageSourceLocal, } model, err := img.ToModel() if err != nil { return nil, fmt.Errorf("image %s: %w", id, err) } return model, nil } // Pull ensures a Model is locally available for running. // It first checks if the image exists locally. If not, it pulls from the registry. // Returns ErrNotCogModel if the image is not a valid Cog model. // Returns ErrNotFound if the image cannot be found locally or remotely. func (r *Resolver) Pull(ctx context.Context, ref *ParsedRef, opts ...Option) (*Model, error) { o := defaultOptions() for _, opt := range opts { opt(o) } // First, try to inspect locally model, err := r.loadLocal(ctx, ref) if err == nil { return model, nil } // If local-only mode, don't try to pull if o.localOnly { return nil, fmt.Errorf("image %s: %w", ref.Original, ErrNotFound) } // If local image exists but isn't a Cog model, don't try to pull // (pulling won't change the existing image) if errors.Is(err, ErrNotCogModel) { return nil, err } // Check if it's a "not found" error (safe to try pull) if !isNotFoundError(errors.Unwrap(err)) { // Real error (connection refused, etc.) - don't mask it return nil, err } // Pull the image // TODO: Support platform option for multi-platform images _, err = r.docker.Pull(ctx, ref.String(), false) if err != nil { if isNotFoundError(err) { return nil, fmt.Errorf("image %s: %w", ref.Original, ErrNotFound) } return nil, fmt.Errorf("failed to pull image %s: %w", ref.Original, err) } // Inspect the now-local image return r.loadLocal(ctx, ref) } // Build creates a Model by building from source. func (r *Resolver) Build(ctx context.Context, src *Source, opts BuildOptions) (*Model, error) { if src == nil { return nil, fmt.Errorf("source is required for Build") } if src.Config == nil { return nil, fmt.Errorf("source.Config is required for Build") } if src.ProjectDir == "" { return nil, fmt.Errorf("source.ProjectDir is required for Build") } opts = opts.WithDefaults(src) // Build image artifact via ImageBuilder ib := NewImageBuilder(r.factory, r.docker, src, opts) imageSpec := NewImageSpec("model", opts.ImageName) imgResult, err := ib.Build(ctx, imageSpec) if err != nil { return nil, err } ia, ok := imgResult.(*ImageArtifact) if !ok { return nil, fmt.Errorf("unexpected artifact type from image builder: %T", imgResult) } m, err := r.modelFromImage(ia, src.Config) if err != nil { return nil, err } m.OCIIndex = opts.OCIIndex m.Artifacts = []Artifact{ia} // Build weight artifacts if OCI index mode is enabled lockPath := opts.WeightsLockPath if lockPath == "" { lockPath = filepath.Join(src.ProjectDir, WeightsLockFilename) } if opts.OCIIndex && len(src.Config.Weights) > 0 { wb := NewWeightBuilder(src, m.CogVersion, lockPath) for _, ws := range src.Config.Weights { spec := NewWeightSpec(ws.Name, ws.Source, ws.Target) artifact, buildErr := wb.Build(ctx, spec) if buildErr != nil { return nil, fmt.Errorf("build weight %q: %w", ws.Name, buildErr) } m.Artifacts = append(m.Artifacts, artifact) } } return m, nil } // Push pushes a Model to a container registry. // // Uses the OCI chunked push path (via ImagePusher) which bypasses Docker's // monolithic push and supports layers of any size through chunked uploads. // Falls back to legacy Docker push if OCI push is not available. func (r *Resolver) Push(ctx context.Context, m *Model, opts PushOptions) error { if m.OCIIndex { pusher := NewBundlePusher(r.docker, r.registry) return pusher.Push(ctx, m, opts) } imgArtifact := m.GetImageArtifact() if imgArtifact == nil { return fmt.Errorf("no image artifact in model") } var imagePushOpts []ImagePushOption if opts.ImageProgressFn != nil { imagePushOpts = append(imagePushOpts, WithProgressFn(opts.ImageProgressFn)) } if opts.OnFallback != nil { imagePushOpts = append(imagePushOpts, WithOnFallback(opts.OnFallback)) } return r.imagePusher.Push(ctx, imgArtifact, imagePushOpts...) } // loadLocal loads a Model from the local docker daemon. func (r *Resolver) loadLocal(ctx context.Context, ref *ParsedRef) (*Model, error) { resp, err := r.docker.Inspect(ctx, ref.String()) if err != nil { if isNotFoundError(err) { return nil, fmt.Errorf("image %s not found locally: %w", ref.Original, err) } return nil, fmt.Errorf("failed to inspect local image %s: %w", ref.Original, err) } return r.modelFromInspect(ref, resp, ImageSourceLocal) } // loadRemote loads a Model from the remote registry. func (r *Resolver) loadRemote(ctx context.Context, ref *ParsedRef, platform *registry.Platform) (*Model, error) { manifest, err := r.registry.Inspect(ctx, ref.String(), platform) if err != nil { if errors.Is(err, registry.NotFoundError) { return nil, fmt.Errorf("image %s not found in registry: %w", ref.Original, err) } return nil, fmt.Errorf("failed to inspect remote image %s: %w", ref.Original, err) } return r.modelFromManifest(ref, manifest, ImageSourceRemote) } // modelFromImage creates a Model from ImageArtifact with a known config (post-build). // Uses the provided config rather than parsing from labels. func (r *Resolver) modelFromImage(img *ImageArtifact, cfg *config.Config) (*Model, error) { schema, err := img.ParsedOpenAPISchema() if err != nil { return nil, fmt.Errorf("failed to parse schema from image labels: %w", err) } return &Model{ Image: img, Config: cfg, Schema: schema, CogVersion: img.CogVersion(), }, nil } // modelFromInspect creates a Model from docker inspect response. // Returns ErrNotCogModel if the image is not a valid Cog model. func (r *Resolver) modelFromInspect(ref *ParsedRef, resp *image.InspectResponse, source ImageSource) (*Model, error) { img := &ImageArtifact{ Reference: ref.String(), Digest: resp.ID, Labels: resp.Config.Labels, Source: source, } model, err := img.ToModel() if err != nil { return nil, fmt.Errorf("image %s: %w", ref.Original, err) } return model, nil } // modelFromManifest creates a Model from registry manifest. // Returns ErrNotCogModel if the image is not a valid Cog model. func (r *Resolver) modelFromManifest(ref *ParsedRef, manifest *registry.ManifestResult, source ImageSource) (*Model, error) { // Check if this is an OCI Index (v2 format) if isOCIIndex(manifest) { return r.modelFromIndex(ref, manifest, source) } // Standard image (v1 format) img := &ImageArtifact{ Reference: ref.String(), Digest: manifest.Config, // Config digest serves as image ID Labels: manifest.Labels, Source: source, } m, err := img.ToModel() if err != nil { return nil, fmt.Errorf("image %s: %w", ref.Original, err) } return m, nil } // modelFromIndex creates a Model from an OCI Image Index. // It extracts the image manifest and weights manifest from the index. func (r *Resolver) modelFromIndex(ref *ParsedRef, manifest *registry.ManifestResult, source ImageSource) (*Model, error) { // Find the image manifest (skip unknown/unknown platform artifacts) imgManifest := findImageManifest(manifest.Manifests, nil) if imgManifest == nil { return nil, fmt.Errorf("no image manifest found in index %s", ref.Original) } // Create ImageArtifact from the image manifest img := &ImageArtifact{ Reference: ref.String(), Digest: imgManifest.Digest, Labels: manifest.Labels, // Labels come from the index inspection Source: source, Platform: &Platform{ OS: imgManifest.OS, Architecture: imgManifest.Architecture, Variant: imgManifest.Variant, }, } m, err := img.ToModel() if err != nil { return nil, fmt.Errorf("image %s: %w", ref.Original, err) } m.Index = &Index{ Digest: manifest.Digest, // Content-addressable digest from registry Reference: ref.String(), MediaType: manifest.MediaType, Manifests: make([]IndexManifest, len(manifest.Manifests)), } // Populate index manifests for i, pm := range manifest.Manifests { im := IndexManifest{ Digest: pm.Digest, MediaType: pm.MediaType, Size: pm.Size, Annotations: pm.Annotations, } if pm.OS != "" { im.Platform = &Platform{ OS: pm.OS, Architecture: pm.Architecture, Variant: pm.Variant, } } // Determine manifest type if pm.OS == PlatformUnknown && pm.Annotations != nil && pm.Annotations[AnnotationReferenceType] == AnnotationValueWeights { im.Type = ManifestTypeWeights } else { im.Type = ManifestTypeImage } m.Index.Manifests[i] = im } return m, nil } // isOCIIndex checks if the manifest result is an OCI Image Index. func isOCIIndex(mr *registry.ManifestResult) bool { return mr.IsIndex() } // findWeightsManifest finds the weights manifest in an index. // Returns nil if no weights manifest is found. func findWeightsManifest(manifests []registry.PlatformManifest) *registry.PlatformManifest { for i := range manifests { m := &manifests[i] if m.Annotations != nil && m.Annotations[AnnotationReferenceType] == AnnotationValueWeights { return m } } return nil } // findImageManifest finds the model image manifest in an index. // If platform is specified, matches on OS/Architecture. // Skips artifacts (platform: unknown/unknown). func findImageManifest(manifests []registry.PlatformManifest, platform *registry.Platform) *registry.PlatformManifest { for i := range manifests { m := &manifests[i] // Skip artifacts (unknown platform) if m.OS == PlatformUnknown { continue } // Match platform if specified if platform != nil { if m.OS != platform.OS || m.Architecture != platform.Architecture { continue } } return m } return nil } // isNotFoundError checks if an error indicates "not found" vs a real error. // Only "not found" errors should trigger fallback to alternative source. func isNotFoundError(err error) bool { if err == nil { return false } // Don't treat context errors as "not found" if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return false } // Check for registry NotFoundError if errors.Is(err, registry.NotFoundError) { return true } // Check for common not-found patterns in error strings errStr := err.Error() return strings.Contains(errStr, "not found") || strings.Contains(errStr, "No such image") || strings.Contains(errStr, "manifest unknown") || strings.Contains(errStr, "NAME_UNKNOWN") } ================================================ FILE: pkg/model/resolver_test.go ================================================ package model import ( "context" "errors" "io" "os" "path/filepath" "testing" "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/image" v1 "github.com/google/go-containerregistry/pkg/v1" dockerspec "github.com/moby/docker-image-spec/specs-go/v1" ocispec "github.com/opencontainers/image-spec/specs-go/v1" "github.com/stretchr/testify/require" "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/docker/command" "github.com/replicate/cog/pkg/registry" ) // mockDocker implements command.Command for testing. type mockDocker struct { inspectFunc func(ctx context.Context, ref string) (*image.InspectResponse, error) pullFunc func(ctx context.Context, ref string, force bool) (*image.InspectResponse, error) pushFunc func(ctx context.Context, ref string) error imageSaveFunc func(ctx context.Context, imageRef string) (io.ReadCloser, error) } func (m *mockDocker) Inspect(ctx context.Context, ref string) (*image.InspectResponse, error) { if m.inspectFunc != nil { return m.inspectFunc(ctx, ref) } return nil, errors.New("not implemented") } func (m *mockDocker) Pull(ctx context.Context, ref string, force bool) (*image.InspectResponse, error) { if m.pullFunc != nil { return m.pullFunc(ctx, ref, force) } return nil, errors.New("mockDocker.Pull not implemented") } func (m *mockDocker) Push(ctx context.Context, ref string) error { if m.pushFunc != nil { return m.pushFunc(ctx, ref) } return errors.New("mockDocker.Push not implemented") } func (m *mockDocker) LoadUserInformation(ctx context.Context, registryHost string) (*command.UserInfo, error) { return nil, errors.New("mockDocker.LoadUserInformation not implemented") } func (m *mockDocker) ImageExists(ctx context.Context, ref string) (bool, error) { return false, errors.New("mockDocker.ImageExists not implemented") } func (m *mockDocker) ContainerLogs(ctx context.Context, containerID string, w io.Writer) error { return errors.New("mockDocker.ContainerLogs not implemented") } func (m *mockDocker) ContainerInspect(ctx context.Context, id string) (*container.InspectResponse, error) { return nil, errors.New("mockDocker.ContainerInspect not implemented") } func (m *mockDocker) ContainerStop(ctx context.Context, containerID string) error { return errors.New("mockDocker.ContainerStop not implemented") } func (m *mockDocker) RemoveImage(ctx context.Context, ref string) error { return errors.New("mockDocker.RemoveImage not implemented") } func (m *mockDocker) ImageBuild(ctx context.Context, options command.ImageBuildOptions) (string, error) { return "", errors.New("mockDocker.ImageBuild not implemented") } func (m *mockDocker) Run(ctx context.Context, options command.RunOptions) error { return errors.New("mockDocker.Run not implemented") } func (m *mockDocker) ContainerStart(ctx context.Context, options command.RunOptions) (string, error) { return "", errors.New("mockDocker.ContainerStart not implemented") } func (m *mockDocker) ImageSave(ctx context.Context, imageRef string) (io.ReadCloser, error) { if m.imageSaveFunc != nil { return m.imageSaveFunc(ctx, imageRef) } return nil, errors.New("mockDocker.ImageSave not implemented") } // mockRegistry implements registry.Client for testing. type mockRegistry struct { inspectFunc func(ctx context.Context, ref string, platform *registry.Platform) (*registry.ManifestResult, error) getImageFunc func(ctx context.Context, ref string, platform *registry.Platform) (v1.Image, error) getDescriptorFunc func(ctx context.Context, ref string) (v1.Descriptor, error) pushImageFunc func(ctx context.Context, ref string, img v1.Image) error pushIndexFunc func(ctx context.Context, ref string, idx v1.ImageIndex) error writeLayerFunc func(ctx context.Context, opts registry.WriteLayerOptions) error } func (m *mockRegistry) Inspect(ctx context.Context, ref string, platform *registry.Platform) (*registry.ManifestResult, error) { if m.inspectFunc != nil { return m.inspectFunc(ctx, ref, platform) } return nil, registry.NotFoundError } func (m *mockRegistry) GetImage(ctx context.Context, ref string, platform *registry.Platform) (v1.Image, error) { if m.getImageFunc != nil { return m.getImageFunc(ctx, ref, platform) } return nil, errors.New("mockRegistry.GetImage not implemented") } func (m *mockRegistry) GetDescriptor(ctx context.Context, ref string) (v1.Descriptor, error) { if m.getDescriptorFunc != nil { return m.getDescriptorFunc(ctx, ref) } return v1.Descriptor{}, errors.New("mockRegistry.GetDescriptor not implemented") } func (m *mockRegistry) Exists(ctx context.Context, ref string) (bool, error) { return false, errors.New("mockRegistry.Exists not implemented") } func (m *mockRegistry) PushImage(ctx context.Context, ref string, img v1.Image) error { if m.pushImageFunc != nil { return m.pushImageFunc(ctx, ref, img) } return errors.New("mockRegistry.PushImage not implemented") } func (m *mockRegistry) PushIndex(ctx context.Context, ref string, idx v1.ImageIndex) error { if m.pushIndexFunc != nil { return m.pushIndexFunc(ctx, ref, idx) } return errors.New("mockRegistry.PushIndex not implemented") } func (m *mockRegistry) WriteLayer(ctx context.Context, opts registry.WriteLayerOptions) error { if m.writeLayerFunc != nil { return m.writeLayerFunc(ctx, opts) } // Default: no-op. The caller (WeightPusher) owns closing ProgressCh. return nil } // mockFactory implements Factory for testing. type mockFactory struct { name string buildFunc func(ctx context.Context, src *Source, opts BuildOptions) (*ImageArtifact, error) } func (f *mockFactory) Build(ctx context.Context, src *Source, opts BuildOptions) (*ImageArtifact, error) { if f.buildFunc != nil { return f.buildFunc(ctx, src, opts) } return &ImageArtifact{Reference: opts.ImageName, Source: ImageSourceBuild}, nil } func (f *mockFactory) Name() string { return f.name } func TestNewResolver(t *testing.T) { docker := &mockDocker{} reg := &mockRegistry{} resolver := NewResolver(docker, reg) require.NotNil(t, resolver) require.Equal(t, "dockerfile", resolver.factory.Name()) } func TestResolver_WithFactory(t *testing.T) { docker := &mockDocker{} reg := &mockRegistry{} resolver := NewResolver(docker, reg) require.Equal(t, "dockerfile", resolver.factory.Name()) customFactory := &mockFactory{name: "custom"} result := resolver.WithFactory(customFactory) // WithFactory returns the same resolver for chaining require.Same(t, resolver, result) require.Equal(t, "custom", resolver.factory.Name()) } func TestResolver_Inspect_LocalOnly_Found(t *testing.T) { docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { return &image.InspectResponse{ ID: "sha256:abc123", Config: &dockerspec.DockerOCIImageConfig{ ImageConfig: ocispec.ImageConfig{ Labels: map[string]string{ LabelConfig: `{"build":{"python_version":"3.11"}}`, LabelVersion: "0.10.0", }, }, }, }, nil }, } reg := &mockRegistry{} resolver := NewResolver(docker, reg) ref, err := ParseRef("my-image:latest") require.NoError(t, err) model, err := resolver.Inspect(context.Background(), ref, LocalOnly()) require.NoError(t, err) require.NotNil(t, model) require.Equal(t, ImageSourceLocal, model.Image.Source) require.Equal(t, "0.10.0", model.CogVersion) require.NotNil(t, model.Config) require.Equal(t, "3.11", model.Config.Build.PythonVersion) } func TestResolver_Inspect_LocalOnly_NotFound(t *testing.T) { docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { return nil, errors.New("No such image: my-image:latest") }, } reg := &mockRegistry{ inspectFunc: func(ctx context.Context, ref string, platform *registry.Platform) (*registry.ManifestResult, error) { t.Fatal("should not call registry when LocalOnly") return nil, nil }, } resolver := NewResolver(docker, reg) ref, err := ParseRef("my-image:latest") require.NoError(t, err) _, err = resolver.Inspect(context.Background(), ref, LocalOnly()) require.Error(t, err) require.Contains(t, err.Error(), "not found locally") } func TestResolver_Inspect_RemoteOnly_Found(t *testing.T) { docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { t.Fatal("should not call docker.Inspect when RemoteOnly") return nil, nil }, } reg := &mockRegistry{ inspectFunc: func(ctx context.Context, ref string, platform *registry.Platform) (*registry.ManifestResult, error) { return ®istry.ManifestResult{ SchemaVersion: 2, MediaType: "application/vnd.docker.distribution.manifest.v2+json", Config: "sha256:configdigest", Labels: map[string]string{ LabelConfig: `{"build":{"python_version":"3.11"}}`, LabelVersion: "0.10.0", }, }, nil }, } resolver := NewResolver(docker, reg) ref, err := ParseRef("r8.im/user/model") require.NoError(t, err) model, err := resolver.Inspect(context.Background(), ref, RemoteOnly()) require.NoError(t, err) require.NotNil(t, model) require.Equal(t, ImageSourceRemote, model.Image.Source) require.Equal(t, "0.10.0", model.CogVersion) } func TestResolver_Inspect_RemoteOnly_NotFound(t *testing.T) { docker := &mockDocker{} reg := &mockRegistry{ inspectFunc: func(ctx context.Context, ref string, platform *registry.Platform) (*registry.ManifestResult, error) { return nil, registry.NotFoundError }, } resolver := NewResolver(docker, reg) ref, err := ParseRef("r8.im/user/model") require.NoError(t, err) _, err = resolver.Inspect(context.Background(), ref, RemoteOnly()) require.Error(t, err) require.Contains(t, err.Error(), "not found in registry") } func TestResolver_Inspect_RemoteOnly_NotCogModel(t *testing.T) { docker := &mockDocker{} reg := &mockRegistry{ inspectFunc: func(ctx context.Context, ref string, platform *registry.Platform) (*registry.ManifestResult, error) { return ®istry.ManifestResult{ SchemaVersion: 2, MediaType: "application/vnd.docker.distribution.manifest.v2+json", Config: "sha256:configdigest", Labels: map[string]string{ // No Cog labels - just a regular image "maintainer": "someone@example.com", }, }, nil }, } resolver := NewResolver(docker, reg) ref, err := ParseRef("nginx:latest") require.NoError(t, err) _, err = resolver.Inspect(context.Background(), ref, RemoteOnly()) require.Error(t, err) require.ErrorIs(t, err, ErrNotCogModel) } func TestResolver_Inspect_PreferLocal_FoundLocally(t *testing.T) { localCalled := false remoteCalled := false docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { localCalled = true return &image.InspectResponse{ ID: "sha256:local123", Config: &dockerspec.DockerOCIImageConfig{ ImageConfig: ocispec.ImageConfig{ Labels: map[string]string{ LabelConfig: `{"build":{"python_version":"3.11"}}`, LabelVersion: "0.9.0", }, }, }, }, nil }, } reg := &mockRegistry{ inspectFunc: func(ctx context.Context, ref string, platform *registry.Platform) (*registry.ManifestResult, error) { remoteCalled = true return nil, nil }, } resolver := NewResolver(docker, reg) ref, err := ParseRef("my-image:latest") require.NoError(t, err) model, err := resolver.Inspect(context.Background(), ref, PreferLocal()) require.NoError(t, err) require.True(t, localCalled, "should try local first") require.False(t, remoteCalled, "should not call remote when local succeeds") require.Equal(t, ImageSourceLocal, model.Image.Source) } func TestResolver_Inspect_PreferLocal_Fallback(t *testing.T) { localCalled := false remoteCalled := false docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { localCalled = true return nil, errors.New("No such image") }, } reg := &mockRegistry{ inspectFunc: func(ctx context.Context, ref string, platform *registry.Platform) (*registry.ManifestResult, error) { remoteCalled = true return ®istry.ManifestResult{ SchemaVersion: 2, Labels: map[string]string{ LabelConfig: `{"build":{"python_version":"3.11"}}`, LabelVersion: "0.10.0", }, }, nil }, } resolver := NewResolver(docker, reg) ref, err := ParseRef("my-image:latest") require.NoError(t, err) model, err := resolver.Inspect(context.Background(), ref, PreferLocal()) require.NoError(t, err) require.True(t, localCalled, "should try local first") require.True(t, remoteCalled, "should fall back to remote") require.Equal(t, ImageSourceRemote, model.Image.Source) } func TestResolver_Inspect_PreferLocal_NoFallbackOnRealError(t *testing.T) { docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { return nil, errors.New("connection refused") // Real error, not "not found" }, } reg := &mockRegistry{ inspectFunc: func(ctx context.Context, ref string, platform *registry.Platform) (*registry.ManifestResult, error) { t.Fatal("should not fall back to remote on real error") return nil, nil }, } resolver := NewResolver(docker, reg) ref, err := ParseRef("my-image:latest") require.NoError(t, err) _, err = resolver.Inspect(context.Background(), ref, PreferLocal()) require.Error(t, err) require.Contains(t, err.Error(), "failed to inspect local image") require.Contains(t, err.Error(), "connection refused") } func TestResolver_Inspect_PreferRemote_FoundRemotely(t *testing.T) { localCalled := false remoteCalled := false docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { localCalled = true return nil, nil }, } reg := &mockRegistry{ inspectFunc: func(ctx context.Context, ref string, platform *registry.Platform) (*registry.ManifestResult, error) { remoteCalled = true return ®istry.ManifestResult{ SchemaVersion: 2, Labels: map[string]string{ LabelConfig: `{"build":{"python_version":"3.11"}}`, LabelVersion: "0.10.0", }, }, nil }, } resolver := NewResolver(docker, reg) ref, err := ParseRef("my-image:latest") require.NoError(t, err) model, err := resolver.Inspect(context.Background(), ref, PreferRemote()) require.NoError(t, err) require.False(t, localCalled, "should not try local when remote succeeds") require.True(t, remoteCalled, "should try remote first") require.Equal(t, ImageSourceRemote, model.Image.Source) } func TestResolver_Inspect_PreferRemote_Fallback(t *testing.T) { localCalled := false remoteCalled := false docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { localCalled = true return &image.InspectResponse{ ID: "sha256:local123", Config: &dockerspec.DockerOCIImageConfig{ ImageConfig: ocispec.ImageConfig{ Labels: map[string]string{ LabelConfig: `{"build":{"python_version":"3.11"}}`, LabelVersion: "0.10.0", }, }, }, }, nil }, } reg := &mockRegistry{ inspectFunc: func(ctx context.Context, ref string, platform *registry.Platform) (*registry.ManifestResult, error) { remoteCalled = true return nil, errors.New("manifest unknown") }, } resolver := NewResolver(docker, reg) ref, err := ParseRef("my-image:latest") require.NoError(t, err) model, err := resolver.Inspect(context.Background(), ref, PreferRemote()) require.NoError(t, err) require.True(t, remoteCalled, "should try remote first") require.True(t, localCalled, "should fall back to local") require.Equal(t, ImageSourceLocal, model.Image.Source) } func TestResolver_Inspect_PreferRemote_NoFallbackOnRealError(t *testing.T) { docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { t.Fatal("should not fall back to local on real error") return nil, nil }, } reg := &mockRegistry{ inspectFunc: func(ctx context.Context, ref string, platform *registry.Platform) (*registry.ManifestResult, error) { return nil, errors.New("authentication required") }, } resolver := NewResolver(docker, reg) ref, err := ParseRef("my-image:latest") require.NoError(t, err) _, err = resolver.Inspect(context.Background(), ref, PreferRemote()) require.Error(t, err) require.Contains(t, err.Error(), "failed to inspect remote image") require.Contains(t, err.Error(), "authentication required") } func TestResolver_Inspect_WithPlatform(t *testing.T) { var capturedPlatform *registry.Platform docker := &mockDocker{} reg := &mockRegistry{ inspectFunc: func(ctx context.Context, ref string, platform *registry.Platform) (*registry.ManifestResult, error) { capturedPlatform = platform return ®istry.ManifestResult{ SchemaVersion: 2, Labels: map[string]string{ LabelConfig: `{"build":{"python_version":"3.11"}}`, LabelVersion: "0.10.0", }, }, nil }, } resolver := NewResolver(docker, reg) ref, err := ParseRef("my-image") require.NoError(t, err) platform := ®istry.Platform{OS: "linux", Architecture: "amd64"} _, err = resolver.Inspect(context.Background(), ref, RemoteOnly(), WithPlatform(platform)) require.NoError(t, err) require.NotNil(t, capturedPlatform) require.Equal(t, "linux", capturedPlatform.OS) require.Equal(t, "amd64", capturedPlatform.Architecture) } func TestResolver_Inspect_ParsesConfigFromLabels(t *testing.T) { docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { return &image.InspectResponse{ ID: "sha256:abc123", Config: &dockerspec.DockerOCIImageConfig{ ImageConfig: ocispec.ImageConfig{ Labels: map[string]string{ LabelConfig: `{"build":{"gpu":true,"python_version":"3.12"},"predict":"predict.py:Predictor"}`, LabelVersion: "0.11.0", }, }, }, }, nil }, } resolver := NewResolver(docker, &mockRegistry{}) ref, err := ParseRef("my-image:latest") require.NoError(t, err) model, err := resolver.Inspect(context.Background(), ref, LocalOnly()) require.NoError(t, err) require.NotNil(t, model.Config) require.NotNil(t, model.Config.Build) require.True(t, model.Config.Build.GPU) require.Equal(t, "3.12", model.Config.Build.PythonVersion) require.Equal(t, "predict.py:Predictor", model.Config.Predict) } func TestResolver_Inspect_InvalidConfigJSON(t *testing.T) { docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { return &image.InspectResponse{ ID: "sha256:abc123", Config: &dockerspec.DockerOCIImageConfig{ ImageConfig: ocispec.ImageConfig{ Labels: map[string]string{ LabelConfig: `{invalid json`, }, }, }, }, nil }, } resolver := NewResolver(docker, &mockRegistry{}) ref, err := ParseRef("my-image:latest") require.NoError(t, err) _, err = resolver.Inspect(context.Background(), ref, LocalOnly()) require.Error(t, err) // Error should contain the JSON parse error message require.Contains(t, err.Error(), "invalid character") } func TestResolver_Inspect_NoConfigLabel_ReturnsErrNotCogModel(t *testing.T) { docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { return &image.InspectResponse{ ID: "sha256:abc123", Config: &dockerspec.DockerOCIImageConfig{ ImageConfig: ocispec.ImageConfig{ Labels: map[string]string{ // No LabelConfig - just version label LabelVersion: "0.10.0", }, }, }, }, nil }, } resolver := NewResolver(docker, &mockRegistry{}) ref, err := ParseRef("my-image:latest") require.NoError(t, err) _, err = resolver.Inspect(context.Background(), ref, LocalOnly()) // Without LabelConfig, image is not a valid Cog model require.Error(t, err) require.ErrorIs(t, err, ErrNotCogModel) } func TestResolver_Inspect_NotCogModel(t *testing.T) { docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { return &image.InspectResponse{ ID: "sha256:abc123", Config: &dockerspec.DockerOCIImageConfig{ ImageConfig: ocispec.ImageConfig{ Labels: map[string]string{ // No Cog labels at all - just some random image "maintainer": "someone@example.com", }, }, }, }, nil }, } resolver := NewResolver(docker, &mockRegistry{}) ref, err := ParseRef("nginx:latest") require.NoError(t, err) _, err = resolver.Inspect(context.Background(), ref, LocalOnly()) require.Error(t, err) require.ErrorIs(t, err, ErrNotCogModel) require.Contains(t, err.Error(), "nginx:latest") } func TestResolver_InspectByID_Found(t *testing.T) { docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { // Verify the ID was passed directly (not mangled by ParseRef) require.Equal(t, "9056219a5fb2", ref) return &image.InspectResponse{ ID: "sha256:9056219a5fb2abc123def456", Config: &dockerspec.DockerOCIImageConfig{ ImageConfig: ocispec.ImageConfig{ Labels: map[string]string{ LabelConfig: `{"build":{"python_version":"3.11"}}`, LabelVersion: "0.10.0", }, }, }, }, nil }, } resolver := NewResolver(docker, &mockRegistry{}) model, err := resolver.InspectByID(context.Background(), "9056219a5fb2") require.NoError(t, err) require.NotNil(t, model) require.Equal(t, ImageSourceLocal, model.Image.Source) require.Equal(t, "sha256:9056219a5fb2abc123def456", model.Image.Digest) require.Equal(t, "sha256:9056219a5fb2abc123def456", model.Image.Reference) require.Equal(t, "0.10.0", model.CogVersion) require.NotNil(t, model.Config) require.Equal(t, "3.11", model.Config.Build.PythonVersion) } func TestResolver_InspectByID_FullSHA(t *testing.T) { fullID := "sha256:9056219a5fb2abc123def456789" docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { require.Equal(t, fullID, ref) return &image.InspectResponse{ ID: fullID, Config: &dockerspec.DockerOCIImageConfig{ ImageConfig: ocispec.ImageConfig{ Labels: map[string]string{ LabelConfig: `{"build":{"python_version":"3.11"}}`, LabelVersion: "0.10.0", }, }, }, }, nil }, } resolver := NewResolver(docker, &mockRegistry{}) model, err := resolver.InspectByID(context.Background(), fullID) require.NoError(t, err) require.NotNil(t, model) require.Equal(t, fullID, model.Image.Digest) } func TestResolver_InspectByID_NotCogModel(t *testing.T) { docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { return &image.InspectResponse{ ID: "sha256:abc123", Config: &dockerspec.DockerOCIImageConfig{ ImageConfig: ocispec.ImageConfig{ Labels: map[string]string{ // No Cog labels "maintainer": "someone", }, }, }, }, nil }, } resolver := NewResolver(docker, &mockRegistry{}) _, err := resolver.InspectByID(context.Background(), "abc123") require.Error(t, err) require.ErrorIs(t, err, ErrNotCogModel) } func TestResolver_InspectByID_NotFound(t *testing.T) { docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { return nil, errors.New("No such image: abc123") }, } resolver := NewResolver(docker, &mockRegistry{}) _, err := resolver.InspectByID(context.Background(), "abc123") require.Error(t, err) require.Contains(t, err.Error(), "failed to load image by ID") } // ============================================================================= // Pull tests // ============================================================================= func TestResolver_Pull_AlreadyLocal(t *testing.T) { pullCalled := false docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { return &image.InspectResponse{ ID: "sha256:abc123", Config: &dockerspec.DockerOCIImageConfig{ ImageConfig: ocispec.ImageConfig{ Labels: map[string]string{ LabelConfig: `{"build":{"gpu":false}}`, LabelVersion: "0.10.0", }, }, }, }, nil }, pullFunc: func(ctx context.Context, ref string, force bool) (*image.InspectResponse, error) { pullCalled = true return nil, nil }, } resolver := NewResolver(docker, &mockRegistry{}) ref, err := ParseRef("my-image:latest") require.NoError(t, err) model, err := resolver.Pull(context.Background(), ref) require.NoError(t, err) require.False(t, pullCalled, "should not pull when image exists locally") require.NotNil(t, model) require.Equal(t, "0.10.0", model.CogVersion) } func TestResolver_Pull_NotLocal_PullsAndReturns(t *testing.T) { pullCalled := false inspectCalls := 0 docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { inspectCalls++ if inspectCalls == 1 { // First call: not found locally return nil, errors.New("No such image") } // After pull: found return &image.InspectResponse{ ID: "sha256:abc123", Config: &dockerspec.DockerOCIImageConfig{ ImageConfig: ocispec.ImageConfig{ Labels: map[string]string{ LabelConfig: `{"build":{"gpu":true}}`, LabelVersion: "0.10.0", }, }, }, }, nil }, pullFunc: func(ctx context.Context, ref string, force bool) (*image.InspectResponse, error) { pullCalled = true return &image.InspectResponse{ ID: "sha256:abc123", Config: &dockerspec.DockerOCIImageConfig{ ImageConfig: ocispec.ImageConfig{ Labels: map[string]string{ LabelConfig: `{"build":{"gpu":true}}`, LabelVersion: "0.10.0", }, }, }, }, nil }, } resolver := NewResolver(docker, &mockRegistry{}) ref, err := ParseRef("r8.im/user/model:latest") require.NoError(t, err) model, err := resolver.Pull(context.Background(), ref) require.NoError(t, err) require.True(t, pullCalled, "should call Pull when image not local") require.NotNil(t, model) require.True(t, model.HasGPU()) } func TestResolver_Pull_NotCogModel(t *testing.T) { inspectCalls := 0 docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { inspectCalls++ if inspectCalls == 1 { // First call: not found locally return nil, errors.New("No such image") } // After pull: found but not a Cog model return &image.InspectResponse{ ID: "sha256:abc123", Config: &dockerspec.DockerOCIImageConfig{ ImageConfig: ocispec.ImageConfig{ Labels: map[string]string{ // Not a Cog model "some.label": "value", }, }, }, }, nil }, pullFunc: func(ctx context.Context, ref string, force bool) (*image.InspectResponse, error) { return &image.InspectResponse{ ID: "sha256:abc123", Config: &dockerspec.DockerOCIImageConfig{ ImageConfig: ocispec.ImageConfig{ Labels: map[string]string{ "some.label": "value", }, }, }, }, nil }, } resolver := NewResolver(docker, &mockRegistry{}) ref, err := ParseRef("not-cog:latest") require.NoError(t, err) _, err = resolver.Pull(context.Background(), ref) require.Error(t, err) require.ErrorIs(t, err, ErrNotCogModel) } func TestResolver_Pull_LocalOnly_NotFound(t *testing.T) { docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { return nil, errors.New("No such image") }, pullFunc: func(ctx context.Context, ref string, force bool) (*image.InspectResponse, error) { t.Fatal("should not pull when LocalOnly") return nil, nil }, } resolver := NewResolver(docker, &mockRegistry{}) ref, err := ParseRef("my-image:latest") require.NoError(t, err) _, err = resolver.Pull(context.Background(), ref, LocalOnly()) require.Error(t, err) require.ErrorIs(t, err, ErrNotFound) } func TestResolver_Pull_PullFails(t *testing.T) { docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { return nil, errors.New("No such image") }, pullFunc: func(ctx context.Context, ref string, force bool) (*image.InspectResponse, error) { return nil, errors.New("manifest unknown") }, } resolver := NewResolver(docker, &mockRegistry{}) ref, err := ParseRef("nonexistent:latest") require.NoError(t, err) _, err = resolver.Pull(context.Background(), ref) require.Error(t, err) require.ErrorIs(t, err, ErrNotFound) } func TestResolver_Pull_LocalInspectRealError(t *testing.T) { docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { return nil, errors.New("connection refused") }, pullFunc: func(ctx context.Context, ref string, force bool) (*image.InspectResponse, error) { t.Fatal("should not pull when local inspect has real error") return nil, nil }, } resolver := NewResolver(docker, &mockRegistry{}) ref, err := ParseRef("my-image:latest") require.NoError(t, err) _, err = resolver.Pull(context.Background(), ref) require.Error(t, err) require.Contains(t, err.Error(), "connection refused") } // ============================================================================= // Build tests // ============================================================================= func TestResolver_Build_NoWeightsManifestWithoutWeights(t *testing.T) { validDigest := "sha256:a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { return &image.InspectResponse{ ID: validDigest, Config: &dockerspec.DockerOCIImageConfig{ ImageConfig: ocispec.ImageConfig{ Labels: map[string]string{ LabelConfig: `{"build":{"python_version":"3.11"}}`, LabelVersion: "0.10.0", }, }, }, }, nil }, } factory := &mockFactory{} resolver := NewResolver(docker, &mockRegistry{}).WithFactory(factory) src := &Source{ Config: &config.Config{Build: &config.Build{}}, ProjectDir: t.TempDir(), } m, err := resolver.Build(context.Background(), src, BuildOptions{ ImageName: "test-image", }) require.NoError(t, err) require.False(t, m.IsBundle()) require.Empty(t, m.WeightArtifacts()) } func TestResolver_Build_PopulatesArtifacts(t *testing.T) { imageDigest := "sha256:a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { return &image.InspectResponse{ ID: imageDigest, Config: &dockerspec.DockerOCIImageConfig{ ImageConfig: ocispec.ImageConfig{ Labels: map[string]string{ LabelConfig: `{"build":{"python_version":"3.11"}}`, LabelVersion: "0.15.0", }, }, }, }, nil }, } factory := &mockFactory{ buildFunc: func(ctx context.Context, src *Source, opts BuildOptions) (*ImageArtifact, error) { return &ImageArtifact{ Reference: opts.ImageName, Digest: imageDigest, Source: ImageSourceBuild, }, nil }, } resolver := NewResolver(docker, &mockRegistry{}).WithFactory(factory) src := &Source{ Config: &config.Config{Build: &config.Build{}}, ProjectDir: t.TempDir(), } m, err := resolver.Build(context.Background(), src, BuildOptions{ ImageName: "test-image:latest", }) require.NoError(t, err) require.NotNil(t, m.Artifacts, "Build should populate Artifacts") require.Len(t, m.Artifacts, 1, "should have exactly one artifact (image)") // Verify it's an ImageArtifact with correct data imgArtifact := m.GetImageArtifact() require.NotNil(t, imgArtifact, "should contain an ImageArtifact") require.Equal(t, "model", imgArtifact.Name()) require.Equal(t, ArtifactTypeImage, imgArtifact.Type()) require.Equal(t, "test-image:latest", imgArtifact.Reference) require.Equal(t, imageDigest, imgArtifact.Descriptor().Digest.String()) } func TestResolver_Build_PopulatesWeightArtifacts(t *testing.T) { imageDigest := "sha256:a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { return &image.InspectResponse{ ID: imageDigest, Config: &dockerspec.DockerOCIImageConfig{ ImageConfig: ocispec.ImageConfig{ Labels: map[string]string{ LabelConfig: `{"build":{"python_version":"3.11"}}`, LabelVersion: "0.15.0", }, }, }, }, nil }, } factory := &mockFactory{ buildFunc: func(ctx context.Context, src *Source, opts BuildOptions) (*ImageArtifact, error) { return &ImageArtifact{ Reference: opts.ImageName, Digest: imageDigest, Source: ImageSourceBuild, }, nil }, } resolver := NewResolver(docker, &mockRegistry{}).WithFactory(factory) // Create a temp directory with a real weight file dir := t.TempDir() weightContent := []byte("test weight for resolver build") require.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors"), weightContent, 0o644)) src := &Source{ Config: &config.Config{ Build: &config.Build{}, Weights: []config.WeightSource{ {Name: "my-model", Source: "model.safetensors", Target: "/srv/weights/model.safetensors"}, }, }, ProjectDir: dir, } m, err := resolver.Build(context.Background(), src, BuildOptions{ ImageName: "test-image:latest", OCIIndex: true, }) require.NoError(t, err) require.NotNil(t, m.Artifacts) // Should have 2 artifacts: 1 image + 1 weight require.Len(t, m.Artifacts, 2, "should have image + weight artifacts") // Verify image artifact imgArtifact := m.GetImageArtifact() require.NotNil(t, imgArtifact) require.Equal(t, "model", imgArtifact.Name()) // Verify weight artifact weightArtifacts := m.WeightArtifacts() require.Len(t, weightArtifacts, 1) wa := weightArtifacts[0] require.Equal(t, "my-model", wa.Name()) require.Equal(t, ArtifactTypeWeight, wa.Type()) require.Equal(t, "/srv/weights/model.safetensors", wa.Target) require.Equal(t, filepath.Join(dir, "model.safetensors"), wa.FilePath) // Weight config should be populated require.Equal(t, "1.0", wa.Config.SchemaVersion) require.Equal(t, "my-model", wa.Config.Name) require.Equal(t, "/srv/weights/model.safetensors", wa.Config.Target) require.False(t, wa.Config.Created.IsZero()) } func TestResolver_Build_WithWeightsLoadsManifest(t *testing.T) { imageDigest := "sha256:a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" docker := &mockDocker{ inspectFunc: func(ctx context.Context, ref string) (*image.InspectResponse, error) { return &image.InspectResponse{ ID: imageDigest, Config: &dockerspec.DockerOCIImageConfig{ ImageConfig: ocispec.ImageConfig{ Labels: map[string]string{ LabelConfig: `{"build":{"python_version":"3.11"}}`, LabelVersion: "0.15.0", }, }, }, }, nil }, } factory := &mockFactory{ buildFunc: func(ctx context.Context, src *Source, opts BuildOptions) (*ImageArtifact, error) { return &ImageArtifact{ Reference: opts.ImageName, Digest: imageDigest, Source: ImageSourceBuild, }, nil }, } resolver := NewResolver(docker, &mockRegistry{}).WithFactory(factory) dir := t.TempDir() require.NoError(t, os.WriteFile(filepath.Join(dir, "model.bin"), []byte("test weights"), 0o644)) src := &Source{ Config: &config.Config{ Build: &config.Build{}, Weights: []config.WeightSource{ {Name: "my-model", Source: "model.bin", Target: "/weights/model.bin"}, }, }, ProjectDir: dir, } m, err := resolver.Build(context.Background(), src, BuildOptions{ ImageName: "test-image:latest", OCIIndex: true, }) require.NoError(t, err) require.True(t, m.IsBundle()) require.True(t, m.OCIIndex) // Should have 2 artifacts: image + weight require.Len(t, m.Artifacts, 2) require.NotNil(t, m.GetImageArtifact()) require.Len(t, m.WeightArtifacts(), 1) // Weight artifacts should be populated require.Len(t, m.WeightArtifacts(), 1) } func TestIndexDetectionHelpers(t *testing.T) { t.Run("findWeightsManifest", func(t *testing.T) { manifests := []registry.PlatformManifest{ {Digest: "sha256:image123", OS: "linux", Architecture: "amd64"}, { Digest: "sha256:weights456", OS: PlatformUnknown, Architecture: PlatformUnknown, Annotations: map[string]string{ AnnotationReferenceType: AnnotationValueWeights, }, }, } wm := findWeightsManifest(manifests) require.NotNil(t, wm) require.Equal(t, "sha256:weights456", wm.Digest) }) t.Run("findWeightsManifest not found", func(t *testing.T) { manifests := []registry.PlatformManifest{ {Digest: "sha256:image123", OS: "linux", Architecture: "amd64"}, } wm := findWeightsManifest(manifests) require.Nil(t, wm) }) t.Run("findImageManifest", func(t *testing.T) { manifests := []registry.PlatformManifest{ {Digest: "sha256:image123", OS: "linux", Architecture: "amd64"}, {Digest: "sha256:weights456", OS: PlatformUnknown, Architecture: PlatformUnknown}, } platform := ®istry.Platform{OS: "linux", Architecture: "amd64"} im := findImageManifest(manifests, platform) require.NotNil(t, im) require.Equal(t, "sha256:image123", im.Digest) }) t.Run("findImageManifest skips unknown", func(t *testing.T) { manifests := []registry.PlatformManifest{ {Digest: "sha256:weights456", OS: PlatformUnknown, Architecture: PlatformUnknown}, } im := findImageManifest(manifests, nil) require.Nil(t, im) }) t.Run("findImageManifest no platform filter", func(t *testing.T) { manifests := []registry.PlatformManifest{ {Digest: "sha256:arm123", OS: "linux", Architecture: "arm64"}, {Digest: "sha256:weights456", OS: PlatformUnknown, Architecture: PlatformUnknown}, } im := findImageManifest(manifests, nil) require.NotNil(t, im) require.Equal(t, "sha256:arm123", im.Digest) }) t.Run("findImageManifest platform mismatch", func(t *testing.T) { manifests := []registry.PlatformManifest{ {Digest: "sha256:arm123", OS: "linux", Architecture: "arm64"}, {Digest: "sha256:weights456", OS: PlatformUnknown, Architecture: PlatformUnknown}, } platform := ®istry.Platform{OS: "linux", Architecture: "amd64"} im := findImageManifest(manifests, platform) require.Nil(t, im) }) t.Run("isOCIIndex with index", func(t *testing.T) { mr := ®istry.ManifestResult{ MediaType: "application/vnd.oci.image.index.v1+json", } require.True(t, isOCIIndex(mr)) }) t.Run("isOCIIndex with single manifest", func(t *testing.T) { mr := ®istry.ManifestResult{ MediaType: "application/vnd.oci.image.manifest.v1+json", } require.False(t, isOCIIndex(mr)) }) } func TestIsNotFoundError(t *testing.T) { tests := []struct { name string err error expected bool }{ { name: "nil error", err: nil, expected: false, }, { name: "No such image", err: errors.New("No such image: my-image:latest"), expected: true, }, { name: "not found", err: errors.New("image not found"), expected: true, }, { name: "manifest unknown", err: errors.New("manifest unknown: repository does not exist"), expected: true, }, { name: "NAME_UNKNOWN", err: errors.New("NAME_UNKNOWN: repository name not known to registry"), expected: true, }, { name: "connection refused", err: errors.New("connection refused"), expected: false, }, { name: "authentication required", err: errors.New("authentication required"), expected: false, }, { name: "context canceled", err: context.Canceled, expected: false, }, { name: "context deadline exceeded", err: context.DeadlineExceeded, expected: false, }, { name: "registry NotFoundError", err: registry.NotFoundError, expected: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := isNotFoundError(tt.err) require.Equal(t, tt.expected, result) }) } } ================================================ FILE: pkg/model/source.go ================================================ package model import ( "os" "path/filepath" "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/util/console" ) // Source represents a Cog project ready to build. // It combines the parsed configuration with the project directory location. type Source struct { Config *config.Config ProjectDir string ConfigFilename string // Base filename like "cog.yaml" or "my-config.yaml" Warnings []config.DeprecationWarning } // NewSource loads configuration from the given path and returns a Source. // The configPath can be a filename (e.g., "cog.yaml") which will be searched // for in the current directory and parent directories. func NewSource(configPath string) (*Source, error) { if configPath == "" { configPath = "cog.yaml" } // Find the root project directory rootDir, err := config.GetProjectDir(configPath) if err != nil { return nil, err } // Open and read the config file fullPath := filepath.Join(rootDir, configPath) f, err := os.Open(fullPath) if err != nil { return nil, &config.ParseError{Filename: configPath, Err: err} } defer f.Close() result, err := config.Load(f, rootDir) if err != nil { // Add filename context to parse errors if not already present if parseErr, ok := err.(*config.ParseError); ok && parseErr.Filename == "" { parseErr.Filename = configPath } return nil, err } // Display deprecation warnings for _, w := range result.Warnings { console.Warnf("%s", w.Error()) } return &Source{ Config: result.Config, ProjectDir: result.RootDir, ConfigFilename: filepath.Base(configPath), Warnings: result.Warnings, }, nil } // NewSourceFromConfig creates a Source from an existing Config. // Use this when you already have a parsed config and know the project directory. func NewSourceFromConfig(cfg *config.Config, projectDir string) *Source { return &Source{ Config: cfg, ProjectDir: projectDir, ConfigFilename: "cog.yaml", } } // ArtifactSpecs returns the artifact declarations derived from this source. // Always produces at least one ImageSpec. Produces a WeightSpec for each // weight declared in the config. Returns nil if Config is nil. func (s *Source) ArtifactSpecs() []ArtifactSpec { if s.Config == nil { return nil } var specs []ArtifactSpec // Always have an image artifact specs = append(specs, NewImageSpec("model", s.Config.Image)) // Add weight specs from config for _, w := range s.Config.Weights { specs = append(specs, NewWeightSpec(w.Name, w.Source, w.Target)) } return specs } ================================================ FILE: pkg/model/source_test.go ================================================ package model import ( "testing" "github.com/stretchr/testify/require" "github.com/replicate/cog/pkg/config" ) func TestNewSourceFromConfig(t *testing.T) { cfg := &config.Config{ Build: &config.Build{ GPU: true, PythonVersion: "3.11", }, } projectDir := "/path/to/project" src := NewSourceFromConfig(cfg, projectDir) require.NotNil(t, src) require.Equal(t, cfg, src.Config) require.Equal(t, projectDir, src.ProjectDir) } func TestNewSourceFromConfig_NilConfig(t *testing.T) { src := NewSourceFromConfig(nil, "/path/to/project") require.NotNil(t, src) require.Nil(t, src.Config) require.Equal(t, "/path/to/project", src.ProjectDir) } func TestSource_ArtifactSpecs_NoWeights(t *testing.T) { cfg := &config.Config{ Image: "r8.im/user/model", Build: &config.Build{ GPU: true, PythonVersion: "3.11", }, } src := NewSourceFromConfig(cfg, "/path/to/project") specs := src.ArtifactSpecs() require.Len(t, specs, 1) // First spec should be an ImageSpec imgSpec, ok := specs[0].(*ImageSpec) require.True(t, ok, "first spec should be *ImageSpec") require.Equal(t, ArtifactTypeImage, imgSpec.Type()) require.Equal(t, "model", imgSpec.Name()) require.Equal(t, "r8.im/user/model", imgSpec.ImageName) } func TestSource_ArtifactSpecs_WithWeights(t *testing.T) { cfg := &config.Config{ Image: "r8.im/user/model", Build: &config.Build{PythonVersion: "3.11"}, Weights: []config.WeightSource{ {Name: "llama-7b", Source: "/data/llama-7b.safetensors", Target: "/weights/llama-7b.safetensors"}, {Name: "embeddings", Source: "/data/embeddings.bin", Target: "/weights/embeddings.bin"}, }, } src := NewSourceFromConfig(cfg, "/path/to/project") specs := src.ArtifactSpecs() require.Len(t, specs, 3) // 1 image + 2 weights // First is always the image imgSpec, ok := specs[0].(*ImageSpec) require.True(t, ok, "first spec should be *ImageSpec") require.Equal(t, ArtifactTypeImage, imgSpec.Type()) // Remaining are weight specs in order w1, ok := specs[1].(*WeightSpec) require.True(t, ok, "second spec should be *WeightSpec") require.Equal(t, ArtifactTypeWeight, w1.Type()) require.Equal(t, "llama-7b", w1.Name()) require.Equal(t, "/data/llama-7b.safetensors", w1.Source) require.Equal(t, "/weights/llama-7b.safetensors", w1.Target) w2, ok := specs[2].(*WeightSpec) require.True(t, ok, "third spec should be *WeightSpec") require.Equal(t, "embeddings", w2.Name()) require.Equal(t, "/data/embeddings.bin", w2.Source) require.Equal(t, "/weights/embeddings.bin", w2.Target) } func TestSource_ArtifactSpecs_EmptyImageName(t *testing.T) { cfg := &config.Config{ Build: &config.Build{PythonVersion: "3.11"}, } src := NewSourceFromConfig(cfg, "/path/to/project") specs := src.ArtifactSpecs() require.Len(t, specs, 1) imgSpec, ok := specs[0].(*ImageSpec) require.True(t, ok) require.Equal(t, "", imgSpec.ImageName) // empty is fine; BuildOptions fills it later } func TestSource_ArtifactSpecs_NilConfig(t *testing.T) { src := NewSourceFromConfig(nil, "/path/to/project") specs := src.ArtifactSpecs() require.Nil(t, specs) } ================================================ FILE: pkg/model/weight_builder.go ================================================ package model import ( "context" "fmt" "os" "path/filepath" "time" v1 "github.com/google/go-containerregistry/pkg/v1" ) // WeightBuilder builds WeightArtifact from WeightSpec. // It hashes the source file, creates a WeightConfig, and manages a lockfile as build cache. type WeightBuilder struct { source *Source cogVersion string lockPath string } // NewWeightBuilder creates a WeightBuilder. // lockPath is where the weights.lock file is read/written as a build cache. func NewWeightBuilder(source *Source, cogVersion, lockPath string) *WeightBuilder { return &WeightBuilder{ source: source, cogVersion: cogVersion, lockPath: lockPath, } } // Build builds a WeightArtifact from a WeightSpec. // It resolves the source file, computes its SHA256 digest, and creates the artifact // with a versioned WeightConfig. func (b *WeightBuilder) Build(ctx context.Context, spec ArtifactSpec) (Artifact, error) { ws, ok := spec.(*WeightSpec) if !ok { return nil, fmt.Errorf("weight builder: expected *WeightSpec, got %T", spec) } select { case <-ctx.Done(): return nil, ctx.Err() default: } // Resolve file path absPath := filepath.Join(b.source.ProjectDir, ws.Source) // Stat the file to check existence and size fi, err := os.Stat(absPath) if err != nil { if os.IsNotExist(err) { return nil, fmt.Errorf("weight source not found: %s", ws.Source) } return nil, fmt.Errorf("stat weight file %s: %w", ws.Source, err) } // Check lockfile cache: if we have a matching entry (name + size), skip hashing. // NOTE: This cache only checks name + file size. Same-size modifications (rare for // weight files) won't be detected. Delete the lockfile to force re-hashing. // TODO: Consider adding mtime to the cache key for stronger invalidation. var digestStr string var size int64 if cached := b.findCachedEntry(ws.Name(), fi.Size()); cached != nil { digestStr = cached.Digest size = cached.Size } else { // Cache miss: hash the file digestStr, size, err = hashFile(absPath) if err != nil { return nil, fmt.Errorf("hash weight file %s: %w", ws.Source, err) } } // Parse as v1.Hash for the descriptor digest, err := v1.NewHash(digestStr) if err != nil { return nil, fmt.Errorf("parse digest: %w", err) } // Build the WeightConfig cfg := WeightConfig{ SchemaVersion: "1.0", CogVersion: b.cogVersion, Name: ws.Name(), Target: ws.Target, Created: time.Now().UTC(), } // Build the descriptor desc := v1.Descriptor{ Digest: digest, Size: size, MediaType: MediaTypeWeightLayer, } // Update lockfile if err := b.updateLockfile(ws, digestStr, size); err != nil { return nil, fmt.Errorf("update lockfile: %w", err) } return NewWeightArtifact(ws.Name(), desc, absPath, ws.Target, cfg), nil } // findCachedEntry checks the lockfile for an entry matching name and fileSize. // Returns the cached WeightFile if found and size matches, nil otherwise. func (b *WeightBuilder) findCachedEntry(name string, fileSize int64) *WeightFile { if _, err := os.Stat(b.lockPath); err != nil { return nil } lock, err := LoadWeightsLock(b.lockPath) if err != nil { return nil } for i, f := range lock.Files { if f.Name == name && f.Size == fileSize { return &lock.Files[i] } } return nil } // updateLockfile loads the existing lockfile (if any), adds or updates // the entry for the given weight, and saves it back. func (b *WeightBuilder) updateLockfile(ws *WeightSpec, digest string, size int64) error { // Load existing lockfile, or start fresh. // LoadWeightsLock wraps the underlying error, so we check the raw file first. lock := &WeightsLock{ Version: "1.0", Created: time.Now().UTC(), } if _, err := os.Stat(b.lockPath); err == nil { existing, loadErr := LoadWeightsLock(b.lockPath) if loadErr != nil { return fmt.Errorf("load existing lockfile: %w", loadErr) } lock = existing } entry := WeightFile{ Name: ws.Name(), Dest: ws.Target, Digest: digest, DigestOriginal: digest, Size: size, SizeUncompressed: size, MediaType: MediaTypeWeightLayer, } // Update existing entry or append updated := false for i, f := range lock.Files { if f.Name == ws.Name() { lock.Files[i] = entry updated = true break } } if !updated { lock.Files = append(lock.Files, entry) } return lock.Save(b.lockPath) } ================================================ FILE: pkg/model/weight_builder_test.go ================================================ package model import ( "context" "crypto/sha256" "encoding/hex" "os" "path/filepath" "testing" "github.com/stretchr/testify/require" "github.com/replicate/cog/pkg/config" ) func TestWeightBuilder_HappyPath(t *testing.T) { // Setup: real temp file as a weight source tmpDir := t.TempDir() weightContent := []byte("test weight data for builder") weightFile := filepath.Join(tmpDir, "model.safetensors") err := os.WriteFile(weightFile, weightContent, 0o644) require.NoError(t, err) // Compute expected digest hash := sha256.Sum256(weightContent) expectedDigest := "sha256:" + hex.EncodeToString(hash[:]) // Create source with config that has one weight src := NewSourceFromConfig(&config.Config{ Weights: []config.WeightSource{ {Name: "my-model", Source: "model.safetensors", Target: "/srv/weights/model.safetensors"}, }, }, tmpDir) // Create a WeightBuilder lockPath := filepath.Join(tmpDir, "weights.lock") wb := NewWeightBuilder(src, "0.15.0", lockPath) // Build from the weight spec spec := NewWeightSpec("my-model", "model.safetensors", "/srv/weights/model.safetensors") artifact, err := wb.Build(context.Background(), spec) require.NoError(t, err) require.NotNil(t, artifact) // Type assertion: should be a *WeightArtifact wa, ok := artifact.(*WeightArtifact) require.True(t, ok, "expected *WeightArtifact, got %T", artifact) // Check artifact interface require.Equal(t, ArtifactTypeWeight, wa.Type()) require.Equal(t, "my-model", wa.Name()) // Check descriptor desc := wa.Descriptor() require.Equal(t, expectedDigest, desc.Digest.String()) require.Equal(t, int64(len(weightContent)), desc.Size) // Check weight-specific fields require.Equal(t, weightFile, wa.FilePath) require.Equal(t, "/srv/weights/model.safetensors", wa.Target) // Check WeightConfig require.Equal(t, "1.0", wa.Config.SchemaVersion) require.Equal(t, "0.15.0", wa.Config.CogVersion) require.Equal(t, "my-model", wa.Config.Name) require.Equal(t, "/srv/weights/model.safetensors", wa.Config.Target) require.False(t, wa.Config.Created.IsZero(), "Created should be set") } func TestWeightBuilder_WritesLockfile(t *testing.T) { // After Build(), a weights.lock should be written/updated at lockPath. tmpDir := t.TempDir() weightContent := []byte("lockfile test weight") err := os.WriteFile(filepath.Join(tmpDir, "model.bin"), weightContent, 0o644) require.NoError(t, err) hash := sha256.Sum256(weightContent) expectedDigest := "sha256:" + hex.EncodeToString(hash[:]) src := NewSourceFromConfig(&config.Config{ Weights: []config.WeightSource{ {Name: "my-model", Source: "model.bin", Target: "/weights/model.bin"}, }, }, tmpDir) lockPath := filepath.Join(tmpDir, "weights.lock") wb := NewWeightBuilder(src, "0.15.0", lockPath) spec := NewWeightSpec("my-model", "model.bin", "/weights/model.bin") _, err = wb.Build(context.Background(), spec) require.NoError(t, err) // Lockfile should exist _, err = os.Stat(lockPath) require.NoError(t, err, "lockfile should be created") // Load and verify lockfile contents lock, err := LoadWeightsLock(lockPath) require.NoError(t, err) require.Equal(t, "1.0", lock.Version) require.Len(t, lock.Files, 1) wf := lock.Files[0] require.Equal(t, "my-model", wf.Name) require.Equal(t, "/weights/model.bin", wf.Dest) require.Equal(t, expectedDigest, wf.Digest) require.Equal(t, int64(len(weightContent)), wf.Size) } func TestWeightBuilder_UpdatesExistingLockfile(t *testing.T) { // If a lockfile already exists with entries, Build() should add/update the entry // for the built weight without losing other entries. tmpDir := t.TempDir() // Create two weight files content1 := []byte("weight one data") content2 := []byte("weight two data") err := os.WriteFile(filepath.Join(tmpDir, "w1.bin"), content1, 0o644) require.NoError(t, err) err = os.WriteFile(filepath.Join(tmpDir, "w2.bin"), content2, 0o644) require.NoError(t, err) src := NewSourceFromConfig(&config.Config{ Weights: []config.WeightSource{ {Name: "weight-1", Source: "w1.bin", Target: "/weights/w1.bin"}, {Name: "weight-2", Source: "w2.bin", Target: "/weights/w2.bin"}, }, }, tmpDir) lockPath := filepath.Join(tmpDir, "weights.lock") wb := NewWeightBuilder(src, "0.15.0", lockPath) // Build first weight spec1 := NewWeightSpec("weight-1", "w1.bin", "/weights/w1.bin") _, err = wb.Build(context.Background(), spec1) require.NoError(t, err) // Build second weight spec2 := NewWeightSpec("weight-2", "w2.bin", "/weights/w2.bin") _, err = wb.Build(context.Background(), spec2) require.NoError(t, err) // Lockfile should contain both entries lock, err := LoadWeightsLock(lockPath) require.NoError(t, err) require.Len(t, lock.Files, 2) names := map[string]bool{} for _, f := range lock.Files { names[f.Name] = true } require.True(t, names["weight-1"]) require.True(t, names["weight-2"]) } func TestWeightBuilder_CacheHit(t *testing.T) { // When a lockfile entry exists with matching name and size, // the builder should use the cached digest without re-hashing. tmpDir := t.TempDir() weightContent := []byte("cached weight data") err := os.WriteFile(filepath.Join(tmpDir, "model.bin"), weightContent, 0o644) require.NoError(t, err) hash := sha256.Sum256(weightContent) expectedDigest := "sha256:" + hex.EncodeToString(hash[:]) src := NewSourceFromConfig(&config.Config{ Weights: []config.WeightSource{ {Name: "my-model", Source: "model.bin", Target: "/weights/model.bin"}, }, }, tmpDir) lockPath := filepath.Join(tmpDir, "weights.lock") wb := NewWeightBuilder(src, "0.15.0", lockPath) // First build — populates lockfile spec := NewWeightSpec("my-model", "model.bin", "/weights/model.bin") artifact1, err := wb.Build(context.Background(), spec) require.NoError(t, err) // Second build — should hit cache artifact2, err := wb.Build(context.Background(), spec) require.NoError(t, err) // Both builds should produce the same digest wa1 := artifact1.(*WeightArtifact) wa2 := artifact2.(*WeightArtifact) require.Equal(t, expectedDigest, wa1.Descriptor().Digest.String()) require.Equal(t, expectedDigest, wa2.Descriptor().Digest.String()) // Lockfile should still have exactly one entry (not duplicated) lock, err := LoadWeightsLock(lockPath) require.NoError(t, err) require.Len(t, lock.Files, 1) require.Equal(t, "my-model", lock.Files[0].Name) } func TestWeightBuilder_CacheMiss_SizeChanged(t *testing.T) { // When the file size changes, the builder should re-hash. tmpDir := t.TempDir() content1 := []byte("original content") err := os.WriteFile(filepath.Join(tmpDir, "model.bin"), content1, 0o644) require.NoError(t, err) src := NewSourceFromConfig(&config.Config{ Weights: []config.WeightSource{ {Name: "my-model", Source: "model.bin", Target: "/weights/model.bin"}, }, }, tmpDir) lockPath := filepath.Join(tmpDir, "weights.lock") wb := NewWeightBuilder(src, "0.15.0", lockPath) spec := NewWeightSpec("my-model", "model.bin", "/weights/model.bin") // First build _, err = wb.Build(context.Background(), spec) require.NoError(t, err) // Change the file (different size) content2 := []byte("modified content with different size!!") err = os.WriteFile(filepath.Join(tmpDir, "model.bin"), content2, 0o644) require.NoError(t, err) // Second build — should detect size change and re-hash artifact2, err := wb.Build(context.Background(), spec) require.NoError(t, err) wa2 := artifact2.(*WeightArtifact) hash2 := sha256.Sum256(content2) expectedDigest2 := "sha256:" + hex.EncodeToString(hash2[:]) require.Equal(t, expectedDigest2, wa2.Descriptor().Digest.String()) require.Equal(t, int64(len(content2)), wa2.Descriptor().Size) } func TestWeightBuilder_ErrorWrongSpecType(t *testing.T) { tmpDir := t.TempDir() src := NewSourceFromConfig(&config.Config{}, tmpDir) lockPath := filepath.Join(tmpDir, "weights.lock") wb := NewWeightBuilder(src, "0.15.0", lockPath) // Pass an ImageSpec instead of WeightSpec imageSpec := NewImageSpec("model", "test-image") _, err := wb.Build(context.Background(), imageSpec) require.Error(t, err) require.Contains(t, err.Error(), "expected *WeightSpec") } func TestWeightBuilder_ErrorFileNotFound(t *testing.T) { tmpDir := t.TempDir() src := NewSourceFromConfig(&config.Config{}, tmpDir) lockPath := filepath.Join(tmpDir, "weights.lock") wb := NewWeightBuilder(src, "0.15.0", lockPath) spec := NewWeightSpec("missing", "nonexistent.bin", "/weights/nonexistent.bin") _, err := wb.Build(context.Background(), spec) require.Error(t, err) require.Contains(t, err.Error(), "weight source not found") } func TestWeightBuilder_ErrorContextCancelled(t *testing.T) { tmpDir := t.TempDir() err := os.WriteFile(filepath.Join(tmpDir, "model.bin"), []byte("data"), 0o644) require.NoError(t, err) src := NewSourceFromConfig(&config.Config{}, tmpDir) lockPath := filepath.Join(tmpDir, "weights.lock") wb := NewWeightBuilder(src, "0.15.0", lockPath) ctx, cancel := context.WithCancel(context.Background()) cancel() // Cancel immediately spec := NewWeightSpec("model", "model.bin", "/weights/model.bin") _, err = wb.Build(ctx, spec) require.Error(t, err) require.ErrorIs(t, err, context.Canceled) } func TestWeightBuilder_ImplementsBuilderInterface(t *testing.T) { tmpDir := t.TempDir() src := NewSourceFromConfig(&config.Config{}, tmpDir) lockPath := filepath.Join(tmpDir, "weights.lock") // Compile-time check var _ Builder = NewWeightBuilder(src, "0.1.0", lockPath) } ================================================ FILE: pkg/model/weight_pusher.go ================================================ package model import ( "context" "crypto/sha256" "encoding/hex" "encoding/json" "fmt" "os" "strings" "sync" "time" v1 "github.com/google/go-containerregistry/pkg/v1" "github.com/google/go-containerregistry/pkg/v1/empty" "github.com/google/go-containerregistry/pkg/v1/mutate" "github.com/google/go-containerregistry/pkg/v1/tarball" "github.com/google/go-containerregistry/pkg/v1/types" "github.com/replicate/cog/pkg/registry" ) // WeightPushOptions configures optional behavior for WeightPusher.Push. type WeightPushOptions struct { // ProgressFn is an optional callback for reporting upload progress. ProgressFn func(PushProgress) // RetryFn is an optional callback for reporting retry attempts. // Return false to abort the retry. RetryFn func(WeightRetryEvent) bool } // WeightRetryEvent reports a retry attempt for a weight file upload. type WeightRetryEvent struct { // Name identifies which file is being retried. Name string // Attempt is the current retry attempt number (1-indexed). Attempt int // MaxAttempts is the maximum number of retry attempts. MaxAttempts int // Err is the error that caused the retry. Err error // NextRetryIn is the duration until the next retry attempt. NextRetryIn time.Duration } // WeightPushResult contains the result of pushing a single weight artifact. type WeightPushResult struct { // Ref is the full image reference for the pushed weight manifest (e.g., "registry/repo:weights-name-abc123"). Ref string // Descriptor is the OCI descriptor for the pushed weight manifest. Descriptor v1.Descriptor } // WeightPusher pushes a WeightArtifact as a proper OCI artifact manifest // with config blob and tarball layers. The layer blob is pushed via // registry.WriteLayer (which supports multipart uploads, progress, and retry), // followed by the manifest via PushImage. type WeightPusher struct { registry registry.Client } // NewWeightPusher creates a new WeightPusher. func NewWeightPusher(reg registry.Client) *WeightPusher { return &WeightPusher{registry: reg} } // Push pushes a WeightArtifact to the registry as an OCI artifact manifest. // The layer blob is pushed first via WriteLayer (multipart uploads, progress, retry), // then the manifest is pushed via PushImage. // Returns the descriptor of the pushed manifest. func (p *WeightPusher) Push(ctx context.Context, repo string, artifact *WeightArtifact, opts ...WeightPushOptions) (*WeightPushResult, error) { if artifact == nil { return nil, fmt.Errorf("artifact is nil") } if repo == "" { return nil, fmt.Errorf("repo is required") } // Merge options (use first if provided) var opt WeightPushOptions if len(opts) > 0 { opt = opts[0] } // Verify the weight file exists if _, err := os.Stat(artifact.FilePath); err != nil { return nil, fmt.Errorf("weight file %q: %w", artifact.FilePath, err) } // Build the OCI artifact image (config blob + tarball layer) img, err := buildWeightImage(artifact) if err != nil { return nil, fmt.Errorf("build weight image: %w", err) } // Extract the layer to push via WriteLayer (gets multipart + progress + retry) layers, err := img.Layers() if err != nil { return nil, fmt.Errorf("get image layers: %w", err) } if len(layers) != 1 { return nil, fmt.Errorf("expected 1 layer, got %d", len(layers)) } layer := layers[0] // Build progress callback var onProgress func(v1.Update) if opt.ProgressFn != nil { onProgress = func(update v1.Update) { opt.ProgressFn(PushProgress{ Complete: update.Complete, Total: update.Total, }) } } // Build retry configuration if callback is provided var retryConfig *registry.RetryConfig if opt.RetryFn != nil { retryConfig = ®istry.RetryConfig{ OnRetry: func(event registry.RetryEvent) bool { return opt.RetryFn(WeightRetryEvent{ Name: artifact.Name(), Attempt: event.Attempt, MaxAttempts: event.MaxAttempts, Err: event.Err, NextRetryIn: event.NextRetryIn, }) }, } } // 1. Push layer blob via WriteLayer (multipart uploads, progress, retry) writeErr := writeLayerWithProgress(ctx, p.registry, registry.WriteLayerOptions{ Repo: repo, Layer: layer, Retry: retryConfig, }, onProgress) if writeErr != nil { return nil, fmt.Errorf("push weight layer: %w", writeErr) } // 2. Push manifest via PushImage with a single tag combining name and digest. // The layer blob is already in the registry, so PushImage will skip re-uploading it. // Tag format: :weights--<12chars> (e.g., :weights-model-v1-383d1f4afa43) // // We use the artifact's descriptor digest (original file hash from the lock file), // NOT the tarball layer digest. This ensures that `weights inspect` can look up the tag // using the same digest stored in weights.lock, independent of the transport format. tag := WeightTag(artifact.Name(), artifact.Descriptor().Digest.String()) ref := repo + ":" + tag if err := p.registry.PushImage(ctx, ref, img); err != nil { return nil, fmt.Errorf("push weight manifest (%s): %w", tag, err) } // Build result descriptor from the pushed image desc, err := descriptorFromImage(img) if err != nil { return nil, fmt.Errorf("compute manifest descriptor: %w", err) } return &WeightPushResult{Ref: ref, Descriptor: desc}, nil } // buildWeightImage creates an OCI artifact image with a config blob (WeightConfig JSON) // and a tarball layer for the weight file. func buildWeightImage(artifact *WeightArtifact) (v1.Image, error) { // 1. Create the base image with OCI manifest media type img := mutate.MediaType(empty.Image, types.OCIManifestSchema1) // 2. Create tarball layer from the weight file. // WithCompressedCaching memoizes the compressed output so that Digest() and // Compressed() see identical bytes. Without this, gzip non-determinism between // separate passes causes DIGEST_INVALID errors on large uploads. layer, err := tarball.LayerFromFile(artifact.FilePath, tarball.WithMediaType(types.MediaType(MediaTypeWeightLayer)), tarball.WithCompressedCaching, ) if err != nil { return nil, fmt.Errorf("create tarball layer: %w", err) } // 3. Append the layer img, err = mutate.AppendLayers(img, layer) if err != nil { return nil, fmt.Errorf("append weight layer: %w", err) } // 4. Serialize the WeightConfig as the config blob configJSON, err := json.Marshal(artifact.Config) if err != nil { return nil, fmt.Errorf("marshal weight config: %w", err) } // 5. Wrap to set custom config blob, config media type, and artifactType return &weightManifestImage{ Image: img, configBlob: configJSON, }, nil } // descriptorFromImage computes the v1.Descriptor for a built image manifest. func descriptorFromImage(img v1.Image) (v1.Descriptor, error) { digest, err := img.Digest() if err != nil { return v1.Descriptor{}, fmt.Errorf("get digest: %w", err) } rawManifest, err := img.RawManifest() if err != nil { return v1.Descriptor{}, fmt.Errorf("get raw manifest: %w", err) } return v1.Descriptor{ MediaType: types.OCIManifestSchema1, Size: int64(len(rawManifest)), Digest: digest, }, nil } // weightOCIManifest extends v1.Manifest with artifactType for OCI 1.1 support. // go-containerregistry v0.20.5's v1.Manifest struct does not include artifactType, // so we serialize it ourselves. type weightOCIManifest struct { SchemaVersion int64 `json:"schemaVersion"` MediaType types.MediaType `json:"mediaType,omitempty"` Config v1.Descriptor `json:"config"` Layers []v1.Descriptor `json:"layers"` Annotations map[string]string `json:"annotations,omitempty"` ArtifactType string `json:"artifactType,omitempty"` } // weightManifestImage wraps a v1.Image to set a custom config blob with // the correct media type and artifactType. This produces a proper OCI 1.1 // artifact manifest for weight data. // // The raw manifest is cached on first computation to ensure deterministic // digests across multiple calls (e.g., during remote.Write which calls // both RawManifest and Digest). type weightManifestImage struct { v1.Image configBlob []byte rawManifest []byte rawManifestErr error rawOnce sync.Once } // RawConfigFile returns the WeightConfig JSON as the config blob. func (w *weightManifestImage) RawConfigFile() ([]byte, error) { return w.configBlob, nil } // Digest computes the digest from the cached raw manifest. func (w *weightManifestImage) Digest() (v1.Hash, error) { raw, err := w.RawManifest() if err != nil { return v1.Hash{}, err } h := sha256.Sum256(raw) return v1.Hash{ Algorithm: "sha256", Hex: hex.EncodeToString(h[:]), }, nil } // ArtifactType implements the withArtifactType interface used by partial.Descriptor. func (w *weightManifestImage) ArtifactType() (string, error) { return MediaTypeWeightArtifact, nil } // Manifest returns the modified manifest with custom config descriptor. func (w *weightManifestImage) Manifest() (*v1.Manifest, error) { m, err := w.Image.Manifest() if err != nil { return nil, err } // Make a copy to avoid mutating the original mCopy := m.DeepCopy() // Set config to point to our custom config blob configDigest := sha256.Sum256(w.configBlob) mCopy.Config = v1.Descriptor{ MediaType: types.MediaType(MediaTypeWeightConfig), Size: int64(len(w.configBlob)), Digest: v1.Hash{ Algorithm: "sha256", Hex: hex.EncodeToString(configDigest[:]), }, } return mCopy, nil } // RawManifest serializes our modified manifest with artifactType field. // The result is cached to ensure deterministic digests across multiple calls. func (w *weightManifestImage) RawManifest() ([]byte, error) { w.rawOnce.Do(func() { m, err := w.Manifest() if err != nil { w.rawManifestErr = err return } // Build the OCI manifest with artifactType (not in v1.Manifest struct) ociManifest := weightOCIManifest{ SchemaVersion: m.SchemaVersion, MediaType: m.MediaType, Config: m.Config, Layers: m.Layers, Annotations: m.Annotations, ArtifactType: MediaTypeWeightArtifact, } w.rawManifest, w.rawManifestErr = json.Marshal(ociManifest) }) return w.rawManifest, w.rawManifestErr } // ============================================================================= // Weight tag helpers // ============================================================================= const weightTagPrefix = "weights-" // WeightTag returns the tag for a weight manifest combining name and digest. // The digest should be in "sha256:abc123..." format. // Returns e.g., "weights-model-v1-abc123def456" (12-char hex suffix). // Falls back to "weights-" if digest is empty or invalid. func WeightTag(name, digest string) string { _, hex, ok := strings.Cut(digest, ":") if !ok || hex == "" { return weightTagPrefix + name } short := hex if len(short) > 12 { short = short[:12] } return weightTagPrefix + name + "-" + short } ================================================ FILE: pkg/model/weight_pusher_test.go ================================================ package model import ( "context" "encoding/json" "fmt" "os" "path/filepath" "sync" "testing" "time" v1 "github.com/google/go-containerregistry/pkg/v1" "github.com/google/go-containerregistry/pkg/v1/types" "github.com/stretchr/testify/require" "github.com/replicate/cog/pkg/registry" ) func TestWeightPusher_Push_ReturnsErrorForNilArtifact(t *testing.T) { reg := &mockRegistry{} pusher := NewWeightPusher(reg) _, err := pusher.Push(context.Background(), "r8.im/user/model", nil) require.Error(t, err) require.Contains(t, err.Error(), "artifact is nil") } func TestWeightPusher_Push_ReturnsErrorForMissingFile(t *testing.T) { reg := &mockRegistry{} pusher := NewWeightPusher(reg) artifact := NewWeightArtifact("model-v1", v1.Descriptor{}, "/nonexistent/path/weights.bin", "/weights/model.bin", WeightConfig{ SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "model-v1", Target: "/weights/model.bin", Created: time.Now().UTC(), }) _, err := pusher.Push(context.Background(), "r8.im/user/model", artifact) require.Error(t, err) require.Contains(t, err.Error(), "weight file") } func TestWeightPusher_Push_PushesCorrectOCIArtifact(t *testing.T) { // Create a temp weight file dir := t.TempDir() weightPath := filepath.Join(dir, "model.safetensors") weightContent := []byte("fake weight data for testing tarball layer creation") require.NoError(t, os.WriteFile(weightPath, weightContent, 0o644)) created := time.Date(2026, 2, 5, 12, 0, 0, 0, time.UTC) cfg := WeightConfig{ SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "model-v1", Target: "/weights/model.safetensors", Created: created, } desc := v1.Descriptor{ Digest: v1.Hash{Algorithm: "sha256", Hex: "aabbccddee112233445566778899aabb00112233445566778899aabbccddeeff"}, } artifact := NewWeightArtifact("model-v1", desc, weightPath, "/weights/model.safetensors", cfg) // Capture what gets pushed var pushedRefs []string var pushedImg v1.Image reg := &mockRegistry{ pushImageFunc: func(ctx context.Context, ref string, img v1.Image) error { pushedRefs = append(pushedRefs, ref) pushedImg = img return nil }, } pusher := NewWeightPusher(reg) result, err := pusher.Push(context.Background(), "r8.im/user/model", artifact) require.NoError(t, err) require.NotNil(t, result) // Verify the image was pushed with a single combined tag require.Len(t, pushedRefs, 1) require.Equal(t, "r8.im/user/model:weights-model-v1-aabbccddee11", pushedRefs[0]) require.NotNil(t, pushedImg) // Verify manifest structure manifest, err := pushedImg.Manifest() require.NoError(t, err) require.Equal(t, types.OCIManifestSchema1, manifest.MediaType) // Verify config blob has correct media type require.Equal(t, types.MediaType(MediaTypeWeightConfig), manifest.Config.MediaType) // Verify config blob content is correct WeightConfig JSON configBlob, err := pushedImg.RawConfigFile() require.NoError(t, err) var parsedConfig WeightConfig require.NoError(t, json.Unmarshal(configBlob, &parsedConfig)) require.Equal(t, "1.0", parsedConfig.SchemaVersion) require.Equal(t, "0.15.0", parsedConfig.CogVersion) require.Equal(t, "model-v1", parsedConfig.Name) require.Equal(t, "/weights/model.safetensors", parsedConfig.Target) require.Equal(t, created, parsedConfig.Created) // Verify there's exactly one layer (single file = single layer) require.Len(t, manifest.Layers, 1) // Verify layer media type require.Equal(t, types.MediaType(MediaTypeWeightLayer), manifest.Layers[0].MediaType) // Verify layer size matches the tarball wrapping of the weight file // (tarball will be larger than raw content due to tar headers) require.Greater(t, manifest.Layers[0].Size, int64(0)) // Verify the result contains a valid descriptor require.NotEmpty(t, result.Descriptor.Digest.String()) require.Greater(t, result.Descriptor.Size, int64(0)) } func TestWeightPusher_Push_PropagatesPushError(t *testing.T) { dir := t.TempDir() weightPath := filepath.Join(dir, "model.bin") require.NoError(t, os.WriteFile(weightPath, []byte("test"), 0o644)) artifact := NewWeightArtifact("model-v1", v1.Descriptor{}, weightPath, "/weights/model.bin", WeightConfig{ SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "model-v1", Target: "/weights/model.bin", Created: time.Now().UTC(), }) reg := &mockRegistry{ pushImageFunc: func(ctx context.Context, ref string, img v1.Image) error { return fmt.Errorf("unauthorized: authentication required") }, } pusher := NewWeightPusher(reg) _, err := pusher.Push(context.Background(), "r8.im/user/model", artifact) require.Error(t, err) require.Contains(t, err.Error(), "push weight manifest") require.Contains(t, err.Error(), "unauthorized") } func TestWeightPusher_Push_RawManifestContainsArtifactType(t *testing.T) { dir := t.TempDir() weightPath := filepath.Join(dir, "model.bin") require.NoError(t, os.WriteFile(weightPath, []byte("test weight data"), 0o644)) artifact := NewWeightArtifact("model-v1", v1.Descriptor{}, weightPath, "/weights/model.bin", WeightConfig{ SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "model-v1", Target: "/weights/model.bin", Created: time.Date(2026, 2, 5, 12, 0, 0, 0, time.UTC), }) var pushedImg v1.Image reg := &mockRegistry{ pushImageFunc: func(ctx context.Context, ref string, img v1.Image) error { pushedImg = img return nil }, } pusher := NewWeightPusher(reg) _, err := pusher.Push(context.Background(), "r8.im/user/model", artifact) require.NoError(t, err) // Parse raw manifest JSON to verify artifactType field rawManifest, err := pushedImg.RawManifest() require.NoError(t, err) var manifestJSON map[string]any require.NoError(t, json.Unmarshal(rawManifest, &manifestJSON)) // artifactType must be present at the manifest level (OCI 1.1) require.Equal(t, MediaTypeWeightArtifact, manifestJSON["artifactType"]) // config.mediaType must be the weight config type configMap, ok := manifestJSON["config"].(map[string]any) require.True(t, ok, "config should be an object") require.Equal(t, MediaTypeWeightConfig, configMap["mediaType"]) // layers should have exactly one entry with the weight layer media type layers, ok := manifestJSON["layers"].([]any) require.True(t, ok, "layers should be an array") require.Len(t, layers, 1) layerMap, ok := layers[0].(map[string]any) require.True(t, ok, "layer should be an object") require.Equal(t, MediaTypeWeightLayer, layerMap["mediaType"]) } func TestWeightPusher_Push_ReturnsErrorForEmptyRepo(t *testing.T) { dir := t.TempDir() weightPath := filepath.Join(dir, "model.bin") require.NoError(t, os.WriteFile(weightPath, []byte("test"), 0o644)) artifact := NewWeightArtifact("model-v1", v1.Descriptor{}, weightPath, "/weights/model.bin", WeightConfig{ SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "model-v1", Target: "/weights/model.bin", Created: time.Now().UTC(), }) reg := &mockRegistry{} pusher := NewWeightPusher(reg) _, err := pusher.Push(context.Background(), "", artifact) require.Error(t, err) require.Contains(t, err.Error(), "repo is required") } func TestWeightPusher_Push_ReportsProgressViaWriteLayer(t *testing.T) { dir := t.TempDir() weightPath := filepath.Join(dir, "model.bin") require.NoError(t, os.WriteFile(weightPath, []byte("test weight data for progress tracking"), 0o644)) artifact := NewWeightArtifact("model-v1", v1.Descriptor{}, weightPath, "/weights/model.bin", WeightConfig{ SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "model-v1", Target: "/weights/model.bin", Created: time.Now().UTC(), }) // Track progress updates received via callback var ( mu sync.Mutex progress []PushProgress ) // Mock WriteLayer to simulate progress updates (caller owns closing the channel) reg := &mockRegistry{ writeLayerFunc: func(ctx context.Context, opts registry.WriteLayerOptions) error { // Simulate progress updates like the real registry client if opts.ProgressCh != nil { opts.ProgressCh <- v1.Update{Complete: 500, Total: 1000} opts.ProgressCh <- v1.Update{Complete: 1000, Total: 1000} } return nil }, pushImageFunc: func(ctx context.Context, ref string, img v1.Image) error { return nil }, } pusher := NewWeightPusher(reg) result, err := pusher.Push(context.Background(), "r8.im/user/model", artifact, WeightPushOptions{ ProgressFn: func(p PushProgress) { mu.Lock() defer mu.Unlock() progress = append(progress, p) }, }) require.NoError(t, err) require.NotNil(t, result) // Verify we received progress updates mu.Lock() defer mu.Unlock() require.GreaterOrEqual(t, len(progress), 2, "should receive at least 2 progress updates") // Verify progress updates contain expected values require.Equal(t, int64(500), progress[0].Complete) require.Equal(t, int64(1000), progress[0].Total) require.Equal(t, int64(1000), progress[1].Complete) require.Equal(t, int64(1000), progress[1].Total) } func TestWeightPusher_Push_ForwardsRetryCallback(t *testing.T) { dir := t.TempDir() weightPath := filepath.Join(dir, "model.bin") require.NoError(t, os.WriteFile(weightPath, []byte("test weight data"), 0o644)) artifact := NewWeightArtifact("model-v1", v1.Descriptor{}, weightPath, "/weights/model.bin", WeightConfig{ SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "model-v1", Target: "/weights/model.bin", Created: time.Now().UTC(), }) // Mock WriteLayer to capture the retry config and invoke it var retryEvents []WeightRetryEvent reg := &mockRegistry{ writeLayerFunc: func(ctx context.Context, opts registry.WriteLayerOptions) error { // Simulate the registry invoking the retry callback if opts.Retry != nil && opts.Retry.OnRetry != nil { opts.Retry.OnRetry(registry.RetryEvent{ Attempt: 1, MaxAttempts: 3, Err: fmt.Errorf("connection reset"), NextRetryIn: 2 * time.Second, }) } return nil }, pushImageFunc: func(ctx context.Context, ref string, img v1.Image) error { return nil }, } pusher := NewWeightPusher(reg) _, err := pusher.Push(context.Background(), "r8.im/user/model", artifact, WeightPushOptions{ RetryFn: func(event WeightRetryEvent) bool { retryEvents = append(retryEvents, event) return true }, }) require.NoError(t, err) require.Len(t, retryEvents, 1) require.Equal(t, "model-v1", retryEvents[0].Name) require.Equal(t, 1, retryEvents[0].Attempt) require.Equal(t, 3, retryEvents[0].MaxAttempts) require.Contains(t, retryEvents[0].Err.Error(), "connection reset") require.Equal(t, 2*time.Second, retryEvents[0].NextRetryIn) } func TestWeightPusher_Push_WriteLayerErrorReported(t *testing.T) { dir := t.TempDir() weightPath := filepath.Join(dir, "model.bin") require.NoError(t, os.WriteFile(weightPath, []byte("test"), 0o644)) artifact := NewWeightArtifact("model-v1", v1.Descriptor{}, weightPath, "/weights/model.bin", WeightConfig{ SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "model-v1", Target: "/weights/model.bin", Created: time.Now().UTC(), }) reg := &mockRegistry{ writeLayerFunc: func(ctx context.Context, opts registry.WriteLayerOptions) error { return fmt.Errorf("upload failed: 503 Service Unavailable") }, } pusher := NewWeightPusher(reg) _, err := pusher.Push(context.Background(), "r8.im/user/model", artifact) require.Error(t, err) require.Contains(t, err.Error(), "push weight layer") require.Contains(t, err.Error(), "503 Service Unavailable") } func TestWeightPusher_Push_PropagatesContextCancellation(t *testing.T) { dir := t.TempDir() weightPath := filepath.Join(dir, "model.bin") require.NoError(t, os.WriteFile(weightPath, []byte("test"), 0o644)) artifact := NewWeightArtifact("model-v1", v1.Descriptor{}, weightPath, "/weights/model.bin", WeightConfig{ SchemaVersion: "1.0", CogVersion: "0.15.0", Name: "model-v1", Target: "/weights/model.bin", Created: time.Now().UTC(), }) ctx, cancel := context.WithCancel(context.Background()) cancel() // Cancel immediately reg := &mockRegistry{ pushImageFunc: func(ctx context.Context, ref string, img v1.Image) error { return ctx.Err() }, } pusher := NewWeightPusher(reg) _, err := pusher.Push(ctx, "r8.im/user/model", artifact) require.Error(t, err) require.Contains(t, err.Error(), "context canceled") } ================================================ FILE: pkg/model/weights.go ================================================ package model // WeightFile represents a single weight file entry in a weights lockfile or manifest. // The Name field is an identifier/handle (like a Docker volume name), not a filename. type WeightFile struct { // Name is the identifier/handle for this weight (e.g., "personaplex-7b-v1", "model-v42.5"). // This is a logical name that maps to deployment blob metadata, not a file path. Name string `json:"name"` // Dest is the mount path in the container (e.g., /cache/model.safetensors). Dest string `json:"dest"` // DigestOriginal is the SHA256 of the uncompressed file (canonical ID). DigestOriginal string `json:"digestOriginal"` // Digest is the SHA256 of the compressed blob (OCI layer ID). Digest string `json:"digest"` // Size is the compressed size in bytes. Size int64 `json:"size"` // SizeUncompressed is the original size in bytes. SizeUncompressed int64 `json:"sizeUncompressed"` // MediaType is the OCI layer media type (e.g., application/vnd.cog.weight.layer.v1+gzip). MediaType string `json:"mediaType"` // ContentType is the file's MIME type (e.g., application/octet-stream). ContentType string `json:"contentType,omitempty"` } ================================================ FILE: pkg/model/weights_lock.go ================================================ // pkg/model/weights_lock.go package model import ( "encoding/json" "fmt" "os" "time" ) // WeightsLockFilename is the default filename for the weights lock file. const WeightsLockFilename = "weights.lock" // WeightsLock represents a weights.lock file that pins weight file metadata. // This is a placeholder format that will be replaced by the declarative weights implementation. type WeightsLock struct { // Version is the lockfile format version. Version string `json:"version"` // Created is when the lockfile was generated. Created time.Time `json:"created"` // Files are the weight file entries. Files []WeightFile `json:"files"` } // ParseWeightsLock parses a weights.lock JSON document. func ParseWeightsLock(data []byte) (*WeightsLock, error) { var lock WeightsLock if err := json.Unmarshal(data, &lock); err != nil { return nil, fmt.Errorf("parse weights.lock: %w", err) } return &lock, nil } // LoadWeightsLock loads a weights.lock file from disk. func LoadWeightsLock(path string) (*WeightsLock, error) { data, err := os.ReadFile(path) if err != nil { return nil, fmt.Errorf("read weights.lock: %w", err) } return ParseWeightsLock(data) } // Save writes the weights.lock to disk. func (wl *WeightsLock) Save(path string) error { data, err := json.MarshalIndent(wl, "", " ") if err != nil { return fmt.Errorf("marshal weights.lock: %w", err) } if err := os.WriteFile(path, data, 0o644); err != nil { return fmt.Errorf("write weights.lock: %w", err) } return nil } ================================================ FILE: pkg/model/weights_lock_test.go ================================================ // pkg/model/weights_lock_test.go package model import ( "os" "path/filepath" "testing" "time" "github.com/stretchr/testify/require" ) func TestWeightsLock(t *testing.T) { t.Run("parse valid lockfile", func(t *testing.T) { json := `{ "version": "1", "created": "2026-01-30T12:00:00Z", "files": [ { "name": "model.safetensors", "dest": "/cache/model.safetensors", "digestOriginal": "sha256:abc123", "digest": "sha256:def456", "size": 1000, "sizeUncompressed": 2000, "mediaType": "application/vnd.cog.weights.layer.v1+gzip" } ] }` lock, err := ParseWeightsLock([]byte(json)) require.NoError(t, err) require.Equal(t, "1", lock.Version) require.Len(t, lock.Files, 1) require.Equal(t, "model.safetensors", lock.Files[0].Name) require.Equal(t, "/cache/model.safetensors", lock.Files[0].Dest) require.Equal(t, "sha256:abc123", lock.Files[0].DigestOriginal) require.Equal(t, "sha256:def456", lock.Files[0].Digest) require.Equal(t, int64(1000), lock.Files[0].Size) }) t.Run("load from file", func(t *testing.T) { dir := t.TempDir() lockPath := filepath.Join(dir, "weights.lock") content := `{"version": "1", "created": "2026-01-30T12:00:00Z", "files": []}` require.NoError(t, os.WriteFile(lockPath, []byte(content), 0o644)) lock, err := LoadWeightsLock(lockPath) require.NoError(t, err) require.Equal(t, "1", lock.Version) }) t.Run("save to file", func(t *testing.T) { dir := t.TempDir() lockPath := filepath.Join(dir, "weights.lock") lock := &WeightsLock{ Version: "1", Created: time.Date(2026, 1, 30, 12, 0, 0, 0, time.UTC), Files: []WeightFile{ {Name: "test.bin", Dest: "/cache/test.bin"}, }, } require.NoError(t, lock.Save(lockPath)) loaded, err := LoadWeightsLock(lockPath) require.NoError(t, err) require.Equal(t, lock.Version, loaded.Version) require.Len(t, loaded.Files, 1) }) } ================================================ FILE: pkg/model/weights_test.go ================================================ package model import ( "testing" "github.com/stretchr/testify/require" ) func TestWeightFile(t *testing.T) { t.Run("media type constants", func(t *testing.T) { require.Equal(t, "application/vnd.cog.weight.layer.v1+gzip", MediaTypeWeightLayerGzip) require.Equal(t, "application/vnd.cog.weight.v1", MediaTypeWeightArtifact) require.Equal(t, "application/vnd.cog.weight.layer.v1", MediaTypeWeightLayer) }) } ================================================ FILE: pkg/path/path.go ================================================ package path import ( go_path "path" "strconv" "strings" ) func TrimExt(s string) string { return strings.TrimSuffix(s, go_path.Ext(s)) } func IsExtInteger(ext string) bool { ext = strings.TrimPrefix(ext, ".") _, err := strconv.Atoi(ext) return err == nil } ================================================ FILE: pkg/path/path_test.go ================================================ package path import ( "testing" "github.com/stretchr/testify/require" ) func TestTrimExt(t *testing.T) { path := TrimExt("/mydir/myoutput.bmp") require.Equal(t, path, "/mydir/myoutput") } func TestIsExtInteger(t *testing.T) { require.True(t, IsExtInteger(".0")) } ================================================ FILE: pkg/predict/api.go ================================================ package predict import "github.com/replicate/cog/pkg/config" type HelpResponse struct { Arguments map[string]*config.RunArgument `json:"arguments"` } ================================================ FILE: pkg/predict/input.go ================================================ package predict import ( "encoding/json" "fmt" "os" "path/filepath" "reflect" "strconv" "strings" "github.com/getkin/kin-openapi/openapi3" "github.com/mitchellh/go-homedir" "github.com/vincent-petithory/dataurl" "github.com/replicate/cog/pkg/util/mime" ) type Input struct { String *string File *string Array *[]any Json *json.RawMessage Float *float32 Int *int32 } type Inputs map[string]Input func NewInputs(keyVals map[string][]string, schema *openapi3.T) (Inputs, error) { return NewInputsForMode(keyVals, schema, false) } func NewInputsForMode(keyVals map[string][]string, schema *openapi3.T, isTrain bool) (Inputs, error) { schemaKey := "Input" if isTrain { schemaKey = "TrainingInput" } var inputComponent *openapi3.SchemaRef for name, component := range schema.Components.Schemas { if name == schemaKey { inputComponent = component break } } // Fallback: if TrainingInput not found, try Input (legacy schemas) if inputComponent == nil && isTrain { for name, component := range schema.Components.Schemas { if name == "Input" { inputComponent = component break } } } input := Inputs{} for key, vals := range keyVals { if len(vals) == 1 { val := vals[0] if strings.HasPrefix(val, "@") { val = val[1:] input[key] = Input{File: &val} } else { // Check if we should explicitly parse the JSON based on a known schema if inputComponent != nil { properties, err := inputComponent.JSONLookup("properties") if err != nil { return input, err } propertiesSchemas := properties.(openapi3.Schemas) property, err := propertiesSchemas.JSONLookup(key) if err == nil { propertySchema := property.(*openapi3.Schema) // Resolve allOf/$ref to find the actual type. // cog-schema-gen emits allOf:[{$ref: ...}] for choices/enums, // where the referenced schema has the concrete type. propertySchema = resolveSchemaType(propertySchema) switch { case propertySchema.Type.Is("object"): encodedVal := json.RawMessage(val) input[key] = Input{Json: &encodedVal} continue case propertySchema.Type.Is("array"): var parsed any err := json.Unmarshal([]byte(val), &parsed) if err == nil { t := reflect.TypeOf(parsed) if t.Kind() == reflect.Slice || t.Kind() == reflect.Array { encodedVal := json.RawMessage(val) input[key] = Input{Json: &encodedVal} continue } } var arr = []any{val} input[key] = Input{Array: &arr} continue case propertySchema.Type.Is("number"): value, err := strconv.ParseInt(val, 10, 32) if err == nil { valueInt := int32(value) input[key] = Input{Int: &valueInt} continue } else { value, err := strconv.ParseFloat(val, 32) if err != nil { return input, err } float := float32(value) input[key] = Input{Float: &float} continue } case propertySchema.Type.Is("integer"): value, err := strconv.ParseInt(val, 10, 32) if err != nil { return input, err } valueInt := int32(value) input[key] = Input{Int: &valueInt} continue } } } input[key] = Input{String: &val} } } else if len(vals) > 1 { var anyVals = make([]any, len(vals)) for i, v := range vals { anyVals[i] = v } input[key] = Input{Array: &anyVals} } } return input, nil } func NewInputsWithBaseDir(keyVals map[string]string, baseDir string) Inputs { input := Inputs{} for key, val := range keyVals { if strings.HasPrefix(val, "@") { val = filepath.Join(baseDir, val[1:]) input[key] = Input{File: &val} } else { input[key] = Input{String: &val} } } return input } func (inputs *Inputs) toMap() (map[string]any, error) { keyVals := map[string]any{} for key, input := range *inputs { switch { case input.String != nil: // Directly assign the string value keyVals[key] = *input.String case input.File != nil: // Single file handling: read content and convert to a data URL dataURL, err := fileToDataURL(*input.File) if err != nil { return keyVals, err } keyVals[key] = dataURL case input.Array != nil: // Handle array, potentially containing file paths dataURLs := make([]string, len(*input.Array)) for i, elem := range *input.Array { if str, ok := elem.(string); ok && strings.HasPrefix(str, "@") { filePath := str[1:] // Remove '@' prefix dataURL, err := fileToDataURL(filePath) if err != nil { return keyVals, err } dataURLs[i] = dataURL } else if ok { // Directly use the string if it's not a file path dataURLs[i] = str } } keyVals[key] = dataURLs case input.Json != nil: keyVals[key] = *input.Json case input.Float != nil: keyVals[key] = *input.Float case input.Int != nil: keyVals[key] = *input.Int } } return keyVals, nil } // Helper function to read file content and convert to a data URL func fileToDataURL(filePath string) (string, error) { // Expand home directory if necessary expandedVal, err := homedir.Expand(filePath) if err != nil { return "", fmt.Errorf("error expanding homedir for '%s': %w", filePath, err) } content, err := os.ReadFile(expandedVal) if err != nil { return "", err } mimeType := mime.TypeByExtension(filepath.Ext(expandedVal)) dataURL := dataurl.New(content, mimeType).String() return dataURL, nil } // resolveSchemaType walks through allOf/anyOf/$ref wrappers to find a schema // that has a concrete Type set. This is needed because the static schema gen // emits allOf:[{$ref: "#/components/schemas/Foo"}] for enum/choices fields, // where the referenced schema carries the type (e.g. "integer") but the wrapper does not. func resolveSchemaType(s *openapi3.Schema) *openapi3.Schema { if s.Type != nil && s.Type.Slice() != nil { return s } // Check allOf entries for _, ref := range s.AllOf { if ref.Value != nil && ref.Value.Type != nil && ref.Value.Type.Slice() != nil { return ref.Value } } // Check anyOf entries for _, ref := range s.AnyOf { if ref.Value != nil && ref.Value.Type != nil && ref.Value.Type.Slice() != nil { return ref.Value } } return s } ================================================ FILE: pkg/predict/predictor.go ================================================ package predict import ( "bytes" "context" "encoding/json" "fmt" "io" "net/http" "strings" "time" "github.com/getkin/kin-openapi/openapi3" "github.com/replicate/cog/pkg/docker" "github.com/replicate/cog/pkg/docker/command" "github.com/replicate/cog/pkg/global" "github.com/replicate/cog/pkg/util/console" ) type status string type HealthcheckResponse struct { Status string `json:"status"` } type RequestContext struct { ReplicateAPIToken string `json:"replicate_api_token,omitempty"` } type Request struct { // TODO: could this be Inputs? Input map[string]any `json:"input"` Context RequestContext `json:"context"` } type Response struct { Status status `json:"status"` Output *any `json:"output"` Error string `json:"error"` } type ValidationErrorResponse struct { Detail []struct { Location []string `json:"loc"` Message string `json:"msg"` Type string `json:"type"` } `json:"detail"` } type Predictor struct { runOptions command.RunOptions isTrain bool dockerClient command.Command // Running state containerID string port int } func NewPredictor(ctx context.Context, runOptions command.RunOptions, isTrain bool, dockerCommand command.Command) (*Predictor, error) { if global.Debug { runOptions.Env = append(runOptions.Env, "COG_LOG_LEVEL=debug") } else { runOptions.Env = append(runOptions.Env, "COG_LOG_LEVEL=warning") } return &Predictor{ runOptions: runOptions, isTrain: isTrain, dockerClient: dockerCommand, }, nil } func (p *Predictor) Start(ctx context.Context, logsWriter io.Writer, timeout time.Duration) error { var err error containerPort := 5000 p.runOptions.Ports = append(p.runOptions.Ports, command.Port{HostPort: 0, ContainerPort: containerPort}) p.containerID, err = docker.RunDaemon(ctx, p.dockerClient, p.runOptions, logsWriter) if err != nil { return fmt.Errorf("Failed to start container: %w", err) } p.port, err = docker.GetHostPortForContainer(ctx, p.dockerClient, p.containerID, containerPort) if err != nil { return fmt.Errorf("Failed to determine container port: %w", err) } go func() { if err := p.dockerClient.ContainerLogs(ctx, p.containerID, logsWriter); err != nil { // if user hits ctrl-c we expect an error signal if !strings.Contains(err.Error(), "signal: interrupt") { console.Warnf("Error getting container logs: %s", err) } } }() return p.waitForContainerReady(ctx, timeout) } func (p *Predictor) waitForContainerReady(ctx context.Context, timeout time.Duration) error { url := fmt.Sprintf("http://localhost:%d/health-check", p.port) start := time.Now() for { if time.Since(start) > timeout { return fmt.Errorf("Timed out") } time.Sleep(100 * time.Millisecond) cont, err := p.dockerClient.ContainerInspect(ctx, p.containerID) if err != nil { return fmt.Errorf("Failed to get container status: %w", err) } if cont.State != nil && (cont.State.Status == "exited" || cont.State.Status == "dead") { return fmt.Errorf("Container exited unexpectedly") } healthcheck, err := func() (*HealthcheckResponse, error) { ctx, cancel := context.WithTimeout(ctx, 250*time.Millisecond) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, fmt.Errorf("Failed to create HTTP request to %s: %w", url, err) } resp, err := http.DefaultClient.Do(req) //nolint:gosec // G704: URL from localhost health check if err != nil { return nil, nil } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, nil } healthcheck := &HealthcheckResponse{} if err := json.NewDecoder(resp.Body).Decode(healthcheck); err != nil { return nil, fmt.Errorf("Container healthcheck returned invalid response: %w", err) } return healthcheck, nil }() if err != nil { return err } if healthcheck == nil { continue } // These status values are defined in python/cog/server/http.py switch healthcheck.Status { case "STARTING": continue case "SETUP_FAILED": return fmt.Errorf("Model setup failed") case "READY": return nil default: return fmt.Errorf("Container healthcheck returned unexpected status: %s", healthcheck.Status) } } } func (p *Predictor) Stop(ctx context.Context) error { return p.dockerClient.ContainerStop(ctx, p.containerID) } func (p *Predictor) Predict(inputs Inputs, context RequestContext) (*Response, error) { inputMap, err := inputs.toMap() if err != nil { return nil, err } request := Request{ Input: inputMap, Context: context, } requestBody, err := json.Marshal(request) if err != nil { return nil, err } url := p.url() req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(requestBody)) if err != nil { return nil, fmt.Errorf("Failed to create HTTP request to %s: %w", url, err) } req.Header.Set("Content-Type", "application/json") req.Close = true httpClient := &http.Client{} resp, err := httpClient.Do(req) //nolint:gosec // G704: URL from localhost prediction endpoint if err != nil { return nil, fmt.Errorf("Failed to POST HTTP request to %s: %w", url, err) } defer resp.Body.Close() if resp.StatusCode == http.StatusUnprocessableEntity { errorResponse := &ValidationErrorResponse{} if err := json.NewDecoder(resp.Body).Decode(errorResponse); err != nil { return nil, fmt.Errorf("/%s call returned status 422, and the response body failed to decode: %w", p.endpoint(), err) } return nil, p.buildInputValidationErrorMessage(errorResponse) } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("/%s call returned status %d", p.endpoint(), resp.StatusCode) } prediction := &Response{} if err = json.NewDecoder(resp.Body).Decode(prediction); err != nil { return nil, fmt.Errorf("Failed to decode prediction response: %w", err) } return prediction, nil } func (p *Predictor) GetSchema() (*openapi3.T, error) { resp, err := http.Get(fmt.Sprintf("http://localhost:%d/openapi.json", p.port)) if err != nil { return nil, err } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("Failed to get OpenAPI schema: %d", resp.StatusCode) } body, err := io.ReadAll(resp.Body) if err != nil { return nil, err } return openapi3.NewLoader().LoadFromData(body) } func (p *Predictor) endpoint() string { if p.isTrain { return "trainings" } return "predictions" } func (p *Predictor) url() string { return fmt.Sprintf("http://localhost:%d/%s", p.port, p.endpoint()) } func (p *Predictor) buildInputValidationErrorMessage(errorResponse *ValidationErrorResponse) error { errorMessages := []string{} for _, validationError := range errorResponse.Detail { if len(validationError.Location) != 3 || validationError.Location[0] != "body" || validationError.Location[1] != "input" { responseBody, _ := json.MarshalIndent(errorResponse, "", "\t") return fmt.Errorf("/%s call returned status 422, and there was an unexpected message in response:\n\n%s", p.endpoint(), responseBody) } errorMessages = append(errorMessages, fmt.Sprintf("- %s: %s", validationError.Location[2], validationError.Message)) } command := "predict" if p.isTrain { command = "train" } return fmt.Errorf( `The inputs you passed could not be validated: %[2]s You can provide an input with -i. For example: cog %[1]s -i blur=3.5 If your input is a local file, you need to prefix the path with @ to tell Cog to read the file contents. For example: cog %[1]s -i path=@image.jpg`, command, strings.Join(errorMessages, "\n"), ) } ================================================ FILE: pkg/provider/generic/generic.go ================================================ package generic import ( "bufio" "context" "fmt" "os" "strings" "golang.org/x/term" "github.com/replicate/cog/pkg/docker" "github.com/replicate/cog/pkg/provider" "github.com/replicate/cog/pkg/util/console" ) // GenericProvider works with any OCI-compliant registry type GenericProvider struct{} // New creates a new GenericProvider func New() *GenericProvider { return &GenericProvider{} } func (p *GenericProvider) Name() string { return "generic" } func (p *GenericProvider) MatchesRegistry(host string) bool { return true // Fallback - matches everything } func (p *GenericProvider) Login(ctx context.Context, opts provider.LoginOptions) error { console.InfoUnformattedf("Logging in to %s", opts.Host) console.InfoUnformatted("") // TODO: support non-interactive login with token stdin for generic registries // Prompt for username fmt.Print("Username: ") reader := bufio.NewReader(os.Stdin) username, err := reader.ReadString('\n') if err != nil { return fmt.Errorf("failed to read username: %w", err) } username = strings.TrimSpace(username) if username == "" { return fmt.Errorf("username cannot be empty") } // Prompt for password (hidden input) fmt.Print("Password: ") passwordBytes, err := term.ReadPassword(int(os.Stdin.Fd())) //nolint:gosec // G115: Fd() fits in int on all supported platforms if err != nil { return fmt.Errorf("failed to read password: %w", err) } fmt.Println() // newline after hidden input password := string(passwordBytes) if password == "" { return fmt.Errorf("password cannot be empty") } // Save credentials using Docker's credential system if err := docker.SaveLoginToken(ctx, opts.Host, username, password); err != nil { return fmt.Errorf("failed to save credentials: %w", err) } console.Successf("Login succeeded for %s", console.Bold(opts.Host)) return nil } func (p *GenericProvider) PostPush(ctx context.Context, opts provider.PushOptions, pushErr error) error { // No special post-push handling for generic registries // Just show a simple success message if push succeeded if pushErr == nil { console.Successf("Image %s pushed", console.Bold(opts.Image)) } return nil } // Verify interface compliance at compile time var _ provider.Provider = (*GenericProvider)(nil) ================================================ FILE: pkg/provider/generic/generic_test.go ================================================ package generic import ( "context" "errors" "testing" "github.com/stretchr/testify/require" "github.com/replicate/cog/pkg/provider" ) func TestGenericProvider_Name(t *testing.T) { p := New() require.Equal(t, "generic", p.Name()) } func TestGenericProvider_MatchesRegistry(t *testing.T) { p := New() // Generic provider matches everything (it's the fallback) require.True(t, p.MatchesRegistry("ghcr.io")) require.True(t, p.MatchesRegistry("docker.io")) require.True(t, p.MatchesRegistry("ecr.aws")) require.True(t, p.MatchesRegistry("anything.example.com")) } func TestGenericProvider_Login(t *testing.T) { // Login() prompts for username/password interactively and saves credentials // via Docker's credential system. This cannot be easily tested without mocking // stdin and the docker credential helpers. // // The Login method: // 1. Prompts for username (from stdin) // 2. Prompts for password (hidden input via terminal) // 3. Calls docker.SaveLoginToken() to store credentials // // For integration testing, use manual testing with 'cog login --registry ' t.Skip("Login requires interactive input - test manually") } func TestGenericProvider_PostPush(t *testing.T) { p := New() t.Run("success", func(t *testing.T) { opts := provider.PushOptions{ Image: "ghcr.io/org/model", } err := p.PostPush(context.Background(), opts, nil) require.NoError(t, err) }) t.Run("with error", func(t *testing.T) { opts := provider.PushOptions{ Image: "ghcr.io/org/model", } pushErr := errors.New("push failed") err := p.PostPush(context.Background(), opts, pushErr) require.NoError(t, err) // PostPush itself doesn't error }) } ================================================ FILE: pkg/provider/provider.go ================================================ package provider import ( "context" "github.com/replicate/cog/pkg/config" ) // PushOptions contains all options for a push operation type PushOptions struct { Image string Config *config.Config ProjectDir string } type LoginOptions struct { TokenStdin bool Host string } // Provider encapsulates registry-specific behavior type Provider interface { // Name returns the provider identifier (e.g., "replicate", "generic") Name() string // MatchesRegistry returns true if this provider handles the given registry host MatchesRegistry(host string) bool // Login performs provider-specific authentication Login(ctx context.Context, opts LoginOptions) error // PostPush is called after push attempt (success or failure) // - Shows success message (e.g., Replicate model URL) // - May transform errors into provider-specific messages // - pushErr is nil on success, contains the push error on failure PostPush(ctx context.Context, opts PushOptions, pushErr error) error } ================================================ FILE: pkg/provider/registry.go ================================================ package provider import ( "strings" "sync" ) // defaultRegistry is the global singleton registry var defaultRegistry *Registry // DefaultRegistry returns the global provider registry, initializing it on first call // The registry is pre-populated with Replicate and Generic providers func DefaultRegistry() *Registry { if defaultRegistry == nil { defaultRegistry = NewRegistry() // Note: providers are registered by init() functions in their respective packages // via RegisterProvider(), or can be set up explicitly } return defaultRegistry } // RegisterProvider adds a provider to the default registry // This should be called from init() functions in provider packages func RegisterProvider(p Provider) { DefaultRegistry().Register(p) } // Registry manages provider lookup and registration type Registry struct { providers []Provider mu sync.RWMutex } // NewRegistry creates a new Registry with no providers registered func NewRegistry() *Registry { return &Registry{ providers: make([]Provider, 0), } } // Register adds a provider to the registry // Providers are checked in registration order, so register more specific // providers before generic fallback providers func (r *Registry) Register(p Provider) { r.mu.Lock() defer r.mu.Unlock() r.providers = append(r.providers, p) } // ForImage returns the provider for a given image name // It extracts the registry host from the image and delegates to ForHost func (r *Registry) ForImage(image string) Provider { host := ExtractHost(image) return r.ForHost(host) } // ForHost returns the provider for a given registry host // Returns the first provider that matches, or nil if none match func (r *Registry) ForHost(host string) Provider { r.mu.RLock() defer r.mu.RUnlock() for _, p := range r.providers { if p.MatchesRegistry(host) { return p } } return nil } // ExtractHost extracts the registry host from an image name // Examples: // - "r8.im/user/model" -> "r8.im" // - "ghcr.io/owner/repo:tag" -> "ghcr.io" // - "gcr.io/project/image" -> "gcr.io" // - "docker.io/library/nginx" -> "docker.io" // - "nginx" -> "docker.io" (Docker Hub default) // - "myregistry.com:5000/image" -> "myregistry.com:5000" // - "localhost:5000/image" -> "localhost:5000" func ExtractHost(image string) string { // Handle empty image if image == "" { return "docker.io" } // Remove digest first (@sha256:...) if idx := strings.Index(image, "@"); idx != -1 { image = image[:idx] } // Get the first component (everything before the first slash) // If there's no slash, it's a Docker Hub image (e.g., "nginx" or "nginx:latest") firstComponent, _, found := strings.Cut(image, "/") if !found { return "docker.io" } // Check if it looks like a registry host: // - Contains a dot (e.g., gcr.io, ghcr.io, r8.im, myregistry.com) // - Contains a colon (e.g., localhost:5000, myregistry.com:5000) // - Is "localhost" if strings.Contains(firstComponent, ".") || strings.Contains(firstComponent, ":") || firstComponent == "localhost" { return firstComponent } // Otherwise it's a Docker Hub user/image (e.g., "user/image") return "docker.io" } ================================================ FILE: pkg/provider/registry_test.go ================================================ package provider import ( "context" "testing" "github.com/stretchr/testify/require" ) // mockProvider implements Provider for testing type mockProvider struct { name string matches func(host string) bool } func (m *mockProvider) Name() string { return m.name } func (m *mockProvider) MatchesRegistry(host string) bool { return m.matches(host) } func (m *mockProvider) Login(ctx context.Context, opts LoginOptions) error { return nil } func (m *mockProvider) PrePush(ctx context.Context, opts PushOptions) error { return nil } func (m *mockProvider) PostPush(ctx context.Context, opts PushOptions, pushErr error) error { return nil } func TestRegistry_ForHost(t *testing.T) { r := NewRegistry() replicateProvider := &mockProvider{ name: "replicate", matches: func(host string) bool { return host == "r8.im" }, } genericProvider := &mockProvider{ name: "generic", matches: func(host string) bool { return true }, } // Register replicate first (more specific), then generic (fallback) r.Register(replicateProvider) r.Register(genericProvider) t.Run("matches replicate", func(t *testing.T) { p := r.ForHost("r8.im") require.NotNil(t, p) require.Equal(t, "replicate", p.Name()) }) t.Run("falls back to generic", func(t *testing.T) { p := r.ForHost("ghcr.io") require.NotNil(t, p) require.Equal(t, "generic", p.Name()) }) t.Run("empty host falls back to generic", func(t *testing.T) { p := r.ForHost("") require.NotNil(t, p) require.Equal(t, "generic", p.Name()) }) } func TestRegistry_ForImage(t *testing.T) { r := NewRegistry() replicateProvider := &mockProvider{ name: "replicate", matches: func(host string) bool { return host == "r8.im" }, } genericProvider := &mockProvider{ name: "generic", matches: func(host string) bool { return true }, } r.Register(replicateProvider) r.Register(genericProvider) tests := []struct { image string expectedName string }{ {"r8.im/user/model", "replicate"}, {"r8.im/user/model:v1", "replicate"}, {"ghcr.io/owner/repo", "generic"}, {"gcr.io/project/image:tag", "generic"}, {"docker.io/library/nginx", "generic"}, {"nginx", "generic"}, {"myregistry.com/image", "generic"}, } for _, tt := range tests { t.Run(tt.image, func(t *testing.T) { p := r.ForImage(tt.image) require.NotNil(t, p) require.Equal(t, tt.expectedName, p.Name()) }) } } func TestRegistry_NoProviders(t *testing.T) { r := NewRegistry() p := r.ForHost("any.registry.io") require.Nil(t, p) } func TestExtractHost(t *testing.T) { tests := []struct { image string expected string }{ // Replicate {"r8.im/user/model", "r8.im"}, {"r8.im/user/model:v1", "r8.im"}, {"r8.im/user/model:latest", "r8.im"}, // GitHub Container Registry {"ghcr.io/owner/repo", "ghcr.io"}, {"ghcr.io/owner/repo:tag", "ghcr.io"}, // Google Container Registry {"gcr.io/project/image", "gcr.io"}, {"gcr.io/project/image:tag", "gcr.io"}, {"us.gcr.io/project/image", "us.gcr.io"}, // Docker Hub explicit {"docker.io/library/nginx", "docker.io"}, {"docker.io/user/image", "docker.io"}, // Docker Hub implicit (no registry specified) {"nginx", "docker.io"}, {"nginx:latest", "docker.io"}, {"user/image", "docker.io"}, {"user/image:tag", "docker.io"}, // Custom registries {"myregistry.com/image", "myregistry.com"}, {"myregistry.example.com/path/to/image", "myregistry.example.com"}, // Registries with ports {"localhost:5000/image", "localhost:5000"}, {"myregistry.com:5000/image", "myregistry.com:5000"}, {"myregistry.com:5000/image:tag", "myregistry.com:5000"}, // With digest {"ghcr.io/owner/repo@sha256:abc123", "ghcr.io"}, // Edge cases {"", "docker.io"}, {"localhost/image", "localhost"}, } for _, tt := range tests { t.Run(tt.image, func(t *testing.T) { result := ExtractHost(tt.image) require.Equal(t, tt.expected, result, "ExtractHost(%q)", tt.image) }) } } ================================================ FILE: pkg/provider/replicate/replicate.go ================================================ package replicate import ( "bufio" "context" "encoding/json" "fmt" "net/http" "net/url" "os" "os/exec" "runtime" "strings" "golang.org/x/term" "github.com/replicate/cog/pkg/docker" "github.com/replicate/cog/pkg/docker/command" "github.com/replicate/cog/pkg/global" "github.com/replicate/cog/pkg/provider" "github.com/replicate/cog/pkg/util/console" ) // ReplicateProvider handles Replicate's r8.im registry type ReplicateProvider struct{} // New creates a new ReplicateProvider func New() *ReplicateProvider { return &ReplicateProvider{} } func (p *ReplicateProvider) Name() string { return "replicate" } func (p *ReplicateProvider) MatchesRegistry(host string) bool { // Only match the default Replicate registry host (r8.im) // Note: We don't use global.ReplicateRegistryHost here because that variable // gets updated by the --registry flag, which would cause us to incorrectly // match any registry the user specifies. return host == global.DefaultReplicateRegistryHost } // Login performs login to the registry with options func (p *ReplicateProvider) Login(ctx context.Context, opts provider.LoginOptions) error { var ( token string err error ) if opts.TokenStdin { token, err = readTokenFromStdin() if err != nil { return err } } else { token, err = readTokenInteractively(opts.Host) if err != nil { return err } } token = strings.TrimSpace(token) if err := checkTokenFormat(token); err != nil { return err } username, err := verifyToken(opts.Host, token) if err != nil { return err } if err := docker.SaveLoginToken(ctx, opts.Host, username, token); err != nil { return err } console.Successf("You've successfully authenticated as %s! You can now use the %s registry.", console.Bold(username), console.Bold(opts.Host)) return nil } func (p *ReplicateProvider) PostPush(ctx context.Context, opts provider.PushOptions, pushErr error) error { if pushErr != nil { // Return Replicate-specific error message for repository not found errors if command.IsNotFoundError(pushErr) { return fmt.Errorf(`Unable to find existing Replicate model for %s. Go to replicate.com and create a new model before pushing. If the model already exists, you may be getting this error because you're not logged in as owner of the model. This can happen if you did 'sudo cog login' instead of 'cog login' or 'sudo cog push' instead of 'cog push', which causes Docker to use the wrong Docker credentials.`, opts.Image) } return pushErr } // Success - show Replicate model URL console.Successf("Image %s pushed", console.Bold(opts.Image)) replicatePage := fmt.Sprintf("https://%s", strings.Replace(opts.Image, global.ReplicateRegistryHost, global.ReplicateWebsiteHost, 1)) console.Infof("\nRun your model on Replicate:\n %s", console.Bold(replicatePage)) return nil } // readTokenFromStdin reads the authentication token from stdin func readTokenFromStdin() (string, error) { tokenBytes, err := os.ReadFile("/dev/stdin") if err != nil { return "", fmt.Errorf("failed to read token from stdin: %w", err) } return string(tokenBytes), nil } // readTokenInteractively guides user through browser-based token flow func readTokenInteractively(registryHost string) (string, error) { tokenURL, err := getDisplayTokenURL(registryHost) if err != nil { return "", err } console.InfoUnformattedf("This command will authenticate Docker with Replicate's '%s' Docker registry. You will need a Replicate account.", registryHost) console.InfoUnformatted("") console.InfoUnformatted("Hit enter to get started. A browser will open with an authentication token that you need to paste here.") inputReader := os.Stdin inputFd := int(os.Stdin.Fd()) //nolint:gosec // G115: Fd() fits in int on all supported platforms reader := bufio.NewReader(inputReader) if _, err := reader.ReadString('\n'); err != nil { return "", err } console.InfoUnformatted("If it didn't open automatically, open this URL in a web browser:") console.InfoUnformatted(tokenURL) maybeOpenBrowser(tokenURL) console.InfoUnformatted("") console.InfoUnformatted("Once you've signed in, copy the token from that web page, paste it here, then hit enter:") console.InfoUnformatted("") fmt.Print("Token: ") // Read the token securely, masking the input tokenBytes, err := term.ReadPassword(inputFd) if err != nil { return "", fmt.Errorf("failed to read token: %w", err) } // Print a newline after the hidden input fmt.Println() console.InfoUnformatted("") return string(tokenBytes), nil } // getDisplayTokenURL fetches the token URL from Replicate's API func getDisplayTokenURL(registryHost string) (string, error) { resp, err := http.Get(addressWithScheme(registryHost) + "/cog/v1/display-token-url") if err != nil { return "", fmt.Errorf("failed to log in to %s: %w", registryHost, err) } defer resp.Body.Close() if resp.StatusCode == http.StatusNotFound { return "", fmt.Errorf("%s is not the Replicate registry\nPlease log in using 'docker login'", registryHost) } if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("%s returned HTTP status %d", registryHost, resp.StatusCode) } body := &struct { URL string `json:"url"` }{} if err := json.NewDecoder(resp.Body).Decode(body); err != nil { return "", err } return body.URL, nil } // addressWithScheme ensures the address has an https:// scheme func addressWithScheme(address string) string { if strings.Contains(address, "://") { return address } return "https://" + address } // maybeOpenBrowser attempts to open the URL in the default browser func maybeOpenBrowser(urlToOpen string) { switch runtime.GOOS { case "linux": _ = exec.Command("xdg-open", urlToOpen).Start() case "windows": _ = exec.Command("rundll32", "url.dll,FileProtocolHandler", urlToOpen).Start() case "darwin": _ = exec.Command("open", urlToOpen).Start() } } // checkTokenFormat validates the token isn't an API token func checkTokenFormat(token string) error { if strings.HasPrefix(token, "r8_") { return fmt.Errorf("that looks like a Replicate API token, not a CLI auth token. Please fetch a token from https://replicate.com/auth/token to log in") } return nil } // verifyToken validates the token with Replicate and returns the username func verifyToken(registryHost string, token string) (username string, err error) { if token == "" { return "", fmt.Errorf("token is empty") } resp, err := http.PostForm(addressWithScheme(registryHost)+"/cog/v1/verify-token", url.Values{ "token": []string{token}, }) if err != nil { return "", fmt.Errorf("failed to verify token: %w", err) } defer resp.Body.Close() if resp.StatusCode == http.StatusNotFound { return "", fmt.Errorf("user does not exist") } if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("failed to verify token, got status %d", resp.StatusCode) } body := &struct { Username string `json:"username"` }{} if err := json.NewDecoder(resp.Body).Decode(body); err != nil { return "", err } return body.Username, nil } // Verify interface compliance at compile time var _ provider.Provider = (*ReplicateProvider)(nil) ================================================ FILE: pkg/provider/replicate/replicate_test.go ================================================ package replicate import ( "context" "encoding/json" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/require" "github.com/replicate/cog/pkg/docker/command" "github.com/replicate/cog/pkg/global" "github.com/replicate/cog/pkg/provider" ) func TestReplicateProvider_Name(t *testing.T) { p := New() require.Equal(t, "replicate", p.Name()) } func TestReplicateProvider_MatchesRegistry(t *testing.T) { p := New() // Should match default r8.im require.True(t, p.MatchesRegistry("r8.im")) // Should match the current global registry host (in case it was overridden) require.True(t, p.MatchesRegistry(global.ReplicateRegistryHost)) // Should not match other registries require.False(t, p.MatchesRegistry("ghcr.io")) require.False(t, p.MatchesRegistry("docker.io")) require.False(t, p.MatchesRegistry("gcr.io")) require.False(t, p.MatchesRegistry("myregistry.example.com")) } func TestReplicateProvider_PostPush(t *testing.T) { p := New() opts := provider.PushOptions{ Image: "r8.im/user/model", } t.Run("success", func(t *testing.T) { err := p.PostPush(context.Background(), opts, nil) require.NoError(t, err) }) t.Run("repository not found error", func(t *testing.T) { // Simulate a NotFoundError from docker push (repository doesn't exist) pushErr := &command.NotFoundError{Ref: "r8.im/user/model", Object: "repository"} err := p.PostPush(context.Background(), opts, pushErr) require.Error(t, err) require.Contains(t, err.Error(), "Unable to find existing Replicate model") require.Contains(t, err.Error(), "replicate.com and create a new model") }) t.Run("tag not found error", func(t *testing.T) { // Tag not found errors should also trigger the helpful message pushErr := &command.NotFoundError{Ref: "r8.im/user/model:v1", Object: "tag"} err := p.PostPush(context.Background(), opts, pushErr) require.Error(t, err) require.Contains(t, err.Error(), "Unable to find existing Replicate model") }) } func TestCheckTokenFormat(t *testing.T) { tests := []struct { name string token string wantErr bool }{ { name: "valid CLI token", token: "abc123def456", wantErr: false, }, { name: "API token rejected", token: "r8_abc123", wantErr: true, }, { name: "empty token allowed (separate validation)", token: "", wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := checkTokenFormat(tt.token) if tt.wantErr { require.Error(t, err) require.Contains(t, err.Error(), "API token") } else { require.NoError(t, err) } }) } } func TestVerifyToken(t *testing.T) { t.Run("successful verification", func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Equal(t, "/cog/v1/verify-token", r.URL.Path) require.Equal(t, "POST", r.Method) w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]string{"username": "testuser"}) })) defer server.Close() username, err := verifyToken(server.URL, "valid-token") require.NoError(t, err) require.Equal(t, "testuser", username) }) t.Run("empty token", func(t *testing.T) { _, err := verifyToken("http://localhost", "") require.Error(t, err) require.Contains(t, err.Error(), "empty") }) t.Run("user not found", func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) })) defer server.Close() _, err := verifyToken(server.URL, "unknown-token") require.Error(t, err) require.Contains(t, err.Error(), "does not exist") }) t.Run("server error", func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) })) defer server.Close() _, err := verifyToken(server.URL, "some-token") require.Error(t, err) require.Contains(t, err.Error(), "500") }) } func TestGetDisplayTokenURL(t *testing.T) { t.Run("successful fetch", func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Equal(t, "/cog/v1/display-token-url", r.URL.Path) require.Equal(t, "GET", r.Method) w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]string{"url": "https://replicate.com/auth/token"}) })) defer server.Close() url, err := getDisplayTokenURL(server.URL) require.NoError(t, err) require.Equal(t, "https://replicate.com/auth/token", url) }) t.Run("not replicate registry", func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) })) defer server.Close() _, err := getDisplayTokenURL(server.URL) require.Error(t, err) require.Contains(t, err.Error(), "not the Replicate registry") }) } func TestAddressWithScheme(t *testing.T) { tests := []struct { input string expected string }{ {"r8.im", "https://r8.im"}, {"https://r8.im", "https://r8.im"}, {"http://localhost:8080", "http://localhost:8080"}, {"myregistry.com", "https://myregistry.com"}, } for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { result := addressWithScheme(tt.input) require.Equal(t, tt.expected, result) }) } } ================================================ FILE: pkg/provider/setup/setup.go ================================================ // Package setup initializes the default provider registry package setup import ( "sync" "github.com/replicate/cog/pkg/provider" "github.com/replicate/cog/pkg/provider/generic" "github.com/replicate/cog/pkg/provider/replicate" ) var once sync.Once // Init initializes the default provider registry with all built-in providers // This function is idempotent - it only runs once even if called multiple times func Init() { once.Do(func() { registry := provider.DefaultRegistry() // Register Replicate provider first (more specific) registry.Register(replicate.New()) // Register Generic provider last (fallback for any OCI registry) registry.Register(generic.New()) }) } ================================================ FILE: pkg/provider/setup/setup_test.go ================================================ package setup import ( "testing" "github.com/stretchr/testify/require" "github.com/replicate/cog/pkg/provider" ) func TestInit(t *testing.T) { // Call Init multiple times - should be idempotent Init() Init() registry := provider.DefaultRegistry() // Replicate images should get the Replicate provider p := registry.ForImage("r8.im/user/model") require.NotNil(t, p) require.Equal(t, "replicate", p.Name()) // Other images should get the Generic provider p = registry.ForImage("ghcr.io/owner/repo") require.NotNil(t, p) require.Equal(t, "generic", p.Name()) // Docker Hub images should get Generic provider p = registry.ForImage("nginx") require.NotNil(t, p) require.Equal(t, "generic", p.Name()) } ================================================ FILE: pkg/registry/client.go ================================================ package registry import ( "context" "time" v1 "github.com/google/go-containerregistry/pkg/v1" "github.com/google/go-containerregistry/pkg/v1/remote" ) type Platform struct { OS string Architecture string Variant string } type PlatformManifest struct { Digest string MediaType string Size int64 OS string Architecture string Variant string Annotations map[string]string } // RetryEvent contains information about a retry attempt. type RetryEvent struct { // Attempt is the current retry attempt number (1-indexed). Attempt int // MaxAttempts is the maximum number of retry attempts. MaxAttempts int // Err is the error that caused the retry. Err error // NextRetryIn is the duration until the next retry attempt. NextRetryIn time.Duration } // RetryCallback is called when a retry occurs. Return false to abort retrying. type RetryCallback func(event RetryEvent) bool // RetryConfig configures retry behavior for registry operations. type RetryConfig struct { // Backoff configures the exponential backoff for retries. // If nil, the default backoff from go-containerregistry is used (3 attempts, 1s initial, 3x factor). Backoff *remote.Backoff // OnRetry is called when a retry occurs. If nil, no callback is invoked. // The callback receives information about the retry attempt. OnRetry RetryCallback } // WriteLayerOptions configures the WriteLayer operation. type WriteLayerOptions struct { // Repo is the repository to push to. Repo string // Layer is the layer to push. Layer v1.Layer // ProgressCh receives progress updates. Use a buffered channel to avoid deadlocks. // If nil, no progress updates are sent. ProgressCh chan<- v1.Update // Retry configures retry behavior. If nil, default retry behavior is used // (5 attempts with exponential backoff starting at 2 seconds). Retry *RetryConfig } type Client interface { // Read methods Inspect(ctx context.Context, imageRef string, platform *Platform) (*ManifestResult, error) GetImage(ctx context.Context, imageRef string, platform *Platform) (v1.Image, error) Exists(ctx context.Context, imageRef string) (bool, error) // GetDescriptor returns the OCI descriptor for an image reference without downloading // the full image. This is a lightweight HEAD request useful for building OCI indexes // from already-pushed manifests. GetDescriptor(ctx context.Context, imageRef string) (v1.Descriptor, error) // Write methods for OCI index support PushImage(ctx context.Context, ref string, img v1.Image) error PushIndex(ctx context.Context, ref string, idx v1.ImageIndex) error // WriteLayer pushes a single layer (blob) to a repository with retry and optional progress reporting. // This method handles transient failures automatically with exponential backoff. // Use WriteLayerOptions to configure progress reporting and retry callbacks. WriteLayer(ctx context.Context, opts WriteLayerOptions) error } ================================================ FILE: pkg/registry/client_test.go ================================================ package registry import ( "context" "encoding/json" "errors" "io" "net" "os" "syscall" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/replicate/cog/pkg/registry_testhelpers" ) func TestInspect(t *testing.T) { if testing.Short() { // TODO[md]: this is a hack to skip the test in GitHub Actions because // because macos runners don't have rootless docker. this should get added back // and be part of a normal integration suite we run on all target platforms t.Skip("skipping integration tests") } registry := registry_testhelpers.StartTestRegistry(t) t.Run("it returns an index for multi-platform images when a platform isn't provided", func(t *testing.T) { imageRef := registry.ImageRef("alpine:latest") client := NewRegistryClient() resp, err := client.Inspect(t.Context(), imageRef, nil) require.NoError(t, err) require.NotNil(t, resp) assert.True(t, resp.IsIndex(), "expected index") json.NewEncoder(os.Stdout).Encode(resp) }) t.Run("it returns a single platform image when a platform is provided", func(t *testing.T) { imageRef := registry.ImageRef("alpine:latest") client := NewRegistryClient() resp, err := client.Inspect(t.Context(), imageRef, &Platform{OS: "linux", Architecture: "amd64"}) require.NoError(t, err) require.NotNil(t, resp) assert.False(t, resp.IsIndex(), "expected single platform image") assert.True(t, resp.IsSinglePlatform(), "expected single platform image") json.NewEncoder(os.Stdout).Encode(resp) }) t.Run("when a repo does not exist", func(t *testing.T) { imageRef := registry.ImageRef("i-do-not-exist:latest") client := NewRegistryClient() resp, err := client.Inspect(t.Context(), imageRef, nil) assert.ErrorIs(t, err, NotFoundError, "expected not found error") assert.Nil(t, resp) }) t.Run("when a repo with a slashdoes not exist", func(t *testing.T) { imageRef := registry.ImageRef("i-do-not-exist/with-a-slash:latest") client := NewRegistryClient() resp, err := client.Inspect(t.Context(), imageRef, nil) assert.ErrorIs(t, err, NotFoundError, "expected not found error") assert.Nil(t, resp) }) t.Run("when the repo exists but the tag does not", func(t *testing.T) { imageRef := registry.ImageRef("alpine:not-found") client := NewRegistryClient() resp, err := client.Inspect(t.Context(), imageRef, nil) assert.ErrorIs(t, err, NotFoundError, "expected not found error") assert.Nil(t, resp) }) t.Run("when the repo and tag exist but platform does not", func(t *testing.T) { imageRef := registry.ImageRef("alpine:latest") client := NewRegistryClient() resp, err := client.Inspect(t.Context(), imageRef, &Platform{OS: "windows", Architecture: "i386"}) assert.ErrorContains(t, err, "platform not found") assert.Nil(t, resp) }) } func TestIsRetryableError(t *testing.T) { tests := []struct { name string err error expected bool }{ { name: "nil error", err: nil, expected: false, }, { name: "io.EOF", err: io.EOF, expected: true, }, { name: "io.ErrUnexpectedEOF", err: io.ErrUnexpectedEOF, expected: true, }, { name: "syscall.EPIPE (broken pipe)", err: syscall.EPIPE, expected: true, }, { name: "syscall.ECONNRESET", err: syscall.ECONNRESET, expected: true, }, { name: "net.ErrClosed", err: net.ErrClosed, expected: true, }, { name: "context.Canceled", err: context.Canceled, expected: false, }, { name: "context.DeadlineExceeded", err: context.DeadlineExceeded, expected: false, }, { name: "generic error (not retryable)", err: errors.New("something completely different"), expected: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := isRetryableError(tt.err) assert.Equal(t, tt.expected, result, "isRetryableError(%v) = %v, want %v", tt.err, result, tt.expected) }) } } ================================================ FILE: pkg/registry/config.go ================================================ package registry import ( "os" "strconv" ) const ( // DefaultChunkSize is the size (in bytes) of each chunk in a multipart upload. // This is used as a fallback when the registry does not advertise chunk size // limits via OCI-Chunk-Min-Length / OCI-Chunk-Max-Length headers. // 96 MB stays under common CDN/proxy request body limits while still being // large enough to reduce HTTP round-trips for multi-GB files. DefaultChunkSize = 96 * 1024 * 1024 // 96 MB // DefaultMultipartThreshold is the minimum blob size (in bytes) before using multipart upload. // Blobs smaller than this are uploaded in a single request to avoid multipart overhead. // Set higher than DefaultChunkSize so that blobs that would fit in a single chunk // are uploaded in one request, avoiding unnecessary multipart overhead. DefaultMultipartThreshold = 128 * 1024 * 1024 // 128 MB // chunkSizeMargin is subtracted from the server's OCI-Chunk-Max-Length to stay // safely under the limit (e.g. for HTTP framing overhead). chunkSizeMargin = 64 * 1024 // 64 KB // envPushDefaultChunkSize sets the default chunk size for multipart uploads. // This is only used when the registry does not advertise OCI-Chunk-Max-Length. // When the registry does advertise a maximum, the server's limit takes precedence. envPushDefaultChunkSize = "COG_PUSH_DEFAULT_CHUNK_SIZE" // envMultipartThreshold overrides the minimum blob size for multipart uploads. envMultipartThreshold = "COG_PUSH_MULTIPART_THRESHOLD" ) // getDefaultChunkSize returns the client-configured default chunk size for multipart uploads. // This is used as a fallback when the registry does not advertise chunk size limits. func getDefaultChunkSize() int64 { if v := os.Getenv(envPushDefaultChunkSize); v != "" { if n, err := strconv.ParseInt(v, 10, 64); err == nil && n > 0 { return n } } return DefaultChunkSize } // getMultipartThreshold returns the minimum blob size for multipart uploads. func getMultipartThreshold() int64 { if v := os.Getenv(envMultipartThreshold); v != "" { if n, err := strconv.ParseInt(v, 10, 64); err == nil && n > 0 { return n } } return DefaultMultipartThreshold } ================================================ FILE: pkg/registry/config_test.go ================================================ package registry import ( "testing" "github.com/stretchr/testify/assert" ) func TestGetDefaultChunkSize(t *testing.T) { t.Run("returns default when env not set", func(t *testing.T) { t.Setenv(envPushDefaultChunkSize, "") assert.Equal(t, int64(DefaultChunkSize), getDefaultChunkSize()) }) t.Run("returns env var value", func(t *testing.T) { t.Setenv(envPushDefaultChunkSize, "134217728") // 128 MB assert.Equal(t, int64(134217728), getDefaultChunkSize()) }) t.Run("returns default for invalid value", func(t *testing.T) { t.Setenv(envPushDefaultChunkSize, "xyz") assert.Equal(t, int64(DefaultChunkSize), getDefaultChunkSize()) }) t.Run("returns default for zero", func(t *testing.T) { t.Setenv(envPushDefaultChunkSize, "0") assert.Equal(t, int64(DefaultChunkSize), getDefaultChunkSize()) }) t.Run("returns default for negative", func(t *testing.T) { t.Setenv(envPushDefaultChunkSize, "-100") assert.Equal(t, int64(DefaultChunkSize), getDefaultChunkSize()) }) } func TestEffectiveChunkSize(t *testing.T) { t.Run("uses client default when server provides no limits", func(t *testing.T) { t.Setenv(envPushDefaultChunkSize, "") s := uploadSession{} assert.Equal(t, int64(DefaultChunkSize), s.effectiveChunkSize()) }) t.Run("uses env var default when server provides no limits", func(t *testing.T) { t.Setenv(envPushDefaultChunkSize, "50000000") // 50 MB s := uploadSession{} assert.Equal(t, int64(50000000), s.effectiveChunkSize()) }) t.Run("server max takes precedence over client default", func(t *testing.T) { t.Setenv(envPushDefaultChunkSize, "") serverMax := int64(90 * 1024 * 1024) s := uploadSession{ChunkMaxBytes: serverMax} expected := serverMax - chunkSizeMargin assert.Equal(t, expected, s.effectiveChunkSize()) }) t.Run("server max takes precedence even when larger than client default", func(t *testing.T) { t.Setenv(envPushDefaultChunkSize, "") // Server max of 200 MB -- server still dictates, not the client default serverMax := int64(200 * 1024 * 1024) s := uploadSession{ChunkMaxBytes: serverMax} expected := serverMax - chunkSizeMargin assert.Equal(t, expected, s.effectiveChunkSize()) }) t.Run("server max takes precedence over env var", func(t *testing.T) { t.Setenv(envPushDefaultChunkSize, "50000000") // 50 MB -- ignored when server provides max serverMax := int64(100 * 1000 * 1000) // 100 MB s := uploadSession{ChunkMaxBytes: serverMax} expected := serverMax - chunkSizeMargin assert.Equal(t, expected, s.effectiveChunkSize()) }) t.Run("handles very small server max gracefully", func(t *testing.T) { t.Setenv(envPushDefaultChunkSize, "") s := uploadSession{ChunkMaxBytes: 1000} // 1000 bytes, smaller than margin // Margin is bigger than max, so we use the max directly assert.Equal(t, int64(1000), s.effectiveChunkSize()) }) t.Run("server min does not raise chunk size when already above it", func(t *testing.T) { t.Setenv(envPushDefaultChunkSize, "") // Server says min=5MiB max=90MiB; max-margin is well above min, so min has no effect serverMax := int64(90 * 1024 * 1024) s := uploadSession{ChunkMinBytes: 5 * 1024 * 1024, ChunkMaxBytes: serverMax} expected := serverMax - chunkSizeMargin assert.Equal(t, expected, s.effectiveChunkSize()) }) t.Run("server min clamps up a too-small client default", func(t *testing.T) { t.Setenv(envPushDefaultChunkSize, "1000") // 1 KB, below server min serverMin := int64(5 * 1024 * 1024) // 5 MiB s := uploadSession{ChunkMinBytes: serverMin} assert.Equal(t, serverMin, s.effectiveChunkSize()) }) t.Run("server min clamps up when max minus margin falls below min", func(t *testing.T) { t.Setenv(envPushDefaultChunkSize, "") // Contrived: max is just above min, so max-margin < min. Min should win. serverMin := int64(5 * 1024 * 1024) serverMax := serverMin + chunkSizeMargin/2 // max - margin < min s := uploadSession{ChunkMinBytes: serverMin, ChunkMaxBytes: serverMax} assert.Equal(t, serverMin, s.effectiveChunkSize()) }) } func TestGetMultipartThreshold(t *testing.T) { t.Run("returns default when env not set", func(t *testing.T) { t.Setenv(envMultipartThreshold, "") assert.Equal(t, int64(DefaultMultipartThreshold), getMultipartThreshold()) }) t.Run("returns env var value", func(t *testing.T) { t.Setenv(envMultipartThreshold, "104857600") // 100 MB assert.Equal(t, int64(104857600), getMultipartThreshold()) }) t.Run("returns default for invalid value", func(t *testing.T) { t.Setenv(envMultipartThreshold, "abc") assert.Equal(t, int64(DefaultMultipartThreshold), getMultipartThreshold()) }) t.Run("returns default for zero", func(t *testing.T) { t.Setenv(envMultipartThreshold, "0") assert.Equal(t, int64(DefaultMultipartThreshold), getMultipartThreshold()) }) t.Run("returns default for negative", func(t *testing.T) { t.Setenv(envMultipartThreshold, "-50") assert.Equal(t, int64(DefaultMultipartThreshold), getMultipartThreshold()) }) } ================================================ FILE: pkg/registry/manifest_result.go ================================================ package registry import "github.com/google/go-containerregistry/pkg/v1/types" type ManifestResult struct { SchemaVersion int64 MediaType string // Digest is the content-addressable digest of the manifest (sha256:...). Digest string Manifests []PlatformManifest Layers []string Config string Labels map[string]string } func (m *ManifestResult) IsIndex() bool { return m.MediaType == string(types.OCIImageIndex) || m.MediaType == string(types.DockerManifestList) } func (m *ManifestResult) IsSinglePlatform() bool { return !m.IsIndex() } ================================================ FILE: pkg/registry/push_test.go ================================================ package registry import ( "context" "testing" v1 "github.com/google/go-containerregistry/pkg/v1" "github.com/google/go-containerregistry/pkg/v1/empty" "github.com/google/go-containerregistry/pkg/v1/mutate" "github.com/google/go-containerregistry/pkg/v1/types" "github.com/stretchr/testify/require" "github.com/replicate/cog/pkg/registry_testhelpers" ) func TestPushOperations(t *testing.T) { if testing.Short() { t.Skip("skipping integration tests in short mode") } // Start a test registry using testcontainers registry := registry_testhelpers.StartTestRegistry(t) registryAddr := registry.RegistryHost() ctx := context.Background() client := NewRegistryClient() t.Run("push image", func(t *testing.T) { img := empty.Image img, _ = mutate.Config(img, v1.Config{}) err := client.PushImage(ctx, registryAddr+"/test/push-test:v1", img) require.NoError(t, err) // Verify it exists exists, err := client.Exists(ctx, registryAddr+"/test/push-test:v1") require.NoError(t, err) require.True(t, exists) }) t.Run("push index", func(t *testing.T) { img := empty.Image img, _ = mutate.Config(img, v1.Config{}) // Push the child image first — PushIndex only writes the index manifest, // it does not recursively push child manifests/blobs. err := client.PushImage(ctx, registryAddr+"/test/push-idx:child", img) require.NoError(t, err) idx := mutate.IndexMediaType(empty.Index, types.OCIImageIndex) idx = mutate.AppendManifests(idx, mutate.IndexAddendum{ Add: img, Descriptor: v1.Descriptor{Platform: &v1.Platform{OS: "linux", Architecture: "amd64"}}, }) err = client.PushIndex(ctx, registryAddr+"/test/push-idx:v1", idx) require.NoError(t, err) // Verify it exists exists, err := client.Exists(ctx, registryAddr+"/test/push-idx:v1") require.NoError(t, err) require.True(t, exists) }) } ================================================ FILE: pkg/registry/registry_client.go ================================================ package registry import ( "bytes" "context" "crypto/tls" "errors" "fmt" "io" "math/rand" "net" "net/http" "net/url" "slices" "strconv" "sync" "syscall" "time" "github.com/google/go-containerregistry/pkg/authn" "github.com/google/go-containerregistry/pkg/name" v1 "github.com/google/go-containerregistry/pkg/v1" "github.com/google/go-containerregistry/pkg/v1/remote" "github.com/google/go-containerregistry/pkg/v1/remote/transport" "github.com/google/go-containerregistry/pkg/v1/types" ) //nolint:staticcheck // ST1012: exported API, renaming would be breaking change var NotFoundError = errors.New("image reference not found") // chunkBufPool pools byte slices used for multipart chunk reads to avoid // re-allocating large buffers (up to ~96 MB each) on every layer upload. var chunkBufPool = sync.Pool{} //nolint:gochecknoglobals type RegistryClient struct { // transport is the HTTP transport used for all registry operations. // Uses HTTP/1.1 only to avoid HTTP/2 head-of-line blocking and stream // errors (RST_STREAM INTERNAL_ERROR) that occur when uploading large // blobs through CDN/proxy edges. transport http.RoundTripper } func NewRegistryClient() Client { return &RegistryClient{ transport: http1OnlyTransport(), } } // remoteOptions returns the common remote.Option set for go-containerregistry calls, // including authentication, context, and HTTP/1.1 transport. func (c *RegistryClient) remoteOptions(ctx context.Context) []remote.Option { return []remote.Option{ remote.WithContext(ctx), remote.WithAuthFromKeychain(authn.DefaultKeychain), remote.WithTransport(c.transport), } } func (c *RegistryClient) Inspect(ctx context.Context, imageRef string, platform *Platform) (*ManifestResult, error) { ref, err := name.ParseReference(imageRef, name.Insecure) if err != nil { return nil, fmt.Errorf("parsing reference: %w", err) } desc, err := remote.Get(ref, c.remoteOptions(ctx)...) if err != nil { if checkError(err, transport.ManifestUnknownErrorCode, transport.NameUnknownErrorCode) { return nil, NotFoundError } return nil, fmt.Errorf("fetching descriptor: %w", err) } mediaType := desc.MediaType if platform == nil { switch mediaType { case types.OCIImageIndex, types.DockerManifestList: idx, err := desc.ImageIndex() if err != nil { return nil, fmt.Errorf("loading image index: %w", err) } indexManifest, err := idx.IndexManifest() if err != nil { return nil, fmt.Errorf("getting index manifest: %w", err) } result := &ManifestResult{ SchemaVersion: indexManifest.SchemaVersion, MediaType: string(mediaType), Digest: desc.Digest.String(), } for _, m := range indexManifest.Manifests { result.Manifests = append(result.Manifests, PlatformManifest{ Digest: m.Digest.String(), MediaType: string(m.MediaType), Size: m.Size, OS: m.Platform.OS, Architecture: m.Platform.Architecture, Variant: m.Platform.Variant, Annotations: m.Annotations, }) } // For indexes, pick a default image to get labels from. // Prefer linux/amd64, otherwise use the first manifest. defaultImg, err := pickDefaultImage(ref, indexManifest, c.remoteOptions(ctx)...) if err != nil { return nil, fmt.Errorf("failed to read image config from index: %w", err) } configFile, err := defaultImg.ConfigFile() if err != nil { return nil, fmt.Errorf("failed to get image config: %w", err) } result.Labels = configFile.Config.Labels return result, nil case types.OCIManifestSchema1, types.DockerManifestSchema2: img, err := desc.Image() if err != nil { return nil, fmt.Errorf("loading image: %w", err) } manifest, err := img.Manifest() if err != nil { return nil, fmt.Errorf("getting manifest: %w", err) } configFile, err := img.ConfigFile() if err != nil { return nil, fmt.Errorf("getting config file: %w", err) } result := &ManifestResult{ SchemaVersion: manifest.SchemaVersion, MediaType: string(mediaType), Digest: desc.Digest.String(), Config: manifest.Config.Digest.String(), Labels: configFile.Config.Labels, } for _, layer := range manifest.Layers { result.Layers = append(result.Layers, layer.Digest.String()) } return result, nil default: return nil, fmt.Errorf("unsupported media type: %s", mediaType) } } // platform is set, we expect a manifest list or error if mediaType != types.OCIImageIndex && mediaType != types.DockerManifestList { return nil, fmt.Errorf("image is not a manifest list but platform was specified") } idx, err := desc.ImageIndex() if err != nil { return nil, fmt.Errorf("loading image index: %w", err) } indexManifest, err := idx.IndexManifest() if err != nil { return nil, fmt.Errorf("getting index manifest: %w", err) } var matchedDigest string for _, m := range indexManifest.Manifests { if m.Platform.OS == platform.OS && m.Platform.Architecture == platform.Architecture && m.Platform.Variant == platform.Variant { matchedDigest = m.Digest.String() break } } if matchedDigest == "" { return nil, fmt.Errorf("platform not found in manifest list") } digestRef, err := name.NewDigest(ref.Context().Name() + "@" + matchedDigest) if err != nil { return nil, fmt.Errorf("creating digest ref: %w", err) } manifestDesc, err := remote.Get(digestRef, c.remoteOptions(ctx)...) if err != nil { return nil, fmt.Errorf("fetching platform manifest: %w", err) } img, err := manifestDesc.Image() if err != nil { return nil, fmt.Errorf("loading platform image: %w", err) } manifest, err := img.Manifest() if err != nil { return nil, fmt.Errorf("getting manifest: %w", err) } configFile, err := img.ConfigFile() if err != nil { return nil, fmt.Errorf("getting config file: %w", err) } result := &ManifestResult{ SchemaVersion: manifest.SchemaVersion, MediaType: string(manifestDesc.MediaType), Digest: manifestDesc.Digest.String(), Config: manifest.Config.Digest.String(), Labels: configFile.Config.Labels, } for _, layer := range manifest.Layers { result.Layers = append(result.Layers, layer.Digest.String()) } return result, nil } func (c *RegistryClient) GetImage(ctx context.Context, imageRef string, platform *Platform) (v1.Image, error) { ref, err := name.ParseReference(imageRef, name.Insecure) if err != nil { return nil, fmt.Errorf("parsing reference: %w", err) } desc, err := remote.Get(ref, c.remoteOptions(ctx)...) if err != nil { return nil, fmt.Errorf("fetching descriptor: %w", err) } mediaType := desc.MediaType // If no platform is specified and it's a single image, return it directly if platform == nil { switch mediaType { case types.OCIManifestSchema1, types.DockerManifestSchema2: return desc.Image() case types.OCIImageIndex, types.DockerManifestList: return nil, fmt.Errorf("platform must be specified for multi-platform image") default: return nil, fmt.Errorf("unsupported media type: %s", mediaType) } } // For platform-specific requests, we need to handle manifest lists if mediaType != types.OCIImageIndex && mediaType != types.DockerManifestList { return nil, fmt.Errorf("image is not a manifest list but platform was specified") } idx, err := desc.ImageIndex() if err != nil { return nil, fmt.Errorf("loading image index: %w", err) } indexManifest, err := idx.IndexManifest() if err != nil { return nil, fmt.Errorf("getting index manifest: %w", err) } // Find the matching platform manifest var matchedDigest string for _, m := range indexManifest.Manifests { if m.Platform.OS == platform.OS && m.Platform.Architecture == platform.Architecture && m.Platform.Variant == platform.Variant { matchedDigest = m.Digest.String() break } } if matchedDigest == "" { return nil, fmt.Errorf("platform not found in manifest list") } // Get the image for the matched digest digestRef, err := name.NewDigest(ref.Context().Name() + "@" + matchedDigest) if err != nil { return nil, fmt.Errorf("creating digest ref: %w", err) } manifestDesc, err := remote.Get(digestRef, c.remoteOptions(ctx)...) if err != nil { return nil, fmt.Errorf("fetching platform manifest: %w", err) } return manifestDesc.Image() } // GetDescriptor returns the OCI descriptor for an image reference using a HEAD request. // This is lightweight — it does not download the full manifest or image layers. func (c *RegistryClient) GetDescriptor(ctx context.Context, imageRef string) (v1.Descriptor, error) { ref, err := name.ParseReference(imageRef, name.Insecure) if err != nil { return v1.Descriptor{}, fmt.Errorf("parsing reference: %w", err) } desc, err := remote.Head(ref, c.remoteOptions(ctx)...) if err != nil { return v1.Descriptor{}, fmt.Errorf("head request for %s: %w", imageRef, err) } return *desc, nil } func (c *RegistryClient) Exists(ctx context.Context, imageRef string) (bool, error) { if _, err := c.Inspect(ctx, imageRef, nil); err != nil { if errors.Is(err, NotFoundError) { return false, nil } return false, err } return true, nil } func checkError(err error, codes ...transport.ErrorCode) bool { if err == nil { return false } var e *transport.Error if errors.As(err, &e) { for _, diagnosticErr := range e.Errors { if slices.Contains(codes, diagnosticErr.Code) { return true } } } return false } // PushImage pushes a single image to a registry. func (c *RegistryClient) PushImage(ctx context.Context, ref string, img v1.Image) error { parsedRef, err := name.ParseReference(ref, name.Insecure) if err != nil { return fmt.Errorf("parsing reference: %w", err) } if err := remote.Write(parsedRef, img, c.remoteOptions(ctx)...); err != nil { return fmt.Errorf("pushing image %s: %w", ref, err) } return nil } // PushIndex pushes an OCI Image Index to a registry. func (c *RegistryClient) PushIndex(ctx context.Context, ref string, idx v1.ImageIndex) error { parsedRef, err := name.ParseReference(ref, name.Insecure) if err != nil { return fmt.Errorf("parsing reference: %w", err) } // Use remote.Put instead of remote.WriteIndex because all child manifests // (image + weights) are already pushed to the registry. WriteIndex would // try to recursively resolve and push children via idx.Image(), which fails // for our descriptor-only index. Put just writes the index manifest. if err := remote.Put(parsedRef, idx, c.remoteOptions(ctx)...); err != nil { return fmt.Errorf("pushing index %s: %w", ref, err) } return nil } // http1OnlyTransport returns an http.Transport that only speaks HTTP/1.1. // HTTP/2 is avoided for all registry operations because high-throughput uploads // suffer from head-of-line blocking and stream errors (RST_STREAM INTERNAL_ERROR) // when pushed through CDN/proxy edges. Multiple concurrent HTTP/1.1 connections // outperform a single multiplexed HTTP/2 connection for large blob uploads. func http1OnlyTransport() *http.Transport { t := http.DefaultTransport.(*http.Transport).Clone() t.TLSClientConfig = tlsConfigHTTP1Only(t.TLSClientConfig) // ForceAttemptHTTP2 is true by default on cloned transports; disable it. t.ForceAttemptHTTP2 = false return t } // tlsConfigHTTP1Only returns a TLS config that only advertises HTTP/1.1 via ALPN. func tlsConfigHTTP1Only(base *tls.Config) *tls.Config { if base == nil { base = &tls.Config{MinVersion: tls.VersionTLS12} } cfg := base.Clone() cfg.NextProtos = []string{"http/1.1"} return cfg } // DefaultRetryBackoff returns the default retry backoff configuration for weight pushes. // It retries 5 times with exponential backoff starting at 2 seconds. func DefaultRetryBackoff() remote.Backoff { return remote.Backoff{ Duration: 2 * time.Second, Factor: 2.0, Jitter: 0.1, Steps: 5, } } // isRetryableError determines if an error should trigger a retry. // This matches the go-containerregistry default retry predicate plus additional cases. func isRetryableError(err error) bool { if err == nil { return false } // Check for context cancellation - don't retry these if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return false } // Check for temporary errors (network issues, etc.) var tempErr interface{ Temporary() bool } if errors.As(err, &tempErr) && tempErr.Temporary() { return true } // Check for common transient errors if errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, io.EOF) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) || errors.Is(err, net.ErrClosed) { return true } // Check for retryable HTTP status codes in transport errors var transportErr *transport.Error if errors.As(err, &transportErr) { switch transportErr.StatusCode { case http.StatusRequestTimeout, http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout, 499, // nginx-specific, client closed request 522: // Cloudflare-specific, connection timeout return true } } // Check for network operation errors (connection refused, timeout, etc.) var netErr *net.OpError if errors.As(err, &netErr) { return true } // Check for DNS errors var dnsErr *net.DNSError if errors.As(err, &dnsErr) { return dnsErr.Temporary() } return false } // WriteLayer pushes a single layer with retry and optional progress reporting. // This implements retry at the application level with callbacks for CLI feedback. // Unlike the standard remote.WriteLayer, this implementation performs multipart uploads // using Content-Range headers to upload the blob in chunks. func (c *RegistryClient) WriteLayer(ctx context.Context, opts WriteLayerOptions) error { parsedRepo, err := name.NewRepository(opts.Repo, name.Insecure) if err != nil { return fmt.Errorf("parsing repository: %w", err) } // Determine retry configuration backoff := DefaultRetryBackoff() if opts.Retry != nil && opts.Retry.Backoff != nil { backoff = *opts.Retry.Backoff } var lastErr error currentDelay := backoff.Duration for attempt := 1; attempt <= backoff.Steps; attempt++ { // Check for context cancellation select { case <-ctx.Done(): return ctx.Err() default: } // Attempt the push using custom multipart upload err := c.writeLayerMultipart(ctx, parsedRepo, opts) if err == nil { return nil // Success } lastErr = err // Check if this error is retryable if !isRetryableError(err) { return fmt.Errorf("pushing layer to %s: %w", opts.Repo, err) } // Don't retry if this was the last attempt if attempt >= backoff.Steps { break } // Calculate next delay with randomized jitter to avoid thundering herd nextDelay := currentDelay if backoff.Jitter > 0 { jitterAmount := time.Duration(float64(currentDelay) * backoff.Jitter * rand.Float64()) //nolint:gosec nextDelay = currentDelay + jitterAmount } // Invoke retry callback if configured if opts.Retry != nil && opts.Retry.OnRetry != nil { event := RetryEvent{ Attempt: attempt, MaxAttempts: backoff.Steps, Err: err, NextRetryIn: nextDelay, } if !opts.Retry.OnRetry(event) { // Callback returned false, abort retrying return fmt.Errorf("pushing layer to %s (retry aborted): %w", opts.Repo, err) } } // Wait before retrying select { case <-ctx.Done(): return ctx.Err() case <-time.After(nextDelay): } // Update delay for next iteration currentDelay = time.Duration(float64(currentDelay) * backoff.Factor) } return fmt.Errorf("pushing layer to %s (after %d attempts): %w", opts.Repo, backoff.Steps, lastErr) } // writeLayerMultipart uploads a layer using multipart uploads with Content-Range headers. // This is a custom implementation that supports chunked uploads compatible with the // server-side code provided. func (c *RegistryClient) writeLayerMultipart(ctx context.Context, repo name.Repository, opts WriteLayerOptions) error { // Get layer metadata digest, err := opts.Layer.Digest() if err != nil { return fmt.Errorf("getting layer digest: %w", err) } size, err := opts.Layer.Size() if err != nil { return fmt.Errorf("getting layer size: %w", err) } // Create authenticated HTTP client auth, err := authn.Resolve(ctx, authn.DefaultKeychain, repo) if err != nil { return fmt.Errorf("resolving auth: %w", err) } scopes := []string{repo.Scope(transport.PushScope)} tr, err := transport.NewWithContext(ctx, repo.Registry, auth, c.transport, scopes) if err != nil { return fmt.Errorf("creating transport: %w", err) } client := &http.Client{Transport: tr} // Check if blob already exists exists, err := c.checkBlobExists(ctx, client, repo, digest) if err != nil { return fmt.Errorf("checking blob existence: %w", err) } if exists { if opts.ProgressCh != nil { opts.ProgressCh <- v1.Update{Complete: size, Total: size} } return nil } // Initiate upload session, err := c.initiateUpload(ctx, client, repo) if err != nil { return fmt.Errorf("initiating upload: %w", err) } // Upload the blob in chunks finalLocation, err := c.uploadBlobChunks(ctx, client, repo, opts.Layer, session, size, opts.ProgressCh) if err != nil { return fmt.Errorf("uploading blob chunks: %w", err) } // Commit the upload using the final location (which contains updated state hash) err = c.commitUpload(ctx, client, finalLocation, digest) if err != nil { return fmt.Errorf("committing upload: %w", err) } return nil } // checkBlobExists checks if a blob already exists in the repository. func (c *RegistryClient) checkBlobExists(ctx context.Context, client *http.Client, repo name.Repository, digest v1.Hash) (bool, error) { u := url.URL{ Scheme: repo.Scheme(), Host: repo.RegistryStr(), Path: fmt.Sprintf("/v2/%s/blobs/%s", repo.RepositoryStr(), digest.String()), } req, err := http.NewRequestWithContext(ctx, http.MethodHead, u.String(), nil) if err != nil { return false, err } resp, err := client.Do(req) //nolint:gosec // G704: URL from registry reference, not user input if err != nil { return false, err } defer resp.Body.Close() if err := transport.CheckError(resp, http.StatusOK, http.StatusNotFound); err != nil { return false, err } return resp.StatusCode == http.StatusOK, nil } // uploadSession holds the result of initiating a blob upload, including the // upload location URL and any server-advertised chunk size constraints. type uploadSession struct { // Location is the URL to which blob data should be uploaded. Location string // ChunkMinBytes is the minimum chunk size the server accepts (from OCI-Chunk-Min-Length). // Zero means the server did not advertise a minimum. ChunkMinBytes int64 // ChunkMaxBytes is the maximum chunk size the server accepts (from OCI-Chunk-Max-Length). // Zero means the server did not advertise a maximum. ChunkMaxBytes int64 } // effectiveChunkSize returns the chunk size to use for uploads. // The server-advertised OCI-Chunk-Max-Length always takes precedence: when // present, we use (max - margin) to stay safely under the limit regardless // of any client-side configuration. The result is also clamped to be at least // OCI-Chunk-Min-Length when the server advertises one. The client default // (COG_PUSH_DEFAULT_CHUNK_SIZE env var or DefaultChunkSize) is only used when // the server does not advertise a maximum. func (s uploadSession) effectiveChunkSize() int64 { var chunkSize = getDefaultChunkSize() // Start with client default as baseline if s.ChunkMaxBytes > 0 { // Server advertised a maximum — use it minus a small margin. chunkSize = s.ChunkMaxBytes - chunkSizeMargin if chunkSize <= 0 { // Degenerate case: margin bigger than max. Use max directly. chunkSize = s.ChunkMaxBytes } } // Enforce the server-advertised minimum. if s.ChunkMinBytes > 0 && chunkSize < s.ChunkMinBytes { chunkSize = s.ChunkMinBytes } return chunkSize } // initiateUpload initiates a blob upload and returns an uploadSession containing // the upload location URL and server-advertised chunk size limits. func (c *RegistryClient) initiateUpload(ctx context.Context, client *http.Client, repo name.Repository) (uploadSession, error) { u := url.URL{ Scheme: repo.Scheme(), Host: repo.RegistryStr(), Path: fmt.Sprintf("/v2/%s/blobs/uploads/", repo.RepositoryStr()), } req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), nil) if err != nil { return uploadSession{}, err } req.Header.Set("Content-Type", "application/json") resp, err := client.Do(req) //nolint:gosec // G704: URL from registry reference, not user input if err != nil { return uploadSession{}, err } defer resp.Body.Close() if err := transport.CheckError(resp, http.StatusAccepted); err != nil { return uploadSession{}, err } loc := resp.Header.Get("Location") if loc == "" { return uploadSession{}, errors.New("missing Location header in initiate upload response") } // Resolve relative URLs locURL, err := url.Parse(loc) if err != nil { return uploadSession{}, fmt.Errorf("parsing location URL: %w", err) } baseURL := url.URL{ Scheme: repo.Scheme(), Host: repo.RegistryStr(), } session := uploadSession{ Location: baseURL.ResolveReference(locURL).String(), } // Parse OCI chunk size headers if the registry advertises them. if v := resp.Header.Get("OCI-Chunk-Min-Length"); v != "" { if n, parseErr := strconv.ParseInt(v, 10, 64); parseErr == nil && n > 0 { session.ChunkMinBytes = n } } if v := resp.Header.Get("OCI-Chunk-Max-Length"); v != "" { if n, parseErr := strconv.ParseInt(v, 10, 64); parseErr == nil && n > 0 { session.ChunkMaxBytes = n } } return session, nil } // uploadBlobChunks uploads a blob using either multipart or single-part upload depending on server support. // The repo parameter is needed to restart the upload session if multipart fails. // The session carries the upload location and any server-advertised chunk size limits // (OCI-Chunk-Min-Length / OCI-Chunk-Max-Length). // Returns the final upload location URL which must be used for committing the upload. func (c *RegistryClient) uploadBlobChunks(ctx context.Context, client *http.Client, repo name.Repository, layer v1.Layer, session uploadSession, totalSize int64, progressCh chan<- v1.Update) (string, error) { // The chunk size is determined by the server's OCI-Chunk-Max-Length header // (minus a small margin). When the server does not advertise a maximum, // the client falls back to COG_PUSH_DEFAULT_CHUNK_SIZE or DefaultChunkSize (96 MiB). // COG_PUSH_MULTIPART_THRESHOLD controls the minimum blob size for multipart upload (default: 128 MiB). var ( multipartThreshold = getMultipartThreshold() chunkSize = session.effectiveChunkSize() location = session.Location ) if totalSize > multipartThreshold { finalLocation, newLocation, fallback, err := c.tryMultipartWithFallback(ctx, client, repo, layer, location, totalSize, chunkSize, progressCh) if err != nil { return "", err } if !fallback { return finalLocation, nil } // Multipart not supported, continue with single-part using the new location location = newLocation } // Single-part upload for small blobs or servers that don't support multipart blob, err := layer.Compressed() if err != nil { return "", fmt.Errorf("getting compressed blob: %w", err) } defer blob.Close() finalLocation, err := c.uploadBlobSingle(ctx, client, location, blob, totalSize, progressCh) if err != nil { return "", err } return finalLocation, nil } // tryMultipartWithFallback attempts multipart upload and handles fallback if not supported. // Returns (finalLocation, newLocation, fallback, error): // - If multipart succeeds: (finalLocation, "", false, nil) // - If multipart not supported: ("", newLocation, true, nil) - caller should use single-part with newLocation // - If error: ("", "", false, error) func (c *RegistryClient) tryMultipartWithFallback(ctx context.Context, client *http.Client, repo name.Repository, layer v1.Layer, location string, totalSize int64, chunkSize int64, progressCh chan<- v1.Update) (finalLocation string, newLocation string, fallback bool, err error) { blob, err := layer.Compressed() if err != nil { return "", "", false, fmt.Errorf("getting compressed blob: %w", err) } defer blob.Close() finalLocation, err = c.tryMultipartUpload(ctx, client, location, blob, totalSize, chunkSize, progressCh) if err == nil { return finalLocation, "", false, nil } // Check if error indicates multipart not supported var transportErr *transport.Error if errors.As(err, &transportErr) && (transportErr.StatusCode == http.StatusRequestedRangeNotSatisfiable || transportErr.StatusCode == http.StatusBadRequest) { // Multipart not supported - restart upload session for single-part fallback newSession, err := c.initiateUpload(ctx, client, repo) if err != nil { return "", "", false, fmt.Errorf("restarting upload after multipart failure: %w", err) } return "", newSession.Location, true, nil } return "", "", false, err } // tryMultipartUpload attempts to upload using Content-Range headers. // Returns the final location or an error. func (c *RegistryClient) tryMultipartUpload(ctx context.Context, client *http.Client, location string, blob io.Reader, totalSize int64, chunkSize int64, progressCh chan<- v1.Update) (string, error) { var uploaded int64 // Reuse chunk buffers via pool to reduce memory pressure when pushing // multiple layers concurrently (default concurrency 5 × up to 96 MB each). // No need to zero the buffer before reuse: io.ReadFull overwrites from // index 0, and we slice to buffer[:n] so stale bytes are never sent. var buffer []byte if v, ok := chunkBufPool.Get().(*[]byte); ok && int64(len(*v)) == chunkSize { buffer = *v } else { buffer = make([]byte, chunkSize) } defer func() { chunkBufPool.Put(&buffer) }() currentLocation := location for uploaded < totalSize { // Read the next chunk n, err := io.ReadFull(blob, buffer) if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) { return "", fmt.Errorf("reading blob: %w", err) } if n == 0 { break } chunk := buffer[:n] start := uploaded end := uploaded + int64(n) - 1 // Range is inclusive // Upload the chunk with Content-Range (progress is reported within uploadChunk) newLocation, err := c.uploadChunk(ctx, client, currentLocation, chunk, start, end, totalSize, progressCh) if err != nil { return "", err } // Update location for next chunk (server may change it) if newLocation != "" { currentLocation = newLocation } uploaded += int64(n) // Check for context cancellation select { case <-ctx.Done(): return "", ctx.Err() default: } } return currentLocation, nil } // uploadBlobSingle uploads the entire blob in one request without Content-Range headers. func (c *RegistryClient) uploadBlobSingle(ctx context.Context, client *http.Client, location string, blob io.Reader, totalSize int64, progressCh chan<- v1.Update) (string, error) { // Wrap the reader to report progress var uploaded int64 reader := &progressReader{ reader: blob, onRead: func(n int) { uploaded += int64(n) if progressCh != nil { // Cap at totalSize defensively complete := min(uploaded, totalSize) select { case progressCh <- v1.Update{Complete: complete, Total: totalSize}: default: // Don't block if channel is full } } }, } req, err := http.NewRequestWithContext(ctx, http.MethodPatch, location, reader) if err != nil { return "", err } req.Header.Set("Content-Type", "application/octet-stream") req.ContentLength = totalSize resp, err := client.Do(req) //nolint:gosec // G704: URL from registry upload session, not user input if err != nil { return "", err } defer resp.Body.Close() if err := transport.CheckError(resp, http.StatusAccepted, http.StatusNoContent, http.StatusCreated); err != nil { return "", err } // Return the updated Location header — the registry includes upload state // that commitUpload needs for the final PUT. if loc := resp.Header.Get("Location"); loc != "" { locURL, parseErr := url.Parse(loc) if parseErr == nil { baseURL := url.URL{Scheme: "http", Host: req.URL.Host} if req.URL.Scheme != "" { baseURL.Scheme = req.URL.Scheme } return baseURL.ResolveReference(locURL).String(), nil } } return location, nil } // uploadChunk uploads a single chunk of a blob with Content-Range header. // Returns the new location URL if the server returns one. // If progressCh is provided, progress updates are sent as bytes are uploaded. // Progress updates occur approximately every 32-64KB based on HTTP client buffer size. func (c *RegistryClient) uploadChunk(ctx context.Context, client *http.Client, location string, chunk []byte, start, end int64, totalSize int64, progressCh chan<- v1.Update) (string, error) { // Wrap the chunk reader to report progress as bytes are written var reader io.Reader if progressCh != nil { var chunkUploaded int64 reader = &progressReader{ reader: bytes.NewReader(chunk), onRead: func(n int) { chunkUploaded += int64(n) // Cap at totalSize defensively complete := min(start+chunkUploaded, totalSize) select { case progressCh <- v1.Update{Complete: complete, Total: totalSize}: default: // Don't block if channel is full } }, } } else { reader = bytes.NewReader(chunk) } req, err := http.NewRequestWithContext(ctx, http.MethodPatch, location, reader) if err != nil { return "", err } req.Header.Set("Content-Type", "application/octet-stream") req.Header.Set("Content-Length", strconv.FormatInt(int64(len(chunk)), 10)) req.Header.Set("Content-Range", fmt.Sprintf("%d-%d", start, end)) resp, err := client.Do(req) //nolint:gosec // G704: URL from registry upload session, not user input if err != nil { return "", err } defer resp.Body.Close() if err := transport.CheckError(resp, http.StatusAccepted, http.StatusNoContent, http.StatusCreated); err != nil { return "", err } // Get the new location for the next chunk newLocation := resp.Header.Get("Location") if newLocation != "" { // Resolve relative URLs locURL, err := url.Parse(newLocation) if err != nil { return "", fmt.Errorf("parsing location URL: %w", err) } // Parse the original location to get the base URL origURL, err := url.Parse(location) if err != nil { return "", fmt.Errorf("parsing original location URL: %w", err) } return origURL.ResolveReference(locURL).String(), nil } return "", nil } // progressReader wraps an io.Reader to report progress. type progressReader struct { reader io.Reader onRead func(int) } func (pr *progressReader) Read(p []byte) (int, error) { n, err := pr.reader.Read(p) if n > 0 { pr.onRead(n) } return n, err } // commitUpload finalizes the upload by sending a PUT request with the digest. func (c *RegistryClient) commitUpload(ctx context.Context, client *http.Client, location string, digest v1.Hash) error { u, err := url.Parse(location) if err != nil { return fmt.Errorf("parsing location URL: %w", err) } // Add digest query parameter q := u.Query() q.Set("digest", digest.String()) u.RawQuery = q.Encode() req, err := http.NewRequestWithContext(ctx, http.MethodPut, u.String(), nil) if err != nil { return err } req.Header.Set("Content-Type", "application/octet-stream") resp, err := client.Do(req) //nolint:gosec // G704: URL from registry upload session, not user input if err != nil { return err } defer resp.Body.Close() return transport.CheckError(resp, http.StatusCreated) } // pickDefaultImage selects an image from a manifest index to use for fetching labels. // Prefers linux/amd64, otherwise returns the first image manifest. // Returns an error if no suitable image is found or if fetching fails. func pickDefaultImage(ref name.Reference, idx *v1.IndexManifest, opts ...remote.Option) (v1.Image, error) { var targetDigest string // First, look for linux/amd64 for _, m := range idx.Manifests { if m.Platform != nil && m.Platform.OS == "linux" && m.Platform.Architecture == "amd64" { targetDigest = m.Digest.String() break } } // Fall back to first manifest if targetDigest == "" && len(idx.Manifests) > 0 { targetDigest = idx.Manifests[0].Digest.String() } if targetDigest == "" { return nil, fmt.Errorf("index for %s contains no manifests", ref.String()) } digestRef, err := name.NewDigest(ref.Context().Name()+"@"+targetDigest, name.Insecure) if err != nil { return nil, fmt.Errorf("failed to create digest reference: %w", err) } desc, err := remote.Get(digestRef, opts...) if err != nil { return nil, fmt.Errorf("failed to fetch image %s: %w", digestRef.String(), err) } img, err := desc.Image() if err != nil { return nil, fmt.Errorf("failed to load image %s: %w", digestRef.String(), err) } return img, nil } ================================================ FILE: pkg/registry/registrytest/mock_client.go ================================================ package registrytest import ( "context" v1 "github.com/google/go-containerregistry/pkg/v1" "github.com/replicate/cog/pkg/registry" ) type MockRegistryClient struct { mockImages map[string]bool } func NewMockRegistryClient() *MockRegistryClient { return &MockRegistryClient{ mockImages: map[string]bool{}, } } func (c *MockRegistryClient) Exists(ctx context.Context, imageRef string) (bool, error) { _, exists := c.mockImages[imageRef] return exists, nil } func (c *MockRegistryClient) GetImage(ctx context.Context, imageRef string, platform *registry.Platform) (v1.Image, error) { return nil, nil } func (c *MockRegistryClient) Inspect(ctx context.Context, imageRef string, platform *registry.Platform) (*registry.ManifestResult, error) { return nil, nil } func (c *MockRegistryClient) AddMockImage(imageRef string) { c.mockImages[imageRef] = true } func (c *MockRegistryClient) PushImage(ctx context.Context, ref string, img v1.Image) error { c.mockImages[ref] = true return nil } func (c *MockRegistryClient) PushIndex(ctx context.Context, ref string, idx v1.ImageIndex) error { c.mockImages[ref] = true return nil } func (c *MockRegistryClient) GetDescriptor(ctx context.Context, imageRef string) (v1.Descriptor, error) { return v1.Descriptor{}, nil } func (c *MockRegistryClient) WriteLayer(ctx context.Context, opts registry.WriteLayerOptions) error { return nil } ================================================ FILE: pkg/registry_testhelpers/registry_container.go ================================================ package registry_testhelpers import ( "context" "fmt" "path" "path/filepath" "runtime" "strconv" "strings" "testing" "time" "github.com/docker/docker/api/types/container" "github.com/docker/go-connections/nat" "github.com/google/go-containerregistry/pkg/authn" "github.com/google/go-containerregistry/pkg/crane" "github.com/google/go-containerregistry/pkg/name" "github.com/google/go-containerregistry/pkg/v1/remote" "github.com/stretchr/testify/require" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/modules/registry" "github.com/testcontainers/testcontainers-go/wait" "golang.org/x/crypto/bcrypt" dockerregistry "github.com/docker/docker/api/types/registry" "github.com/replicate/cog/pkg/util" ) // StartTestRegistry starts a test registry container on a random local port populated // with image data from the testdata/docker directory. It returns a RegistryContainer // that can be used to inspect the registry and generate absolute image references. It will // automatically be cleaned when the test finishes. // This is safe to run concurrently across multiple tests. func StartTestRegistry(t *testing.T, opts ...Option) *RegistryContainer { t.Helper() container, cleanup, err := StartTestRegistryWithCleanup(t.Context(), opts...) require.NoError(t, err, "Failed to start registry container") // Register cleanup with testing.T t.Cleanup(cleanup) return container } // StartTestRegistryWithCleanup starts a test registry and returns a cleanup function. // Use this when you don't have a *testing.T (e.g., in testscript harness). // The caller is responsible for calling the cleanup function when done. func StartTestRegistryWithCleanup(ctx context.Context, opts ...Option) (*RegistryContainer, func(), error) { options := &options{} for _, opt := range opts { opt(options) } _, filename, _, _ := runtime.Caller(0) testdataDir := filepath.Join(filepath.Dir(filename), "testdata", "docker") // Pick a port in the insecure range (Docker considers localhost:1-9999 as insecure) port, err := util.PickFreePort(1024, 9999) if err != nil { return nil, nil, fmt.Errorf("pick free port: %w", err) } containerCustomizers := []testcontainers.ContainerCustomizer{ testcontainers.WithFiles(testcontainers.ContainerFile{ HostFilePath: testdataDir, ContainerFilePath: "/var/lib/registry/", FileMode: 0o755, }), testcontainers.WithWaitStrategy( wait.ForHTTP("/").WithPort("5000/tcp"). WithStartupTimeout(10 * time.Second), ), testcontainers.WithHostConfigModifier(func(hostConfig *container.HostConfig) { hostConfig.PortBindings = map[nat.Port][]nat.PortBinding{ nat.Port("5000/tcp"): {{HostIP: "0.0.0.0", HostPort: strconv.Itoa(port)}}, } }), } if options.auth != nil { htpasswd, err := generateHtpasswd(options.auth.Username, options.auth.Password) if err != nil { return nil, nil, fmt.Errorf("generate htpasswd: %w", err) } containerCustomizers = append(containerCustomizers, registry.WithHtpasswd(htpasswd), ) } registryContainer, err := registry.Run( ctx, "registry:3", containerCustomizers..., ) if err != nil { return nil, nil, fmt.Errorf("start registry container: %w", err) } cleanup := func() { if registryContainer != nil { _ = registryContainer.Terminate(context.Background()) } } return &RegistryContainer{ Container: registryContainer, options: options, }, cleanup, nil } type RegistryContainer struct { Container *registry.RegistryContainer options *options } func (c *RegistryContainer) ImageRef(ref string) string { return path.Join(c.Container.RegistryName, ref) } func (c *RegistryContainer) ImageRefForTest(t *testing.T, label string) string { if label == "" { label = fmt.Sprintf("test-%d", time.Now().Unix()) } repo := strings.ToLower(t.Name()) return c.ImageRef(fmt.Sprintf("%s:%s", repo, label)) } func (c *RegistryContainer) CloneRepo(t *testing.T, existingRepo, newRepo string) string { existingRepo = c.ImageRef(existingRepo) newRepo = c.ImageRef(newRepo) err := crane.CopyRepository(existingRepo, newRepo) require.NoError(t, err, "Failed to clone repo %q to %q", existingRepo, newRepo) return newRepo } func (c *RegistryContainer) CloneRepoForTest(t *testing.T, repo string) string { return c.CloneRepo(t, repo, strings.ToLower(t.Name())) } func (c *RegistryContainer) ImageExists(t *testing.T, ref string) error { parsedRef, err := name.ParseReference(ref, name.WithDefaultRegistry(c.RegistryHost())) require.NoError(t, err) var opts []remote.Option if c.options.auth != nil { opts = append(opts, remote.WithAuth(authn.FromConfig(authn.AuthConfig{ Username: c.options.auth.Username, Password: c.options.auth.Password, }))) } _, err = remote.Head(parsedRef, opts...) return err } func (c *RegistryContainer) RegistryHost() string { return c.Container.RegistryName } type Option func(*options) func WithAuth(username, password string) func(*options) { return func(o *options) { o.auth = &dockerregistry.AuthConfig{ Username: username, Password: password, } } } type options struct { auth *dockerregistry.AuthConfig } func generateHtpasswd(username, password string) (string, error) { hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return "", err } return fmt.Sprintf("%s:%s", username, string(hash)), nil } ================================================ FILE: pkg/registry_testhelpers/testdata/docker/registry/v2/blobs/sha256/1c/1c4eef651f65e2f7daee7ee785882ac164b02b78fb74503052a26dc061c90474/data ================================================ { "schemaVersion": 2, "mediaType": "application/vnd.oci.image.manifest.v1+json", "config": { "mediaType": "application/vnd.oci.image.config.v1+json", "digest": "sha256:aded1e1a5b3705116fa0a92ba074a5e0b0031647d9c315983ccba2ee5428ec8b", "size": 581 }, "layers": [ { "mediaType": "application/vnd.oci.image.layer.v1.tar+gzip", "digest": "sha256:f18232174bc91741fdf3da96d85011092101a032a93a388b79e99e69c2d5c870", "size": 3642247 } ], "annotations": { "com.docker.official-images.bashbrew.arch": "amd64", "org.opencontainers.image.base.name": "scratch", "org.opencontainers.image.created": "2025-02-14T03:28:36Z", "org.opencontainers.image.revision": "17fe3d1e2d2cbf54d745139eab749c252e35b883", "org.opencontainers.image.source": "https://github.com/alpinelinux/docker-alpine.git#17fe3d1e2d2cbf54d745139eab749c252e35b883:x86_64", "org.opencontainers.image.url": "https://hub.docker.com/_/alpine", "org.opencontainers.image.version": "3.21.3" } } ================================================ FILE: pkg/registry_testhelpers/testdata/docker/registry/v2/blobs/sha256/75/757d680068d77be46fd1ea20fb21db16f150468c5e7079a08a2e4705aec096ac/data ================================================ { "schemaVersion": 2, "mediaType": "application/vnd.oci.image.manifest.v1+json", "config": { "mediaType": "application/vnd.oci.image.config.v1+json", "digest": "sha256:8d591b0b7dea080ea3be9e12ae563eebf9869168ffced1cb25b2470a3d9fe15e", "size": 597 }, "layers": [ { "mediaType": "application/vnd.oci.image.layer.v1.tar+gzip", "digest": "sha256:6e771e15690e2fabf2332d3a3b744495411d6e0b00b2aea64419b58b0066cf81", "size": 3993029 } ], "annotations": { "com.docker.official-images.bashbrew.arch": "arm64v8", "org.opencontainers.image.base.name": "scratch", "org.opencontainers.image.created": "2025-02-14T03:28:36Z", "org.opencontainers.image.revision": "17fe3d1e2d2cbf54d745139eab749c252e35b883", "org.opencontainers.image.source": "https://github.com/alpinelinux/docker-alpine.git#17fe3d1e2d2cbf54d745139eab749c252e35b883:aarch64", "org.opencontainers.image.url": "https://hub.docker.com/_/alpine", "org.opencontainers.image.version": "3.21.3" } } ================================================ FILE: pkg/registry_testhelpers/testdata/docker/registry/v2/blobs/sha256/8d/8d591b0b7dea080ea3be9e12ae563eebf9869168ffced1cb25b2470a3d9fe15e/data ================================================ {"architecture":"arm64","config":{"Env":["PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"],"Cmd":["/bin/sh"],"WorkingDir":"/"},"created":"2025-02-14T03:28:36Z","history":[{"created":"2025-02-14T03:28:36Z","created_by":"ADD alpine-minirootfs-3.21.3-aarch64.tar.gz / # buildkit","comment":"buildkit.dockerfile.v0"},{"created":"2025-02-14T03:28:36Z","created_by":"CMD [\"/bin/sh\"]","comment":"buildkit.dockerfile.v0","empty_layer":true}],"os":"linux","rootfs":{"type":"layers","diff_ids":["sha256:a16e98724c05975ee8c40d8fe389c3481373d34ab20a1cf52ea2accc43f71f4c"]},"variant":"v8"} ================================================ FILE: pkg/registry_testhelpers/testdata/docker/registry/v2/blobs/sha256/9a/9a0ff41dccad7a96f324a4655a715c623ed3511c7336361ffa9dadcecbdb99e5/data ================================================ {"schemaVersion":2,"mediaType":"application/vnd.oci.image.index.v1+json","manifests":[{"mediaType":"application/vnd.oci.image.manifest.v1+json","size":1022,"digest":"sha256:1c4eef651f65e2f7daee7ee785882ac164b02b78fb74503052a26dc061c90474","annotations":{"com.docker.official-images.bashbrew.arch":"amd64","org.opencontainers.image.base.name":"scratch","org.opencontainers.image.created":"2025-02-14T18:27:58Z","org.opencontainers.image.revision":"17fe3d1e2d2cbf54d745139eab749c252e35b883","org.opencontainers.image.source":"https://github.com/alpinelinux/docker-alpine.git#17fe3d1e2d2cbf54d745139eab749c252e35b883:x86_64","org.opencontainers.image.url":"https://hub.docker.com/_/alpine","org.opencontainers.image.version":"3.21.3"},"platform":{"architecture":"amd64","os":"linux"}},{"mediaType":"application/vnd.oci.image.manifest.v1+json","size":1025,"digest":"sha256:757d680068d77be46fd1ea20fb21db16f150468c5e7079a08a2e4705aec096ac","annotations":{"com.docker.official-images.bashbrew.arch":"arm64v8","org.opencontainers.image.base.name":"scratch","org.opencontainers.image.created":"2025-02-14T18:27:49Z","org.opencontainers.image.revision":"17fe3d1e2d2cbf54d745139eab749c252e35b883","org.opencontainers.image.source":"https://github.com/alpinelinux/docker-alpine.git#17fe3d1e2d2cbf54d745139eab749c252e35b883:aarch64","org.opencontainers.image.url":"https://hub.docker.com/_/alpine","org.opencontainers.image.version":"3.21.3"},"platform":{"architecture":"arm64","os":"linux"}}]} ================================================ FILE: pkg/registry_testhelpers/testdata/docker/registry/v2/blobs/sha256/ad/aded1e1a5b3705116fa0a92ba074a5e0b0031647d9c315983ccba2ee5428ec8b/data ================================================ {"architecture":"amd64","config":{"Env":["PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"],"Cmd":["/bin/sh"],"WorkingDir":"/"},"created":"2025-02-14T03:28:36Z","history":[{"created":"2025-02-14T03:28:36Z","created_by":"ADD alpine-minirootfs-3.21.3-x86_64.tar.gz / # buildkit","comment":"buildkit.dockerfile.v0"},{"created":"2025-02-14T03:28:36Z","created_by":"CMD [\"/bin/sh\"]","comment":"buildkit.dockerfile.v0","empty_layer":true}],"os":"linux","rootfs":{"type":"layers","diff_ids":["sha256:08000c18d16dadf9553d747a58cf44023423a9ab010aab96cf263d2216b8b350"]}} ================================================ FILE: pkg/registry_testhelpers/testdata/docker/registry/v2/repositories/alpine/_layers/sha256/6e771e15690e2fabf2332d3a3b744495411d6e0b00b2aea64419b58b0066cf81/link ================================================ sha256:6e771e15690e2fabf2332d3a3b744495411d6e0b00b2aea64419b58b0066cf81 ================================================ FILE: pkg/registry_testhelpers/testdata/docker/registry/v2/repositories/alpine/_layers/sha256/8d591b0b7dea080ea3be9e12ae563eebf9869168ffced1cb25b2470a3d9fe15e/link ================================================ sha256:8d591b0b7dea080ea3be9e12ae563eebf9869168ffced1cb25b2470a3d9fe15e ================================================ FILE: pkg/registry_testhelpers/testdata/docker/registry/v2/repositories/alpine/_layers/sha256/aded1e1a5b3705116fa0a92ba074a5e0b0031647d9c315983ccba2ee5428ec8b/link ================================================ sha256:aded1e1a5b3705116fa0a92ba074a5e0b0031647d9c315983ccba2ee5428ec8b ================================================ FILE: pkg/registry_testhelpers/testdata/docker/registry/v2/repositories/alpine/_layers/sha256/f18232174bc91741fdf3da96d85011092101a032a93a388b79e99e69c2d5c870/link ================================================ sha256:f18232174bc91741fdf3da96d85011092101a032a93a388b79e99e69c2d5c870 ================================================ FILE: pkg/registry_testhelpers/testdata/docker/registry/v2/repositories/alpine/_manifests/revisions/sha256/1c4eef651f65e2f7daee7ee785882ac164b02b78fb74503052a26dc061c90474/link ================================================ sha256:1c4eef651f65e2f7daee7ee785882ac164b02b78fb74503052a26dc061c90474 ================================================ FILE: pkg/registry_testhelpers/testdata/docker/registry/v2/repositories/alpine/_manifests/revisions/sha256/757d680068d77be46fd1ea20fb21db16f150468c5e7079a08a2e4705aec096ac/link ================================================ sha256:757d680068d77be46fd1ea20fb21db16f150468c5e7079a08a2e4705aec096ac ================================================ FILE: pkg/registry_testhelpers/testdata/docker/registry/v2/repositories/alpine/_manifests/revisions/sha256/9a0ff41dccad7a96f324a4655a715c623ed3511c7336361ffa9dadcecbdb99e5/link ================================================ sha256:9a0ff41dccad7a96f324a4655a715c623ed3511c7336361ffa9dadcecbdb99e5 ================================================ FILE: pkg/registry_testhelpers/testdata/docker/registry/v2/repositories/alpine/_manifests/tags/latest/current/link ================================================ sha256:9a0ff41dccad7a96f324a4655a715c623ed3511c7336361ffa9dadcecbdb99e5 ================================================ FILE: pkg/registry_testhelpers/testdata/docker/registry/v2/repositories/alpine/_manifests/tags/latest/index/sha256/9a0ff41dccad7a96f324a4655a715c623ed3511c7336361ffa9dadcecbdb99e5/link ================================================ sha256:9a0ff41dccad7a96f324a4655a715c623ed3511c7336361ffa9dadcecbdb99e5 ================================================ FILE: pkg/requirements/requirements.go ================================================ package requirements import ( "bufio" "errors" "fmt" "os" "path/filepath" "regexp" "strings" "github.com/replicate/cog/pkg/util/files" ) const RequirementsFile = "requirements.txt" const OverridesFile = "overrides.txt" func GenerateRequirements(tmpDir string, path string, fileName string) (string, error) { bs, err := os.ReadFile(path) if err != nil { return "", err } requirements := string(bs) // Check against the old requirements requirementsFile := filepath.Join(tmpDir, fileName) if err := files.WriteIfDifferent(requirementsFile, requirements); err != nil { return "", err } return requirementsFile, err } func CurrentRequirements(tmpDir string) (string, error) { requirementsFile := filepath.Join(tmpDir, RequirementsFile) _, err := os.Stat(requirementsFile) if err != nil { if errors.Is(err, os.ErrNotExist) { return "", nil } return "", err } return requirementsFile, nil } func ReadRequirements(path string) ([]string, error) { re := regexp.MustCompile(`(?m)^\s*-e\s+\.\s*$`) fh, err := os.Open(path) if err != nil { return nil, err } defer fh.Close() // Use scanner to handle CRLF endings scanner := bufio.NewScanner(fh) scanner.Split(scanLinesWithContinuations) var requirements []string for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) // Skip empty lines and comment lines if strings.HasPrefix(line, "#") { continue } if re.MatchString(line) { continue } // Remove any trailing comments if idx := strings.Index(line, "#"); idx >= 0 { line = line[:idx] } if line != "" { requirements = append(requirements, line) } } return requirements, scanner.Err() } // scanLinesWithContinuations is a modified version of bufio.ScanLines that // also handles line continuations (lines ending with a backslash). func scanLinesWithContinuations(data []byte, atEOF bool) (advance int, token []byte, err error) { // If we're at EOF and there's no data, return nil if atEOF && len(data) == 0 { return 0, nil, nil } var line []byte start := 0 for i := range data { if data[i] == '\n' { end := i if end > 0 && data[end-1] == '\r' { end-- } // Add this segment to our accumulated line line = append(line, data[start:end]...) if len(line) > 0 && line[len(line)-1] == '\\' { // This is a continuation - remove the backslash and continue line = line[:len(line)-1] start = i + 1 continue } if len(line) == 0 { continue } // Not a continuation, return the accumulated line return i + 1, line, nil } } // If we're at EOF, we have a final, non-terminated line if atEOF { if len(data) > start { line = append(line, data[start:]...) if len(line) > 0 && line[len(line)-1] == '\r' { line = line[:len(line)-1] } } return len(data), line, nil } // Need more data return 0, nil, nil } // SplitPinnedPythonRequirement returns the name, version, findLinks, and extraIndexURLs from a requirements.txt line // in the form name==version [--find-links=] [-f ] [--extra-index-url=] func SplitPinnedPythonRequirement(requirement string) (name string, version string, findLinks []string, extraIndexURLs []string, err error) { pinnedPackageRe := regexp.MustCompile(`(?:([a-zA-Z0-9\-_]+)==([^ ]+)|--find-links=([^\s]+)|-f\s+([^\s]+)|--extra-index-url=([^\s]+))`) matches := pinnedPackageRe.FindAllStringSubmatch(requirement, -1) if matches == nil { return "", "", nil, nil, fmt.Errorf("Package %s is not in the expected format", requirement) } nameFound := false versionFound := false for _, match := range matches { if match[1] != "" { name = match[1] nameFound = true } if match[2] != "" { version = match[2] versionFound = true } if match[3] != "" { findLinks = append(findLinks, match[3]) } if match[4] != "" { findLinks = append(findLinks, match[4]) } if match[5] != "" { extraIndexURLs = append(extraIndexURLs, match[5]) } } if !nameFound || !versionFound { return "", "", nil, nil, fmt.Errorf("Package name or version is missing in %s", requirement) } return name, version, findLinks, extraIndexURLs, nil } func PackageName(pipRequirement string) string { re := regexp.MustCompile(`^([a-zA-Z0-9_\-\.]+(?:\[[^\]]+\])?)`) match := re.FindStringSubmatch(pipRequirement) if len(match) > 1 { return match[1] } return "" } func VersionSpecifier(pipRequirement string) string { re := regexp.MustCompile(`^[a-zA-Z0-9_\-\.]+(?:\[[^\]]+\])?\s*([<>=!~]=?\s*[^;,#\s]+(?:\s*,\s*[<>=!~]=?\s*[^;,#\s]+)*(?:\s*\|\|\s*[<>=!~]=?\s*[^;,#\s]+(?:\s*,\s*[<>=!~]=?\s*[^;,#\s]+)*)*)?`) match := re.FindStringSubmatch(pipRequirement) if len(match) > 1 { // Optional: strip spaces for uniform output return strings.ReplaceAll(match[1], " ", "") } return "" } func Versions(pipRequirement string) []string { var versions []string // Match standard specifier versions reVersion := regexp.MustCompile(`[<>=!~]=?\s*([^\s,;|]+)`) matches := reVersion.FindAllStringSubmatch(pipRequirement, -1) for _, match := range matches { if len(match) > 1 { versions = append(versions, match[1]) } } // Match @ file/url version reURL := regexp.MustCompile(`@\s*([^\s]+)`) if match := reURL.FindStringSubmatch(pipRequirement); len(match) > 1 { versions = append(versions, match[1]) } return versions } ================================================ FILE: pkg/requirements/requirements_test.go ================================================ package requirements import ( "os" "path" "path/filepath" "strings" "testing" "github.com/stretchr/testify/require" ) func TestPythonRequirements(t *testing.T) { srcDir := t.TempDir() reqFile := path.Join(srcDir, "requirements.txt") err := os.WriteFile(reqFile, []byte("torch==2.5.1"), 0o644) require.NoError(t, err) tmpDir := t.TempDir() requirementsFile, err := GenerateRequirements(tmpDir, reqFile, RequirementsFile) require.NoError(t, err) require.Equal(t, filepath.Join(tmpDir, "requirements.txt"), requirementsFile) } func TestReadRequirements(t *testing.T) { srcDir := t.TempDir() reqFile := path.Join(srcDir, "requirements.txt") err := os.WriteFile(reqFile, []byte("torch==2.5.1"), 0o644) require.NoError(t, err) requirements, err := ReadRequirements(reqFile) require.NoError(t, err) require.Equal(t, []string{"torch==2.5.1"}, requirements) } func TestReadRequirementsLineContinuations(t *testing.T) { srcDir := t.TempDir() reqFile := path.Join(srcDir, "requirements.txt") err := os.WriteFile(reqFile, []byte("torch==\\\n2.5.1\ntorchvision==\\\r\n2.5.1"), 0o644) require.NoError(t, err) requirements, err := ReadRequirements(reqFile) require.NoError(t, err) require.Equal(t, []string{"torch==2.5.1", "torchvision==2.5.1"}, requirements) } func TestReadRequirementsStripComments(t *testing.T) { srcDir := t.TempDir() reqFile := path.Join(srcDir, "requirements.txt") err := os.WriteFile(reqFile, []byte("torch==\\\n2.5.1# Heres my comment\ntorchvision==2.5.1\n# Heres a beginning of line comment"), 0o644) require.NoError(t, err) requirements, err := ReadRequirements(reqFile) require.NoError(t, err) require.Equal(t, []string{"torch==2.5.1", "torchvision==2.5.1"}, requirements) } func TestReadRequirementsComplex(t *testing.T) { srcDir := t.TempDir() reqFile := path.Join(srcDir, "requirements.txt") err := os.WriteFile(reqFile, []byte(`foo==1.0.0 # complex requirements fastapi>=0.6,<1 flask>0.4 # comments! # blank lines! # arguments -f http://example.com`), 0o644) require.NoError(t, err) requirements, err := ReadRequirements(reqFile) require.NoError(t, err) require.Equal(t, []string{"foo==1.0.0", "fastapi>=0.6,<1", "flask>0.4", "-f http://example.com"}, requirements) } func TestReadRequirementsLongLine(t *testing.T) { srcDir := t.TempDir() reqFile := path.Join(srcDir, "requirements.txt") err := os.WriteFile(reqFile, []byte(` antlr4-python3-runtime==4.9.3 \ --hash=sha256:f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b colorama==0.4.6 ; sys_platform == 'win32' \ --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 contourpy==1.3.2 \ --hash=sha256:0475b1f6604896bc7c53bb070e355e9321e1bc0d381735421a2d2068ec56531f \ --hash=sha256:106fab697af11456fcba3e352ad50effe493a90f893fca6c2ca5c033820cea92 \ --hash=sha256:15ce6ab60957ca74cff444fe66d9045c1fd3e92c8936894ebd1f3eef2fff075f \ --hash=sha256:1c48188778d4d2f3d48e4643fb15d8608b1d01e4b4d6b0548d9b336c28fc9b6f \ --hash=sha256:3859783aefa2b8355697f16642695a5b9792e7a46ab86da1118a4a23a51a33d7 \ --hash=sha256:3d80b2c0300583228ac98d0a927a1ba6a2ba6b8a742463c564f1d419ee5b211e \ --hash=sha256:3f9e896f447c5c8618f1edb2bafa9a4030f22a575ec418ad70611450720b5b08 \ --hash=sha256:434f0adf84911c924519d2b08fc10491dd282b20bdd3fa8f60fd816ea0b48841 \ --hash=sha256:49b65a95d642d4efa8f64ba12558fcb83407e58a2dfba9d796d77b63ccfcaff5 \ --hash=sha256:4caf2bcd2969402bf77edc4cb6034c7dd7c0803213b3523f111eb7460a51b8d2 \ --hash=sha256:532fd26e715560721bb0d5fc7610fce279b3699b018600ab999d1be895b09415 \ --hash=sha256:5ebac872ba09cb8f2131c46b8739a7ff71de28a24c869bcad554477eb089a878 \ --hash=sha256:5f5964cdad279256c084b69c3f412b7801e15356b16efa9d78aa974041903da0 \ --hash=sha256:65a887a6e8c4cd0897507d814b14c54a8c2e2aa4ac9f7686292f9769fcf9a6ab \ --hash=sha256:6a37a2fb93d4df3fc4c0e363ea4d16f83195fc09c891bc8ce072b9d084853445 \ --hash=sha256:70771a461aaeb335df14deb6c97439973d253ae70660ca085eec25241137ef43 \ --hash=sha256:71e2bd4a1c4188f5c2b8d274da78faab884b59df20df63c34f74aa1813c4427c \ --hash=sha256:745b57db7758f3ffc05a10254edd3182a2a83402a89c00957a8e8a22f5582823 \ --hash=sha256:78e9253c3de756b3f6a5174d024c4835acd59eb3f8e2ca13e775dbffe1558f69 \ --hash=sha256:82199cb78276249796419fe36b7386bd8d2cc3f28b3bc19fe2454fe2e26c4c15 \ --hash=sha256:8b7fc0cd78ba2f4695fd0a6ad81a19e7e3ab825c31b577f384aa9d7817dc3bef \ --hash=sha256:8c5acb8dddb0752bf252e01a3035b21443158910ac16a3b0d20e7fed7d534ce5 \ --hash=sha256:8c942a01d9163e2e5cfb05cb66110121b8d07ad438a17f9e766317bcb62abf73 \ --hash=sha256:90df94c89a91b7362e1142cbee7568f86514412ab8a2c0d0fca72d7e91b62912 \ --hash=sha256:970e9173dbd7eba9b4e01aab19215a48ee5dd3f43cef736eebde064a171f89a5 \ --hash=sha256:977e98a0e0480d3fe292246417239d2d45435904afd6d7332d8455981c408b85 \ --hash=sha256:b6945942715a034c671b7fc54f9588126b0b8bf23db2696e3ca8328f3ff0ab54 \ --hash=sha256:b7cd50c38f500bbcc9b6a46643a40e0913673f869315d8e70de0438817cb7773 \ --hash=sha256:c49f73e61f1f774650a55d221803b101d966ca0c5a2d6d5e4320ec3997489441 \ --hash=sha256:c66c4906cdbc50e9cba65978823e6e00b45682eb09adbb78c9775b74eb222422 \ --hash=sha256:c6c4639a9c22230276b7bffb6a850dfc8258a2521305e1faefe804d006b2e532 \ --hash=sha256:c85bb486e9be652314bb5b9e2e3b0d1b2e643d5eec4992c0fbe8ac71775da739 \ --hash=sha256:cc829960f34ba36aad4302e78eabf3ef16a3a100863f0d4eeddf30e8a485a03b \ --hash=sha256:d0e589ae0d55204991450bb5c23f571c64fe43adaa53f93fc902a84c96f52fe1 \ --hash=sha256:d14f12932a8d620e307f715857107b1d1845cc44fdb5da2bc8e850f5ceba9f87 \ --hash=sha256:d32530b534e986374fc19eaa77fcb87e8a99e5431499949b828312bdcd20ac52 \ --hash=sha256:d6658ccc7251a4433eebd89ed2672c2ed96fba367fd25ca9512aa92a4b46c4f1 \ --hash=sha256:d91a3ccc7fea94ca0acab82ceb77f396d50a1f67412efe4c526f5d20264e6ecd \ --hash=sha256:de39db2604ae755316cb5967728f4bea92685884b1e767b7c24e983ef5f771cb \ --hash=sha256:de425af81b6cea33101ae95ece1f696af39446db9682a0b56daaa48cfc29f38f \ --hash=sha256:e1578f7eafce927b168752ed7e22646dad6cd9bca673c60bff55889fa236ebf9 \ --hash=sha256:e298e7e70cf4eb179cc1077be1c725b5fd131ebc81181bf0c03525c8abc297fd \ --hash=sha256:eab0f6db315fa4d70f1d8ab514e527f0366ec021ff853d7ed6a2d33605cf4b83 \ --hash=sha256:f26b383144cf2d2c29f01a1e8170f50dacf0eac02d64139dcd709a8ac4eb3cfe cycler==0.12.1 \ --hash=sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30 \ --hash=sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c`), 0o644) require.NoError(t, err) requirements, err := ReadRequirements(reqFile) require.NoError(t, err) checkRequirements(t, []string{ "antlr4-python3-runtime==4.9.3 --hash=sha256:f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b", "colorama==0.4.6 ; sys_platform == 'win32' --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", "contourpy==1.3.2 --hash=sha256:0475b1f6604896bc7c53bb070e355e9321e1bc0d381735421a2d2068ec56531f --hash=sha256:106fab697af11456fcba3e352ad50effe493a90f893fca6c2ca5c033820cea92 --hash=sha256:15ce6ab60957ca74cff444fe66d9045c1fd3e92c8936894ebd1f3eef2fff075f --hash=sha256:1c48188778d4d2f3d48e4643fb15d8608b1d01e4b4d6b0548d9b336c28fc9b6f --hash=sha256:3859783aefa2b8355697f16642695a5b9792e7a46ab86da1118a4a23a51a33d7 --hash=sha256:3d80b2c0300583228ac98d0a927a1ba6a2ba6b8a742463c564f1d419ee5b211e --hash=sha256:3f9e896f447c5c8618f1edb2bafa9a4030f22a575ec418ad70611450720b5b08 --hash=sha256:434f0adf84911c924519d2b08fc10491dd282b20bdd3fa8f60fd816ea0b48841 --hash=sha256:49b65a95d642d4efa8f64ba12558fcb83407e58a2dfba9d796d77b63ccfcaff5 --hash=sha256:4caf2bcd2969402bf77edc4cb6034c7dd7c0803213b3523f111eb7460a51b8d2 --hash=sha256:532fd26e715560721bb0d5fc7610fce279b3699b018600ab999d1be895b09415 --hash=sha256:5ebac872ba09cb8f2131c46b8739a7ff71de28a24c869bcad554477eb089a878 --hash=sha256:5f5964cdad279256c084b69c3f412b7801e15356b16efa9d78aa974041903da0 --hash=sha256:65a887a6e8c4cd0897507d814b14c54a8c2e2aa4ac9f7686292f9769fcf9a6ab --hash=sha256:6a37a2fb93d4df3fc4c0e363ea4d16f83195fc09c891bc8ce072b9d084853445 --hash=sha256:70771a461aaeb335df14deb6c97439973d253ae70660ca085eec25241137ef43 --hash=sha256:71e2bd4a1c4188f5c2b8d274da78faab884b59df20df63c34f74aa1813c4427c --hash=sha256:745b57db7758f3ffc05a10254edd3182a2a83402a89c00957a8e8a22f5582823 --hash=sha256:78e9253c3de756b3f6a5174d024c4835acd59eb3f8e2ca13e775dbffe1558f69 --hash=sha256:82199cb78276249796419fe36b7386bd8d2cc3f28b3bc19fe2454fe2e26c4c15 --hash=sha256:8b7fc0cd78ba2f4695fd0a6ad81a19e7e3ab825c31b577f384aa9d7817dc3bef --hash=sha256:8c5acb8dddb0752bf252e01a3035b21443158910ac16a3b0d20e7fed7d534ce5 --hash=sha256:8c942a01d9163e2e5cfb05cb66110121b8d07ad438a17f9e766317bcb62abf73 --hash=sha256:90df94c89a91b7362e1142cbee7568f86514412ab8a2c0d0fca72d7e91b62912 --hash=sha256:970e9173dbd7eba9b4e01aab19215a48ee5dd3f43cef736eebde064a171f89a5 --hash=sha256:977e98a0e0480d3fe292246417239d2d45435904afd6d7332d8455981c408b85 --hash=sha256:b6945942715a034c671b7fc54f9588126b0b8bf23db2696e3ca8328f3ff0ab54 --hash=sha256:b7cd50c38f500bbcc9b6a46643a40e0913673f869315d8e70de0438817cb7773 --hash=sha256:c49f73e61f1f774650a55d221803b101d966ca0c5a2d6d5e4320ec3997489441 --hash=sha256:c66c4906cdbc50e9cba65978823e6e00b45682eb09adbb78c9775b74eb222422 --hash=sha256:c6c4639a9c22230276b7bffb6a850dfc8258a2521305e1faefe804d006b2e532 --hash=sha256:c85bb486e9be652314bb5b9e2e3b0d1b2e643d5eec4992c0fbe8ac71775da739 --hash=sha256:cc829960f34ba36aad4302e78eabf3ef16a3a100863f0d4eeddf30e8a485a03b --hash=sha256:d0e589ae0d55204991450bb5c23f571c64fe43adaa53f93fc902a84c96f52fe1 --hash=sha256:d14f12932a8d620e307f715857107b1d1845cc44fdb5da2bc8e850f5ceba9f87 --hash=sha256:d32530b534e986374fc19eaa77fcb87e8a99e5431499949b828312bdcd20ac52 --hash=sha256:d6658ccc7251a4433eebd89ed2672c2ed96fba367fd25ca9512aa92a4b46c4f1 --hash=sha256:d91a3ccc7fea94ca0acab82ceb77f396d50a1f67412efe4c526f5d20264e6ecd --hash=sha256:de39db2604ae755316cb5967728f4bea92685884b1e767b7c24e983ef5f771cb --hash=sha256:de425af81b6cea33101ae95ece1f696af39446db9682a0b56daaa48cfc29f38f --hash=sha256:e1578f7eafce927b168752ed7e22646dad6cd9bca673c60bff55889fa236ebf9 --hash=sha256:e298e7e70cf4eb179cc1077be1c725b5fd131ebc81181bf0c03525c8abc297fd --hash=sha256:eab0f6db315fa4d70f1d8ab514e527f0366ec021ff853d7ed6a2d33605cf4b83 --hash=sha256:f26b383144cf2d2c29f01a1e8170f50dacf0eac02d64139dcd709a8ac4eb3cfe", "cycler==0.12.1 --hash=sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30 --hash=sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c", }, requirements) } func TestComfyUIRequirements(t *testing.T) { srcDir := t.TempDir() reqFile := path.Join(srcDir, "requirements.txt") err := os.WriteFile(reqFile, []byte(`torch torchvision torchaudio torchsde einops transformers>=4.49.0 tokenizers>=0.13.3 sentencepiece safetensors>=0.3.0 aiohttp accelerate>=1.1.1 pyyaml Pillow scipy tqdm psutil spandrel soundfile kornia>=0.7.1 websocket-client==1.6.3 diffusers>=0.31.0 av>=14.1.0 comfyui-frontend-package==1.17.11 comfyui-workflow-templates==0.1.3 # ComfyUI-AdvancedLivePortrait dill # Inspire webcolors albumentations==1.4.3 # was-node-suite-comfyui # https://github.com/WASasquatch/was-node-suite-comfyui/blob/main/requirements.txt cmake imageio joblib matplotlib pilgram scikit-learn rembg # ComfyUI_essentials numba # ComfyUI_FizzNodes pandas numexpr # comfyui-reactor-node insightface onnx # ComfyUI-Impact-Pack segment-anything piexif # ComfyUI-Impact-Subpack ultralytics!=8.0.177 # comfyui_segment_anything timm # comfyui_controlnet_aux # https://github.com/Fannovel16/comfyui_controlnet_aux/blob/main/requirements.txt importlib_metadata opencv-python-headless>=4.0.1.24 filelock numpy scikit-image python-dateutil mediapipe svglib fvcore yapf omegaconf ftfy addict yacs trimesh[easy] # ComfyUI-KJNodes librosa color-matcher # PuLID facexlib # SUPIR open-clip-torch>=2.24.0 pytorch-lightning>=2.2.1 # For train.py and custom loras huggingface_hub[hf-transfer] # ComfyUI-segment-anything-2 iopath`), 0o644) require.NoError(t, err) requirements, err := ReadRequirements(reqFile) require.NoError(t, err) require.Equal(t, []string{ "torch", "torchvision", "torchaudio", "torchsde", "einops", "transformers>=4.49.0", "tokenizers>=0.13.3", "sentencepiece", "safetensors>=0.3.0", "aiohttp", "accelerate>=1.1.1", "pyyaml", "Pillow", "scipy", "tqdm", "psutil", "spandrel", "soundfile", "kornia>=0.7.1", "websocket-client==1.6.3", "diffusers>=0.31.0", "av>=14.1.0", "comfyui-frontend-package==1.17.11", "comfyui-workflow-templates==0.1.3", "dill", "webcolors", "albumentations==1.4.3", "cmake", "imageio", "joblib", "matplotlib", "pilgram", "scikit-learn", "rembg", "numba", "pandas", "numexpr", "insightface", "onnx", "segment-anything", "piexif", "ultralytics!=8.0.177", "timm", "importlib_metadata", "opencv-python-headless>=4.0.1.24", "filelock", "numpy", "scikit-image", "python-dateutil", "mediapipe", "svglib", "fvcore", "yapf", "omegaconf", "ftfy", "addict", "yacs", "trimesh[easy]", "librosa", "color-matcher", "facexlib", "open-clip-torch>=2.24.0", "pytorch-lightning>=2.2.1", "huggingface_hub[hf-transfer]", "iopath", }, requirements) } func TestTensorflowRequirements(t *testing.T) { srcDir := t.TempDir() reqFile := path.Join(srcDir, ".requirements.txt") err := os.WriteFile(reqFile, []byte(`compel==2.0.3 diffusers>=0.27.1 gputil==1.4.0 loguru==0.7.2 opencv-python>=4.9.0.80 pillow>=10.2.0 psutil==6.1.1 replicate>=1.0.4 sentry-sdk[fastapi,loguru]>=2.16.0 antialiased_cnns==0.3 beautifulsoup4==4.13.4 imageio==2.37.0 ipdb==0.13.13 kornia==0.8.1 matplotlib==3.10.3 numpy==1.23.5 opencv_python==4.11.0.86 Pillow==11.2.1 pytorch_lightning==2.3.3 PyYAML==6.0.2 Requests==2.32.3 scipy==1.15.3 scikit-image==0.24.0 tensorflow==2.10.0 tensorlayer==2.2.5 tf_slim==1.1.0 timm==1.0.15 torch==2.0.1 torchvision==0.15.2 tqdm==4.67.1`), 0o644) require.NoError(t, err) requirements, err := ReadRequirements(reqFile) require.NoError(t, err) require.Equal(t, []string{ "compel==2.0.3", "diffusers>=0.27.1", "gputil==1.4.0", "loguru==0.7.2", "opencv-python>=4.9.0.80", "pillow>=10.2.0", "psutil==6.1.1", "replicate>=1.0.4", "sentry-sdk[fastapi,loguru]>=2.16.0", "antialiased_cnns==0.3", "beautifulsoup4==4.13.4", "imageio==2.37.0", "ipdb==0.13.13", "kornia==0.8.1", "matplotlib==3.10.3", "numpy==1.23.5", "opencv_python==4.11.0.86", "Pillow==11.2.1", "pytorch_lightning==2.3.3", "PyYAML==6.0.2", "Requests==2.32.3", "scipy==1.15.3", "scikit-image==0.24.0", "tensorflow==2.10.0", "tensorlayer==2.2.5", "tf_slim==1.1.0", "timm==1.0.15", "torch==2.0.1", "torchvision==0.15.2", "tqdm==4.67.1", }, requirements) } func TestSplitPinnedPythonRequirement(t *testing.T) { testCases := []struct { input string expectedName string expectedVersion string expectedFindLinks []string expectedExtraIndexURLs []string expectedError bool }{ {"package1==1.0.0", "package1", "1.0.0", nil, nil, false}, {"package1==1.0.0+alpha", "package1", "1.0.0+alpha", nil, nil, false}, {"--find-links=link1 --find-links=link2 package3==3.0.0", "package3", "3.0.0", []string{"link1", "link2"}, nil, false}, {"package4==4.0.0 --extra-index-url=url1 --extra-index-url=url2", "package4", "4.0.0", nil, []string{"url1", "url2"}, false}, {"-f link1 --find-links=link2 package5==5.0.0 --extra-index-url=url1 --extra-index-url=url2", "package5", "5.0.0", []string{"link1", "link2"}, []string{"url1", "url2"}, false}, {"package6 --find-links=link1 --find-links=link2 --extra-index-url=url1 --extra-index-url=url2", "", "", nil, nil, true}, {"invalid package", "", "", nil, nil, true}, {"package8==", "", "", nil, nil, true}, {"==8.0.0", "", "", nil, nil, true}, } for _, tc := range testCases { name, version, findLinks, extraIndexURLs, err := SplitPinnedPythonRequirement(tc.input) if tc.expectedError { require.Error(t, err) } else { require.NoError(t, err) require.Equal(t, tc.expectedName, name, "input: "+tc.input) require.Equal(t, tc.expectedVersion, version, "input: "+tc.input) require.Equal(t, tc.expectedFindLinks, findLinks, "input: "+tc.input) require.Equal(t, tc.expectedExtraIndexURLs, extraIndexURLs, "input: "+tc.input) } } } func TestReadRequirementsWithEditable(t *testing.T) { srcDir := t.TempDir() reqFile := path.Join(srcDir, "requirements.txt") err := os.WriteFile(reqFile, []byte("-e .\ntorch==2.5.1"), 0o644) require.NoError(t, err) requirements, err := ReadRequirements(reqFile) require.NoError(t, err) require.Equal(t, []string{"torch==2.5.1"}, requirements) } func TestVersionSpecifier(t *testing.T) { specifier := VersionSpecifier("mypackage>= 1.0, < 1.4 || > 2.0") require.Equal(t, specifier, ">=1.0,<1.4||>2.0") } func TestPackageName(t *testing.T) { name := PackageName("mypackage>= 1.0, < 1.4 || > 2.0") require.Equal(t, name, "mypackage") } func TestVersions(t *testing.T) { versions := Versions("another @ https://some.domain/package.whl") require.Equal(t, versions, []string{"https://some.domain/package.whl"}) } func checkRequirements(t *testing.T, expected []string, actual []string) { t.Helper() for n, expectLine := range expected { actualLine := actual[n] // collapse any multiple-space runs with single spaces in the actual line - the generator may output these // but we don't care about them for comparison purposes actualLine = strings.Join(strings.Fields(actualLine), " ") require.Equal(t, expectLine, actualLine) } require.Equal(t, len(expected), len(actual)) } ================================================ FILE: pkg/schema/errors.go ================================================ package schema import "fmt" // SchemaError represents errors during schema generation. type SchemaError struct { Kind SchemaErrorKind Message string } func (e *SchemaError) Error() string { return e.Message } // SchemaErrorKind classifies schema generation errors. type SchemaErrorKind int const ( ErrParse SchemaErrorKind = iota ErrPredictorNotFound ErrMethodNotFound ErrMissingReturnType ErrMissingTypeAnnotation ErrUnsupportedType ErrDefaultFactoryNotSupported ErrInvalidConstraint ErrInvalidPredictRef ErrOptionalOutput ErrConcatIteratorNotStr ErrChoicesNotResolvable ErrDefaultNotResolvable ErrUnresolvableType ErrOther ) // NewError creates a SchemaError with the given kind and message. func NewError(kind SchemaErrorKind, msg string) *SchemaError { return &SchemaError{Kind: kind, Message: msg} } // WrapError creates a SchemaError, appending the inner error's message if non-nil. func WrapError(kind SchemaErrorKind, msg string, inner error) *SchemaError { if inner != nil { return &SchemaError{Kind: kind, Message: fmt.Sprintf("%s: %s", msg, inner.Error())} } return &SchemaError{Kind: kind, Message: msg} } func errParse(msg string) error { //nolint:unused // used by generator.go (not yet written) return &SchemaError{Kind: ErrParse, Message: fmt.Sprintf("failed to parse Python source: %s", msg)} } func errPredictorNotFound(name string) error { //nolint:unused // used by generator.go (not yet written) return &SchemaError{Kind: ErrPredictorNotFound, Message: fmt.Sprintf("predictor not found: %s", name)} } func errMethodNotFound(class, method string) error { //nolint:unused // used by generator.go (not yet written) return &SchemaError{Kind: ErrMethodNotFound, Message: fmt.Sprintf("%s method not found on %s", method, class)} } func errMissingReturnType(method string) error { //nolint:unused // used by generator.go (not yet written) return &SchemaError{Kind: ErrMissingReturnType, Message: fmt.Sprintf("missing return type annotation on %s", method)} } func errMissingTypeAnnotation(method, param string) error { //nolint:unused // used by generator.go (not yet written) return &SchemaError{Kind: ErrMissingTypeAnnotation, Message: fmt.Sprintf("missing type annotation for parameter '%s' on %s", param, method)} } func errUnsupportedType(msg string) error { return &SchemaError{Kind: ErrUnsupportedType, Message: fmt.Sprintf("unsupported type: %s", msg)} } func errDefaultFactoryNotSupported(param string) error { //nolint:unused // used by generator.go (not yet written) return &SchemaError{ Kind: ErrDefaultFactoryNotSupported, Message: fmt.Sprintf("default_factory is not supported in Input() — use a literal default value instead (parameter '%s')", param), } } func errInvalidPredictRef(ref string) error { return &SchemaError{ Kind: ErrInvalidPredictRef, Message: fmt.Sprintf("invalid predict reference '%s' — expected format: file.py:ClassName or file.py:function_name", ref), } } func errOptionalOutput() error { return &SchemaError{Kind: ErrOptionalOutput, Message: "unsupported output type: Optional is not allowed as a return type"} } func errConcatIteratorNotStr(got string) error { return &SchemaError{Kind: ErrConcatIteratorNotStr, Message: fmt.Sprintf("ConcatenateIterator element type must be str, got %s", got)} } func errChoicesNotResolvable(param string) error { //nolint:unused // used by generator.go (not yet written) return &SchemaError{ Kind: ErrChoicesNotResolvable, Message: fmt.Sprintf("choices for parameter '%s' cannot be statically resolved — use a literal list instead (e.g. choices=[\"a\", \"b\"])", param), } } func errUnresolvableImportedType(name, module string) error { return &SchemaError{ Kind: ErrUnresolvableType, Message: fmt.Sprintf( "cannot resolve output type '%s' (imported from '%s') — "+ "external types cannot be statically analyzed. "+ "Define it as a BaseModel subclass in your predict file, or provide a .pyi stub", name, module), } } func errUnresolvableType(name string) error { return &SchemaError{ Kind: ErrUnresolvableType, Message: fmt.Sprintf( "cannot resolve output type '%s' — "+ "it is not a primitive type (str, int, float, bool, Path) "+ "and no BaseModel definition was found in the predict file", name), } } func errDefaultNotResolvable(param, expr string) error { //nolint:unused // used by generator.go (not yet written) return &SchemaError{ Kind: ErrDefaultNotResolvable, Message: fmt.Sprintf( "default value for parameter '%s' cannot be statically resolved: `%s`. "+ "Defaults must be literals (string, int, float, bool, None, list) or Input() calls.", param, expr), } } ================================================ FILE: pkg/schema/generator.go ================================================ package schema import ( "encoding/json" "fmt" "os" "strings" ) // Parser is a function that parses source code and extracts predictor info. // This is defined as a type to avoid an import cycle between schema and // schema/python. The concrete implementation is python.ParsePredictor. // // sourceDir is the project root directory, used for resolving cross-file // imports (e.g. "from .types import Output"). Pass "" if unknown. type Parser func(source []byte, predictRef string, mode Mode, sourceDir string) (*PredictorInfo, error) // Generate produces an OpenAPI 3.0.2 JSON schema from a predict/train reference. // // predictRef has the format "module.py:ClassName" (e.g. "predict.py:Predictor"). // sourceDir is the directory containing the source file. // mode selects predict vs train. // parse is the parser implementation (use python.ParsePredictor). // // If the COG_OPENAPI_SCHEMA environment variable is set, its value is treated // as a path to a pre-built JSON schema file. The file contents are returned // directly and no parsing or generation takes place. func Generate(predictRef string, sourceDir string, mode Mode, parse Parser) ([]byte, error) { // "Bring your own schema" override if schemaPath := os.Getenv("COG_OPENAPI_SCHEMA"); schemaPath != "" { data, err := os.ReadFile(schemaPath) //nolint:gosec // G703: path from trusted env var if err != nil { return nil, fmt.Errorf("COG_OPENAPI_SCHEMA: failed to read %s: %w", schemaPath, err) } return data, nil } filePath, className, err := parsePredictRef(predictRef) if err != nil { return nil, err } fullPath := filePath if sourceDir != "" { fullPath = sourceDir + "/" + filePath } source, err := os.ReadFile(fullPath) if err != nil { return nil, fmt.Errorf("failed to read predictor source %s: %w", fullPath, err) } return GenerateFromSource(source, className, mode, parse, sourceDir) } // GenerateFromSource produces an OpenAPI 3.0.2 JSON schema from Python source bytes. // // predictRef is the class or function name (e.g. "Predictor" or "predict"). // parse is the parser implementation (use python.ParsePredictor). // sourceDir is the project root for resolving cross-file imports. Pass "" if unknown. // This is the lower-level API — it does not read files or check COG_OPENAPI_SCHEMA. func GenerateFromSource(source []byte, predictRef string, mode Mode, parse Parser, sourceDir string) ([]byte, error) { info, err := parse(source, predictRef, mode, sourceDir) if err != nil { return nil, err } return GenerateOpenAPISchema(info) } // GenerateCombined produces an OpenAPI schema for both predict and train (when // both are configured) and merges them into a single document. If only one mode // is configured, it returns that single schema. // // If the COG_OPENAPI_SCHEMA environment variable is set, its value is treated // as a path to a pre-built JSON schema file and returned directly. func GenerateCombined(sourceDir string, predictRef string, trainRef string, parse Parser) ([]byte, error) { // "Bring your own schema" override if schemaPath := os.Getenv("COG_OPENAPI_SCHEMA"); schemaPath != "" { data, err := os.ReadFile(schemaPath) //nolint:gosec // G703: path from trusted env var if err != nil { return nil, fmt.Errorf("COG_OPENAPI_SCHEMA: failed to read %s: %w", schemaPath, err) } return data, nil } if predictRef == "" && trainRef == "" { return nil, fmt.Errorf("no predict or train reference provided") } // Single-mode: just generate the one schema if predictRef == "" { return Generate(trainRef, sourceDir, ModeTrain, parse) } if trainRef == "" { return Generate(predictRef, sourceDir, ModePredict, parse) } // Both modes: generate each and merge predictJSON, err := Generate(predictRef, sourceDir, ModePredict, parse) if err != nil { return nil, fmt.Errorf("predict schema: %w", err) } trainJSON, err := Generate(trainRef, sourceDir, ModeTrain, parse) if err != nil { return nil, fmt.Errorf("train schema: %w", err) } var predictSchema, trainSchema map[string]any if err := json.Unmarshal(predictJSON, &predictSchema); err != nil { return nil, fmt.Errorf("failed to parse predict schema: %w", err) } if err := json.Unmarshal(trainJSON, &trainSchema); err != nil { return nil, fmt.Errorf("failed to parse train schema: %w", err) } merged := MergeSchemas(predictSchema, trainSchema) return json.MarshalIndent(merged, "", " ") } // MergeSchemas merges a predict-mode and train-mode OpenAPI schema into a single // combined schema. The predict schema is used as the base; paths and component // schemas from the train schema are added to it. func MergeSchemas(predict, train map[string]any) map[string]any { // Merge paths predictPaths, _ := predict["paths"].(map[string]any) trainPaths, _ := train["paths"].(map[string]any) if predictPaths != nil && trainPaths != nil { for k, v := range trainPaths { if _, exists := predictPaths[k]; !exists { predictPaths[k] = v } } } // Merge component schemas predictComponents, _ := predict["components"].(map[string]any) trainComponents, _ := train["components"].(map[string]any) if predictComponents != nil && trainComponents != nil { predictSchemas, _ := predictComponents["schemas"].(map[string]any) trainSchemas, _ := trainComponents["schemas"].(map[string]any) if predictSchemas != nil && trainSchemas != nil { for k, v := range trainSchemas { if _, exists := predictSchemas[k]; !exists { predictSchemas[k] = v } } } } return predict } // parsePredictRef splits a predict reference like "predict.py:Predictor" into // the file path and class/function name. func parsePredictRef(ref string) (filePath string, name string, err error) { parts := strings.SplitN(ref, ":", 2) if len(parts) != 2 || parts[0] == "" || parts[1] == "" { return "", "", errInvalidPredictRef(ref) } return parts[0], parts[1], nil } ================================================ FILE: pkg/schema/generator_test.go ================================================ package schema import ( "encoding/json" "os" "path/filepath" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // mockParser is a test parser that returns a fixed PredictorInfo. func mockParser(source []byte, predictRef string, mode Mode, sourceDir string) (*PredictorInfo, error) { inputs := NewOrderedMap[string, InputField]() inputs.Set("prompt", InputField{ Name: "prompt", Order: 0, FieldType: FieldType{Primitive: TypeString, Repetition: Required}, }) return &PredictorInfo{ Inputs: inputs, Output: SchemaPrim(TypeString), Mode: mode, }, nil } // failParser always returns an error. func failParser(_ []byte, _ string, _ Mode, _ string) (*PredictorInfo, error) { return nil, NewError(ErrParse, "mock parse failure") } // --------------------------------------------------------------------------- // parsePredictRef // --------------------------------------------------------------------------- func TestParsePredictRef(t *testing.T) { tests := []struct { input string file string name string wantErr bool }{ {"predict.py:Predictor", "predict.py", "Predictor", false}, {"src/model.py:MyModel", "src/model.py", "MyModel", false}, {"train.py:train", "train.py", "train", false}, {"no_colon", "", "", true}, {":NoFile", "", "", true}, {"no_name:", "", "", true}, {"", "", "", true}, } for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { file, name, err := parsePredictRef(tt.input) if tt.wantErr { require.Error(t, err) var se *SchemaError require.ErrorAs(t, err, &se) assert.Equal(t, ErrInvalidPredictRef, se.Kind) } else { require.NoError(t, err) assert.Equal(t, tt.file, file) assert.Equal(t, tt.name, name) } }) } } // --------------------------------------------------------------------------- // GenerateFromSource // --------------------------------------------------------------------------- func TestGenerateFromSource(t *testing.T) { data, err := GenerateFromSource([]byte("unused"), "Predictor", ModePredict, mockParser, "") require.NoError(t, err) var spec map[string]any require.NoError(t, json.Unmarshal(data, &spec)) assert.Equal(t, "3.0.2", spec["openapi"]) props := getPath(spec, "components", "schemas", "Input", "properties").(map[string]any) assert.Contains(t, props, "prompt") } func TestGenerateFromSourceTrainMode(t *testing.T) { data, err := GenerateFromSource([]byte("unused"), "Trainer", ModeTrain, mockParser, "") require.NoError(t, err) var spec map[string]any require.NoError(t, json.Unmarshal(data, &spec)) assert.NotNil(t, getPath(spec, "components", "schemas", "TrainingInput")) assert.NotNil(t, getPath(spec, "paths", "/trainings", "post")) } func TestGenerateFromSourceParseError(t *testing.T) { _, err := GenerateFromSource([]byte("unused"), "Predictor", ModePredict, failParser, "") require.Error(t, err) assert.Contains(t, err.Error(), "mock parse failure") } // --------------------------------------------------------------------------- // Generate — file-based // --------------------------------------------------------------------------- func TestGenerateReadsFile(t *testing.T) { dir := t.TempDir() err := os.WriteFile(filepath.Join(dir, "predict.py"), []byte("class Predictor: pass"), 0o644) require.NoError(t, err) data, err := Generate("predict.py:Predictor", dir, ModePredict, mockParser) require.NoError(t, err) var spec map[string]any require.NoError(t, json.Unmarshal(data, &spec)) assert.Equal(t, "3.0.2", spec["openapi"]) } func TestGenerateMissingFile(t *testing.T) { dir := t.TempDir() _, err := Generate("missing.py:Predictor", dir, ModePredict, mockParser) require.Error(t, err) assert.Contains(t, err.Error(), "failed to read predictor source") } func TestGenerateInvalidRef(t *testing.T) { _, err := Generate("no_colon", ".", ModePredict, mockParser) require.Error(t, err) var se *SchemaError require.ErrorAs(t, err, &se) assert.Equal(t, ErrInvalidPredictRef, se.Kind) } // --------------------------------------------------------------------------- // COG_OPENAPI_SCHEMA env var // --------------------------------------------------------------------------- func TestGenerateCogOpenAPISchemaEnv(t *testing.T) { // Write a pre-built schema file dir := t.TempDir() schemaContent := `{"openapi": "3.0.2", "info": {"title": "Custom"}}` schemaPath := filepath.Join(dir, "custom_schema.json") err := os.WriteFile(schemaPath, []byte(schemaContent), 0o644) require.NoError(t, err) t.Setenv("COG_OPENAPI_SCHEMA", schemaPath) // Should return the custom schema without parsing // (using failParser to prove parsing is skipped) data, err := Generate("predict.py:Predictor", ".", ModePredict, failParser) require.NoError(t, err) assert.Equal(t, schemaContent, string(data)) } func TestGenerateCogOpenAPISchemaEnvMissingFile(t *testing.T) { t.Setenv("COG_OPENAPI_SCHEMA", "/nonexistent/schema.json") _, err := Generate("predict.py:Predictor", ".", ModePredict, mockParser) require.Error(t, err) assert.Contains(t, err.Error(), "COG_OPENAPI_SCHEMA") assert.Contains(t, err.Error(), "failed to read") } func TestGenerateCogOpenAPISchemaEnvNotSet(t *testing.T) { // Ensure env var is not set t.Setenv("COG_OPENAPI_SCHEMA", "") dir := t.TempDir() err := os.WriteFile(filepath.Join(dir, "predict.py"), []byte("class Predictor: pass"), 0o644) require.NoError(t, err) // Should proceed with normal generation (not use env var) data, err := Generate("predict.py:Predictor", dir, ModePredict, mockParser) require.NoError(t, err) var spec map[string]any require.NoError(t, json.Unmarshal(data, &spec)) assert.Equal(t, "Cog", getPath(spec, "info", "title")) } ================================================ FILE: pkg/schema/openapi.go ================================================ package schema import ( "encoding/json" "maps" "sort" ) // GenerateOpenAPISchema produces a complete OpenAPI 3.0.2 specification // from a PredictorInfo. The returned bytes are compact JSON. func GenerateOpenAPISchema(info *PredictorInfo) ([]byte, error) { spec := buildOpenAPISpec(info) // Post-processing: remove title next to $ref, fix nullable anyOf removeTitleNextToRef(spec) fixNullableAnyOf(spec) return json.Marshal(spec) } // buildOpenAPISpec constructs the full OpenAPI 3.0.2 map. func buildOpenAPISpec(info *PredictorInfo) map[string]any { inputSchema, enumSchemas := buildInputSchema(info) outputSchema := info.Output.JSONSchema() isTrain := info.Mode == ModeTrain var ( endpoint string requestName string responseName string cancelEP string summary string description string opID string cancelOpID string cancelParam string inputKey string outputKey string ) if isTrain { endpoint = "/trainings" requestName = "TrainingRequest" responseName = "TrainingResponse" cancelEP = "/trainings/{training_id}/cancel" summary = "Train" description = "Run a single training on the model" opID = "train_trainings_post" cancelOpID = "cancel_trainings__training_id__cancel_post" cancelParam = "training_id" inputKey = "TrainingInput" outputKey = "TrainingOutput" } else { endpoint = "/predictions" requestName = "PredictionRequest" responseName = "PredictionResponse" cancelEP = "/predictions/{prediction_id}/cancel" summary = "Predict" description = "Run a single prediction on the model" opID = "predict_predictions_post" cancelOpID = "cancel_predictions__prediction_id__cancel_post" cancelParam = "prediction_id" inputKey = "Input" outputKey = "Output" } // Build components/schemas components := newOrderedMapAny() // Input schema inputSchema["title"] = inputKey components.Set(inputKey, inputSchema) // Output schema components.Set(outputKey, outputSchema) // Enum schemas for choices for _, es := range enumSchemas { components.Set(es.name, es.schema) } inputRef := "#/components/schemas/" + inputKey outputRef := "#/components/schemas/" + outputKey // Request schema components.Set(requestName, map[string]any{ "title": requestName, "type": "object", "properties": map[string]any{ "id": map[string]any{"title": "Id", "type": "string"}, "input": map[string]any{"$ref": inputRef}, }, }) // Response schema components.Set(responseName, map[string]any{ "title": responseName, "type": "object", "properties": map[string]any{ "input": map[string]any{"$ref": inputRef}, "output": map[string]any{"$ref": outputRef}, "id": map[string]any{"title": "Id", "type": "string"}, "version": map[string]any{"title": "Version", "type": "string"}, "created_at": map[string]any{"title": "Created At", "type": "string", "format": "date-time"}, "started_at": map[string]any{"title": "Started At", "type": "string", "format": "date-time"}, "completed_at": map[string]any{"title": "Completed At", "type": "string", "format": "date-time"}, "status": map[string]any{"title": "Status", "type": "string"}, "error": map[string]any{"title": "Error", "type": "string"}, "logs": map[string]any{"title": "Logs", "type": "string"}, "metrics": map[string]any{"title": "Metrics", "type": "object"}, }, }) // Status enum components.Set("Status", map[string]any{ "title": "Status", "description": "An enumeration.", "enum": []any{"starting", "processing", "succeeded", "canceled", "failed"}, "type": "string", }) // Validation error schemas components.Set("HTTPValidationError", map[string]any{ "title": "HTTPValidationError", "type": "object", "properties": map[string]any{ "detail": map[string]any{ "title": "Detail", "type": "array", "items": map[string]any{"$ref": "#/components/schemas/ValidationError"}, }, }, }) components.Set("ValidationError", map[string]any{ "title": "ValidationError", "required": []any{"loc", "msg", "type"}, "type": "object", "properties": map[string]any{ "loc": map[string]any{ "title": "Location", "type": "array", "items": map[string]any{ "anyOf": []any{ map[string]any{"type": "string"}, map[string]any{"type": "integer"}, }, }, }, "msg": map[string]any{"title": "Message", "type": "string"}, "type": map[string]any{"title": "Error Type", "type": "string"}, }, }) requestRef := "#/components/schemas/" + requestName responseRef := "#/components/schemas/" + responseName // Build paths paths := newOrderedMapAny() // Root paths.Set("/", map[string]any{ "get": map[string]any{ "summary": "Root", "operationId": "root__get", "responses": map[string]any{ "200": map[string]any{ "description": "Successful Response", "content": map[string]any{"application/json": map[string]any{"schema": map[string]any{}}}, }, }, }, }) // Health check paths.Set("/health-check", map[string]any{ "get": map[string]any{ "summary": "Healthcheck", "operationId": "healthcheck_health_check_get", "responses": map[string]any{ "200": map[string]any{ "description": "Successful Response", "content": map[string]any{"application/json": map[string]any{"schema": map[string]any{}}}, }, }, }, }) // Main endpoint (predict or train) paths.Set(endpoint, map[string]any{ "post": map[string]any{ "summary": summary, "description": description, "operationId": opID, "requestBody": map[string]any{ "content": map[string]any{ "application/json": map[string]any{ "schema": map[string]any{"$ref": requestRef}, }, }, }, "responses": map[string]any{ "200": map[string]any{ "description": "Successful Response", "content": map[string]any{ "application/json": map[string]any{ "schema": map[string]any{"$ref": responseRef}, }, }, }, "422": map[string]any{ "description": "Validation Error", "content": map[string]any{ "application/json": map[string]any{ "schema": map[string]any{"$ref": "#/components/schemas/HTTPValidationError"}, }, }, }, }, }, }) // Cancel endpoint paths.Set(cancelEP, map[string]any{ "post": map[string]any{ "summary": "Cancel", "operationId": cancelOpID, "parameters": []any{ map[string]any{ "required": true, "schema": map[string]any{"title": TitleCaseSingle(cancelParam), "type": "string"}, "name": cancelParam, "in": "path", }, }, "responses": map[string]any{ "200": map[string]any{ "description": "Successful Response", "content": map[string]any{"application/json": map[string]any{"schema": map[string]any{}}}, }, "422": map[string]any{ "description": "Validation Error", "content": map[string]any{ "application/json": map[string]any{ "schema": map[string]any{"$ref": "#/components/schemas/HTTPValidationError"}, }, }, }, }, }, }) return map[string]any{ "openapi": "3.0.2", "info": map[string]any{"title": "Cog", "version": "0.1.0"}, "paths": paths, "components": map[string]any{ "schemas": components, }, } } // enumSchema pairs a name with its schema for choices fields. type enumSchema struct { name string schema map[string]any } // buildInputSchema builds the Input schema object and any enum schemas for choices. func buildInputSchema(info *PredictorInfo) (map[string]any, []enumSchema) { properties := newOrderedMapAny() var required []string var enums []enumSchema info.Inputs.Entries(func(name string, field InputField) { prop := newOrderedMapAny() // x-order for field ordering prop.Set("x-order", field.Order) if len(field.Choices) > 0 { // Choices -> use allOf with $ref to enum schema enumName := TitleCaseSingle(name) enumType := field.FieldType.Primitive.JSONType() typeStr, _ := enumType["type"].(string) if typeStr == "" { typeStr = "string" } choiceValues := make([]any, len(field.Choices)) for i, c := range field.Choices { choiceValues[i] = c.ToJSON() } enums = append(enums, enumSchema{ name: enumName, schema: map[string]any{ "title": enumName, "description": "An enumeration.", "enum": choiceValues, "type": typeStr, }, }) prop.Set("allOf", []any{ map[string]any{"$ref": "#/components/schemas/" + enumName}, }) } else { // Regular field — inline type prop.Set("title", TitleCase(name)) typeSchema := field.FieldType.JSONType() // Merge type schema keys into prop in sorted order for determinism keys := make([]string, 0, len(typeSchema)) for k := range typeSchema { keys = append(keys, k) } sort.Strings(keys) for _, k := range keys { prop.Set(k, typeSchema[k]) } } // Required? if field.IsRequired() { required = append(required, name) } // Default value if field.Default != nil { prop.Set("default", field.Default.ToJSON()) } // Nullable if field.FieldType.Repetition == Optional { prop.Set("nullable", true) } // Description if field.Description != nil { prop.Set("description", *field.Description) } // Numeric constraints if field.GE != nil { prop.Set("minimum", *field.GE) } if field.LE != nil { prop.Set("maximum", *field.LE) } // String constraints if field.MinLength != nil { prop.Set("minLength", *field.MinLength) } if field.MaxLength != nil { prop.Set("maxLength", *field.MaxLength) } if field.Regex != nil { prop.Set("pattern", *field.Regex) } // Deprecated if field.Deprecated != nil && *field.Deprecated { prop.Set("deprecated", true) } properties.Set(name, prop) }) inputSchema := map[string]any{ "title": "Input", "type": "object", "properties": properties, } if len(required) > 0 { inputSchema["required"] = required } return inputSchema, enums } // --------------------------------------------------------------------------- // Post-processing (mirrors openapi_schema.py fixups) // --------------------------------------------------------------------------- // removeTitleNextToRef removes "title" from any map that also has "$ref". // OpenAPI 3.0 doesn't allow sibling keywords next to $ref. func removeTitleNextToRef(v any) { switch val := v.(type) { case map[string]any: if _, hasRef := val["$ref"]; hasRef { delete(val, "title") } for _, child := range val { removeTitleNextToRef(child) } case *orderedMapAny: if _, hasRef := val.Get("$ref"); hasRef { val.Delete("title") } val.Entries(func(_ string, child any) { removeTitleNextToRef(child) }) case []any: for _, child := range val { removeTitleNextToRef(child) } } } // fixNullableAnyOf converts {"anyOf": [{"type": T}, {"type": "null"}]} to // {"type": T, "nullable": true}. OpenAPI 3.0 uses nullable instead of union-with-null. func fixNullableAnyOf(v any) { switch val := v.(type) { case map[string]any: // Recurse first for _, child := range val { fixNullableAnyOf(child) } // Check for anyOf with null pattern anyOf, ok := val["anyOf"].([]any) if !ok || len(anyOf) != 2 { return } var nonNull map[string]any hasNull := false for _, variant := range anyOf { m, ok := variant.(map[string]any) if !ok { return } if t, _ := m["type"].(string); t == "null" { hasNull = true } else { nonNull = m } } if hasNull && nonNull != nil { delete(val, "anyOf") maps.Copy(val, nonNull) val["nullable"] = true } case *orderedMapAny: // Recurse first val.Entries(func(_ string, child any) { fixNullableAnyOf(child) }) // Check for anyOf with null pattern anyOfRaw, ok := val.Get("anyOf") if !ok { return } anyOf, ok := anyOfRaw.([]any) if !ok || len(anyOf) != 2 { return } var nonNull map[string]any hasNull := false for _, variant := range anyOf { m, ok := variant.(map[string]any) if !ok { return } if t, _ := m["type"].(string); t == "null" { hasNull = true } else { nonNull = m } } if hasNull && nonNull != nil { val.Delete("anyOf") for k, v := range nonNull { val.Set(k, v) } val.Set("nullable", true) } case []any: for _, child := range val { fixNullableAnyOf(child) } } } // --------------------------------------------------------------------------- // orderedMapAny — ordered map with JSON marshaling that preserves key order. // Used for schema properties where field ordering matters. // --------------------------------------------------------------------------- type orderedMapAny struct { keys []string values map[string]any } func newOrderedMapAny() *orderedMapAny { return &orderedMapAny{values: make(map[string]any)} } func (m *orderedMapAny) Set(key string, value any) { if _, exists := m.values[key]; !exists { m.keys = append(m.keys, key) } m.values[key] = value } func (m *orderedMapAny) Get(key string) (any, bool) { v, ok := m.values[key] return v, ok } func (m *orderedMapAny) Delete(key string) { if _, exists := m.values[key]; !exists { return } delete(m.values, key) for i, k := range m.keys { if k == key { m.keys = append(m.keys[:i], m.keys[i+1:]...) break } } } func (m *orderedMapAny) Entries(fn func(key string, value any)) { for _, k := range m.keys { fn(k, m.values[k]) } } // MarshalJSON produces a JSON object with keys in insertion order. func (m *orderedMapAny) MarshalJSON() ([]byte, error) { buf := []byte{'{'} for i, k := range m.keys { if i > 0 { buf = append(buf, ',') } keyBytes, err := json.Marshal(k) if err != nil { return nil, err } buf = append(buf, keyBytes...) buf = append(buf, ':') valBytes, err := json.Marshal(m.values[k]) if err != nil { return nil, err } buf = append(buf, valBytes...) } buf = append(buf, '}') return buf, nil } ================================================ FILE: pkg/schema/openapi_test.go ================================================ package schema import ( "encoding/json" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // --------------------------------------------------------------------------- // Helpers // --------------------------------------------------------------------------- func simplePredictor() *PredictorInfo { inputs := NewOrderedMap[string, InputField]() inputs.Set("s", InputField{ Name: "s", Order: 0, FieldType: FieldType{Primitive: TypeString, Repetition: Required}, }) return &PredictorInfo{ Inputs: inputs, Output: SchemaPrim(TypeString), Mode: ModePredict, } } func ptr[T any](v T) *T { return &v } // parseSpec is a test helper that generates the schema and unmarshals // it into a generic map for assertion. func parseSpec(t *testing.T, info *PredictorInfo) map[string]any { t.Helper() data, err := GenerateOpenAPISchema(info) require.NoError(t, err) var spec map[string]any require.NoError(t, json.Unmarshal(data, &spec)) return spec } func getPath(m map[string]any, keys ...string) any { var cur any = m for _, k := range keys { obj, ok := cur.(map[string]any) if !ok { return nil } cur = obj[k] } return cur } // --------------------------------------------------------------------------- // Tests: Top-level structure // --------------------------------------------------------------------------- func TestGeneratesValidOpenAPI(t *testing.T) { spec := parseSpec(t, simplePredictor()) assert.Equal(t, "3.0.2", spec["openapi"]) assert.Equal(t, "Cog", getPath(spec, "info", "title")) assert.Equal(t, "0.1.0", getPath(spec, "info", "version")) } func TestPredictEndpoints(t *testing.T) { spec := parseSpec(t, simplePredictor()) // Root assert.NotNil(t, getPath(spec, "paths", "/", "get")) // Health check assert.NotNil(t, getPath(spec, "paths", "/health-check", "get")) // Predictions post := getPath(spec, "paths", "/predictions", "post") require.NotNil(t, post) postMap := post.(map[string]any) assert.Equal(t, "Predict", postMap["summary"]) assert.Equal(t, "predict_predictions_post", postMap["operationId"]) // Cancel assert.NotNil(t, getPath(spec, "paths", "/predictions/{prediction_id}/cancel", "post")) } func TestTrainEndpoints(t *testing.T) { info := simplePredictor() info.Mode = ModeTrain spec := parseSpec(t, info) post := getPath(spec, "paths", "/trainings", "post") require.NotNil(t, post) postMap := post.(map[string]any) assert.Equal(t, "Train", postMap["summary"]) assert.Equal(t, "train_trainings_post", postMap["operationId"]) // Cancel cancel := getPath(spec, "paths", "/trainings/{training_id}/cancel", "post") require.NotNil(t, cancel) // Schema keys use TrainingInput/TrainingOutput assert.NotNil(t, getPath(spec, "components", "schemas", "TrainingInput")) assert.NotNil(t, getPath(spec, "components", "schemas", "TrainingOutput")) assert.NotNil(t, getPath(spec, "components", "schemas", "TrainingRequest")) assert.NotNil(t, getPath(spec, "components", "schemas", "TrainingResponse")) } // --------------------------------------------------------------------------- // Tests: Fixed components // --------------------------------------------------------------------------- func TestFixedComponentSchemas(t *testing.T) { spec := parseSpec(t, simplePredictor()) schemas := getPath(spec, "components", "schemas").(map[string]any) // PredictionRequest req := schemas["PredictionRequest"].(map[string]any) assert.Equal(t, "PredictionRequest", req["title"]) props := req["properties"].(map[string]any) assert.Equal(t, "#/components/schemas/Input", getPath(props, "input", "$ref")) assert.Equal(t, "string", getPath(props, "id", "type")) // PredictionResponse resp := schemas["PredictionResponse"].(map[string]any) assert.Equal(t, "PredictionResponse", resp["title"]) respProps := resp["properties"].(map[string]any) assert.Equal(t, "#/components/schemas/Input", getPath(respProps, "input", "$ref")) assert.Equal(t, "#/components/schemas/Output", getPath(respProps, "output", "$ref")) // Status status := schemas["Status"].(map[string]any) assert.Equal(t, "string", status["type"]) enum := status["enum"].([]any) assert.Contains(t, enum, "starting") assert.Contains(t, enum, "succeeded") // Validation errors assert.NotNil(t, schemas["HTTPValidationError"]) assert.NotNil(t, schemas["ValidationError"]) } // --------------------------------------------------------------------------- // Tests: Input schema // --------------------------------------------------------------------------- func TestInputRequiredField(t *testing.T) { spec := parseSpec(t, simplePredictor()) input := getPath(spec, "components", "schemas", "Input").(map[string]any) required := input["required"].([]any) assert.Contains(t, required, "s") } func TestInputOptionalFieldNotRequired(t *testing.T) { inputs := NewOrderedMap[string, InputField]() inputs.Set("name", InputField{ Name: "name", Order: 0, FieldType: FieldType{Primitive: TypeString, Repetition: Optional}, Default: &DefaultValue{Kind: DefaultNone}, }) info := &PredictorInfo{ Inputs: inputs, Output: SchemaPrim(TypeString), Mode: ModePredict, } spec := parseSpec(t, info) input := getPath(spec, "components", "schemas", "Input").(map[string]any) // Should not have required since there's a default assert.Nil(t, input["required"]) // Should have nullable props := input["properties"].(map[string]any) nameField := props["name"].(map[string]any) assert.Equal(t, true, nameField["nullable"]) } func TestInputDefaultValue(t *testing.T) { inputs := NewOrderedMap[string, InputField]() inputs.Set("count", InputField{ Name: "count", Order: 0, FieldType: FieldType{Primitive: TypeInteger, Repetition: Required}, Default: &DefaultValue{Kind: DefaultInt, Int: 42}, }) info := &PredictorInfo{ Inputs: inputs, Output: SchemaPrim(TypeString), Mode: ModePredict, } spec := parseSpec(t, info) props := getPath(spec, "components", "schemas", "Input", "properties").(map[string]any) countField := props["count"].(map[string]any) // JSON numbers unmarshal as float64 assert.Equal(t, float64(42), countField["default"]) } func TestInputDescription(t *testing.T) { inputs := NewOrderedMap[string, InputField]() inputs.Set("text", InputField{ Name: "text", Order: 0, FieldType: FieldType{Primitive: TypeString, Repetition: Required}, Description: ptr("The input text"), }) info := &PredictorInfo{ Inputs: inputs, Output: SchemaPrim(TypeString), Mode: ModePredict, } spec := parseSpec(t, info) props := getPath(spec, "components", "schemas", "Input", "properties").(map[string]any) textField := props["text"].(map[string]any) assert.Equal(t, "The input text", textField["description"]) } func TestInputNumericConstraints(t *testing.T) { inputs := NewOrderedMap[string, InputField]() inputs.Set("temperature", InputField{ Name: "temperature", Order: 0, FieldType: FieldType{Primitive: TypeFloat, Repetition: Required}, GE: ptr(0.0), LE: ptr(1.0), }) info := &PredictorInfo{ Inputs: inputs, Output: SchemaPrim(TypeString), Mode: ModePredict, } spec := parseSpec(t, info) props := getPath(spec, "components", "schemas", "Input", "properties").(map[string]any) tempField := props["temperature"].(map[string]any) assert.Equal(t, float64(0), tempField["minimum"]) assert.Equal(t, float64(1), tempField["maximum"]) } func TestInputStringConstraints(t *testing.T) { inputs := NewOrderedMap[string, InputField]() inputs.Set("name", InputField{ Name: "name", Order: 0, FieldType: FieldType{Primitive: TypeString, Repetition: Required}, MinLength: ptr[uint64](1), MaxLength: ptr[uint64](100), Regex: ptr("^[a-z]+$"), }) info := &PredictorInfo{ Inputs: inputs, Output: SchemaPrim(TypeString), Mode: ModePredict, } spec := parseSpec(t, info) props := getPath(spec, "components", "schemas", "Input", "properties").(map[string]any) nameField := props["name"].(map[string]any) assert.Equal(t, float64(1), nameField["minLength"]) assert.Equal(t, float64(100), nameField["maxLength"]) assert.Equal(t, "^[a-z]+$", nameField["pattern"]) } func TestInputDeprecated(t *testing.T) { inputs := NewOrderedMap[string, InputField]() inputs.Set("old_param", InputField{ Name: "old_param", Order: 0, FieldType: FieldType{Primitive: TypeString, Repetition: Required}, Deprecated: ptr(true), }) info := &PredictorInfo{ Inputs: inputs, Output: SchemaPrim(TypeString), Mode: ModePredict, } spec := parseSpec(t, info) props := getPath(spec, "components", "schemas", "Input", "properties").(map[string]any) field := props["old_param"].(map[string]any) assert.Equal(t, true, field["deprecated"]) } func TestInputXOrder(t *testing.T) { inputs := NewOrderedMap[string, InputField]() inputs.Set("first", InputField{ Name: "first", Order: 0, FieldType: FieldType{Primitive: TypeString, Repetition: Required}, }) inputs.Set("second", InputField{ Name: "second", Order: 1, FieldType: FieldType{Primitive: TypeInteger, Repetition: Required}, }) info := &PredictorInfo{ Inputs: inputs, Output: SchemaPrim(TypeString), Mode: ModePredict, } spec := parseSpec(t, info) props := getPath(spec, "components", "schemas", "Input", "properties").(map[string]any) assert.Equal(t, float64(0), props["first"].(map[string]any)["x-order"]) assert.Equal(t, float64(1), props["second"].(map[string]any)["x-order"]) } func TestInputRepeatedType(t *testing.T) { inputs := NewOrderedMap[string, InputField]() inputs.Set("items", InputField{ Name: "items", Order: 0, FieldType: FieldType{Primitive: TypeString, Repetition: Repeated}, }) info := &PredictorInfo{ Inputs: inputs, Output: SchemaPrim(TypeString), Mode: ModePredict, } spec := parseSpec(t, info) props := getPath(spec, "components", "schemas", "Input", "properties").(map[string]any) itemsField := props["items"].(map[string]any) assert.Equal(t, "array", itemsField["type"]) items := itemsField["items"].(map[string]any) assert.Equal(t, "string", items["type"]) } // --------------------------------------------------------------------------- // Tests: Choices / Enums // --------------------------------------------------------------------------- func TestChoicesGenerateEnum(t *testing.T) { inputs := NewOrderedMap[string, InputField]() inputs.Set("color", InputField{ Name: "color", Order: 0, FieldType: FieldType{Primitive: TypeString, Repetition: Required}, Choices: []DefaultValue{ {Kind: DefaultString, Str: "red"}, {Kind: DefaultString, Str: "blue"}, }, }) info := &PredictorInfo{ Inputs: inputs, Output: SchemaPrim(TypeString), Mode: ModePredict, } spec := parseSpec(t, info) // Enum schema created schemas := getPath(spec, "components", "schemas").(map[string]any) colorEnum := schemas["Color"].(map[string]any) assert.Equal(t, "Color", colorEnum["title"]) assert.Equal(t, "string", colorEnum["type"]) assert.Equal(t, []any{"red", "blue"}, colorEnum["enum"]) // Property uses allOf $ref props := getPath(spec, "components", "schemas", "Input", "properties").(map[string]any) colorProp := props["color"].(map[string]any) allOf := colorProp["allOf"].([]any) assert.Len(t, allOf, 1) ref := allOf[0].(map[string]any) assert.Equal(t, "#/components/schemas/Color", ref["$ref"]) } func TestIntegerChoices(t *testing.T) { inputs := NewOrderedMap[string, InputField]() inputs.Set("size", InputField{ Name: "size", Order: 0, FieldType: FieldType{Primitive: TypeInteger, Repetition: Required}, Choices: []DefaultValue{ {Kind: DefaultInt, Int: 256}, {Kind: DefaultInt, Int: 512}, {Kind: DefaultInt, Int: 1024}, }, }) info := &PredictorInfo{ Inputs: inputs, Output: SchemaPrim(TypeString), Mode: ModePredict, } spec := parseSpec(t, info) schemas := getPath(spec, "components", "schemas").(map[string]any) sizeEnum := schemas["Size"].(map[string]any) assert.Equal(t, "integer", sizeEnum["type"]) // JSON numbers are float64 assert.Equal(t, []any{float64(256), float64(512), float64(1024)}, sizeEnum["enum"]) } // --------------------------------------------------------------------------- // Tests: Output types // --------------------------------------------------------------------------- func TestOutputSingle(t *testing.T) { spec := parseSpec(t, simplePredictor()) output := getPath(spec, "components", "schemas", "Output").(map[string]any) assert.Equal(t, "Output", output["title"]) assert.Equal(t, "string", output["type"]) } func TestOutputList(t *testing.T) { inputs := NewOrderedMap[string, InputField]() info := &PredictorInfo{ Inputs: inputs, Output: SchemaArrayOf(SchemaPrim(TypeString)), Mode: ModePredict, } spec := parseSpec(t, info) output := getPath(spec, "components", "schemas", "Output").(map[string]any) assert.Equal(t, "Output", output["title"]) assert.Equal(t, "array", output["type"]) items := output["items"].(map[string]any) assert.Equal(t, "string", items["type"]) } func TestOutputIterator(t *testing.T) { inputs := NewOrderedMap[string, InputField]() info := &PredictorInfo{ Inputs: inputs, Output: SchemaIteratorOf(SchemaPrim(TypeString)), Mode: ModePredict, } spec := parseSpec(t, info) output := getPath(spec, "components", "schemas", "Output").(map[string]any) assert.Equal(t, "array", output["type"]) assert.Equal(t, "iterator", output["x-cog-array-type"]) } func TestOutputConcatenateIterator(t *testing.T) { inputs := NewOrderedMap[string, InputField]() info := &PredictorInfo{ Inputs: inputs, Output: SchemaConcatIteratorOf(), Mode: ModePredict, } spec := parseSpec(t, info) output := getPath(spec, "components", "schemas", "Output").(map[string]any) assert.Equal(t, "array", output["type"]) assert.Equal(t, "iterator", output["x-cog-array-type"]) assert.Equal(t, "concatenate", output["x-cog-array-display"]) } func TestOutputObject(t *testing.T) { inputs := NewOrderedMap[string, InputField]() fields := NewOrderedMap[string, SchemaField]() fields.Set("name", SchemaField{ Type: SchemaPrim(TypeString), Required: true, }) fields.Set("score", SchemaField{ Type: SchemaPrim(TypeFloat), Required: true, }) fields.Set("notes", SchemaField{ Type: SchemaType{Kind: SchemaPrimitive, Primitive: TypeString, Nullable: true}, Required: false, Default: &DefaultValue{Kind: DefaultNone}, }) info := &PredictorInfo{ Inputs: inputs, Output: SchemaObjectOf(fields), Mode: ModePredict, } spec := parseSpec(t, info) output := getPath(spec, "components", "schemas", "Output").(map[string]any) assert.Equal(t, "object", output["type"]) props := output["properties"].(map[string]any) // name nameField := props["name"].(map[string]any) assert.Equal(t, "string", nameField["type"]) assert.Equal(t, "Name", nameField["title"]) // score scoreField := props["score"].(map[string]any) assert.Equal(t, "number", scoreField["type"]) // notes — nullable notesField := props["notes"].(map[string]any) assert.Equal(t, true, notesField["nullable"]) // Required should include name and score but not notes required := output["required"].([]any) assert.Contains(t, required, "name") assert.Contains(t, required, "score") assert.NotContains(t, required, "notes") } func TestOutputPath(t *testing.T) { inputs := NewOrderedMap[string, InputField]() info := &PredictorInfo{ Inputs: inputs, Output: SchemaPrim(TypePath), Mode: ModePredict, } spec := parseSpec(t, info) output := getPath(spec, "components", "schemas", "Output").(map[string]any) assert.Equal(t, "string", output["type"]) assert.Equal(t, "uri", output["format"]) } // --------------------------------------------------------------------------- // Tests: Post-processing // --------------------------------------------------------------------------- func TestRemoveTitleNextToRef(t *testing.T) { schema := map[string]any{ "title": "Foo", "$ref": "#/components/schemas/Bar", } removeTitleNextToRef(schema) assert.Nil(t, schema["title"]) assert.Equal(t, "#/components/schemas/Bar", schema["$ref"]) } func TestRemoveTitleNextToRefNested(t *testing.T) { schema := map[string]any{ "properties": map[string]any{ "inner": map[string]any{ "title": "Inner", "$ref": "#/components/schemas/Foo", }, }, } removeTitleNextToRef(schema) inner := schema["properties"].(map[string]any)["inner"].(map[string]any) assert.Nil(t, inner["title"]) } func TestFixNullableAnyOf(t *testing.T) { schema := map[string]any{ "anyOf": []any{ map[string]any{"type": "string"}, map[string]any{"type": "null"}, }, } fixNullableAnyOf(schema) assert.Nil(t, schema["anyOf"]) assert.Equal(t, "string", schema["type"]) assert.Equal(t, true, schema["nullable"]) } func TestFixNullableAnyOfNoOp(t *testing.T) { // anyOf with no null should be left alone schema := map[string]any{ "anyOf": []any{ map[string]any{"type": "string"}, map[string]any{"type": "integer"}, }, } fixNullableAnyOf(schema) assert.NotNil(t, schema["anyOf"]) assert.Nil(t, schema["nullable"]) } // --------------------------------------------------------------------------- // Tests: Title case helpers // --------------------------------------------------------------------------- func TestTitleCaseWords(t *testing.T) { assert.Equal(t, "Hello World", TitleCase("hello_world")) assert.Equal(t, "Segmented Image", TitleCase("segmented_image")) assert.Equal(t, "Name", TitleCase("name")) } func TestTitleCaseSingleWord(t *testing.T) { assert.Equal(t, "Prediction_id", TitleCaseSingle("prediction_id")) assert.Equal(t, "Color", TitleCaseSingle("color")) assert.Equal(t, "", TitleCaseSingle("")) } // --------------------------------------------------------------------------- // Tests: JSON output is valid and parseable // --------------------------------------------------------------------------- func TestOutputIsValidJSON(t *testing.T) { data, err := GenerateOpenAPISchema(simplePredictor()) require.NoError(t, err) var parsed any require.NoError(t, json.Unmarshal(data, &parsed)) assert.NotNil(t, parsed) } // --------------------------------------------------------------------------- // Tests: Multiple inputs with various types // --------------------------------------------------------------------------- func TestMultipleInputTypes(t *testing.T) { inputs := NewOrderedMap[string, InputField]() inputs.Set("text", InputField{ Name: "text", Order: 0, FieldType: FieldType{Primitive: TypeString, Repetition: Required}, }) inputs.Set("count", InputField{ Name: "count", Order: 1, FieldType: FieldType{Primitive: TypeInteger, Repetition: Required}, Default: &DefaultValue{Kind: DefaultInt, Int: 10}, }) inputs.Set("image", InputField{ Name: "image", Order: 2, FieldType: FieldType{Primitive: TypePath, Repetition: Required}, }) inputs.Set("flag", InputField{ Name: "flag", Order: 3, FieldType: FieldType{Primitive: TypeBool, Repetition: Required}, Default: &DefaultValue{Kind: DefaultBool, Bool: false}, }) inputs.Set("secret_key", InputField{ Name: "secret_key", Order: 4, FieldType: FieldType{Primitive: TypeSecret, Repetition: Optional}, Default: &DefaultValue{Kind: DefaultNone}, }) info := &PredictorInfo{ Inputs: inputs, Output: SchemaPrim(TypeString), Mode: ModePredict, } spec := parseSpec(t, info) props := getPath(spec, "components", "schemas", "Input", "properties").(map[string]any) // text - string textField := props["text"].(map[string]any) assert.Equal(t, "string", textField["type"]) // count - integer with default countField := props["count"].(map[string]any) assert.Equal(t, "integer", countField["type"]) assert.Equal(t, float64(10), countField["default"]) // image - path (URI) imageField := props["image"].(map[string]any) assert.Equal(t, "string", imageField["type"]) assert.Equal(t, "uri", imageField["format"]) // flag - boolean flagField := props["flag"].(map[string]any) assert.Equal(t, "boolean", flagField["type"]) // secret_key - secret secretField := props["secret_key"].(map[string]any) assert.Equal(t, "string", secretField["type"]) assert.Equal(t, "password", secretField["format"]) assert.Equal(t, true, secretField["writeOnly"]) assert.Equal(t, true, secretField["x-cog-secret"]) assert.Equal(t, true, secretField["nullable"]) // Only text and image should be required (count has default, flag has default, secret has default) required := getPath(spec, "components", "schemas", "Input", "required").([]any) assert.Contains(t, required, "text") assert.Contains(t, required, "image") assert.NotContains(t, required, "count") assert.NotContains(t, required, "flag") assert.NotContains(t, required, "secret_key") } // --------------------------------------------------------------------------- // Tests: Edge cases // --------------------------------------------------------------------------- func TestNoInputs(t *testing.T) { inputs := NewOrderedMap[string, InputField]() info := &PredictorInfo{ Inputs: inputs, Output: SchemaPrim(TypeString), Mode: ModePredict, } spec := parseSpec(t, info) input := getPath(spec, "components", "schemas", "Input").(map[string]any) assert.Equal(t, "object", input["type"]) // required should not be present when there are no required fields assert.Nil(t, input["required"]) } func TestOutputObjectNoFields(t *testing.T) { inputs := NewOrderedMap[string, InputField]() info := &PredictorInfo{ Inputs: inputs, Output: SchemaType{Kind: SchemaObject}, Mode: ModePredict, } spec := parseSpec(t, info) output := getPath(spec, "components", "schemas", "Output").(map[string]any) assert.Equal(t, "object", output["type"]) } // --------------------------------------------------------------------------- // Tests: orderedMapAny JSON output preserves insertion order // --------------------------------------------------------------------------- func TestOrderedMapAnyJSON(t *testing.T) { m := newOrderedMapAny() m.Set("z", 1) m.Set("a", 2) m.Set("m", 3) data, err := json.Marshal(m) require.NoError(t, err) assert.Equal(t, `{"z":1,"a":2,"m":3}`, string(data)) } func TestOrderedMapAnyDelete(t *testing.T) { m := newOrderedMapAny() m.Set("a", 1) m.Set("b", 2) m.Set("c", 3) m.Delete("b") data, err := json.Marshal(m) require.NoError(t, err) assert.Equal(t, `{"a":1,"c":3}`, string(data)) } ================================================ FILE: pkg/schema/python/parser.go ================================================ // Package python implements a tree-sitter based Python parser for extracting // Cog predictor signatures. It walks the concrete syntax tree to extract // imports, class definitions, function parameters with type annotations and // default values, and Input() call keyword arguments. // // This parser is Python-specific. Future languages (e.g. Node.js) would get // their own parser package under pkg/schema/. package python import ( "context" "errors" "fmt" "os" "path/filepath" "strconv" "strings" sitter "github.com/smacker/go-tree-sitter" "github.com/smacker/go-tree-sitter/python" "github.com/replicate/cog/pkg/schema" ) // ParsePredictor parses Python source and extracts predictor information. // predictRef is the class or function name (e.g. "Predictor" or "predict"). // mode controls whether we look for predict or train method. // sourceDir is the project root for resolving cross-file imports. Pass "" if unknown. func ParsePredictor(source []byte, predictRef string, mode schema.Mode, sourceDir string) (*schema.PredictorInfo, error) { parser := sitter.NewParser() parser.SetLanguage(python.GetLanguage()) tree, err := parser.ParseCtx(context.Background(), nil, source) if err != nil { return nil, schema.WrapError(schema.ErrParse, "tree-sitter parse failed", err) } root := tree.RootNode() // 1. Collect imports imports := collectImports(root, source) // 2. Collect module-level variable assignments moduleScope := collectModuleScope(root, source) // 3. Collect BaseModel subclasses (local file first, then cross-file) modelClasses := collectModelClasses(root, source, imports) if sourceDir != "" { resolveExternalModels(sourceDir, imports, modelClasses) } // 4. Collect Input() references from class attributes and static methods inputRegistry := collectInputRegistry(root, source, imports, moduleScope) // 5. Find the target predict/train function methodName := "predict" if mode == schema.ModeTrain { methodName = "train" } funcNode, err := findTargetFunction(root, source, predictRef, methodName) if err != nil { return nil, err } // 6. Check if method (has self first param) paramsNode := funcNode.ChildByFieldName("parameters") if paramsNode == nil { return nil, schema.WrapError(schema.ErrParse, "function has no parameters node", nil) } isMethod := firstParamIsSelf(paramsNode, source) // 7. Extract parameters inputs, err := extractInputs(paramsNode, source, methodName, isMethod, imports, inputRegistry, moduleScope) if err != nil { return nil, err } // 8. Extract return type returnAnn := funcNode.ChildByFieldName("return_type") if returnAnn == nil { return nil, schema.WrapError(schema.ErrMissingReturnType, methodName, nil) } returnTypeAnn, err := parseTypeAnnotation(returnAnn, source) if err != nil { return nil, err } output, err := schema.ResolveSchemaType(returnTypeAnn, imports, modelClasses) if err != nil { return nil, err } return &schema.PredictorInfo{ Inputs: inputs, Output: output, Mode: mode, }, nil } // --------------------------------------------------------------------------- // Helpers // --------------------------------------------------------------------------- // namedChildren returns all named children of a node. func namedChildren(n *sitter.Node) []*sitter.Node { count := int(n.NamedChildCount()) result := make([]*sitter.Node, 0, count) for i := range count { result = append(result, n.NamedChild(i)) } return result } // allChildren returns all children (named and anonymous) of a node. func allChildren(n *sitter.Node) []*sitter.Node { count := int(n.ChildCount()) result := make([]*sitter.Node, 0, count) for i := range count { result = append(result, n.Child(i)) } return result } // content returns the source text for a node. func content(n *sitter.Node, source []byte) string { return n.Content(source) } // --------------------------------------------------------------------------- // Import collection // --------------------------------------------------------------------------- func collectImports(root *sitter.Node, source []byte) *schema.ImportContext { ctx := schema.NewImportContext() for _, child := range namedChildren(root) { if child.Type() == "import_from_statement" { parseImportFrom(child, source, ctx) } } // Always include Python builtins for _, builtin := range []string{"str", "int", "float", "bool", "list", "dict", "set"} { if _, ok := ctx.Names.Get(builtin); !ok { ctx.Names.Set(builtin, schema.ImportEntry{Module: "builtins", Original: builtin}) } } if _, ok := ctx.Names.Get("None"); !ok { ctx.Names.Set("None", schema.ImportEntry{Module: "builtins", Original: "None"}) } return ctx } func parseImportFrom(node *sitter.Node, source []byte, ctx *schema.ImportContext) { moduleNode := node.ChildByFieldName("module_name") if moduleNode == nil { return } module := content(moduleNode, source) for _, child := range allChildren(node) { switch child.Type() { case "dotted_name": // Single import: `from X import name` // Skip if this is the module_name itself if child.StartByte() != moduleNode.StartByte() { name := content(child, source) ctx.Names.Set(name, schema.ImportEntry{Module: module, Original: name}) } case "aliased_import": // Single aliased import: `from X import name as alias` origNode := child.ChildByFieldName("name") aliasNode := child.ChildByFieldName("alias") orig := "" if origNode != nil { orig = content(origNode, source) } alias := orig if aliasNode != nil { alias = content(aliasNode, source) } ctx.Names.Set(alias, schema.ImportEntry{Module: module, Original: orig}) case "import_list": for _, importChild := range allChildren(child) { switch importChild.Type() { case "dotted_name": name := content(importChild, source) ctx.Names.Set(name, schema.ImportEntry{Module: module, Original: name}) case "aliased_import": origNode := importChild.ChildByFieldName("name") aliasNode := importChild.ChildByFieldName("alias") orig := "" if origNode != nil { orig = content(origNode, source) } alias := orig if aliasNode != nil { alias = content(aliasNode, source) } ctx.Names.Set(alias, schema.ImportEntry{Module: module, Original: orig}) } } } } } // --------------------------------------------------------------------------- // Module scope collection // --------------------------------------------------------------------------- type moduleScope map[string]schema.DefaultValue func collectModuleScope(root *sitter.Node, source []byte) moduleScope { scope := make(moduleScope) for _, child := range namedChildren(root) { var assign *sitter.Node if child.Type() == "expression_statement" { if child.NamedChildCount() == 1 { inner := child.NamedChild(0) if inner.Type() == "assignment" { assign = inner } } } else if child.Type() == "assignment" { assign = child } if assign == nil { continue } left := assign.ChildByFieldName("left") if left == nil || left.Type() != "identifier" { continue } name := content(left, source) right := assign.ChildByFieldName("right") if right == nil { continue } if val, ok := parseDefaultValue(right, source); ok { scope[name] = val } } return scope } // resolveDefaultExpr tries to resolve an expression to a DefaultValue by // literal parsing, then falling back to module scope lookup for identifiers. func resolveDefaultExpr(node *sitter.Node, source []byte, scope moduleScope) (schema.DefaultValue, bool) { if val, ok := parseDefaultValue(node, source); ok { return val, true } if node.Type() == "identifier" { name := content(node, source) if val, ok := scope[name]; ok { return val, true } } return schema.DefaultValue{}, false } // resolveChoicesExpr tries to statically resolve a choices= expression. func resolveChoicesExpr(node *sitter.Node, source []byte, scope moduleScope) ([]schema.DefaultValue, bool) { switch node.Type() { case "list": return parseListLiteral(node, source) case "identifier": name := content(node, source) val, ok := scope[name] if !ok { return nil, false } if val.Kind == schema.DefaultList { return val.List, true } return nil, false case "call": return resolveChoicesCall(node, source, scope) case "binary_operator": // Only handle + (list concatenation) hasPlus := false for _, c := range allChildren(node) { if !c.IsNamed() && content(c, source) == "+" { hasPlus = true break } } if !hasPlus { return nil, false } left := node.ChildByFieldName("left") right := node.ChildByFieldName("right") if left == nil || right == nil { return nil, false } leftItems, ok := resolveChoicesExpr(left, source, scope) if !ok { return nil, false } rightItems, ok := resolveChoicesExpr(right, source, scope) if !ok { return nil, false } return append(leftItems, rightItems...), true } return nil, false } // resolveChoicesCall resolves list(X.keys()) or list(X.values()). func resolveChoicesCall(node *sitter.Node, source []byte, scope moduleScope) ([]schema.DefaultValue, bool) { funcNode := node.ChildByFieldName("function") if funcNode == nil || content(funcNode, source) != "list" { return nil, false } args := node.ChildByFieldName("arguments") if args == nil { return nil, false } // Find the single positional argument var arg *sitter.Node for _, c := range namedChildren(args) { arg = c break } if arg == nil || arg.Type() != "call" { return nil, false } innerFunc := arg.ChildByFieldName("function") if innerFunc == nil || innerFunc.Type() != "attribute" { return nil, false } obj := innerFunc.ChildByFieldName("object") attr := innerFunc.ChildByFieldName("attribute") if obj == nil || attr == nil || obj.Type() != "identifier" { return nil, false } varName := content(obj, source) methodName := content(attr, source) dictVal, ok := scope[varName] if !ok || dictVal.Kind != schema.DefaultDict { return nil, false } switch methodName { case "keys": return dictVal.DictKeys, true case "values": return dictVal.DictVals, true } return nil, false } // --------------------------------------------------------------------------- // BaseModel subclass collection // --------------------------------------------------------------------------- func collectModelClasses(root *sitter.Node, source []byte, imports *schema.ImportContext) schema.ModelClassMap { models := schema.NewOrderedMap[string, []schema.ModelField]() for _, child := range namedChildren(root) { classNode := unwrapClass(child) if classNode == nil { continue } nameNode := classNode.ChildByFieldName("name") if nameNode == nil { continue } className := content(nameNode, source) if !inheritsFromBaseModel(classNode, source, imports) { continue } fields := extractClassAnnotations(classNode, source) models.Set(className, fields) } return models } // resolveExternalModels looks at imports that brought in names not yet in // modelClasses, attempts to find the corresponding .py file on disk, parses // it, and merges any BaseModel subclasses into modelClasses. // // This handles every local import permutation: // // from .types import Output → /types.py // from types import Output → /types.py // from models.output import Result → /models/output.py // from .models.output import Result → /models/output.py // from my_app.types import Foo → /my_app/types.py // // Non-local imports (stdlib, pip packages) are skipped because the file // won't exist on disk. func resolveExternalModels(sourceDir string, imports *schema.ImportContext, models schema.ModelClassMap) { // Track which modules we've already tried so we don't re-parse. tried := make(map[string]bool) imports.Names.Entries(func(localName string, entry schema.ImportEntry) { // Already resolved locally — skip. if _, ok := models.Get(localName); ok { return } module := entry.Module if !tried[module] { tried[module] = true // Skip known non-local modules. if isKnownExternalModule(module) { return } // Convert module path to filesystem path and try to find it. pyPath := moduleToFilePath(module) if pyPath == "" { return } fullPath := filepath.Join(sourceDir, pyPath) source, err := os.ReadFile(fullPath) if err != nil { if errors.Is(err, os.ErrNotExist) { // File doesn't exist — it's an external package, not local. return } fmt.Fprintf(os.Stderr, "cog: warning: failed to read %q: %v\n", fullPath, err) return } // Parse the file and extract BaseModel subclasses. parser := sitter.NewParser() parser.SetLanguage(python.GetLanguage()) tree, err := parser.ParseCtx(context.Background(), nil, source) if err != nil { fmt.Fprintf(os.Stderr, "cog: warning: failed to parse %q: %v\n", fullPath, err) return } fileImports := collectImports(tree.RootNode(), source) fileModels := collectModelClasses(tree.RootNode(), source, fileImports) // Merge discovered models into the caller's map. fileModels.Entries(func(name string, fields []schema.ModelField) { if _, exists := models.Get(name); !exists { models.Set(name, fields) } }) } // Handle aliases: "from X import MyOutput as Output" // localName is "Output", entry.Original is "MyOutput". // If we resolved "MyOutput" from the file, also register it under "Output". if localName != entry.Original { if fields, ok := models.Get(entry.Original); ok { if _, exists := models.Get(localName); !exists { models.Set(localName, fields) } } } }) } // moduleToFilePath converts a Python module path to a relative .py file path. // // ".types" → "types.py" // "types" → "types.py" // ".models.output" → "models/output.py" // "models.output" → "models/output.py" // "cog" → "cog.py" (will fail os.ReadFile → skipped) func moduleToFilePath(module string) string { // Strip leading dots (relative import markers). clean := strings.TrimLeft(module, ".") if clean == "" { return "" } // Replace dots with path separators. parts := strings.Split(clean, ".") return filepath.Join(parts...) + ".py" } // isKnownExternalModule returns true for modules that are definitely not // local project files — stdlib, well-known packages, etc. func isKnownExternalModule(module string) bool { // Extract the top-level package name. top := module if i := strings.Index(module, "."); i > 0 { top = module[:i] } top = strings.TrimLeft(top, ".") switch top { case "builtins", "typing", "typing_extensions", "collections", "abc", "enum", "dataclasses", "os", "sys", "io", "json", "re", "math", "pathlib", "functools", "itertools", "contextlib", "concurrent", "asyncio", "multiprocessing", "threading", "logging", "warnings", "unittest", "pytest", "numpy", "torch", "tensorflow", "jax", "scipy", "sklearn", "transformers", "diffusers", "accelerate", "safetensors", "PIL", "cv2", "skimage", "requests", "httpx", "aiohttp", "fastapi", "flask", "pydantic", "cog": return true } return false } func unwrapClass(node *sitter.Node) *sitter.Node { if node.Type() == "class_definition" { return node } if node.Type() == "decorated_definition" { for _, c := range namedChildren(node) { if c.Type() == "class_definition" { return c } } } return nil } func unwrapFunction(node *sitter.Node) *sitter.Node { if node.Type() == "function_definition" { return node } if node.Type() == "decorated_definition" { for _, c := range namedChildren(node) { if c.Type() == "function_definition" { return c } } } return nil } func inheritsFromBaseModel(classNode *sitter.Node, source []byte, imports *schema.ImportContext) bool { supers := classNode.ChildByFieldName("superclasses") if supers == nil { return false } for _, child := range allChildren(supers) { switch child.Type() { case "identifier": name := content(child, source) if imports.IsBaseModel(name) || name == "BaseModel" { return true } case "attribute": // Handle dotted access: pydantic.BaseModel, cog.BaseModel text := content(child, source) if strings.HasSuffix(text, ".BaseModel") { return true } } } return false } func extractClassAnnotations(classNode *sitter.Node, source []byte) []schema.ModelField { body := classNode.ChildByFieldName("body") if body == nil { return nil } var fields []schema.ModelField for _, child := range namedChildren(body) { node := child if child.Type() == "expression_statement" && child.NamedChildCount() == 1 { node = child.NamedChild(0) } switch node.Type() { case "assignment": if f, ok := parseAnnotatedAssignment(node, source); ok { fields = append(fields, f) } case "type": if f, ok := parseBareAnnotation(node, source); ok { fields = append(fields, f) } } } return fields } func parseAnnotatedAssignment(node *sitter.Node, source []byte) (schema.ModelField, bool) { left := node.ChildByFieldName("left") typeNode := node.ChildByFieldName("type") if left == nil || typeNode == nil || left.Type() != "identifier" { return schema.ModelField{}, false } name := content(left, source) typeAnn, err := parseTypeAnnotation(typeNode, source) if err != nil { return schema.ModelField{}, false } var def *schema.DefaultValue if right := node.ChildByFieldName("right"); right != nil { if v, ok := parseDefaultValue(right, source); ok { def = &v } } return schema.ModelField{Name: name, Type: typeAnn, Default: def}, true } func parseBareAnnotation(node *sitter.Node, source []byte) (schema.ModelField, bool) { text := strings.TrimSpace(content(node, source)) parts := strings.SplitN(text, ":", 2) if len(parts) != 2 { return schema.ModelField{}, false } name := strings.TrimSpace(parts[0]) typeStr := strings.TrimSpace(parts[1]) if name == "" || (name[0] != '_' && (name[0] < 'a' || name[0] > 'z') && (name[0] < 'A' || name[0] > 'Z')) { return schema.ModelField{}, false } typeAnn, ok := parseTypeFromString(typeStr) if !ok { return schema.ModelField{}, false } return schema.ModelField{Name: name, Type: typeAnn, Default: nil}, true } func parseTypeFromString(s string) (schema.TypeAnnotation, bool) { s = strings.TrimSpace(s) if s == "" { return schema.TypeAnnotation{}, false } // Union: X | Y if strings.Contains(s, "|") { parts := strings.Split(s, "|") var members []schema.TypeAnnotation for _, p := range parts { m, ok := parseTypeFromString(strings.TrimSpace(p)) if !ok { return schema.TypeAnnotation{}, false } members = append(members, m) } if len(members) >= 2 { return schema.TypeAnnotation{Kind: schema.TypeAnnotUnion, Args: members}, true } return schema.TypeAnnotation{}, false } // Generic: X[Y] or X[Y, Z] bracketPos := strings.Index(s, "[") if bracketPos >= 0 && strings.HasSuffix(s, "]") { outer := strings.TrimSpace(s[:bracketPos]) innerStr := s[bracketPos+1 : len(s)-1] // Split on top-level commas (handles Union[str, None], etc.) parts := splitTopLevelCommas(innerStr) var args []schema.TypeAnnotation for _, p := range parts { arg, ok := parseTypeFromString(strings.TrimSpace(p)) if !ok { return schema.TypeAnnotation{}, false } args = append(args, arg) } if len(args) == 0 { return schema.TypeAnnotation{}, false } return schema.TypeAnnotation{Kind: schema.TypeAnnotGeneric, Name: outer, Args: args}, true } // Simple identifier for _, c := range s { if (c < 'a' || c > 'z') && (c < 'A' || c > 'Z') && (c < '0' || c > '9') && c != '_' { return schema.TypeAnnotation{}, false } } return schema.TypeAnnotation{Kind: schema.TypeAnnotSimple, Name: s}, true } // splitTopLevelCommas splits a string on commas that are not nested inside brackets. // e.g. "str, None" → ["str", "None"], "List[str], None" → ["List[str]", "None"] func splitTopLevelCommas(s string) []string { var parts []string depth := 0 start := 0 for i, c := range s { switch c { case '[': depth++ case ']': depth-- case ',': if depth == 0 { parts = append(parts, s[start:i]) start = i + 1 } } } parts = append(parts, s[start:]) return parts } // --------------------------------------------------------------------------- // InputRegistry — resolves ClassName.attr and ClassName.method(args) // --------------------------------------------------------------------------- type inputCallInfo struct { Default *schema.DefaultValue Description *string GE *float64 LE *float64 MinLength *uint64 MaxLength *uint64 Regex *string Choices []schema.DefaultValue Deprecated *bool } type inputMethodInfo struct { ParamNames []string BaseInfo inputCallInfo } type inputRegistry struct { Attributes map[string]inputCallInfo Methods map[string]inputMethodInfo } func newInputRegistry() *inputRegistry { return &inputRegistry{ Attributes: make(map[string]inputCallInfo), Methods: make(map[string]inputMethodInfo), } } func collectInputRegistry(root *sitter.Node, source []byte, imports *schema.ImportContext, scope moduleScope) *inputRegistry { registry := newInputRegistry() for _, child := range namedChildren(root) { classNode := unwrapClass(child) if classNode == nil { continue } nameNode := classNode.ChildByFieldName("name") if nameNode == nil { continue } className := content(nameNode, source) body := classNode.ChildByFieldName("body") if body == nil { continue } for _, stmt := range namedChildren(body) { inner := stmt if stmt.Type() == "expression_statement" && stmt.NamedChildCount() == 1 { inner = stmt.NamedChild(0) } if inner.Type() == "assignment" { collectInputAttribute(className, inner, source, imports, scope, registry) } if funcNode := unwrapFunction(inner); funcNode != nil { collectInputMethod(className, funcNode, source, imports, scope, registry) } } } return registry } func collectInputAttribute(className string, assignment *sitter.Node, source []byte, imports *schema.ImportContext, scope moduleScope, registry *inputRegistry) { left := assignment.ChildByFieldName("left") if left == nil || left.Type() != "identifier" { return } attrName := content(left, source) right := assignment.ChildByFieldName("right") if right == nil || !isInputCall(right, source, imports) { return } key := className + "." + attrName info, err := parseInputCall(right, source, key, scope) if err != nil { return } registry.Attributes[key] = info } func collectInputMethod(className string, funcNode *sitter.Node, source []byte, imports *schema.ImportContext, scope moduleScope, registry *inputRegistry) { nameNode := funcNode.ChildByFieldName("name") if nameNode == nil { return } methodName := content(nameNode, source) params := funcNode.ChildByFieldName("parameters") if params == nil { return } var paramNames []string for _, param := range allChildren(params) { switch param.Type() { case "identifier": name := content(param, source) if name != "self" && name != "cls" { paramNames = append(paramNames, name) } case "typed_parameter": // typed_parameter has no "name" field; first identifier child is the name for j := 0; j < int(param.NamedChildCount()); j++ { c := param.NamedChild(j) if c.Type() == "identifier" { name := content(c, source) if name != "self" && name != "cls" { paramNames = append(paramNames, name) } break } } case "typed_default_parameter", "default_parameter": if n := param.ChildByFieldName("name"); n != nil { name := content(n, source) if name != "self" && name != "cls" { paramNames = append(paramNames, name) } } } } body := funcNode.ChildByFieldName("body") if body == nil { return } inputCall := findReturnInputCall(body, source, imports) if inputCall == nil { return } key := className + "." + methodName info, err := parseInputCall(inputCall, source, key, scope) if err != nil { return } registry.Methods[key] = inputMethodInfo{ParamNames: paramNames, BaseInfo: info} } func findReturnInputCall(body *sitter.Node, source []byte, imports *schema.ImportContext) *sitter.Node { for _, child := range namedChildren(body) { if child.Type() == "return_statement" { if child.NamedChildCount() > 0 { expr := child.NamedChild(0) if isInputCall(expr, source, imports) { return expr } } } } return nil } func resolveInputReference(node *sitter.Node, source []byte, registry *inputRegistry) (inputCallInfo, bool) { switch node.Type() { case "attribute": text := content(node, source) info, ok := registry.Attributes[text] return info, ok case "call": funcNode := node.ChildByFieldName("function") if funcNode == nil || funcNode.Type() != "attribute" { return inputCallInfo{}, false } key := content(funcNode, source) methodInfo, ok := registry.Methods[key] if !ok { return inputCallInfo{}, false } resolved := methodInfo.BaseInfo args := node.ChildByFieldName("arguments") if args == nil { return resolved, true } // Build param_name -> call-site value map argValues := make(map[string]*sitter.Node) positionalIdx := 0 for _, arg := range namedChildren(args) { if arg.Type() == "keyword_argument" { nameNode := arg.ChildByFieldName("name") valNode := arg.ChildByFieldName("value") if nameNode != nil && valNode != nil { argValues[content(nameNode, source)] = valNode } } else if positionalIdx < len(methodInfo.ParamNames) { argValues[methodInfo.ParamNames[positionalIdx]] = arg positionalIdx++ } } // Override with call-site values for paramName, callNode := range argValues { switch paramName { case "default": if val, ok := parseDefaultValue(callNode, source); ok { resolved.Default = &val } case "description": if s, ok := parseStringLiteral(callNode, source); ok { resolved.Description = &s } case "ge": if n, ok := parseNumberLiteral(callNode, source); ok { resolved.GE = &n } case "le": if n, ok := parseNumberLiteral(callNode, source); ok { resolved.LE = &n } } } return resolved, true } return inputCallInfo{}, false } // --------------------------------------------------------------------------- // Target function finding // --------------------------------------------------------------------------- func findTargetFunction(root *sitter.Node, source []byte, predictRef, methodName string) (*sitter.Node, error) { // First: look for a class with this name for _, child := range namedChildren(root) { classNode := unwrapClass(child) if classNode == nil { continue } nameNode := classNode.ChildByFieldName("name") if nameNode != nil && content(nameNode, source) == predictRef { return findMethodInClass(classNode, source, predictRef, methodName) } } // Second: look for standalone function for _, child := range namedChildren(root) { funcNode := unwrapFunction(child) if funcNode == nil { continue } nameNode := funcNode.ChildByFieldName("name") if nameNode != nil { name := content(nameNode, source) if name == predictRef || name == methodName { return funcNode, nil } } } return nil, schema.WrapError(schema.ErrPredictorNotFound, predictRef, nil) } func findMethodInClass(classNode *sitter.Node, source []byte, className, methodName string) (*sitter.Node, error) { body := classNode.ChildByFieldName("body") if body == nil { return nil, schema.WrapError(schema.ErrParse, fmt.Sprintf("class %s has no body", className), nil) } for _, child := range namedChildren(body) { funcNode := unwrapFunction(child) if funcNode == nil { continue } nameNode := funcNode.ChildByFieldName("name") if nameNode != nil && content(nameNode, source) == methodName { return funcNode, nil } } return nil, schema.WrapError(schema.ErrMethodNotFound, fmt.Sprintf("%s.%s not found", className, methodName), nil) } // --------------------------------------------------------------------------- // Parameter extraction // --------------------------------------------------------------------------- func firstParamIsSelf(params *sitter.Node, source []byte) bool { for _, child := range allChildren(params) { if child.Type() == "identifier" { return content(child, source) == "self" } } return false } func extractInputs( paramsNode *sitter.Node, source []byte, methodName string, skipSelf bool, imports *schema.ImportContext, registry *inputRegistry, scope moduleScope, ) (*schema.OrderedMap[string, schema.InputField], error) { inputs := schema.NewOrderedMap[string, schema.InputField]() order := 0 seenSelf := false for _, child := range allChildren(paramsNode) { switch child.Type() { case "identifier": if !seenSelf && skipSelf { name := content(child, source) if name == "self" { seenSelf = true continue } } case "typed_parameter": input, err := parseTypedParameter(child, source, order, methodName, imports) if err != nil { return nil, err } inputs.Set(input.Name, input) order++ case "typed_default_parameter": input, err := parseTypedDefaultParameter(child, source, order, methodName, imports, registry, scope) if err != nil { return nil, err } inputs.Set(input.Name, input) order++ case "default_parameter": nameNode := child.ChildByFieldName("name") paramName := "" if nameNode != nil { paramName = content(nameNode, source) } return nil, schema.WrapError(schema.ErrMissingTypeAnnotation, fmt.Sprintf("parameter '%s' on %s has no type annotation", paramName, methodName), nil) } } return inputs, nil } func parseTypedParameter(node *sitter.Node, source []byte, order int, methodName string, imports *schema.ImportContext) (schema.InputField, error) { // typed_parameter has no "name" field in the Python grammar. // Structure is: identifier ":" type var name string var typeNode *sitter.Node for i := 0; i < int(node.NamedChildCount()); i++ { c := node.NamedChild(i) switch c.Type() { case "identifier": if name == "" { name = content(c, source) } case "type": typeNode = c } } if name == "" { return schema.InputField{}, schema.WrapError(schema.ErrParse, "typed_parameter has no identifier", nil) } if typeNode == nil { return schema.InputField{}, schema.WrapError(schema.ErrMissingTypeAnnotation, fmt.Sprintf("parameter '%s' on %s has no type annotation", name, methodName), nil) } typeAnn, err := parseTypeAnnotation(typeNode, source) if err != nil { return schema.InputField{}, err } fieldType, err := schema.ResolveFieldType(typeAnn, imports) if err != nil { return schema.InputField{}, err } return schema.InputField{ Name: name, Order: order, FieldType: fieldType, }, nil } func parseTypedDefaultParameter( node *sitter.Node, source []byte, order int, methodName string, imports *schema.ImportContext, registry *inputRegistry, scope moduleScope, ) (schema.InputField, error) { nameNode := node.ChildByFieldName("name") if nameNode == nil { return schema.InputField{}, schema.WrapError(schema.ErrParse, "typed_default_parameter has no name", nil) } name := content(nameNode, source) typeNode := node.ChildByFieldName("type") if typeNode == nil { return schema.InputField{}, schema.WrapError(schema.ErrMissingTypeAnnotation, fmt.Sprintf("parameter '%s' on %s has no type annotation", name, methodName), nil) } typeAnn, err := parseTypeAnnotation(typeNode, source) if err != nil { return schema.InputField{}, err } fieldType, err := schema.ResolveFieldType(typeAnn, imports) if err != nil { return schema.InputField{}, err } valNode := node.ChildByFieldName("value") if valNode != nil { // 1. Direct Input() call if isInputCall(valNode, source, imports) { info, err := parseInputCall(valNode, source, name, scope) if err != nil { return schema.InputField{}, err } return schema.InputField{ Name: name, Order: order, FieldType: fieldType, Default: info.Default, Description: info.Description, GE: info.GE, LE: info.LE, MinLength: info.MinLength, MaxLength: info.MaxLength, Regex: info.Regex, Choices: info.Choices, Deprecated: info.Deprecated, }, nil } // 2. Reference to Input() via class attribute or static method if info, ok := resolveInputReference(valNode, source, registry); ok { return schema.InputField{ Name: name, Order: order, FieldType: fieldType, Default: info.Default, Description: info.Description, GE: info.GE, LE: info.LE, MinLength: info.MinLength, MaxLength: info.MaxLength, Regex: info.Regex, Choices: info.Choices, Deprecated: info.Deprecated, }, nil } // 3. Plain default — must be statically resolvable if def, ok := resolveDefaultExpr(valNode, source, scope); ok { return schema.InputField{ Name: name, Order: order, FieldType: fieldType, Default: &def, }, nil } // Can't resolve — hard error valText := content(valNode, source) return schema.InputField{}, schema.WrapError(schema.ErrDefaultNotResolvable, fmt.Sprintf("parameter '%s': default `%s` cannot be statically resolved", name, valText), nil) } // No default — required parameter return schema.InputField{ Name: name, Order: order, FieldType: fieldType, }, nil } // --------------------------------------------------------------------------- // Type annotation parsing // --------------------------------------------------------------------------- func parseTypeAnnotation(node *sitter.Node, source []byte) (schema.TypeAnnotation, error) { // Unwrap `type` wrapper node n := node if n.Type() == "type" && n.NamedChildCount() > 0 { n = n.NamedChild(0) } switch n.Type() { case "identifier": return schema.TypeAnnotation{Kind: schema.TypeAnnotSimple, Name: content(n, source)}, nil case "subscript": value := n.ChildByFieldName("value") if value == nil { return schema.TypeAnnotation{}, schema.WrapError(schema.ErrParse, "subscript has no value", nil) } outer := content(value, source) var args []schema.TypeAnnotation for _, child := range namedChildren(n) { // Skip the outer identifier (the value field) if child.StartByte() == value.StartByte() { continue } arg, err := parseTypeAnnotation(child, source) if err != nil { return schema.TypeAnnotation{}, err } args = append(args, arg) } if len(args) == 0 { return schema.TypeAnnotation{Kind: schema.TypeAnnotSimple, Name: outer}, nil } return schema.TypeAnnotation{Kind: schema.TypeAnnotGeneric, Name: outer, Args: args}, nil case "binary_operator": left := n.ChildByFieldName("left") right := n.ChildByFieldName("right") if left == nil || right == nil { return schema.TypeAnnotation{}, schema.WrapError(schema.ErrParse, "binary_operator missing operand", nil) } // Check that operator is | isUnion := false for _, c := range allChildren(n) { if !c.IsNamed() && content(c, source) == "|" { isUnion = true break } } if !isUnion { return schema.TypeAnnotation{}, errUnsupported("non-union binary operator in type annotation") } leftAnn, err := parseTypeAnnotation(left, source) if err != nil { return schema.TypeAnnotation{}, err } rightAnn, err := parseTypeAnnotation(right, source) if err != nil { return schema.TypeAnnotation{}, err } // Flatten nested unions var members []schema.TypeAnnotation if leftAnn.Kind == schema.TypeAnnotUnion { members = append(members, leftAnn.Args...) } else { members = append(members, leftAnn) } if rightAnn.Kind == schema.TypeAnnotUnion { members = append(members, rightAnn.Args...) } else { members = append(members, rightAnn) } return schema.TypeAnnotation{Kind: schema.TypeAnnotUnion, Args: members}, nil case "none": return schema.TypeAnnotation{Kind: schema.TypeAnnotSimple, Name: "None"}, nil case "attribute": return schema.TypeAnnotation{Kind: schema.TypeAnnotSimple, Name: content(n, source)}, nil case "string", "concatenated_string": text := content(n, source) inner := strings.TrimLeft(text, "\"'") inner = strings.TrimRight(inner, "\"'") if ann, ok := parseTypeFromString(inner); ok { return ann, nil } return schema.TypeAnnotation{}, errUnsupported(fmt.Sprintf("string annotation: %s", text)) default: text := content(n, source) if ann, ok := parseTypeFromString(text); ok { return ann, nil } return schema.TypeAnnotation{}, errUnsupported(fmt.Sprintf("%s: %s", n.Type(), text)) } } func errUnsupported(msg string) error { return &schema.SchemaError{Kind: schema.ErrUnsupportedType, Message: fmt.Sprintf("unsupported type: %s", msg)} } // --------------------------------------------------------------------------- // Input() call parsing // --------------------------------------------------------------------------- func isInputCall(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { if node.Type() != "call" { return false } funcNode := node.ChildByFieldName("function") if funcNode == nil { return false } name := content(funcNode, source) if name == "Input" { return true } if e, ok := imports.Names.Get(name); ok { return e.Module == "cog" && e.Original == "Input" } return false } func parseInputCall(node *sitter.Node, source []byte, paramName string, scope moduleScope) (inputCallInfo, error) { var info inputCallInfo args := node.ChildByFieldName("arguments") if args == nil { return info, nil } for _, child := range namedChildren(args) { if child.Type() != "keyword_argument" { continue } keyNode := child.ChildByFieldName("name") valNode := child.ChildByFieldName("value") if keyNode == nil || valNode == nil { continue } key := content(keyNode, source) switch key { case "default": val, ok := resolveDefaultExpr(valNode, source, scope) if !ok { none := schema.DefaultValue{Kind: schema.DefaultNone} val = none } info.Default = &val case "default_factory": return inputCallInfo{}, schema.WrapError(schema.ErrDefaultFactoryNotSupported, fmt.Sprintf("parameter '%s': default_factory is not supported in static schema generation", paramName), nil) case "description": if s, ok := parseStringLiteral(valNode, source); ok { info.Description = &s } case "ge": if n, ok := parseNumberLiteral(valNode, source); ok { info.GE = &n } case "le": if n, ok := parseNumberLiteral(valNode, source); ok { info.LE = &n } case "min_length": if n, ok := parseNumberLiteral(valNode, source); ok { u := uint64(n) info.MinLength = &u } case "max_length": if n, ok := parseNumberLiteral(valNode, source); ok { u := uint64(n) info.MaxLength = &u } case "regex": if s, ok := parseStringLiteral(valNode, source); ok { info.Regex = &s } case "choices": if items, ok := parseListLiteral(valNode, source); ok { info.Choices = items } else if items, ok := resolveChoicesExpr(valNode, source, scope); ok { info.Choices = items } else { return inputCallInfo{}, schema.WrapError(schema.ErrChoicesNotResolvable, fmt.Sprintf("parameter '%s': choices expression cannot be statically resolved", paramName), nil) } case "deprecated": if b, ok := parseBoolLiteral(valNode, source); ok { info.Deprecated = &b } } } return info, nil } // --------------------------------------------------------------------------- // Literal parsing // --------------------------------------------------------------------------- func parseDefaultValue(node *sitter.Node, source []byte) (schema.DefaultValue, bool) { switch node.Type() { case "none": return schema.DefaultValue{Kind: schema.DefaultNone}, true case "true": return schema.DefaultValue{Kind: schema.DefaultBool, Bool: true}, true case "false": return schema.DefaultValue{Kind: schema.DefaultBool, Bool: false}, true case "integer": text := content(node, source) n, err := strconv.ParseInt(text, 0, 64) if err != nil { return schema.DefaultValue{}, false } return schema.DefaultValue{Kind: schema.DefaultInt, Int: n}, true case "float": text := content(node, source) f, err := strconv.ParseFloat(text, 64) if err != nil { return schema.DefaultValue{}, false } return schema.DefaultValue{Kind: schema.DefaultFloat, Float: f}, true case "string", "concatenated_string": s, ok := parseStringLiteral(node, source) if !ok { return schema.DefaultValue{}, false } return schema.DefaultValue{Kind: schema.DefaultString, Str: s}, true case "list": items, ok := parseListLiteral(node, source) if !ok { return schema.DefaultValue{}, false } return schema.DefaultValue{Kind: schema.DefaultList, List: items}, true case "dictionary": keys, vals, ok := parseDictLiteral(node, source) if !ok { return schema.DefaultValue{}, false } return schema.DefaultValue{Kind: schema.DefaultDict, DictKeys: keys, DictVals: vals}, true case "set": items, ok := parseSetLiteral(node, source) if !ok { return schema.DefaultValue{}, false } return schema.DefaultValue{Kind: schema.DefaultSet, List: items}, true case "unary_operator": text := strings.TrimSpace(content(node, source)) if n, err := strconv.ParseInt(text, 0, 64); err == nil { return schema.DefaultValue{Kind: schema.DefaultInt, Int: n}, true } if f, err := strconv.ParseFloat(text, 64); err == nil { return schema.DefaultValue{Kind: schema.DefaultFloat, Float: f}, true } return schema.DefaultValue{}, false case "tuple": var items []schema.DefaultValue for _, child := range namedChildren(node) { if val, ok := parseDefaultValue(child, source); ok { items = append(items, val) } } return schema.DefaultValue{Kind: schema.DefaultList, List: items}, true } return schema.DefaultValue{}, false } func parseStringLiteral(node *sitter.Node, source []byte) (string, bool) { text := content(node, source) if strings.HasPrefix(text, `"""`) || strings.HasPrefix(text, `'''`) { if len(text) >= 6 { return text[3 : len(text)-3], true } return "", false } if strings.HasPrefix(text, `"`) || strings.HasPrefix(text, `'`) { if len(text) >= 2 { return text[1 : len(text)-1], true } return "", false } if strings.HasPrefix(text, `r"`) || strings.HasPrefix(text, `r'`) { if len(text) >= 3 { return text[2 : len(text)-1], true } return "", false } return "", false } func parseNumberLiteral(node *sitter.Node, source []byte) (float64, bool) { text := strings.TrimSpace(content(node, source)) f, err := strconv.ParseFloat(text, 64) if err != nil { return 0, false } return f, true } func parseBoolLiteral(node *sitter.Node, source []byte) (bool, bool) { switch node.Type() { case "true": return true, true case "false": return false, true } text := content(node, source) switch text { case "True": return true, true case "False": return false, true } return false, false } func parseListLiteral(node *sitter.Node, source []byte) ([]schema.DefaultValue, bool) { if node.Type() != "list" { return nil, false } var items []schema.DefaultValue for _, child := range namedChildren(node) { val, ok := parseDefaultValue(child, source) if !ok { return nil, false } items = append(items, val) } return items, true } func parseDictLiteral(node *sitter.Node, source []byte) ([]schema.DefaultValue, []schema.DefaultValue, bool) { if node.Type() != "dictionary" { return nil, nil, false } var keys, vals []schema.DefaultValue for _, child := range namedChildren(node) { if child.Type() == "pair" { keyNode := child.ChildByFieldName("key") valNode := child.ChildByFieldName("value") if keyNode == nil || valNode == nil { continue } k, ok1 := parseDefaultValue(keyNode, source) v, ok2 := parseDefaultValue(valNode, source) if ok1 && ok2 { keys = append(keys, k) vals = append(vals, v) } } } return keys, vals, true } func parseSetLiteral(node *sitter.Node, source []byte) ([]schema.DefaultValue, bool) { if node.Type() != "set" { return nil, false } var items []schema.DefaultValue for _, child := range namedChildren(node) { val, ok := parseDefaultValue(child, source) if !ok { return nil, false } items = append(items, val) } return items, true } ================================================ FILE: pkg/schema/python/parser_fuzz_test.go ================================================ package python import ( "testing" schema "github.com/replicate/cog/pkg/schema" ) // FuzzParsePredictor feeds arbitrary bytes as Python source to the parser. // The parser should never panic regardless of input — it may return errors. func FuzzParsePredictor(f *testing.F) { // Seed corpus: valid and invalid Python snippets. f.Add([]byte(` from cog import BasePredictor class Predictor(BasePredictor): def predict(self, x: str) -> str: return x `), "Predictor", uint8(0)) f.Add([]byte(` from cog import BasePredictor from pydantic import BaseModel class Output(BaseModel): text: str score: float = 0.0 class Predictor(BasePredictor): def predict(self, x: str) -> Output: pass `), "Predictor", uint8(0)) f.Add([]byte(` from cog import BasePredictor class Predictor(BasePredictor): def predict(self, x: str) -> dict[str, list[int]]: return {} `), "Predictor", uint8(0)) f.Add([]byte(` from cog import BasePredictor, ConcatenateIterator class Predictor(BasePredictor): def predict(self, x: str) -> ConcatenateIterator[str]: yield x `), "Predictor", uint8(0)) // Training mode. f.Add([]byte(` from cog import BasePredictor class Predictor(BasePredictor): def train(self, x: str) -> str: return x `), "Predictor", uint8(1)) // No predictor class at all. f.Add([]byte(`print("hello")`), "Predictor", uint8(0)) // Empty source. f.Add([]byte{}, "Predictor", uint8(0)) // Garbage bytes. f.Add([]byte{0xff, 0xfe, 0x00, 0x01, 0x80, 0x90}, "Predictor", uint8(0)) f.Fuzz(func(t *testing.T, source []byte, predictRef string, modeRaw uint8) { mode := schema.ModePredict if modeRaw%2 == 1 { mode = schema.ModeTrain } // Must not panic regardless of input. _, _ = ParsePredictor(source, predictRef, mode, "") }) } // FuzzParseTypeAnnotation exercises the type annotation parser with // arbitrary annotation strings embedded in a predict signature. func FuzzParseTypeAnnotation(f *testing.F) { types := []string{ "str", "int", "float", "bool", "Path", "dict", "dict[str, int]", "dict[str, list[str]]", "list[str]", "list[dict[str, float]]", "Optional[str]", "Optional[dict[str, int]]", "Iterator[str]", "ConcatenateIterator[str]", "dict[str, dict[str, dict[str, int]]]", "Any", "None", "list", } for _, typ := range types { f.Add(typ) } f.Fuzz(func(t *testing.T, typeName string) { // Build a minimal predict.py with the fuzzed return type. source := []byte("from cog import BasePredictor\nfrom typing import *\nclass Predictor(BasePredictor):\n def predict(self, x: str) -> " + typeName + ":\n pass\n") // Must not panic. _, _ = ParsePredictor(source, "Predictor", schema.ModePredict, "") }) } ================================================ FILE: pkg/schema/python/parser_test.go ================================================ package python import ( "errors" "os" "path/filepath" "testing" "github.com/stretchr/testify/require" "github.com/replicate/cog/pkg/schema" ) // helper that parses in predict mode and fails on error. func parse(t *testing.T, source, predictRef string) *schema.PredictorInfo { t.Helper() info, err := ParsePredictor([]byte(source), predictRef, schema.ModePredict, "") require.NoError(t, err) return info } // helper to parse and expect an error. func parseErr(t *testing.T, source, predictRef string, mode schema.Mode) *schema.SchemaError { t.Helper() _, err := ParsePredictor([]byte(source), predictRef, mode, "") require.Error(t, err) var se *schema.SchemaError require.True(t, errors.As(err, &se), "expected *schema.SchemaError, got %T: %v", err, err) return se } // --------------------------------------------------------------------------- // Basic predictor tests // --------------------------------------------------------------------------- func TestSimpleStringPredictor(t *testing.T) { source := ` from cog import BasePredictor class Predictor(BasePredictor): def predict(self, s: str) -> str: return "hello " + s ` info := parse(t, source, "Predictor") require.Equal(t, 1, info.Inputs.Len()) s, ok := info.Inputs.Get("s") require.True(t, ok) require.Equal(t, schema.TypeString, s.FieldType.Primitive) require.Equal(t, schema.Required, s.FieldType.Repetition) require.Nil(t, s.Default) require.True(t, s.IsRequired()) require.Equal(t, schema.SchemaPrimitive, info.Output.Kind) require.Equal(t, schema.TypeString, info.Output.Primitive) } func TestMultipleInputsWithDefaults(t *testing.T) { source := ` from cog import BasePredictor, Input, Path class Predictor(BasePredictor): def predict( self, image: Path = Input(description="Grayscale input image"), scale: float = Input(description="Factor to scale image by", ge=0, le=10, default=1.5), ) -> Path: pass ` info := parse(t, source, "Predictor") require.Equal(t, 2, info.Inputs.Len()) image, ok := info.Inputs.Get("image") require.True(t, ok) require.Equal(t, schema.TypePath, image.FieldType.Primitive) require.Nil(t, image.Default) require.NotNil(t, image.Description) require.Equal(t, "Grayscale input image", *image.Description) require.True(t, image.IsRequired()) scale, ok := info.Inputs.Get("scale") require.True(t, ok) require.Equal(t, schema.TypeFloat, scale.FieldType.Primitive) require.NotNil(t, scale.Default) require.Equal(t, schema.DefaultFloat, scale.Default.Kind) require.Equal(t, 1.5, scale.Default.Float) require.NotNil(t, scale.GE) require.Equal(t, 0.0, *scale.GE) require.NotNil(t, scale.LE) require.Equal(t, 10.0, *scale.LE) require.False(t, scale.IsRequired()) } // --------------------------------------------------------------------------- // Optional / union inputs // --------------------------------------------------------------------------- func TestOptionalInputPipeNone(t *testing.T) { source := ` from cog import BasePredictor, Input, Path class Predictor(BasePredictor): def predict( self, test_image: Path | None = Input(description="Test image", default=None), ) -> Path: pass ` info := parse(t, source, "Predictor") img, ok := info.Inputs.Get("test_image") require.True(t, ok) require.Equal(t, schema.Optional, img.FieldType.Repetition) require.Equal(t, schema.TypePath, img.FieldType.Primitive) require.NotNil(t, img.Default) require.Equal(t, schema.DefaultNone, img.Default.Kind) } func TestOptionalInputTyping(t *testing.T) { source := ` from typing import Optional from cog import BasePredictor, Input class Predictor(BasePredictor): def predict( self, name: Optional[str] = Input(default=None), ) -> str: pass ` info := parse(t, source, "Predictor") name, ok := info.Inputs.Get("name") require.True(t, ok) require.Equal(t, schema.Optional, name.FieldType.Repetition) require.Equal(t, schema.TypeString, name.FieldType.Primitive) } // --------------------------------------------------------------------------- // List inputs // --------------------------------------------------------------------------- func TestListInput(t *testing.T) { source := ` from cog import BasePredictor, Input class Predictor(BasePredictor): def predict(self, paths: list[str] = Input(description="Paths")) -> str: pass ` info := parse(t, source, "Predictor") paths, ok := info.Inputs.Get("paths") require.True(t, ok) require.Equal(t, schema.Repeated, paths.FieldType.Repetition) require.Equal(t, schema.TypeString, paths.FieldType.Primitive) } // --------------------------------------------------------------------------- // Choices // --------------------------------------------------------------------------- func TestChoicesLiteralList(t *testing.T) { source := ` from cog import BasePredictor, Input class Predictor(BasePredictor): def predict(self, color: str = Input(choices=["red", "green", "blue"])) -> str: pass ` info := parse(t, source, "Predictor") color, ok := info.Inputs.Get("color") require.True(t, ok) require.Len(t, color.Choices, 3) require.Equal(t, "red", color.Choices[0].Str) require.Equal(t, "green", color.Choices[1].Str) require.Equal(t, "blue", color.Choices[2].Str) } func TestChoicesModuleLevelListVar(t *testing.T) { source := ` from cog import BasePredictor, Input MY_CHOICES = ["x", "y", "z"] class Predictor(BasePredictor): def predict(self, v: str = Input(choices=MY_CHOICES)) -> str: pass ` info := parse(t, source, "Predictor") v, ok := info.Inputs.Get("v") require.True(t, ok) require.Len(t, v.Choices, 3) require.Equal(t, "x", v.Choices[0].Str) require.Equal(t, "y", v.Choices[1].Str) require.Equal(t, "z", v.Choices[2].Str) } func TestChoicesListDictKeys(t *testing.T) { source := ` from cog import BasePredictor, Input ASPECT_RATIOS = { "1:1": (1024, 1024), "16:9": (1344, 768), "2:3": (832, 1216), } class Predictor(BasePredictor): def predict(self, ar: str = Input(choices=list(ASPECT_RATIOS.keys()), default="1:1")) -> str: pass ` info := parse(t, source, "Predictor") ar, ok := info.Inputs.Get("ar") require.True(t, ok) require.Len(t, ar.Choices, 3) require.Equal(t, "1:1", ar.Choices[0].Str) require.Equal(t, "16:9", ar.Choices[1].Str) require.Equal(t, "2:3", ar.Choices[2].Str) } func TestChoicesListDictValues(t *testing.T) { source := ` from cog import BasePredictor, Input LABELS = {"fast": "Fast Mode", "slow": "Slow Mode"} class Predictor(BasePredictor): def predict(self, m: str = Input(choices=list(LABELS.values()))) -> str: pass ` info := parse(t, source, "Predictor") m, ok := info.Inputs.Get("m") require.True(t, ok) require.Len(t, m.Choices, 2) require.Equal(t, "Fast Mode", m.Choices[0].Str) require.Equal(t, "Slow Mode", m.Choices[1].Str) } func TestChoicesDictKeysPlusLiteral(t *testing.T) { source := ` from cog import BasePredictor, Input SIZES = {"small": 256, "large": 1024} class Predictor(BasePredictor): def predict(self, s: str = Input(choices=list(SIZES.keys()) + ["custom"])) -> str: pass ` info := parse(t, source, "Predictor") s, ok := info.Inputs.Get("s") require.True(t, ok) require.Len(t, s.Choices, 3) require.Equal(t, "small", s.Choices[0].Str) require.Equal(t, "large", s.Choices[1].Str) require.Equal(t, "custom", s.Choices[2].Str) } func TestChoicesIntegerDictKeys(t *testing.T) { source := ` from cog import BasePredictor, Input STEP_LABELS = {1: "one", 2: "two", 4: "four"} class Predictor(BasePredictor): def predict(self, steps: int = Input(choices=list(STEP_LABELS.keys()), default=1)) -> str: pass ` info := parse(t, source, "Predictor") steps, ok := info.Inputs.Get("steps") require.True(t, ok) require.Len(t, steps.Choices, 3) require.Equal(t, schema.DefaultInt, steps.Choices[0].Kind) require.Equal(t, int64(1), steps.Choices[0].Int) require.Equal(t, int64(2), steps.Choices[1].Int) require.Equal(t, int64(4), steps.Choices[2].Int) } func TestChoicesConcatTwoVars(t *testing.T) { source := ` from cog import BasePredictor, Input BASE = ["a", "b"] EXTRA = ["c"] class Predictor(BasePredictor): def predict(self, x: str = Input(choices=BASE + EXTRA)) -> str: pass ` info := parse(t, source, "Predictor") x, ok := info.Inputs.Get("x") require.True(t, ok) require.Len(t, x.Choices, 3) require.Equal(t, "a", x.Choices[0].Str) require.Equal(t, "b", x.Choices[1].Str) require.Equal(t, "c", x.Choices[2].Str) } // --------------------------------------------------------------------------- // Choices error cases // --------------------------------------------------------------------------- func TestChoicesVarNotAListErrors(t *testing.T) { source := ` from cog import BasePredictor, Input NOT_A_LIST = "oops" class Predictor(BasePredictor): def predict(self, x: str = Input(choices=NOT_A_LIST)) -> str: pass ` se := parseErr(t, source, "Predictor", schema.ModePredict) require.Equal(t, schema.ErrChoicesNotResolvable, se.Kind) } func TestChoicesUndefinedVarErrors(t *testing.T) { source := ` from cog import BasePredictor, Input class Predictor(BasePredictor): def predict(self, x: str = Input(choices=DOES_NOT_EXIST)) -> str: pass ` se := parseErr(t, source, "Predictor", schema.ModePredict) require.Equal(t, schema.ErrChoicesNotResolvable, se.Kind) } func TestChoicesArbitraryCallErrors(t *testing.T) { source := ` from cog import BasePredictor, Input class Predictor(BasePredictor): def predict(self, x: str = Input(choices=get_choices())) -> str: pass ` se := parseErr(t, source, "Predictor", schema.ModePredict) require.Equal(t, schema.ErrChoicesNotResolvable, se.Kind) } func TestChoicesListComprehensionErrors(t *testing.T) { source := ` from cog import BasePredictor, Input class Predictor(BasePredictor): def predict(self, x: str = Input(choices=[f"{i}x" for i in range(5)])) -> str: pass ` se := parseErr(t, source, "Predictor", schema.ModePredict) require.Equal(t, schema.ErrChoicesNotResolvable, se.Kind) } func TestChoicesErrorIncludesParamName(t *testing.T) { source := ` from cog import BasePredictor, Input class Predictor(BasePredictor): def predict(self, my_param: str = Input(choices=some_func())) -> str: pass ` se := parseErr(t, source, "Predictor", schema.ModePredict) require.Contains(t, se.Message, "my_param") } func TestChoicesNestedVarNotInScope(t *testing.T) { source := ` from cog import BasePredictor, Input def helper(): NESTED = ["a", "b"] class Predictor(BasePredictor): def predict(self, x: str = Input(choices=NESTED)) -> str: pass ` se := parseErr(t, source, "Predictor", schema.ModePredict) require.Equal(t, schema.ErrChoicesNotResolvable, se.Kind) } // --------------------------------------------------------------------------- // Standalone function // --------------------------------------------------------------------------- func TestStandaloneFunction(t *testing.T) { source := ` from cog import Input def predict(text: str = Input(default="world")) -> str: return f"hello {text}" ` info := parse(t, source, "predict") require.Equal(t, 1, info.Inputs.Len()) text, ok := info.Inputs.Get("text") require.True(t, ok) require.NotNil(t, text.Default) require.Equal(t, schema.DefaultString, text.Default.Kind) require.Equal(t, "world", text.Default.Str) } // --------------------------------------------------------------------------- // Output types // --------------------------------------------------------------------------- func TestIteratorOutput(t *testing.T) { source := ` from typing import Iterator from cog import BasePredictor class Predictor(BasePredictor): def predict(self, count: int) -> Iterator[str]: for i in range(count): yield f"chunk {i}" ` info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaIterator, info.Output.Kind) require.NotNil(t, info.Output.Elem) require.Equal(t, schema.SchemaPrimitive, info.Output.Elem.Kind) require.Equal(t, schema.TypeString, info.Output.Elem.Primitive) } func TestConcatenateIteratorOutput(t *testing.T) { source := ` from cog import BasePredictor, ConcatenateIterator class Predictor(BasePredictor): def predict(self, prompt: str) -> ConcatenateIterator[str]: yield "hello " yield "world" ` info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaConcatIterator, info.Output.Kind) require.NotNil(t, info.Output.Elem) require.Equal(t, schema.TypeString, info.Output.Elem.Primitive) } func TestConcatenateIteratorNotStrErrors(t *testing.T) { source := ` from cog import BasePredictor, ConcatenateIterator class Predictor(BasePredictor): def predict(self, n: int) -> ConcatenateIterator[int]: pass ` se := parseErr(t, source, "Predictor", schema.ModePredict) require.Equal(t, schema.ErrConcatIteratorNotStr, se.Kind) } func TestListOutput(t *testing.T) { source := ` from cog import BasePredictor, Path class Predictor(BasePredictor): def predict(self, n: int) -> list[Path]: pass ` info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaArray, info.Output.Kind) require.NotNil(t, info.Output.Items) require.Equal(t, schema.SchemaPrimitive, info.Output.Items.Kind) require.Equal(t, schema.TypePath, info.Output.Items.Primitive) } func TestBaseModelOutput(t *testing.T) { source := ` from cog import BasePredictor, BaseModel class ModelOutput(BaseModel): text: str score: float class Predictor(BasePredictor): def predict(self, prompt: str) -> ModelOutput: pass ` info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaObject, info.Output.Kind) require.NotNil(t, info.Output.Fields) require.Equal(t, 2, info.Output.Fields.Len()) text, ok := info.Output.Fields.Get("text") require.True(t, ok) require.Equal(t, schema.SchemaPrimitive, text.Type.Kind) require.Equal(t, schema.TypeString, text.Type.Primitive) score, ok := info.Output.Fields.Get("score") require.True(t, ok) require.Equal(t, schema.SchemaPrimitive, score.Type.Kind) require.Equal(t, schema.TypeFloat, score.Type.Primitive) } func TestOptionalOutputErrors(t *testing.T) { source := ` from typing import Optional from cog import BasePredictor class Predictor(BasePredictor): def predict(self, s: str) -> Optional[str]: pass ` se := parseErr(t, source, "Predictor", schema.ModePredict) require.Equal(t, schema.ErrOptionalOutput, se.Kind) } func TestOptionalOutputPipeNoneErrors(t *testing.T) { source := ` from cog import BasePredictor class Predictor(BasePredictor): def predict(self, s: str) -> str | None: pass ` se := parseErr(t, source, "Predictor", schema.ModePredict) require.Equal(t, schema.ErrOptionalOutput, se.Kind) } // --------------------------------------------------------------------------- // Train mode // --------------------------------------------------------------------------- func TestTrainMode(t *testing.T) { source := ` from cog import Input, Path def train(n: int) -> Path: pass ` info, err := ParsePredictor([]byte(source), "train", schema.ModeTrain, "") require.NoError(t, err) require.Equal(t, schema.ModeTrain, info.Mode) require.Equal(t, 1, info.Inputs.Len()) } // --------------------------------------------------------------------------- // Non-BasePredictor class (just has predict method) // --------------------------------------------------------------------------- func TestNonBasePredictor(t *testing.T) { source := ` from cog import Input class Predictor: def predict(self, text: str = Input(default="hello")) -> str: return f"hello {text}" ` info := parse(t, source, "Predictor") require.Equal(t, 1, info.Inputs.Len()) text, ok := info.Inputs.Get("text") require.True(t, ok) require.NotNil(t, text.Default) require.Equal(t, "hello", text.Default.Str) } // --------------------------------------------------------------------------- // default_factory hard error // --------------------------------------------------------------------------- func TestDefaultFactoryError(t *testing.T) { source := ` from cog import BasePredictor, Input class Predictor(BasePredictor): def predict(self, items: list[str] = Input(default_factory=list)) -> str: pass ` se := parseErr(t, source, "Predictor", schema.ModePredict) require.Equal(t, schema.ErrDefaultFactoryNotSupported, se.Kind) } // --------------------------------------------------------------------------- // Module-scope default resolution // --------------------------------------------------------------------------- func TestDefaultModuleLevelStringInInput(t *testing.T) { source := ` from cog import BasePredictor, Input DEFAULT_RATIO = "1:1" class Predictor(BasePredictor): def predict(self, ar: str = Input(default=DEFAULT_RATIO)) -> str: pass ` info := parse(t, source, "Predictor") ar, ok := info.Inputs.Get("ar") require.True(t, ok) require.NotNil(t, ar.Default) require.Equal(t, schema.DefaultString, ar.Default.Kind) require.Equal(t, "1:1", ar.Default.Str) } func TestDefaultModuleLevelIntInInput(t *testing.T) { source := ` from cog import BasePredictor, Input DEFAULT_STEPS = 50 class Predictor(BasePredictor): def predict(self, steps: int = Input(default=DEFAULT_STEPS)) -> str: pass ` info := parse(t, source, "Predictor") steps, ok := info.Inputs.Get("steps") require.True(t, ok) require.NotNil(t, steps.Default) require.Equal(t, schema.DefaultInt, steps.Default.Kind) require.Equal(t, int64(50), steps.Default.Int) } func TestDefaultModuleLevelListInInput(t *testing.T) { source := ` from cog import BasePredictor, Input DEFAULT_TAGS = ["a", "b"] class Predictor(BasePredictor): def predict(self, tags: list[str] = Input(default=DEFAULT_TAGS)) -> str: pass ` info := parse(t, source, "Predictor") tags, ok := info.Inputs.Get("tags") require.True(t, ok) require.NotNil(t, tags.Default) require.Equal(t, schema.DefaultList, tags.Default.Kind) require.Len(t, tags.Default.List, 2) require.Equal(t, "a", tags.Default.List[0].Str) require.Equal(t, "b", tags.Default.List[1].Str) } func TestDefaultModuleLevelVarPlain(t *testing.T) { source := ` from cog import BasePredictor MY_DEFAULT = "hello" class Predictor(BasePredictor): def predict(self, text: str = MY_DEFAULT) -> str: pass ` info := parse(t, source, "Predictor") text, ok := info.Inputs.Get("text") require.True(t, ok) require.NotNil(t, text.Default) require.Equal(t, schema.DefaultString, text.Default.Kind) require.Equal(t, "hello", text.Default.Str) } func TestDefaultUndefinedVarPlainErrors(t *testing.T) { source := ` from cog import BasePredictor class Predictor(BasePredictor): def predict(self, text: str = UNDEFINED_VAR) -> str: pass ` se := parseErr(t, source, "Predictor", schema.ModePredict) require.Contains(t, se.Message, "cannot be statically resolved") } // --------------------------------------------------------------------------- // InputRegistry — class attribute reference // --------------------------------------------------------------------------- func TestInputRegistryAttribute(t *testing.T) { source := ` from dataclasses import dataclass from cog import BasePredictor, Input RATIOS = {"1:1": (1024, 1024), "16:9": (1344, 768)} @dataclass(frozen=True) class Inputs: ar = Input(description="Aspect ratio", choices=list(RATIOS.keys()), default="1:1") class Predictor(BasePredictor): def predict(self, ar: str = Inputs.ar) -> str: pass ` info := parse(t, source, "Predictor") ar, ok := info.Inputs.Get("ar") require.True(t, ok) require.NotNil(t, ar.Description) require.Equal(t, "Aspect ratio", *ar.Description) require.Len(t, ar.Choices, 2) require.Equal(t, "1:1", ar.Choices[0].Str) require.Equal(t, "16:9", ar.Choices[1].Str) require.NotNil(t, ar.Default) require.Equal(t, "1:1", ar.Default.Str) } // --------------------------------------------------------------------------- // InputRegistry — static method reference // --------------------------------------------------------------------------- func TestInputRegistryMethod(t *testing.T) { source := ` from dataclasses import dataclass from cog import BasePredictor, Input @dataclass(frozen=True) class Inputs: @staticmethod def guidance(default: float) -> Input: return Input(description="Guidance scale", ge=0.0, le=20.0, default=default) class Predictor(BasePredictor): def predict(self, guidance_scale: float = Inputs.guidance(7.5)) -> str: pass ` info := parse(t, source, "Predictor") gs, ok := info.Inputs.Get("guidance_scale") require.True(t, ok) require.NotNil(t, gs.Description) require.Equal(t, "Guidance scale", *gs.Description) require.NotNil(t, gs.GE) require.Equal(t, 0.0, *gs.GE) require.NotNil(t, gs.LE) require.Equal(t, 20.0, *gs.LE) require.NotNil(t, gs.Default) require.Equal(t, schema.DefaultFloat, gs.Default.Kind) require.Equal(t, 7.5, gs.Default.Float) } // --------------------------------------------------------------------------- // Error cases: missing annotations, predictor not found, etc. // --------------------------------------------------------------------------- func TestPredictorNotFound(t *testing.T) { source := ` from cog import BasePredictor class Other(BasePredictor): def predict(self, s: str) -> str: pass ` se := parseErr(t, source, "Predictor", schema.ModePredict) require.Equal(t, schema.ErrPredictorNotFound, se.Kind) } func TestMethodNotFound(t *testing.T) { source := ` from cog import BasePredictor class Predictor(BasePredictor): def setup(self): pass ` se := parseErr(t, source, "Predictor", schema.ModePredict) require.Equal(t, schema.ErrMethodNotFound, se.Kind) } func TestMissingReturnType(t *testing.T) { source := ` from cog import BasePredictor class Predictor(BasePredictor): def predict(self, s: str): pass ` se := parseErr(t, source, "Predictor", schema.ModePredict) require.Equal(t, schema.ErrMissingReturnType, se.Kind) } func TestMissingTypeAnnotation(t *testing.T) { source := ` from cog import BasePredictor class Predictor(BasePredictor): def predict(self, s="hello") -> str: pass ` se := parseErr(t, source, "Predictor", schema.ModePredict) require.Equal(t, schema.ErrMissingTypeAnnotation, se.Kind) } // --------------------------------------------------------------------------- // All input types // --------------------------------------------------------------------------- func TestAllPrimitiveInputTypes(t *testing.T) { tests := []struct { name string pyType string expected schema.PrimitiveType }{ {"str", "str", schema.TypeString}, {"int", "int", schema.TypeInteger}, {"float", "float", schema.TypeFloat}, {"bool", "bool", schema.TypeBool}, {"Path", "Path", schema.TypePath}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { source := ` from cog import BasePredictor, Path class Predictor(BasePredictor): def predict(self, x: ` + tt.pyType + `) -> str: pass ` info := parse(t, source, "Predictor") x, ok := info.Inputs.Get("x") require.True(t, ok) require.Equal(t, tt.expected, x.FieldType.Primitive) require.Equal(t, schema.Required, x.FieldType.Repetition) }) } } // --------------------------------------------------------------------------- // Input() with constraints // --------------------------------------------------------------------------- func TestInputConstraints(t *testing.T) { source := ` from cog import BasePredictor, Input class Predictor(BasePredictor): def predict( self, text: str = Input( description="Input text", min_length=1, max_length=100, regex="^[a-z]+$", ), ) -> str: pass ` info := parse(t, source, "Predictor") text, ok := info.Inputs.Get("text") require.True(t, ok) require.NotNil(t, text.Description) require.Equal(t, "Input text", *text.Description) require.NotNil(t, text.MinLength) require.Equal(t, uint64(1), *text.MinLength) require.NotNil(t, text.MaxLength) require.Equal(t, uint64(100), *text.MaxLength) require.NotNil(t, text.Regex) require.Equal(t, "^[a-z]+$", *text.Regex) } // --------------------------------------------------------------------------- // Negative numbers and booleans as defaults // --------------------------------------------------------------------------- func TestNegativeNumberDefault(t *testing.T) { source := ` from cog import BasePredictor, Input class Predictor(BasePredictor): def predict(self, temp: float = Input(default=-1.5)) -> str: pass ` info := parse(t, source, "Predictor") temp, ok := info.Inputs.Get("temp") require.True(t, ok) require.NotNil(t, temp.Default) require.Equal(t, schema.DefaultFloat, temp.Default.Kind) require.Equal(t, -1.5, temp.Default.Float) } func TestBoolDefault(t *testing.T) { source := ` from cog import BasePredictor, Input class Predictor(BasePredictor): def predict(self, flag: bool = Input(default=True)) -> str: pass ` info := parse(t, source, "Predictor") flag, ok := info.Inputs.Get("flag") require.True(t, ok) require.NotNil(t, flag.Default) require.Equal(t, schema.DefaultBool, flag.Default.Kind) require.True(t, flag.Default.Bool) } // --------------------------------------------------------------------------- // Parameter ordering // --------------------------------------------------------------------------- func TestParameterOrdering(t *testing.T) { source := ` from cog import BasePredictor, Input class Predictor(BasePredictor): def predict(self, alpha: str, beta: int, gamma: float = Input(default=1.0)) -> str: pass ` info := parse(t, source, "Predictor") require.Equal(t, 3, info.Inputs.Len()) // Check insertion order keys := info.Inputs.Keys() require.Equal(t, "alpha", keys[0]) require.Equal(t, "beta", keys[1]) require.Equal(t, "gamma", keys[2]) alpha, _ := info.Inputs.Get("alpha") require.Equal(t, 0, alpha.Order) beta, _ := info.Inputs.Get("beta") require.Equal(t, 1, beta.Order) gamma, _ := info.Inputs.Get("gamma") require.Equal(t, 2, gamma.Order) } // --------------------------------------------------------------------------- // Deprecated flag // --------------------------------------------------------------------------- func TestDeprecatedInput(t *testing.T) { source := ` from cog import BasePredictor, Input class Predictor(BasePredictor): def predict(self, old_param: str = Input(deprecated=True, default="x")) -> str: pass ` info := parse(t, source, "Predictor") old, ok := info.Inputs.Get("old_param") require.True(t, ok) require.NotNil(t, old.Deprecated) require.True(t, *old.Deprecated) } // --------------------------------------------------------------------------- // File type (deprecated alias for Path) // --------------------------------------------------------------------------- func TestFileType(t *testing.T) { source := ` from cog import BasePredictor, File class Predictor(BasePredictor): def predict(self, f: File) -> str: pass ` info := parse(t, source, "Predictor") f, ok := info.Inputs.Get("f") require.True(t, ok) require.Equal(t, schema.TypeFile, f.FieldType.Primitive) } // --------------------------------------------------------------------------- // Secret type // --------------------------------------------------------------------------- func TestSecretType(t *testing.T) { source := ` from cog import BasePredictor, Secret class Predictor(BasePredictor): def predict(self, token: Secret) -> str: pass ` info := parse(t, source, "Predictor") token, ok := info.Inputs.Get("token") require.True(t, ok) require.Equal(t, schema.TypeSecret, token.FieldType.Primitive) } // --------------------------------------------------------------------------- // Multiple classes — finds the right one // --------------------------------------------------------------------------- func TestMultipleClassesFindsTarget(t *testing.T) { source := ` from cog import BasePredictor, BaseModel class Output(BaseModel): text: str class Helper: pass class Predictor(BasePredictor): def predict(self, s: str) -> str: pass ` info := parse(t, source, "Predictor") require.Equal(t, 1, info.Inputs.Len()) require.Equal(t, schema.SchemaPrimitive, info.Output.Kind) } // --------------------------------------------------------------------------- // BaseModel with defaults // --------------------------------------------------------------------------- func TestBaseModelOutputWithDefaults(t *testing.T) { source := ` from cog import BasePredictor, BaseModel class Result(BaseModel): text: str confidence: float = 0.0 class Predictor(BasePredictor): def predict(self, s: str) -> Result: pass ` info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaObject, info.Output.Kind) conf, ok := info.Output.Fields.Get("confidence") require.True(t, ok) require.NotNil(t, conf.Default) require.Equal(t, schema.DefaultFloat, conf.Default.Kind) require.Equal(t, 0.0, conf.Default.Float) } // --------------------------------------------------------------------------- // Pydantic BaseModel output // --------------------------------------------------------------------------- func TestPydanticBaseModelOutput(t *testing.T) { source := ` from pydantic import BaseModel as PydanticBaseModel from cog import BasePredictor class Result(PydanticBaseModel): name: str score: float tags: list[str] class Predictor(BasePredictor): def predict(self, name: str) -> Result: pass ` info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaObject, info.Output.Kind) require.NotNil(t, info.Output.Fields) require.Equal(t, 3, info.Output.Fields.Len()) name, ok := info.Output.Fields.Get("name") require.True(t, ok) require.Equal(t, schema.SchemaPrimitive, name.Type.Kind) require.Equal(t, schema.TypeString, name.Type.Primitive) score, ok := info.Output.Fields.Get("score") require.True(t, ok) require.Equal(t, schema.SchemaPrimitive, score.Type.Kind) require.Equal(t, schema.TypeFloat, score.Type.Primitive) } func TestPydanticBaseModelDottedOutput(t *testing.T) { source := ` import pydantic from cog import BasePredictor class Result(pydantic.BaseModel): text: str class Predictor(BasePredictor): def predict(self, s: str) -> Result: pass ` info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaObject, info.Output.Kind) text, ok := info.Output.Fields.Get("text") require.True(t, ok) require.Equal(t, schema.SchemaPrimitive, text.Type.Kind) require.Equal(t, schema.TypeString, text.Type.Primitive) } func TestPydanticBaseModelDirectImport(t *testing.T) { source := ` from pydantic import BaseModel from cog import BasePredictor class Output(BaseModel): value: int class Predictor(BasePredictor): def predict(self, x: int) -> Output: pass ` info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaObject, info.Output.Kind) val, ok := info.Output.Fields.Get("value") require.True(t, ok) require.Equal(t, schema.SchemaPrimitive, val.Type.Kind) require.Equal(t, schema.TypeInteger, val.Type.Primitive) } // --------------------------------------------------------------------------- // Unparameterized dict/list output (opaque JSON) // --------------------------------------------------------------------------- func TestDictOutput(t *testing.T) { source := ` from cog import BasePredictor, Input, Path class Predictor(BasePredictor): def predict(self, image: Path = Input(description="Image")) -> dict: return {"class": "hotdog", "score": 0.95} ` info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaAny, info.Output.Kind) } func TestParameterizedDictOutput(t *testing.T) { source := ` from cog import BasePredictor, Input class Predictor(BasePredictor): def predict(self, text: str = Input(description="Text")) -> dict[str, dict[str, str]]: return {"inputs": {"text": text}} ` info := parse(t, source, "Predictor") // dict[str, dict[str, str]] → SchemaDict with nested SchemaDict value type require.Equal(t, schema.SchemaDict, info.Output.Kind) require.NotNil(t, info.Output.ValueType) require.Equal(t, schema.SchemaDict, info.Output.ValueType.Kind) require.NotNil(t, info.Output.ValueType.ValueType) require.Equal(t, schema.SchemaPrimitive, info.Output.ValueType.ValueType.Kind) require.Equal(t, schema.TypeString, info.Output.ValueType.ValueType.Primitive) } func TestBareListOutput(t *testing.T) { source := ` from cog import BasePredictor class Predictor(BasePredictor): def predict(self, s: str) -> list: return [1, 2, 3] ` info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaArray, info.Output.Kind) require.NotNil(t, info.Output.Items) require.Equal(t, schema.SchemaAny, info.Output.Items.Kind) } // --------------------------------------------------------------------------- // No-input predictor (only self) // --------------------------------------------------------------------------- func TestNoInputPredictor(t *testing.T) { source := ` from cog import BasePredictor class Predictor(BasePredictor): def predict(self) -> str: return "hello" ` info := parse(t, source, "Predictor") require.Equal(t, 0, info.Inputs.Len()) require.Equal(t, schema.SchemaPrimitive, info.Output.Kind) } // --------------------------------------------------------------------------- // Falsy defaults (False, 0, 0.0, "") // --------------------------------------------------------------------------- func TestDefaultFalse(t *testing.T) { source := ` from cog import BasePredictor class Predictor(BasePredictor): def predict(self, flag: bool = False) -> str: pass ` info := parse(t, source, "Predictor") f, ok := info.Inputs.Get("flag") require.True(t, ok) require.NotNil(t, f.Default) require.Equal(t, schema.DefaultBool, f.Default.Kind) require.Equal(t, false, f.Default.Bool) require.False(t, f.IsRequired()) } func TestDefaultZeroInt(t *testing.T) { source := ` from cog import BasePredictor class Predictor(BasePredictor): def predict(self, count: int = 0) -> str: pass ` info := parse(t, source, "Predictor") c, ok := info.Inputs.Get("count") require.True(t, ok) require.NotNil(t, c.Default) require.Equal(t, schema.DefaultInt, c.Default.Kind) require.Equal(t, int64(0), c.Default.Int) require.False(t, c.IsRequired()) } func TestDefaultZeroFloat(t *testing.T) { source := ` from cog import BasePredictor class Predictor(BasePredictor): def predict(self, weight: float = 0.0) -> str: pass ` info := parse(t, source, "Predictor") w, ok := info.Inputs.Get("weight") require.True(t, ok) require.NotNil(t, w.Default) require.Equal(t, schema.DefaultFloat, w.Default.Kind) require.Equal(t, 0.0, w.Default.Float) require.False(t, w.IsRequired()) } func TestDefaultEmptyString(t *testing.T) { source := ` from cog import BasePredictor class Predictor(BasePredictor): def predict(self, text: str = "") -> str: pass ` info := parse(t, source, "Predictor") text, ok := info.Inputs.Get("text") require.True(t, ok) require.NotNil(t, text.Default) require.Equal(t, schema.DefaultString, text.Default.Kind) require.Equal(t, "", text.Default.Str) require.False(t, text.IsRequired()) } func TestDefaultNegativeInt(t *testing.T) { source := ` from cog import BasePredictor class Predictor(BasePredictor): def predict(self, offset: int = -1) -> str: pass ` info := parse(t, source, "Predictor") o, ok := info.Inputs.Get("offset") require.True(t, ok) require.NotNil(t, o.Default) require.Equal(t, schema.DefaultInt, o.Default.Kind) require.Equal(t, int64(-1), o.Default.Int) } // --------------------------------------------------------------------------- // Async iterators // --------------------------------------------------------------------------- func TestAsyncIteratorOutput(t *testing.T) { source := ` from typing import AsyncIterator from cog import BasePredictor class Predictor(BasePredictor): async def predict(self, s: str) -> AsyncIterator[str]: yield s ` info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaIterator, info.Output.Kind) require.NotNil(t, info.Output.Elem) require.Equal(t, schema.SchemaPrimitive, info.Output.Elem.Kind) require.Equal(t, schema.TypeString, info.Output.Elem.Primitive) } func TestAsyncConcatenateIteratorOutput(t *testing.T) { source := ` from cog import BasePredictor, ConcatenateIterator class Predictor(BasePredictor): async def predict(self, s: str) -> ConcatenateIterator[str]: yield s ` // Note: AsyncConcatenateIterator is also valid via typing import, // but ConcatenateIterator in async context works the same way info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaConcatIterator, info.Output.Kind) require.NotNil(t, info.Output.Elem) require.Equal(t, schema.TypeString, info.Output.Elem.Primitive) } // --------------------------------------------------------------------------- // typing.List and typing.Union syntax // --------------------------------------------------------------------------- func TestTypingListCapitalL(t *testing.T) { source := ` from typing import List from cog import BasePredictor class Predictor(BasePredictor): def predict(self, items: List[str]) -> str: pass ` info := parse(t, source, "Predictor") items, ok := info.Inputs.Get("items") require.True(t, ok) require.Equal(t, schema.TypeString, items.FieldType.Primitive) require.Equal(t, schema.Repeated, items.FieldType.Repetition) } func TestTypingUnionStrNone(t *testing.T) { source := ` from typing import Union from cog import BasePredictor class Predictor(BasePredictor): def predict(self, text: Union[str, None] = None) -> str: pass ` info := parse(t, source, "Predictor") text, ok := info.Inputs.Get("text") require.True(t, ok) require.Equal(t, schema.TypeString, text.FieldType.Primitive) require.Equal(t, schema.Optional, text.FieldType.Repetition) require.False(t, text.IsRequired()) } // --------------------------------------------------------------------------- // All-optional inputs (no required array) // --------------------------------------------------------------------------- func TestAllOptionalInputs(t *testing.T) { source := ` from cog import BasePredictor, Input class Predictor(BasePredictor): def predict(self, a: str = "x", b: int = Input(default=5)) -> str: pass ` info := parse(t, source, "Predictor") require.Equal(t, 2, info.Inputs.Len()) a, ok := info.Inputs.Get("a") require.True(t, ok) require.False(t, a.IsRequired()) b, ok := info.Inputs.Get("b") require.True(t, ok) require.False(t, b.IsRequired()) } // --------------------------------------------------------------------------- // list[Path] as input // --------------------------------------------------------------------------- func TestListPathInput(t *testing.T) { source := ` from cog import BasePredictor, Path class Predictor(BasePredictor): def predict(self, files: list[Path]) -> str: pass ` info := parse(t, source, "Predictor") files, ok := info.Inputs.Get("files") require.True(t, ok) require.Equal(t, schema.TypePath, files.FieldType.Primitive) require.Equal(t, schema.Repeated, files.FieldType.Repetition) } // --------------------------------------------------------------------------- // Recursive / nested output types // --------------------------------------------------------------------------- func TestDictStrStrOutput(t *testing.T) { source := ` from cog import BasePredictor class Predictor(BasePredictor): def predict(self, x: str) -> dict[str, str]: return {"key": "value"} ` info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaDict, info.Output.Kind) require.NotNil(t, info.Output.ValueType) require.Equal(t, schema.SchemaPrimitive, info.Output.ValueType.Kind) require.Equal(t, schema.TypeString, info.Output.ValueType.Primitive) } func TestDictStrIntOutput(t *testing.T) { source := ` from cog import BasePredictor class Predictor(BasePredictor): def predict(self, x: str) -> dict[str, int]: return {"count": 42} ` info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaDict, info.Output.Kind) require.NotNil(t, info.Output.ValueType) require.Equal(t, schema.SchemaPrimitive, info.Output.ValueType.Kind) require.Equal(t, schema.TypeInteger, info.Output.ValueType.Primitive) } func TestNestedDictOutput(t *testing.T) { // dict[str, dict[str, str]] source := ` from cog import BasePredictor class Predictor(BasePredictor): def predict(self, x: str) -> dict[str, dict[str, str]]: return {"outer": {"inner": "value"}} ` info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaDict, info.Output.Kind) require.NotNil(t, info.Output.ValueType) require.Equal(t, schema.SchemaDict, info.Output.ValueType.Kind) require.NotNil(t, info.Output.ValueType.ValueType) require.Equal(t, schema.SchemaPrimitive, info.Output.ValueType.ValueType.Kind) require.Equal(t, schema.TypeString, info.Output.ValueType.ValueType.Primitive) } func TestDictOfListOutput(t *testing.T) { // dict[str, list[int]] source := ` from cog import BasePredictor class Predictor(BasePredictor): def predict(self, x: str) -> dict[str, list[int]]: return {"numbers": [1, 2, 3]} ` info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaDict, info.Output.Kind) require.NotNil(t, info.Output.ValueType) require.Equal(t, schema.SchemaArray, info.Output.ValueType.Kind) require.NotNil(t, info.Output.ValueType.Items) require.Equal(t, schema.SchemaPrimitive, info.Output.ValueType.Items.Kind) require.Equal(t, schema.TypeInteger, info.Output.ValueType.Items.Primitive) } func TestListOfDictOutput(t *testing.T) { // list[dict[str, str]] source := ` from cog import BasePredictor class Predictor(BasePredictor): def predict(self, x: str) -> list[dict[str, str]]: return [{"key": "value"}] ` info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaArray, info.Output.Kind) require.NotNil(t, info.Output.Items) require.Equal(t, schema.SchemaDict, info.Output.Items.Kind) require.NotNil(t, info.Output.Items.ValueType) require.Equal(t, schema.SchemaPrimitive, info.Output.Items.ValueType.Kind) require.Equal(t, schema.TypeString, info.Output.Items.ValueType.Primitive) } func TestListOfListOutput(t *testing.T) { // list[list[float]] source := ` from cog import BasePredictor class Predictor(BasePredictor): def predict(self, x: str) -> list[list[float]]: return [[1.0, 2.0], [3.0]] ` info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaArray, info.Output.Kind) require.NotNil(t, info.Output.Items) require.Equal(t, schema.SchemaArray, info.Output.Items.Kind) require.NotNil(t, info.Output.Items.Items) require.Equal(t, schema.SchemaPrimitive, info.Output.Items.Items.Kind) require.Equal(t, schema.TypeFloat, info.Output.Items.Items.Primitive) } func TestTripleNestedDictOutput(t *testing.T) { // dict[str, dict[str, dict[str, int]]] source := ` from cog import BasePredictor class Predictor(BasePredictor): def predict(self, x: str) -> dict[str, dict[str, dict[str, int]]]: return {"a": {"b": {"c": 1}}} ` info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaDict, info.Output.Kind) level2 := info.Output.ValueType require.NotNil(t, level2) require.Equal(t, schema.SchemaDict, level2.Kind) level3 := level2.ValueType require.NotNil(t, level3) require.Equal(t, schema.SchemaDict, level3.Kind) leaf := level3.ValueType require.NotNil(t, leaf) require.Equal(t, schema.SchemaPrimitive, leaf.Kind) require.Equal(t, schema.TypeInteger, leaf.Primitive) } func TestListOfDictOfListOutput(t *testing.T) { // list[dict[str, list[str]]] source := ` from cog import BasePredictor class Predictor(BasePredictor): def predict(self, x: str) -> list[dict[str, list[str]]]: return [{"tags": ["a", "b"]}] ` info := parse(t, source, "Predictor") // list[...] require.Equal(t, schema.SchemaArray, info.Output.Kind) // dict[str, ...] dictType := info.Output.Items require.NotNil(t, dictType) require.Equal(t, schema.SchemaDict, dictType.Kind) // list[str] innerList := dictType.ValueType require.NotNil(t, innerList) require.Equal(t, schema.SchemaArray, innerList.Kind) // str require.NotNil(t, innerList.Items) require.Equal(t, schema.SchemaPrimitive, innerList.Items.Kind) require.Equal(t, schema.TypeString, innerList.Items.Primitive) } func TestIteratorOfDictOutput(t *testing.T) { // Iterator[dict[str, str]] — iterator yielding dicts source := ` from typing import Iterator from cog import BasePredictor class Predictor(BasePredictor): def predict(self, x: str) -> Iterator[dict[str, str]]: yield {"key": "value"} ` info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaIterator, info.Output.Kind) require.NotNil(t, info.Output.Elem) require.Equal(t, schema.SchemaDict, info.Output.Elem.Kind) require.NotNil(t, info.Output.Elem.ValueType) require.Equal(t, schema.SchemaPrimitive, info.Output.Elem.ValueType.Kind) require.Equal(t, schema.TypeString, info.Output.Elem.ValueType.Primitive) } func TestIteratorOfListOutput(t *testing.T) { // Iterator[list[int]] — iterator yielding lists source := ` from typing import Iterator from cog import BasePredictor class Predictor(BasePredictor): def predict(self, x: str) -> Iterator[list[int]]: yield [1, 2, 3] ` info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaIterator, info.Output.Kind) require.NotNil(t, info.Output.Elem) require.Equal(t, schema.SchemaArray, info.Output.Elem.Kind) require.NotNil(t, info.Output.Elem.Items) require.Equal(t, schema.SchemaPrimitive, info.Output.Elem.Items.Kind) require.Equal(t, schema.TypeInteger, info.Output.Elem.Items.Primitive) } func TestDictOfPathOutput(t *testing.T) { // dict[str, Path] — dict with file URIs as values source := ` from cog import BasePredictor, Path class Predictor(BasePredictor): def predict(self, x: str) -> dict[str, Path]: return {"file": Path("output.png")} ` info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaDict, info.Output.Kind) require.NotNil(t, info.Output.ValueType) require.Equal(t, schema.SchemaPrimitive, info.Output.ValueType.Kind) require.Equal(t, schema.TypePath, info.Output.ValueType.Primitive) } func TestListOfPathOutput(t *testing.T) { // list[Path] source := ` from cog import BasePredictor, Path class Predictor(BasePredictor): def predict(self, x: str) -> list[Path]: return [Path("a.png"), Path("b.png")] ` info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaArray, info.Output.Kind) require.NotNil(t, info.Output.Items) require.Equal(t, schema.SchemaPrimitive, info.Output.Items.Kind) require.Equal(t, schema.TypePath, info.Output.Items.Primitive) } // --------------------------------------------------------------------------- // Unresolvable output type errors // --------------------------------------------------------------------------- func TestUnresolvableImportedTypeError(t *testing.T) { source := ` from some_random_package import WeirdType from cog import BasePredictor class Predictor(BasePredictor): def predict(self, x: str) -> WeirdType: return None ` se := parseErr(t, source, "Predictor", schema.ModePredict) require.Equal(t, schema.ErrUnresolvableType, se.Kind) require.Contains(t, se.Message, "WeirdType") require.Contains(t, se.Message, "some_random_package") require.Contains(t, se.Message, ".pyi stub") } func TestUnresolvableUndefinedTypeError(t *testing.T) { source := ` from cog import BasePredictor class Predictor(BasePredictor): def predict(self, x: str) -> MysteryType: return None ` se := parseErr(t, source, "Predictor", schema.ModePredict) require.Equal(t, schema.ErrUnresolvableType, se.Kind) require.Contains(t, se.Message, "MysteryType") require.Contains(t, se.Message, "not a primitive type") require.Contains(t, se.Message, "BaseModel") } func TestUnresolvableDottedImportTypeError(t *testing.T) { source := ` from transformers import AutoTokenizer from cog import BasePredictor class Predictor(BasePredictor): def predict(self, x: str) -> AutoTokenizer: return None ` se := parseErr(t, source, "Predictor", schema.ModePredict) require.Equal(t, schema.ErrUnresolvableType, se.Kind) require.Contains(t, se.Message, "AutoTokenizer") require.Contains(t, se.Message, "transformers") } func TestUnresolvableTypeTorchTensor(t *testing.T) { source := ` from torch import Tensor from cog import BasePredictor class Predictor(BasePredictor): def predict(self, x: str) -> Tensor: return None ` se := parseErr(t, source, "Predictor", schema.ModePredict) require.Equal(t, schema.ErrUnresolvableType, se.Kind) require.Contains(t, se.Message, "Tensor") require.Contains(t, se.Message, "torch") } func TestUnresolvableTypeNumpyArray(t *testing.T) { source := ` from numpy import ndarray from cog import BasePredictor class Predictor(BasePredictor): def predict(self, x: str) -> ndarray: return None ` se := parseErr(t, source, "Predictor", schema.ModePredict) require.Equal(t, schema.ErrUnresolvableType, se.Kind) require.Contains(t, se.Message, "ndarray") require.Contains(t, se.Message, "numpy") } func TestDictWithUnresolvableValueTypeErrors(t *testing.T) { // Regression: dict[str, Tensor] used to silently collapse to SchemaAny. // Now it propagates the error from the value type resolution. source := ` from torch import Tensor from cog import BasePredictor class Predictor(BasePredictor): def predict(self, x: str) -> dict[str, Tensor]: return {} ` se := parseErr(t, source, "Predictor", schema.ModePredict) require.Equal(t, schema.ErrUnresolvableType, se.Kind) require.Contains(t, se.Message, "Tensor") } func TestModelFieldDictWithUnresolvableValueTypeErrors(t *testing.T) { // Same bug but inside a BaseModel field. source := ` from torch import Tensor from pydantic import BaseModel from cog import BasePredictor class Result(BaseModel): tensors: dict[str, Tensor] class Predictor(BasePredictor): def predict(self, x: str) -> Result: pass ` se := parseErr(t, source, "Predictor", schema.ModePredict) require.Equal(t, schema.ErrUnresolvableType, se.Kind) require.Contains(t, se.Message, "Tensor") } // --------------------------------------------------------------------------- // Pydantic output still works after migration // --------------------------------------------------------------------------- func TestPydanticV1CompatOutput(t *testing.T) { source := ` from pydantic.v1 import BaseModel from cog import BasePredictor class Result(BaseModel): text: str score: float class Predictor(BasePredictor): def predict(self, x: str) -> Result: pass ` info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaObject, info.Output.Kind) require.NotNil(t, info.Output.Fields) require.Equal(t, 2, info.Output.Fields.Len()) text, ok := info.Output.Fields.Get("text") require.True(t, ok) require.Equal(t, schema.SchemaPrimitive, text.Type.Kind) require.Equal(t, schema.TypeString, text.Type.Primitive) score, ok := info.Output.Fields.Get("score") require.True(t, ok) require.Equal(t, schema.SchemaPrimitive, score.Type.Kind) require.Equal(t, schema.TypeFloat, score.Type.Primitive) } func TestPydanticOutputWithOptionalField(t *testing.T) { source := ` from pydantic import BaseModel from typing import Optional from cog import BasePredictor class Result(BaseModel): text: str error: Optional[str] = None class Predictor(BasePredictor): def predict(self, x: str) -> Result: pass ` info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaObject, info.Output.Kind) require.Equal(t, 2, info.Output.Fields.Len()) text, ok := info.Output.Fields.Get("text") require.True(t, ok) require.True(t, text.Required) errField, ok := info.Output.Fields.Get("error") require.True(t, ok) require.True(t, errField.Type.Nullable) } func TestPydanticOutputDefaultedFieldNotNullable(t *testing.T) { // Regression: a field with a default but NOT Optional must NOT be nullable. // Previously !Required was incorrectly mapped to nullable in JSON Schema. source := ` from pydantic import BaseModel from cog import BasePredictor class Result(BaseModel): text: str debug: bool = False count: int = 0 class Predictor(BasePredictor): def predict(self, x: str) -> Result: pass ` info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaObject, info.Output.Kind) // text: required, not nullable text, ok := info.Output.Fields.Get("text") require.True(t, ok) require.True(t, text.Required) require.False(t, text.Type.Nullable) // debug: has default so not required, but NOT nullable (not Optional) debug, ok := info.Output.Fields.Get("debug") require.True(t, ok) require.False(t, debug.Required, "defaulted field should not be required") require.False(t, debug.Type.Nullable, "non-Optional defaulted field must not be nullable") // count: same — defaulted, not nullable count, ok := info.Output.Fields.Get("count") require.True(t, ok) require.False(t, count.Required) require.False(t, count.Type.Nullable) // Verify JSON Schema output doesn't include "nullable" for these fields js := info.Output.JSONSchema() props, ok := js["properties"].(map[string]any) require.True(t, ok) debugProp, ok := props["debug"].(map[string]any) require.True(t, ok) _, hasNullable := debugProp["nullable"] require.False(t, hasNullable, "JSON Schema for defaulted non-Optional field must not have nullable") } func TestPydanticOutputOptionalFieldNullable(t *testing.T) { source := ` from pydantic import BaseModel from typing import Optional from cog import BasePredictor class Result(BaseModel): text: str error: Optional[str] = None class Predictor(BasePredictor): def predict(self, x: str) -> Result: pass ` info := parse(t, source, "Predictor") errField, ok := info.Output.Fields.Get("error") require.True(t, ok) require.True(t, errField.Type.Nullable, "Optional field should be nullable") require.False(t, errField.Required, "Optional field with default should not be required") // Verify JSON Schema output includes "nullable" for Optional field js := info.Output.JSONSchema() props, ok := js["properties"].(map[string]any) require.True(t, ok) errProp, ok := props["error"].(map[string]any) require.True(t, ok) require.Equal(t, true, errProp["nullable"]) } func TestPydanticOutputWithListField(t *testing.T) { source := ` from pydantic import BaseModel from cog import BasePredictor class Result(BaseModel): tags: list[str] scores: list[float] class Predictor(BasePredictor): def predict(self, x: str) -> Result: pass ` info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaObject, info.Output.Kind) require.Equal(t, 2, info.Output.Fields.Len()) tags, ok := info.Output.Fields.Get("tags") require.True(t, ok) require.Equal(t, schema.SchemaArray, tags.Type.Kind) require.NotNil(t, tags.Type.Items) require.Equal(t, schema.TypeString, tags.Type.Items.Primitive) } func TestPydanticOutputWithDictField(t *testing.T) { source := ` from pydantic import BaseModel from cog import BasePredictor class Result(BaseModel): metadata: dict[str, int] nested: dict[str, list[str]] class Predictor(BasePredictor): def predict(self, x: str) -> Result: pass ` info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaObject, info.Output.Kind) require.Equal(t, 2, info.Output.Fields.Len()) // metadata: dict[str, int] metadata, ok := info.Output.Fields.Get("metadata") require.True(t, ok) require.Equal(t, schema.SchemaDict, metadata.Type.Kind) require.NotNil(t, metadata.Type.ValueType) require.Equal(t, schema.SchemaPrimitive, metadata.Type.ValueType.Kind) require.Equal(t, schema.TypeInteger, metadata.Type.ValueType.Primitive) // nested: dict[str, list[str]] nested, ok := info.Output.Fields.Get("nested") require.True(t, ok) require.Equal(t, schema.SchemaDict, nested.Type.Kind) require.NotNil(t, nested.Type.ValueType) require.Equal(t, schema.SchemaArray, nested.Type.ValueType.Kind) require.NotNil(t, nested.Type.ValueType.Items) require.Equal(t, schema.TypeString, nested.Type.ValueType.Items.Primitive) } func TestPydanticOutputWithOptionalDictField(t *testing.T) { source := ` from typing import Optional from pydantic import BaseModel from cog import BasePredictor class Result(BaseModel): data: Optional[dict[str, float]] class Predictor(BasePredictor): def predict(self, x: str) -> Result: pass ` info := parse(t, source, "Predictor") require.Equal(t, schema.SchemaObject, info.Output.Kind) data, ok := info.Output.Fields.Get("data") require.True(t, ok) require.Equal(t, schema.SchemaDict, data.Type.Kind) require.True(t, data.Type.Nullable) require.False(t, data.Required) require.NotNil(t, data.Type.ValueType) require.Equal(t, schema.TypeFloat, data.Type.ValueType.Primitive) } // --------------------------------------------------------------------------- // Cross-file model resolution // --------------------------------------------------------------------------- // writeFile is a test helper that creates a file in dir with the given content. func writeFile(t *testing.T, dir, name, content string) { t.Helper() full := filepath.Join(dir, name) require.NoError(t, os.MkdirAll(filepath.Dir(full), 0o755)) require.NoError(t, os.WriteFile(full, []byte(content), 0o644)) } // parseFile is a test helper that parses a file from disk with sourceDir context. func parseFile(t *testing.T, dir, filename, predictRef string) *schema.PredictorInfo { t.Helper() source, err := os.ReadFile(filepath.Join(dir, filename)) require.NoError(t, err) info, err := ParsePredictor(source, predictRef, schema.ModePredict, dir) require.NoError(t, err) return info } func TestCrossFileBaseModelSameDir(t *testing.T) { // from types import Output — Output defined in types.py in same dir dir := t.TempDir() writeFile(t, dir, "types.py", ` from pydantic import BaseModel class Output(BaseModel): text: str score: float `) writeFile(t, dir, "predict.py", ` from cog import BasePredictor from types import Output class Predictor(BasePredictor): def predict(self, x: str) -> Output: pass `) info := parseFile(t, dir, "predict.py", "Predictor") require.Equal(t, schema.SchemaObject, info.Output.Kind) require.Equal(t, 2, info.Output.Fields.Len()) text, ok := info.Output.Fields.Get("text") require.True(t, ok) require.Equal(t, schema.TypeString, text.Type.Primitive) score, ok := info.Output.Fields.Get("score") require.True(t, ok) require.Equal(t, schema.TypeFloat, score.Type.Primitive) } func TestCrossFileRelativeImport(t *testing.T) { // from .types import Output — relative dot import dir := t.TempDir() writeFile(t, dir, "types.py", ` from cog import BaseModel class Output(BaseModel): label: str confidence: float `) writeFile(t, dir, "predict.py", ` from cog import BasePredictor from .types import Output class Predictor(BasePredictor): def predict(self, x: str) -> Output: pass `) info := parseFile(t, dir, "predict.py", "Predictor") require.Equal(t, schema.SchemaObject, info.Output.Kind) require.Equal(t, 2, info.Output.Fields.Len()) label, ok := info.Output.Fields.Get("label") require.True(t, ok) require.Equal(t, schema.TypeString, label.Type.Primitive) } func TestCrossFileSubpackageImport(t *testing.T) { // from models.output import Result — nested package dir := t.TempDir() writeFile(t, dir, "models/output.py", ` from pydantic import BaseModel class Result(BaseModel): answer: str tokens: int `) writeFile(t, dir, "predict.py", ` from cog import BasePredictor from models.output import Result class Predictor(BasePredictor): def predict(self, x: str) -> Result: pass `) info := parseFile(t, dir, "predict.py", "Predictor") require.Equal(t, schema.SchemaObject, info.Output.Kind) require.Equal(t, 2, info.Output.Fields.Len()) answer, ok := info.Output.Fields.Get("answer") require.True(t, ok) require.Equal(t, schema.TypeString, answer.Type.Primitive) tokens, ok := info.Output.Fields.Get("tokens") require.True(t, ok) require.Equal(t, schema.TypeInteger, tokens.Type.Primitive) } func TestCrossFileRelativeSubpackage(t *testing.T) { // from .models.output import Result — relative + nested dir := t.TempDir() writeFile(t, dir, "models/output.py", ` from pydantic import BaseModel class Result(BaseModel): name: str `) writeFile(t, dir, "predict.py", ` from cog import BasePredictor from .models.output import Result class Predictor(BasePredictor): def predict(self, x: str) -> Result: pass `) info := parseFile(t, dir, "predict.py", "Predictor") require.Equal(t, schema.SchemaObject, info.Output.Kind) require.Equal(t, 1, info.Output.Fields.Len()) name, ok := info.Output.Fields.Get("name") require.True(t, ok) require.Equal(t, schema.TypeString, name.Type.Primitive) } func TestCrossFileMultipleModelsFromSameFile(t *testing.T) { // Two BaseModel classes in the same external file dir := t.TempDir() writeFile(t, dir, "schema_types.py", ` from pydantic import BaseModel class Metadata(BaseModel): version: str author: str class Prediction(BaseModel): result: str score: float `) writeFile(t, dir, "predict.py", ` from cog import BasePredictor from schema_types import Prediction class Predictor(BasePredictor): def predict(self, x: str) -> Prediction: pass `) info := parseFile(t, dir, "predict.py", "Predictor") require.Equal(t, schema.SchemaObject, info.Output.Kind) require.Equal(t, 2, info.Output.Fields.Len()) result, ok := info.Output.Fields.Get("result") require.True(t, ok) require.Equal(t, schema.TypeString, result.Type.Primitive) score, ok := info.Output.Fields.Get("score") require.True(t, ok) require.Equal(t, schema.TypeFloat, score.Type.Primitive) } func TestCrossFileWithOptionalField(t *testing.T) { dir := t.TempDir() writeFile(t, dir, "output.py", ` from typing import Optional from pydantic import BaseModel class Output(BaseModel): text: str error: Optional[str] = None debug: bool = False `) writeFile(t, dir, "predict.py", ` from cog import BasePredictor from output import Output class Predictor(BasePredictor): def predict(self, x: str) -> Output: pass `) info := parseFile(t, dir, "predict.py", "Predictor") require.Equal(t, schema.SchemaObject, info.Output.Kind) require.Equal(t, 3, info.Output.Fields.Len()) text, ok := info.Output.Fields.Get("text") require.True(t, ok) require.True(t, text.Required) errField, ok := info.Output.Fields.Get("error") require.True(t, ok) require.True(t, errField.Type.Nullable) debug, ok := info.Output.Fields.Get("debug") require.True(t, ok) require.NotNil(t, debug.Default) require.Equal(t, schema.DefaultBool, debug.Default.Kind) require.Equal(t, false, debug.Default.Bool) } func TestCrossFileAliasedImport(t *testing.T) { // from output_types import MyOutput as Output dir := t.TempDir() writeFile(t, dir, "output_types.py", ` from pydantic import BaseModel class MyOutput(BaseModel): value: int `) writeFile(t, dir, "predict.py", ` from cog import BasePredictor from output_types import MyOutput as Output class Predictor(BasePredictor): def predict(self, x: str) -> Output: pass `) info := parseFile(t, dir, "predict.py", "Predictor") require.Equal(t, schema.SchemaObject, info.Output.Kind) require.Equal(t, 1, info.Output.Fields.Len()) val, ok := info.Output.Fields.Get("value") require.True(t, ok) require.Equal(t, schema.TypeInteger, val.Type.Primitive) } func TestCrossFileExternalPackageStillErrors(t *testing.T) { // Importing from a package that doesn't exist locally should still error dir := t.TempDir() writeFile(t, dir, "predict.py", ` from transformers import AutoModelForSequenceClassification from cog import BasePredictor class Predictor(BasePredictor): def predict(self, x: str) -> AutoModelForSequenceClassification: pass `) source, err := os.ReadFile(filepath.Join(dir, "predict.py")) require.NoError(t, err) _, err = ParsePredictor(source, "Predictor", schema.ModePredict, dir) require.Error(t, err) var se *schema.SchemaError require.True(t, errors.As(err, &se)) require.Equal(t, schema.ErrUnresolvableType, se.Kind) require.Contains(t, se.Message, "transformers") } func TestCrossFileLocalPrecedesExternal(t *testing.T) { // A local file shadows an external package name. // E.g. user has a local "utils.py" and does "from utils import Output" dir := t.TempDir() writeFile(t, dir, "utils.py", ` from cog import BaseModel class Output(BaseModel): msg: str `) writeFile(t, dir, "predict.py", ` from cog import BasePredictor from utils import Output class Predictor(BasePredictor): def predict(self, x: str) -> Output: pass `) info := parseFile(t, dir, "predict.py", "Predictor") require.Equal(t, schema.SchemaObject, info.Output.Kind) require.Equal(t, 1, info.Output.Fields.Len()) msg, ok := info.Output.Fields.Get("msg") require.True(t, ok) require.Equal(t, schema.TypeString, msg.Type.Primitive) } func TestCrossFileListFieldInExternalModel(t *testing.T) { dir := t.TempDir() writeFile(t, dir, "types.py", ` from pydantic import BaseModel class Output(BaseModel): tags: list[str] scores: list[float] `) writeFile(t, dir, "predict.py", ` from cog import BasePredictor from types import Output class Predictor(BasePredictor): def predict(self, x: str) -> Output: pass `) info := parseFile(t, dir, "predict.py", "Predictor") require.Equal(t, schema.SchemaObject, info.Output.Kind) require.Equal(t, 2, info.Output.Fields.Len()) tags, ok := info.Output.Fields.Get("tags") require.True(t, ok) require.Equal(t, schema.SchemaArray, tags.Type.Kind) require.Equal(t, schema.TypeString, tags.Type.Items.Primitive) } func TestCrossFileEndToEndSchemaGeneration(t *testing.T) { // Full end-to-end: Generate() reads predict.py from disk, // resolves Output from types.py, and produces valid OpenAPI JSON. dir := t.TempDir() writeFile(t, dir, "types.py", ` from pydantic import BaseModel class Output(BaseModel): text: str score: float `) writeFile(t, dir, "predict.py", ` from cog import BasePredictor from types import Output class Predictor(BasePredictor): def predict(self, prompt: str) -> Output: pass `) data, err := schema.Generate("predict.py:Predictor", dir, schema.ModePredict, ParsePredictor) require.NoError(t, err) require.Contains(t, string(data), `"openapi"`) require.Contains(t, string(data), `"Output"`) require.Contains(t, string(data), `"text"`) require.Contains(t, string(data), `"score"`) require.Contains(t, string(data), `"object"`) } ================================================ FILE: pkg/schema/schema_type.go ================================================ package schema import "fmt" // SchemaType is a recursive algebraic data type representing any type that // can appear in a Cog predictor's output (or, in the future, input) position. // // It replaces the flat OutputType/PrimitiveType system with a composable // type tree that can represent dict[str, list[int]], nested BaseModel // subclasses, TypedDicts, and types resolved from .pyi stubs — all without // running Python. type SchemaType struct { Kind SchemaTypeKind // Primitive: for Kind=SchemaPrimitive — one of the base scalar types. Primitive PrimitiveType // Array: for Kind=SchemaArray — the element type. Items *SchemaType // Dict: for Kind=SchemaDict — key and value types. // KeyType is always string in JSON Schema, but we track it for completeness. KeyType *SchemaType ValueType *SchemaType // Object: for Kind=SchemaObject — named fields with types and defaults. Fields *OrderedMap[string, SchemaField] // Iterator/ConcatIterator: for Kind=SchemaIterator|SchemaConcatIterator. // The yielded element type. Elem *SchemaType // Nullable: wraps any type to allow null. Nullable bool } // SchemaTypeKind tags the active variant in SchemaType. type SchemaTypeKind int const ( // SchemaPrimitive is a scalar type: bool, int, float, str, Path, File, Secret. SchemaPrimitive SchemaTypeKind = iota // SchemaAny is an opaque JSON value (unparameterized dict, Any, etc). SchemaAny // SchemaArray is a homogeneous list/array. SchemaArray // SchemaDict is a string-keyed dictionary with a typed value. SchemaDict // SchemaObject is a product type with named fields (BaseModel, TypedDict, dataclass). SchemaObject // SchemaIterator is a cog Iterator[T] — array with x-cog-array-type=iterator. SchemaIterator // SchemaConcatIterator is a cog ConcatenateIterator[str] — streaming text. SchemaConcatIterator ) // SchemaField is a named field within a SchemaObject. type SchemaField struct { Type SchemaType Default *DefaultValue Required bool } // JSONSchema converts a SchemaType to its JSON Schema representation. // This is used for the "Output" component in the OpenAPI spec. func (s SchemaType) JSONSchema() map[string]any { return s.jsonSchema(true) } func (s SchemaType) jsonSchema(topLevel bool) map[string]any { result := s.coreSchema() if topLevel { result["title"] = "Output" } if s.Nullable { result["nullable"] = true } return result } func (s SchemaType) coreSchema() map[string]any { switch s.Kind { case SchemaPrimitive: return s.Primitive.JSONType() case SchemaAny: return map[string]any{"type": "object"} case SchemaArray: items := map[string]any{"type": "object"} if s.Items != nil { items = s.Items.jsonSchema(false) } result := map[string]any{ "type": "array", "items": items, } return result case SchemaDict: result := map[string]any{"type": "object"} if s.ValueType != nil { result["additionalProperties"] = s.ValueType.jsonSchema(false) } return result case SchemaObject: if s.Fields == nil { return map[string]any{"type": "object"} } properties := make(map[string]any) var required []string s.Fields.Entries(func(name string, field SchemaField) { prop := field.Type.jsonSchema(false) prop["title"] = TitleCase(name) if field.Type.Nullable { prop["nullable"] = true } if field.Required && field.Default == nil { required = append(required, name) } properties[name] = prop }) result := map[string]any{ "type": "object", "properties": properties, } if len(required) > 0 { result["required"] = required } return result case SchemaIterator: items := map[string]any{"type": "object"} if s.Elem != nil { items = s.Elem.jsonSchema(false) } return map[string]any{ "type": "array", "items": items, "x-cog-array-type": "iterator", } case SchemaConcatIterator: items := map[string]any{"type": "object"} if s.Elem != nil { items = s.Elem.jsonSchema(false) } return map[string]any{ "type": "array", "items": items, "x-cog-array-type": "iterator", "x-cog-array-display": "concatenate", } } return map[string]any{"type": "object"} } // --------------------------------------------------------------------------- // Constructors — convenience functions for building SchemaType values. // --------------------------------------------------------------------------- // SchemaPrim creates a primitive SchemaType. func SchemaPrim(p PrimitiveType) SchemaType { return SchemaType{Kind: SchemaPrimitive, Primitive: p} } // SchemaAnyType creates an opaque JSON object type. func SchemaAnyType() SchemaType { return SchemaType{Kind: SchemaAny} } // SchemaArrayOf creates an array type with the given element type. func SchemaArrayOf(elem SchemaType) SchemaType { return SchemaType{Kind: SchemaArray, Items: &elem} } // SchemaDictOf creates a dict type with string keys and the given value type. func SchemaDictOf(value SchemaType) SchemaType { k := SchemaPrim(TypeString) return SchemaType{Kind: SchemaDict, KeyType: &k, ValueType: &value} } // SchemaIteratorOf creates an iterator type with the given element type. func SchemaIteratorOf(elem SchemaType) SchemaType { return SchemaType{Kind: SchemaIterator, Elem: &elem} } // SchemaConcatIteratorOf creates a concatenate iterator type (always str). func SchemaConcatIteratorOf() SchemaType { elem := SchemaPrim(TypeString) return SchemaType{Kind: SchemaConcatIterator, Elem: &elem} } // SchemaObjectOf creates an object type from an ordered map of fields. func SchemaObjectOf(fields *OrderedMap[string, SchemaField]) SchemaType { return SchemaType{Kind: SchemaObject, Fields: fields} } // --------------------------------------------------------------------------- // ResolveSchemaType — recursive output type resolver (replaces ResolveOutputType). // --------------------------------------------------------------------------- // ResolveSchemaType resolves a Python type annotation into a SchemaType. // Unlike the legacy ResolveOutputType, this handles arbitrary nesting: // // dict[str, list[dict[str, int]]] → SchemaDict{ValueType: SchemaArray{Items: SchemaDict{...}}} // list[dict[str, str]] → SchemaArray{Items: SchemaDict{ValueType: SchemaPrim(TypeString)}} // // It also resolves BaseModel subclasses and cog iterators. func ResolveSchemaType(ann TypeAnnotation, ctx *ImportContext, models ModelClassMap) (SchemaType, error) { switch ann.Kind { case TypeAnnotSimple: return resolveSimpleSchemaType(ann, ctx, models) case TypeAnnotGeneric: return resolveGenericSchemaType(ann, ctx, models) case TypeAnnotUnion: return resolveUnionSchemaType(ann) } return SchemaType{}, errUnsupportedType("unknown type annotation") } func resolveSimpleSchemaType(ann TypeAnnotation, ctx *ImportContext, models ModelClassMap) (SchemaType, error) { // Check for BaseModel subclass if fields, ok := models.Get(ann.Name); ok { return resolveModelToSchemaType(fields, ctx, models) } // Unparameterized dict → opaque JSON object if ann.Name == "Any" || ann.Name == "dict" || ann.Name == "Dict" { return SchemaAnyType(), nil } // Unparameterized list → array of opaque objects if ann.Name == "list" || ann.Name == "List" { return SchemaArrayOf(SchemaAnyType()), nil } prim, ok := PrimitiveFromName(ann.Name) if !ok { // Check if this name was imported from an external package if entry, imported := ctx.Names.Get(ann.Name); imported { return SchemaType{}, errUnresolvableImportedType(ann.Name, entry.Module) } return SchemaType{}, errUnresolvableType(ann.Name) } return SchemaPrim(prim), nil } func resolveGenericSchemaType(ann TypeAnnotation, ctx *ImportContext, models ModelClassMap) (SchemaType, error) { outer := ann.Name // dict[K, V] — recursively resolve value type if outer == "dict" || outer == "Dict" { if len(ann.Args) == 2 { valType, err := ResolveSchemaType(ann.Args[1], ctx, models) if err != nil { return SchemaType{}, fmt.Errorf("resolving dict value type: %w", err) } return SchemaDictOf(valType), nil } // Bare dict (no type args) → opaque if len(ann.Args) == 0 { return SchemaAnyType(), nil } return SchemaType{}, errUnsupportedType("dict expects 0 or 2 type arguments") } // Optional[X] → rejected as output type (nullable outputs not supported) if outer == "Optional" { if len(ann.Args) != 1 { return SchemaType{}, errUnsupportedType("Optional expects exactly 1 type argument") } // Optional is not allowed as an output type return SchemaType{}, errOptionalOutput() } // Union[X, Y] → delegate if outer == "Union" { return resolveUnionSchemaType(TypeAnnotation{Kind: TypeAnnotUnion, Args: ann.Args}) } // list[X] / List[X] if outer == "List" || outer == "list" { if len(ann.Args) != 1 { return SchemaType{}, errUnsupportedType("list expects exactly 1 type argument") } elemType, err := ResolveSchemaType(ann.Args[0], ctx, models) if err != nil { return SchemaType{}, err } return SchemaArrayOf(elemType), nil } // Cog iterators — single type arg, recursively resolved (supports nested types) if outer == "Iterator" || outer == "AsyncIterator" { if len(ann.Args) != 1 { return SchemaType{}, errUnsupportedType("Iterator expects exactly 1 type argument") } elemType, err := ResolveSchemaType(ann.Args[0], ctx, models) if err != nil { return SchemaType{}, err } return SchemaIteratorOf(elemType), nil } if outer == "ConcatenateIterator" || outer == "AsyncConcatenateIterator" { if len(ann.Args) != 1 { return SchemaType{}, errUnsupportedType("ConcatenateIterator expects exactly 1 type argument") } inner := ann.Args[0] if inner.Kind != TypeAnnotSimple { return SchemaType{}, errUnsupportedType("ConcatenateIterator element type must be a simple type") } prim, ok := PrimitiveFromName(inner.Name) if !ok || prim != TypeString { return SchemaType{}, errConcatIteratorNotStr(inner.Name) } return SchemaConcatIteratorOf(), nil } return SchemaType{}, errUnsupportedType(fmt.Sprintf("%s[...] is not a supported output type", outer)) } func resolveUnionSchemaType(ann TypeAnnotation) (SchemaType, error) { if _, ok := UnwrapOptional(ann); ok { return SchemaType{}, errOptionalOutput() } return SchemaType{}, errUnsupportedType("union types are not supported as output") } // resolveModelToSchemaType converts a BaseModel's fields into a SchemaObject. // Fields are resolved via resolveFieldSchemaType which supports the full recursive // SchemaType system (dict[str, list[int]], nested BaseModels, etc.) plus Optional // wrapping (which is valid for fields but not for top-level output types). func resolveModelToSchemaType(modelFields []ModelField, ctx *ImportContext, models ModelClassMap) (SchemaType, error) { fields := NewOrderedMap[string, SchemaField]() for _, f := range modelFields { st, required, err := resolveFieldSchemaType(f.Type, ctx, models) if err != nil { return SchemaType{}, fmt.Errorf("field %q: %w", f.Name, err) } if f.Default != nil { required = false } fields.Set(f.Name, SchemaField{ Type: st, Default: f.Default, Required: required, }) } return SchemaObjectOf(fields), nil } // resolveFieldSchemaType resolves a type annotation for a model field. // Unlike ResolveSchemaType (which rejects Optional as a top-level output), // this allows Optional[X] and Union[X, None] for fields, setting Nullable. func resolveFieldSchemaType(ann TypeAnnotation, ctx *ImportContext, models ModelClassMap) (SchemaType, bool, error) { if inner, ok := UnwrapOptional(ann); ok { st, err := ResolveSchemaType(inner, ctx, models) if err != nil { return SchemaType{}, false, err } st.Nullable = true return st, false, nil } st, err := ResolveSchemaType(ann, ctx, models) if err != nil { return SchemaType{}, false, err } return st, true, nil } ================================================ FILE: pkg/schema/schema_type_fuzz_test.go ================================================ package schema import ( "testing" ) // FuzzResolveSchemaType builds arbitrary TypeAnnotation trees from fuzz input // and verifies that ResolveSchemaType never panics. func FuzzResolveSchemaType(f *testing.F) { // Seed corpus — known-good and known-tricky inputs. seeds := []TypeAnnotation{ {Kind: TypeAnnotSimple, Name: "str"}, {Kind: TypeAnnotSimple, Name: "int"}, {Kind: TypeAnnotSimple, Name: "dict"}, {Kind: TypeAnnotSimple, Name: "list"}, {Kind: TypeAnnotSimple, Name: "Any"}, {Kind: TypeAnnotSimple, Name: "UnknownType"}, {Kind: TypeAnnotSimple, Name: ""}, {Kind: TypeAnnotGeneric, Name: "dict", Args: []TypeAnnotation{ {Kind: TypeAnnotSimple, Name: "str"}, {Kind: TypeAnnotSimple, Name: "int"}, }}, {Kind: TypeAnnotGeneric, Name: "list", Args: []TypeAnnotation{ {Kind: TypeAnnotSimple, Name: "str"}, }}, {Kind: TypeAnnotGeneric, Name: "Optional", Args: []TypeAnnotation{ {Kind: TypeAnnotSimple, Name: "str"}, }}, {Kind: TypeAnnotGeneric, Name: "Iterator", Args: []TypeAnnotation{ {Kind: TypeAnnotGeneric, Name: "dict", Args: []TypeAnnotation{ {Kind: TypeAnnotSimple, Name: "str"}, {Kind: TypeAnnotGeneric, Name: "list", Args: []TypeAnnotation{ {Kind: TypeAnnotSimple, Name: "int"}, }}, }}, }}, {Kind: TypeAnnotUnion, Args: []TypeAnnotation{ {Kind: TypeAnnotSimple, Name: "str"}, {Kind: TypeAnnotSimple, Name: "None"}, }}, } // Add byte-encoded seeds. for _, s := range seeds { b := encodeAnnotation(s) f.Add(b) } ctx := NewImportContext() models := NewOrderedMap[string, []ModelField]() f.Fuzz(func(t *testing.T, data []byte) { ann, _ := decodeAnnotation(data, 0, 0) // Must not panic regardless of input. st, err := ResolveSchemaType(ann, ctx, models) if err == nil { // If resolution succeeded, JSONSchema must not panic. _ = st.JSONSchema() } }) } // FuzzJSONSchema constructs random SchemaType trees and ensures // JSONSchema() never panics. func FuzzJSONSchema(f *testing.F) { f.Add([]byte{0}) f.Add([]byte{1}) f.Add([]byte{2, 0, 3, 's', 't', 'r'}) f.Add([]byte{3, 2, 0, 3, 's', 't', 'r'}) f.Add([]byte{4, 1, 2, 0, 3, 'i', 'n', 't'}) f.Fuzz(func(t *testing.T, data []byte) { st, _ := decodeSchemaType(data, 0, 0) // Must not panic. _ = st.JSONSchema() _ = st.jsonSchema(false) }) } // --------------------------------------------------------------------------- // Annotation encoder/decoder — deterministic mapping from bytes to trees. // --------------------------------------------------------------------------- const maxFuzzDepth = 8 // encodeAnnotation serializes a TypeAnnotation to bytes. func encodeAnnotation(ann TypeAnnotation) []byte { buf := append([]byte{byte(ann.Kind), byte(len(ann.Name))}, []byte(ann.Name)...) buf = append(buf, byte(len(ann.Args))) for _, a := range ann.Args { buf = append(buf, encodeAnnotation(a)...) } return buf } // decodeAnnotation deserializes bytes into a TypeAnnotation tree. // Returns the annotation and number of bytes consumed. func decodeAnnotation(data []byte, offset int, depth int) (TypeAnnotation, int) { if depth > maxFuzzDepth || offset >= len(data) { return TypeAnnotation{Kind: TypeAnnotSimple, Name: "str"}, offset } kind := TypeAnnotationKind(data[offset] % 3) offset++ // Read name length and name. nameLen := 0 if offset < len(data) { nameLen = int(data[offset]) % 32 // cap name length offset++ } if offset+nameLen > len(data) { nameLen = len(data) - offset } name := string(data[offset : offset+nameLen]) offset += nameLen // Read args count. numArgs := 0 if offset < len(data) { numArgs = int(data[offset]) % 4 // cap at 3 args offset++ } var args []TypeAnnotation for i := 0; i < numArgs && offset < len(data); i++ { arg, newOffset := decodeAnnotation(data, offset, depth+1) args = append(args, arg) offset = newOffset } return TypeAnnotation{Kind: kind, Name: name, Args: args}, offset } // decodeSchemaType builds a SchemaType tree from bytes. func decodeSchemaType(data []byte, offset int, depth int) (SchemaType, int) { if depth > maxFuzzDepth || offset >= len(data) { return SchemaPrim(TypeString), offset } kind := SchemaTypeKind(data[offset] % 7) offset++ switch kind { case SchemaPrimitive: prim := PrimitiveType(0) if offset < len(data) { prim = PrimitiveType(data[offset] % 9) offset++ } st := SchemaPrim(prim) if offset < len(data) && data[offset]%2 == 1 { st.Nullable = true } if offset < len(data) { offset++ } return st, offset case SchemaAny: return SchemaAnyType(), offset case SchemaArray: items, newOffset := decodeSchemaType(data, offset, depth+1) return SchemaArrayOf(items), newOffset case SchemaDict: val, newOffset := decodeSchemaType(data, offset, depth+1) return SchemaDictOf(val), newOffset case SchemaObject: numFields := 0 if offset < len(data) { numFields = int(data[offset]) % 5 offset++ } fields := NewOrderedMap[string, SchemaField]() for i := 0; i < numFields && offset < len(data); i++ { nameLen := int(data[offset]) % 8 offset++ if offset+nameLen > len(data) { nameLen = len(data) - offset } name := string(data[offset : offset+nameLen]) offset += nameLen ft, newOffset := decodeSchemaType(data, offset, depth+1) required := false if newOffset < len(data) { required = data[newOffset]%2 == 0 newOffset++ } fields.Set(name, SchemaField{Type: ft, Required: required}) offset = newOffset } return SchemaObjectOf(fields), offset case SchemaIterator: elem, newOffset := decodeSchemaType(data, offset, depth+1) return SchemaIteratorOf(elem), newOffset case SchemaConcatIterator: return SchemaConcatIteratorOf(), offset default: return SchemaPrim(TypeString), offset } } ================================================ FILE: pkg/schema/types.go ================================================ package schema import ( "encoding/json" "fmt" "strings" ) // Mode selects whether to extract predict or train signatures. type Mode int const ( ModePredict Mode = iota ModeTrain ) // PrimitiveType maps Python types to JSON Schema types. type PrimitiveType int const ( TypeBool PrimitiveType = iota TypeFloat TypeInteger TypeString TypePath // cog.Path — {"type":"string","format":"uri"} TypeFile // cog.File (deprecated) — same wire format as Path TypeSecret // cog.Secret — write-only, masked TypeAny // typing.Any or unresolved ) // JSONType returns the JSON Schema fragment for this primitive. func (p PrimitiveType) JSONType() map[string]any { switch p { case TypeBool: return map[string]any{"type": "boolean"} case TypeFloat: return map[string]any{"type": "number"} case TypeInteger: return map[string]any{"type": "integer"} case TypeString: return map[string]any{"type": "string"} case TypePath, TypeFile: return map[string]any{"type": "string", "format": "uri"} case TypeSecret: return map[string]any{"type": "string", "format": "password", "writeOnly": true, "x-cog-secret": true} case TypeAny: return map[string]any{"type": "object"} default: return map[string]any{"type": "object"} } } func (p PrimitiveType) String() string { names := [...]string{"bool", "float", "int", "str", "Path", "File", "Secret", "Any"} if int(p) < len(names) { return names[p] } return "unknown" } // PrimitiveFromName resolves a simple type name to a PrimitiveType. func PrimitiveFromName(name string) (PrimitiveType, bool) { switch name { case "bool": return TypeBool, true case "float": return TypeFloat, true case "int": return TypeInteger, true case "str": return TypeString, true case "Path": return TypePath, true case "File": return TypeFile, true case "Secret": return TypeSecret, true case "Any": return TypeAny, true default: return 0, false } } // Repetition describes cardinality of a field. type Repetition int const ( Required Repetition = iota Optional Repeated // list[X] ) // FieldType combines a primitive type with its cardinality. type FieldType struct { Primitive PrimitiveType Repetition Repetition } // JSONType returns the JSON Schema fragment for this field type. func (ft FieldType) JSONType() map[string]any { if ft.Repetition == Repeated { return map[string]any{ "type": "array", "items": ft.Primitive.JSONType(), } } return ft.Primitive.JSONType() } // DefaultValue represents a statically-parsed Python literal. type DefaultValue struct { Kind DefaultKind Bool bool Int int64 Float float64 Str string List []DefaultValue DictKeys []DefaultValue // parallel with DictVals DictVals []DefaultValue } // DefaultKind tags the active field in DefaultValue. type DefaultKind int const ( DefaultNone DefaultKind = iota DefaultBool DefaultInt DefaultFloat DefaultString DefaultList DefaultDict DefaultSet ) // ToJSON converts a DefaultValue to its JSON representation. func (d DefaultValue) ToJSON() any { switch d.Kind { case DefaultNone: return nil case DefaultBool: return d.Bool case DefaultInt: return d.Int case DefaultFloat: return d.Float case DefaultString: return d.Str case DefaultList, DefaultSet: items := make([]any, len(d.List)) for i, v := range d.List { items[i] = v.ToJSON() } return items case DefaultDict: m := make(map[string]any, len(d.DictKeys)) for i := range d.DictKeys { key := fmt.Sprintf("%v", d.DictKeys[i].ToJSON()) if d.DictKeys[i].Kind == DefaultString { key = d.DictKeys[i].Str } m[key] = d.DictVals[i].ToJSON() } return m default: return nil } } // MarshalJSON implements json.Marshaler for DefaultValue. func (d DefaultValue) MarshalJSON() ([]byte, error) { return json.Marshal(d.ToJSON()) } // InputField represents one parameter of predict/train. type InputField struct { Name string Order int FieldType FieldType Default *DefaultValue Description *string GE *float64 LE *float64 MinLength *uint64 MaxLength *uint64 Regex *string Choices []DefaultValue Deprecated *bool } // IsRequired returns true if this field is required in the schema. func (f *InputField) IsRequired() bool { return f.Default == nil && (f.FieldType.Repetition == Required || f.FieldType.Repetition == Repeated) } // PredictorInfo is the top-level extraction result. type PredictorInfo struct { Inputs *OrderedMap[string, InputField] Output SchemaType Mode Mode } // TypeAnnotation is a parsed Python type annotation (intermediate, before resolution). type TypeAnnotation struct { Kind TypeAnnotationKind Name string // for Simple Args []TypeAnnotation // for Generic (outer=Name, args=Args) or Union (members=Args) } // TypeAnnotationKind tags the variant. type TypeAnnotationKind int const ( TypeAnnotSimple TypeAnnotationKind = iota TypeAnnotGeneric TypeAnnotUnion ) // ImportContext tracks what names are imported from which modules. type ImportContext struct { // Names maps local name → (module, original_name) Names *OrderedMap[string, ImportEntry] } // ImportEntry records where a name was imported from. type ImportEntry struct { Module string Original string } // NewImportContext creates an empty ImportContext. func NewImportContext() *ImportContext { return &ImportContext{Names: NewOrderedMap[string, ImportEntry]()} } // IsCogType returns true if name was imported from the "cog" module. func (ctx *ImportContext) IsCogType(name string) bool { if e, ok := ctx.Names.Get(name); ok { return e.Module == "cog" } return false } // IsTypingType returns true if name was imported from "typing" or "typing_extensions". func (ctx *ImportContext) IsTypingType(name string) bool { if e, ok := ctx.Names.Get(name); ok { return e.Module == "typing" || e.Module == "typing_extensions" } return false } // IsBaseModel returns true if name resolves to cog.BaseModel or pydantic.BaseModel. func (ctx *ImportContext) IsBaseModel(name string) bool { if e, ok := ctx.Names.Get(name); ok { return (e.Module == "cog" || e.Module == "pydantic" || e.Module == "pydantic.v1") && e.Original == "BaseModel" } return false } // IsBasePredictor returns true if name resolves to cog.BasePredictor. func (ctx *ImportContext) IsBasePredictor(name string) bool { if e, ok := ctx.Names.Get(name); ok { return e.Module == "cog" && e.Original == "BasePredictor" } return false } // ResolveFieldType resolves a TypeAnnotation into a FieldType. func ResolveFieldType(ann TypeAnnotation, ctx *ImportContext) (FieldType, error) { switch ann.Kind { case TypeAnnotSimple: prim, ok := PrimitiveFromName(ann.Name) if !ok { return FieldType{}, errUnsupportedType(ann.Name) } return FieldType{Primitive: prim, Repetition: Required}, nil case TypeAnnotGeneric: outer := ann.Name if outer == "Optional" { if len(ann.Args) != 1 { return FieldType{}, errUnsupportedType(fmt.Sprintf("Optional expects exactly 1 type argument, got %d", len(ann.Args))) } inner, err := ResolveFieldType(ann.Args[0], ctx) if err != nil { return FieldType{}, err } return FieldType{Primitive: inner.Primitive, Repetition: Optional}, nil } if outer == "Union" { // typing.Union[X, Y] → treat as union type return ResolveFieldType(TypeAnnotation{Kind: TypeAnnotUnion, Args: ann.Args}, ctx) } if outer == "List" || outer == "list" { if len(ann.Args) != 1 { return FieldType{}, errUnsupportedType(fmt.Sprintf("List expects exactly 1 type argument, got %d", len(ann.Args))) } inner, err := ResolveFieldType(ann.Args[0], ctx) if err != nil { return FieldType{}, err } if inner.Repetition != Required { return FieldType{}, errUnsupportedType("nested generics like List[Optional[X]] are not supported") } return FieldType{Primitive: inner.Primitive, Repetition: Repeated}, nil } return FieldType{}, errUnsupportedType(fmt.Sprintf("%s[...] is not a supported input type", outer)) case TypeAnnotUnion: if inner, ok := UnwrapOptional(ann); ok { ft, err := ResolveFieldType(inner, ctx) if err != nil { return FieldType{}, err } return FieldType{Primitive: ft.Primitive, Repetition: Optional}, nil } return FieldType{}, errUnsupportedType("union types other than X | None are not supported") } return FieldType{}, errUnsupportedType("unknown type annotation") } // UnwrapOptional checks if a type annotation represents Optional[X] or Union[X, None]. // If so, it returns the inner type and true. Otherwise it returns the original and false. func UnwrapOptional(ann TypeAnnotation) (TypeAnnotation, bool) { // Optional[X] if ann.Kind == TypeAnnotGeneric && ann.Name == "Optional" && len(ann.Args) == 1 { return ann.Args[0], true } // Union[X, None] or X | None args := ann.Args if (ann.Kind == TypeAnnotGeneric && ann.Name == "Union" || ann.Kind == TypeAnnotUnion) && len(args) == 2 { for i := range args { if args[i].Kind == TypeAnnotSimple && args[i].Name == "None" { return args[1-i], true } } } return ann, false } // ModelClassMap maps class names to their fields. type ModelClassMap = *OrderedMap[string, []ModelField] // ModelField is a field extracted from a BaseModel subclass. type ModelField struct { Name string Type TypeAnnotation Default *DefaultValue } // TitleCase converts snake_case to Title Case. func TitleCase(s string) string { parts := strings.Split(s, "_") for i, p := range parts { if len(p) > 0 { parts[i] = strings.ToUpper(p[:1]) + p[1:] } } return strings.Join(parts, " ") } // TitleCaseSingle title-cases a single word (first letter uppercase). func TitleCaseSingle(s string) string { if len(s) == 0 { return s } return strings.ToUpper(s[:1]) + s[1:] } // OrderedMap is a simple insertion-ordered map. type OrderedMap[K comparable, V any] struct { keys []K values map[K]V } // NewOrderedMap creates a new empty OrderedMap. func NewOrderedMap[K comparable, V any]() *OrderedMap[K, V] { return &OrderedMap[K, V]{values: make(map[K]V)} } // Set inserts or updates a key-value pair. func (m *OrderedMap[K, V]) Set(key K, value V) { if _, exists := m.values[key]; !exists { m.keys = append(m.keys, key) } m.values[key] = value } // Get returns the value for a key and whether it exists. func (m *OrderedMap[K, V]) Get(key K) (V, bool) { v, ok := m.values[key] return v, ok } // Keys returns keys in insertion order. func (m *OrderedMap[K, V]) Keys() []K { return m.keys } // Len returns the number of entries. func (m *OrderedMap[K, V]) Len() int { return len(m.keys) } // Entries iterates over key-value pairs in insertion order. func (m *OrderedMap[K, V]) Entries(fn func(key K, value V)) { for _, k := range m.keys { fn(k, m.values[k]) } } ================================================ FILE: pkg/update/state.go ================================================ package update import ( "encoding/json" "os" "path/filepath" "time" "github.com/mitchellh/go-homedir" "github.com/replicate/cog/pkg/util/console" "github.com/replicate/cog/pkg/util/files" ) type state struct { Message string `json:"message"` LastChecked time.Time `json:"lastChecked"` Version string `json:"version"` } // loadState loads the update check state from disk, returning defaults if it does not exist func loadState() (*state, error) { state := state{} p, err := statePath() if err != nil { return nil, err } exists, err := files.Exists(p) if err != nil { return nil, err } if !exists { return &state, nil } text, err := os.ReadFile(p) if err != nil { console.Debugf("Failed to read %s: %s", p, err) return &state, nil } err = json.Unmarshal(text, &state) if err != nil { return nil, err } return &state, nil } // writeState saves analytics state to disk func writeState(s *state) error { statePath, err := statePath() if err != nil { return err } bytes, err := json.MarshalIndent(s, "", " ") if err != nil { return err } dir := filepath.Dir(statePath) if err := os.MkdirAll(dir, 0o700); err != nil { return err } err = os.WriteFile(statePath, bytes, 0o600) if err != nil { return err } return nil } func userDir() (string, error) { return homedir.Expand("~/.config/cog") } func statePath() (string, error) { dir, err := userDir() if err != nil { return "", err } return filepath.Join(dir, "update-state.json"), nil } ================================================ FILE: pkg/update/update.go ================================================ package update import ( "context" "encoding/json" "errors" "fmt" "net/http" "os" "runtime" "time" "github.com/replicate/cog/pkg/global" "github.com/replicate/cog/pkg/util/console" ) func isUpdateEnabled() bool { return os.Getenv("COG_NO_UPDATE_CHECK") == "" } // DisplayAndCheckForRelease will display an update message if an update is available and will check for a new update in the background // The result of that check will then be displayed the next time the user runs Cog // Returns errors which the caller is assumed to ignore so as not to break the client func DisplayAndCheckForRelease(ctx context.Context) error { if !isUpdateEnabled() { return fmt.Errorf("update check disabled") } s, err := loadState() if err != nil { return err } if s.Version != global.Version { console.Debugf("Resetting update message because Cog has been upgraded") return writeState(&state{Message: "", LastChecked: time.Now(), Version: global.Version}) } if time.Since(s.LastChecked) > time.Hour { startCheckingForRelease(ctx) } if s.Message != "" { console.Info(s.Message) console.Info("") } return nil } func startCheckingForRelease(ctx context.Context) { go func() { console.Debugf("Checking for updates...") ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() switch r, err := checkForRelease(ctx); { case err == nil: if r == nil { break } if err := writeState(&state{Message: r.Message, LastChecked: time.Now(), Version: global.Version}); err != nil { console.Debugf("Failed to write state: %s", err) } console.Debugf("result of update check: %v", r.Message) case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): break default: console.Debugf("failed querying for new release: %v", err) } }() } type updateCheckResponse struct { Message string `json:"message"` } func checkForRelease(ctx context.Context) (*updateCheckResponse, error) { req, err := http.NewRequestWithContext(ctx, "GET", "https://update.cog.run/v1/check", nil) if err != nil { return nil, err } req.Header.Add("Accept", "application/json") q := req.URL.Query() q.Add("version", global.Version) q.Add("commit", global.Commit) q.Add("os", runtime.GOOS) q.Add("arch", runtime.GOARCH) req.URL.RawQuery = q.Encode() resp, err := http.DefaultClient.Do(req) //nolint:gosec // G704: URL is built from hardcoded base + version params if err != nil { return nil, err } defer resp.Body.Close() var response updateCheckResponse if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { return &response, err } return &response, nil } ================================================ FILE: pkg/util/console/console.go ================================================ // Package console provides a standard interface for user- and machine-interface with the console package console import ( "fmt" "math" "os" "strings" "sync" "unicode/utf8" "github.com/logrusorgru/aurora" "github.com/mattn/go-isatty" "golang.org/x/term" ) // ShouldUseColor returns true if color output should be enabled, based on // environment detection. It checks (in order): // - NO_COLOR env var is set and non-empty → no color // - COG_NO_COLOR env var is set and non-empty → no color // - TERM=dumb → no color // - stderr is not a TTY → no color // // This follows the NO_COLOR standard (https://no-color.org/) and common CLI // conventions. The --no-color flag is handled separately at the CLI layer. func ShouldUseColor() bool { if os.Getenv("NO_COLOR") != "" { return false } if os.Getenv("COG_NO_COLOR") != "" { return false } if os.Getenv("TERM") == "dumb" { return false } fd := os.Stderr.Fd() if !isatty.IsTerminal(fd) && !isatty.IsCygwinTerminal(fd) { return false } return true } // Style controls the icon/color used for a log line, independent of level. type Style int const ( // StyleDefault uses the default icon for the log level. StyleDefault Style = iota // StyleSuccess uses a green ✓ icon. StyleSuccess ) // Console represents a standardized interface for console UI. It is designed to abstract: // - Writing main output // - Giving information to user // - Console user interface elements (progress, interactive prompts, etc) // - Switching between human and machine modes for these things (e.g. don't display progress bars or colors in logs, don't prompt for input when in a script) type Console struct { Color bool IsMachine bool Level Level mu sync.Mutex } // Debug prints a verbose debugging message, that is not displayed by default to the user. func (c *Console) Debug(msg string) { c.log(DebugLevel, msg) } // Info tells the user what's going on. func (c *Console) Info(msg string) { c.log(InfoLevel, msg) } // Success tells the user something completed successfully. // Displays at info level with a green ✓ prefix. func (c *Console) Success(msg string) { c.logStyled(InfoLevel, StyleSuccess, msg) } // Warn tells the user that something might break. func (c *Console) Warn(msg string) { c.log(WarnLevel, msg) } // Error tells the user that something is broken. func (c *Console) Error(msg string) { c.log(ErrorLevel, msg) } // Fatal level message, followed by exit func (c *Console) Fatal(msg string) { c.log(FatalLevel, msg) os.Exit(1) } // Debug level message func (c *Console) Debugf(msg string, v ...any) { c.log(DebugLevel, fmt.Sprintf(msg, v...)) } // Info level message func (c *Console) Infof(msg string, v ...any) { c.log(InfoLevel, fmt.Sprintf(msg, v...)) } // Success level message func (c *Console) Successf(msg string, v ...any) { c.logStyled(InfoLevel, StyleSuccess, fmt.Sprintf(msg, v...)) } // Warn level message func (c *Console) Warnf(msg string, v ...any) { c.log(WarnLevel, fmt.Sprintf(msg, v...)) } // Error level message func (c *Console) Errorf(msg string, v ...any) { c.log(ErrorLevel, fmt.Sprintf(msg, v...)) } // Fatal level message, followed by exit func (c *Console) Fatalf(msg string, v ...any) { c.log(FatalLevel, fmt.Sprintf(msg, v...)) os.Exit(1) } // InfoUnformatted writes a message to stderr without any prefix. Useful for conversational // or interactive output (e.g. login prompts) where the icon prefix would be noise. // Displayed at info level. Long lines are wrapped to terminal width when stderr is a TTY. func (c *Console) InfoUnformatted(msg string) { if InfoLevel < c.Level { return } termWidth := stderrTerminalWidth() c.mu.Lock() defer c.mu.Unlock() for line := range strings.SplitSeq(msg, "\n") { if termWidth > 0 { wrapped := wrapLine(line, termWidth) for _, wl := range wrapped { fmt.Fprintln(os.Stderr, wl) } continue } fmt.Fprintln(os.Stderr, line) } } // InfoUnformattedf writes a formatted message to stderr without any prefix. func (c *Console) InfoUnformattedf(msg string, v ...any) { c.InfoUnformatted(fmt.Sprintf(msg, v...)) } // Output a string to stdout. Useful for printing primary output of a command, or the output of a subcommand. // A newline is added to the string. func (c *Console) Output(s string) { c.mu.Lock() defer c.mu.Unlock() _, _ = fmt.Fprintln(os.Stdout, s) } // Bold applies bold formatting to a string when color is enabled. // Use this to highlight dynamic values (image names, paths, URLs) in log messages. func (c *Console) Bold(s string) string { if c.Color { return aurora.Bold(s).String() } return s } func (c *Console) log(level Level, msg string) { c.logStyled(level, StyleDefault, msg) } func (c *Console) logStyled(level Level, style Style, msg string) { if level < c.Level { return } prompt := "" // promptWidth is the visual width of the prompt (excluding ANSI codes). promptWidth := 0 if c.Color { switch style { case StyleSuccess: prompt = " " + aurora.Bold(aurora.Green("✔ ")).String() promptWidth = 4 // " ✔ " default: switch level { case DebugLevel, InfoLevel: prompt = " " + aurora.Faint("⚙ ").String() promptWidth = 4 // " ⚙ " case WarnLevel: prompt = " " + aurora.Bold(aurora.Yellow("⚠ ")).String() promptWidth = 4 // " ⚠ " case ErrorLevel, FatalLevel: prompt = " " + aurora.Bold(aurora.Red("✗ ")).String() promptWidth = 4 // " ✗ " } } } termWidth := stderrTerminalWidth() c.mu.Lock() defer c.mu.Unlock() for line := range strings.SplitSeq(msg, "\n") { if line == "" && (level == DebugLevel || level == InfoLevel) { fmt.Fprintln(os.Stderr) continue } if c.Color && level == DebugLevel { line = aurora.Faint(line).String() } // Wrap long lines to terminal width. if termWidth > 0 && promptWidth > 0 { maxWidth := termWidth - promptWidth if maxWidth > 0 { wrapped := wrapLine(line, maxWidth) for _, wl := range wrapped { fmt.Fprintln(os.Stderr, prompt+wl) } continue } } fmt.Fprintln(os.Stderr, prompt+line) } } // stderrTerminalWidth returns the terminal width of stderr, or 0 if stderr // is not a terminal or the width cannot be determined. func stderrTerminalWidth() int { fd := os.Stderr.Fd() if !isatty.IsTerminal(fd) && !isatty.IsCygwinTerminal(fd) { return 0 } if fd > math.MaxInt { return 0 } w, _, err := term.GetSize(int(fd)) //nolint:gosec // bounded above if err != nil || w <= 0 { return 0 } return w } // wrapLine wraps a single line of text to the given width, breaking on word // boundaries where possible. It operates on the visible text (which may contain // ANSI escape codes — these are counted as zero-width for wrapping purposes). func wrapLine(line string, maxWidth int) []string { if visibleWidth(line) <= maxWidth { return []string{line} } var lines []string for len(line) > 0 { if visibleWidth(line) <= maxWidth { lines = append(lines, line) break } // Find the byte position where we exceed maxWidth visible chars. cutByte := findCutPoint(line, maxWidth) // Try to break at a space before the cut point. breakAt := strings.LastIndex(line[:cutByte], " ") if breakAt <= 0 { // No good break point; hard-break at cutByte. breakAt = cutByte } lines = append(lines, line[:breakAt]) line = strings.TrimLeft(line[breakAt:], " ") } return lines } // visibleWidth returns the number of visible characters in a string, // ignoring ANSI escape sequences. func visibleWidth(s string) int { width := 0 inEscape := false for _, r := range s { if inEscape { if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') { inEscape = false } continue } if r == '\x1b' { inEscape = true continue } width += utf8.RuneLen(r) // approximate: 1 for ASCII, may differ for wide chars if r > 127 { width = width - utf8.RuneLen(r) + 1 // count non-ASCII runes as width 1 } } return width } // findCutPoint returns the byte index in s where the visible width reaches maxWidth. func findCutPoint(s string, maxWidth int) int { width := 0 inEscape := false for i, r := range s { if inEscape { if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') { inEscape = false } continue } if r == '\x1b' { inEscape = true continue } width++ if width >= maxWidth { return i + utf8.RuneLen(r) } } return len(s) } ================================================ FILE: pkg/util/console/formatting.go ================================================ package console import ( "time" "github.com/xeonx/timeago" ) func FormatTime(t time.Time) string { return timeago.English.Format(t) } ================================================ FILE: pkg/util/console/global.go ================================================ package console import ( "os" "github.com/mattn/go-isatty" ) // ConsoleInstance is the global instance of console, so we don't have to pass it around everywhere var ConsoleInstance = &Console{ Color: ShouldUseColor(), Level: InfoLevel, IsMachine: false, } // SetLevel sets log level func SetLevel(level Level) { ConsoleInstance.Level = level } // SetColor sets whether to print colors func SetColor(color bool) { ConsoleInstance.Color = color } // Debug level message. func Debug(msg string) { ConsoleInstance.Debug(msg) } // Info level message. func Info(msg string) { ConsoleInstance.Info(msg) } // Success level message. func Success(msg string) { ConsoleInstance.Success(msg) } // Warn level message. func Warn(msg string) { ConsoleInstance.Warn(msg) } // Error level message. func Error(msg string) { ConsoleInstance.Error(msg) } // Fatal level message. func Fatal(msg string) { ConsoleInstance.Fatal(msg) } // Debug level message. func Debugf(msg string, v ...any) { ConsoleInstance.Debugf(msg, v...) } // Info level message. func Infof(msg string, v ...any) { ConsoleInstance.Infof(msg, v...) } // Success level message. func Successf(msg string, v ...any) { ConsoleInstance.Successf(msg, v...) } // Warn level message. func Warnf(msg string, v ...any) { ConsoleInstance.Warnf(msg, v...) } // Error level message. func Errorf(msg string, v ...any) { ConsoleInstance.Errorf(msg, v...) } // Fatal level message. func Fatalf(msg string, v ...any) { ConsoleInstance.Fatalf(msg, v...) } // InfoUnformatted writes to stderr without prefix. Useful for interactive/conversational output. func InfoUnformatted(msg string) { ConsoleInstance.InfoUnformatted(msg) } // InfoUnformattedf writes to stderr without prefix, with formatting. func InfoUnformattedf(msg string, v ...any) { ConsoleInstance.InfoUnformattedf(msg, v...) } // Output a line to stdout. Useful for printing primary output of a command, or the output of a subcommand. func Output(s string) { ConsoleInstance.Output(s) } // Bold applies bold formatting to a string when color is enabled. func Bold(s string) string { return ConsoleInstance.Bold(s) } // IsTTY checks if a file is a TTY or not. E.g. IsTTY(os.Stdin) func IsTTY(f *os.File) bool { return isatty.IsTerminal(f.Fd()) } ================================================ FILE: pkg/util/console/interactive.go ================================================ package console import ( "bufio" "fmt" "io" "os" "slices" "strings" ) type Interactive struct { Prompt string Default string Options []string Required bool } func (i Interactive) Read() (string, error) { if i.Default != "" && i.Options != nil && !slices.Contains(i.Options, i.Default) { panic("Default is not an option") } parens := "" if i.Required { parens += "required" } if i.Default != "" { if parens != "" { parens += ", " } parens += "default: " + i.Default } if i.Options != nil { if parens != "" { parens += ", " } parens += "options: " + strings.Join(i.Options, ", ") } if parens != "" { parens = " (" + parens + ")" } for { fmt.Printf("%s%s: ", i.Prompt, parens) reader := bufio.NewReader(os.Stdin) text, err := reader.ReadString('\n') if err != nil { return "", err } text = strings.TrimSpace(text) if text == "" && i.Default != "" { text = i.Default } if i.Required && text == "" { Warn("Please enter a value") continue } if !i.Required && text == "" { return "", nil } if i.Options != nil { if !slices.Contains(i.Options, text) { Warnf("%s is not a valid option", text) continue } } return text, nil } } type InteractiveBool struct { Prompt string Default bool // NonDefaultFlag is the flag to suggest passing to do the thing which isn't default when running inside a script NonDefaultFlag string } func (i InteractiveBool) Read() (bool, error) { defaults := "y/N" if i.Default { defaults = "Y/n" } for { fmt.Printf("%s (%s) ", i.Prompt, defaults) reader := bufio.NewReader(os.Stdin) text, err := reader.ReadString('\n') if err != nil { // Only translate error if a flag is set if err == io.EOF && i.NonDefaultFlag != "" { return false, fmt.Errorf("stdin is closed. If you're running in a script, you need to pass the '%s' option", i.NonDefaultFlag) } return false, err } text = strings.ToLower(strings.TrimSpace(text)) if text == "yes" || text == "y" { return true, nil } if text == "no" || text == "n" { return false, nil } if text == "" { return i.Default, nil } Warn("Please enter 'y' or 'n'") } } ================================================ FILE: pkg/util/console/levels.go ================================================ package console // Mostly lifted from https://github.com/apex/log/blob/master/levels.go import ( "errors" "strings" ) // ErrInvalidLevel is returned if the severity level is invalid. var ErrInvalidLevel = errors.New("invalid level") // Level of severity. type Level int // Log levels. const ( InvalidLevel Level = iota - 1 DebugLevel InfoLevel WarnLevel ErrorLevel FatalLevel ) var levelNames = [...]string{ DebugLevel: "debug", InfoLevel: "info", WarnLevel: "warn", ErrorLevel: "error", FatalLevel: "fatal", } var levelStrings = map[string]Level{ "debug": DebugLevel, "info": InfoLevel, "warn": WarnLevel, "warning": WarnLevel, "error": ErrorLevel, "fatal": FatalLevel, } // String implementation. func (l Level) String() string { return levelNames[l] } // ParseLevel parses level string. func ParseLevel(s string) (Level, error) { l, ok := levelStrings[strings.ToLower(s)] if !ok { return InvalidLevel, ErrInvalidLevel } return l, nil } // MustParseLevel parses level string or panics. func MustParseLevel(s string) Level { l, err := ParseLevel(s) if err != nil { panic("invalid log level") } return l } ================================================ FILE: pkg/util/console/term.go ================================================ package console import ( "os" "github.com/moby/term" ) // IsTerminal returns true if we're in a terminal and a user is interacting with us func IsTerminal() bool { return term.IsTerminal(os.Stdin.Fd()) } // GetWidth returns the width of the terminal (from stderr -- stdout might be piped) // // Returns 0 if we're not in a terminal func GetWidth() (uint16, error) { fd := os.Stderr.Fd() if term.IsTerminal(fd) { ws, err := term.GetWinsize(fd) if err != nil { return 0, err } return ws.Width, nil } return 0, nil } ================================================ FILE: pkg/util/env.go ================================================ package util import ( "os" "github.com/replicate/cog/pkg/util/console" ) // GetEnvOrDefault returns an environment variable or a default if either the environment variable // does not exist or fails to parse using the specified conversionFunc function func GetEnvOrDefault[T any](key string, defaultVal T, conversionFunc func(string) (T, error)) T { val, exists := os.LookupEnv(key) if exists { v, err := conversionFunc(val) if err == nil { return v } else { console.Warnf("Failed to convert env var %s to expected type. Continuing with default. Error: %v", key, err) } } return defaultVal } ================================================ FILE: pkg/util/errors.go ================================================ package util import "fmt" // WrapError is just a shortcut for using fmt.Errorf // to wrap an error with a message func WrapError(err error, message string) error { if err == nil { return nil } return fmt.Errorf("%s: %w", message, err) } ================================================ FILE: pkg/util/files/files.go ================================================ package files import ( "errors" "fmt" "io" "os" "path" "strings" "github.com/mitchellh/go-homedir" "github.com/vincent-petithory/dataurl" "golang.org/x/sys/unix" r8_path "github.com/replicate/cog/pkg/path" "github.com/replicate/cog/pkg/util/mime" ) var ( ErrorFailedToSplitDataURL = errors.New("Failed to split data URL into 2 parts") ) func Exists(path string) (bool, error) { if _, err := os.Stat(path); err == nil { return true, nil } else if os.IsNotExist(err) { return false, nil } else { return false, fmt.Errorf("Failed to determine if %s exists: %w", path, err) } } func IsEmpty(path string) (bool, error) { entries, err := os.ReadDir(path) if err != nil { if errors.Is(err, os.ErrNotExist) { return true, nil } return false, err } return len(entries) == 0, nil } func IsDir(path string) (bool, error) { file, err := os.Stat(path) if err != nil { return false, err } return file.Mode().IsDir(), nil } func IsExecutable(path string) bool { return unix.Access(path, unix.X_OK) == nil } func CopyFile(src string, dest string) error { in, err := os.Open(src) if err != nil { return fmt.Errorf("Failed to open %s while copying to %s: %w", src, dest, err) } defer in.Close() out, err := os.Create(dest) if err != nil { return fmt.Errorf("Failed to create %s while copying %s: %w", dest, src, err) } defer out.Close() _, err = io.Copy(out, in) if err != nil { return fmt.Errorf("Failed to copy %s to %s: %w", src, dest, err) } return out.Close() } func WriteIfDifferent(file, content string) error { if _, err := os.Stat(file); err == nil { bs, err := os.ReadFile(file) if err != nil { return err } if string(bs) == content { return nil } } else if !errors.Is(err, os.ErrNotExist) { return err } // Write out a new requirements file err := os.WriteFile(file, []byte(content), 0o644) if err != nil { return err } return nil } func WriteDataURLToFile(url string, destination string) (string, error) { if strings.HasPrefix(url, "data:None;base64") { url = strings.Replace(url, "data:None;base64", "data:;base64", 1) } dataurlObj, err := dataurl.DecodeString(url) if err != nil { // Attempt to fallback to binary base64 file decode. parts := strings.SplitN(url, ",", 2) if len(parts) != 2 { return "", ErrorFailedToSplitDataURL } base64Data := parts[1] url = "data:;base64," + base64Data dataurlObj, err = dataurl.DecodeString(url) if err != nil { return "", fmt.Errorf("Failed to decode data URL: %w", err) } } output := dataurlObj.Data ext := path.Ext(destination) dir := path.Dir(destination) name := r8_path.TrimExt(path.Base(destination)) // Check if ext is an integer, in which case ignore it... if r8_path.IsExtInteger(ext) { ext = "" name = path.Base(destination) } if ext == "" { ext = mime.ExtensionByType(dataurlObj.ContentType()) } path, err := WriteFile(output, path.Join(dir, name+ext)) if err != nil { return "", err } return path, nil } func WriteFile(output []byte, outputPath string) (string, error) { outputPath, err := homedir.Expand(outputPath) if err != nil { return "", err } // Write to file outFile, err := os.OpenFile(outputPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755) if err != nil { return "", err } if _, err := outFile.Write(output); err != nil { return "", err } if err := outFile.Close(); err != nil { return "", err } return outputPath, nil } ================================================ FILE: pkg/util/files/files_test.go ================================================ package files import ( "os" "path/filepath" "testing" "github.com/stretchr/testify/require" ) func TestIsExecutable(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "test-file") err := os.WriteFile(path, []byte{}, 0o644) require.NoError(t, err) require.False(t, IsExecutable(path)) require.NoError(t, os.Chmod(path, 0o744)) require.True(t, IsExecutable(path)) } func TestWriteBadlyFormattedBase64DataURI(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "test-file") _, err := WriteDataURLToFile("data:None;base64,SGVsbG8gVGhlcmU=", path) require.NoError(t, err) } func TestWriteNotRecognisedBase64DataURL(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "test-file") _, err := WriteDataURLToFile("data:None;model/gltf-binary,SGVsbG8gVGhlcmU=", path) require.NoError(t, err) } ================================================ FILE: pkg/util/hash.go ================================================ package util import ( "bytes" "crypto/sha256" "encoding/hex" "errors" "fmt" "io" "os" ) var ( ErrInvalidRange = errors.New("Invalid byte range provided for file") ) func SHA256HashFile(path string) (string, error) { hash := sha256.New() file, err := os.Open(path) if err != nil { return "", err } defer file.Close() if _, err := io.Copy(hash, file); err != nil { return "", err } return hex.EncodeToString(hash.Sum(nil)), nil } func SHA256HashFileWithSaltAndRange(path string, start int, end int, salt string) (string, error) { hash := sha256.New() length := end - start if length < 0 { return "", ErrInvalidRange } file, err := os.Open(path) if err != nil { return "", err } defer file.Close() fileInfo, err := file.Stat() if err != nil { return "", err } if fileInfo.Size() < int64(end) { return "", ErrInvalidRange } _, err = file.Seek(int64(start), 0) if err != nil { return "", fmt.Errorf("failed to open file pointer %s: %w", path, err) } buf := make([]byte, length) n, err := file.Read(buf) if err != nil { return "", err } buf = buf[:n] var hashInput []byte hashInput = append(hashInput, buf...) hashInput = append(hashInput, []byte(salt)...) if _, err := io.Copy(hash, bytes.NewReader(hashInput)); err != nil { return "", err } return hex.EncodeToString(hash.Sum(nil)), nil } ================================================ FILE: pkg/util/hash_test.go ================================================ package util import ( "os" "path/filepath" "testing" "github.com/stretchr/testify/require" ) func TestHash(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "test.tmp") d1 := []byte("hello\ngo\n") err := os.WriteFile(path, d1, 0o644) require.NoError(t, err) sha256, err := SHA256HashFile(path) require.NoError(t, err) require.Equal(t, "43d250d92b5dbb47f75208de8e9a9a321d23e85eed0dc3d5dfa83bc3cc5aa68c", sha256) } func TestHashFileWithSaltAndRange(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "test.tmp") d1 := []byte("hello\nreplicate\nhello\n") err := os.WriteFile(path, d1, 0o644) require.NoError(t, err) _, err = SHA256HashFileWithSaltAndRange(path, 0, 60, "go\n") require.Error(t, err) _, err = SHA256HashFileWithSaltAndRange(path, 23, 1, "go\n") require.Error(t, err) sha256, err := SHA256HashFileWithSaltAndRange(path, 0, 6, "go\n") require.NoError(t, err) require.Equal(t, "43d250d92b5dbb47f75208de8e9a9a321d23e85eed0dc3d5dfa83bc3cc5aa68c", sha256) sha256, err = SHA256HashFileWithSaltAndRange(path, 16, 22, "go\n") require.NoError(t, err) require.Equal(t, "43d250d92b5dbb47f75208de8e9a9a321d23e85eed0dc3d5dfa83bc3cc5aa68c", sha256) } ================================================ FILE: pkg/util/mime/mime.go ================================================ package mime import ( "mime" "strings" ) var typeToExtension = map[string]string{ "application/epub+zip": ".epub", "application/gzip": ".gz", "application/java-archive": ".jar", "application/json": ".json", "application/jsonl": ".jsonl", "application/ld+json": ".jsonld", "application/msword": ".doc", "application/octet-stream": ".bin", "application/ogg": ".ogx", "application/pdf": ".pdf", "application/rtf": ".rtf", "application/vnd.amazon.ebook": ".azw", "application/vnd.apple.installer+xml": ".mpkg", "application/vnd.ms-excel": ".xls", "application/vnd.ms-fontobject": ".eot", "application/vnd.ms-powerpoint": ".ppt", "application/vnd.oasis.opendocument.presentation": ".odp", "application/vnd.oasis.opendocument.spreadsheet": ".ods", "application/vnd.oasis.opendocument.text": ".odt", "application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx", "application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx", "application/vnd.rar": ".rar", "application/vnd.visio": ".vsd", "application/x-7z-compressed": ".7z", "application/x-abiword": ".abw", "application/x-bzip": ".bz", "application/x-bzip2": ".bz2", "application/x-cdf": ".cda", "application/x-csh": ".csh", "application/x-freearc": ".arc", "application/x-httpd-php": ".php", "application/x-ndjson": ".ndjson", "application/x-sh": ".sh", "application/x-shockwave-flash": ".swf", "application/x-tar": ".tar", "application/xhtml+xml": ".xhtml", "application/xml": ".xml", "application/zip": ".zip", "audio/aac": ".aac", "audio/midi audio/x-midi": ".midi", "audio/mpeg": ".mp3", "audio/ogg": ".oga", "audio/opus": ".opus", "audio/wav": ".wav", "audio/webm": ".weba", "font/otf": ".otf", "font/ttf": ".ttf", "font/woff": ".woff", "font/woff2": ".woff2", "image/bmp": ".bmp", "image/x-ms-bmp": ".bmp", "image/gif": ".gif", "image/jpeg": ".jpg", "image/png": ".png", "image/svg+xml": ".svg", "image/tiff": ".tiff", "image/vnd.microsoft.icon": ".ico", "image/webp": ".webp", "model/gltf-binary": ".glb", "model/mtl": ".mtl", "model/obj": ".obj", "text/calendar": ".ics", "text/css": ".css", "text/csv": ".csv", "text/html": ".html", "text/javascript": ".js", "text/markdown": ".md", "text/plain": ".txt", "video/3gpp": ".3gp", "video/3gpp2": ".3gp2", "video/mp2t": ".ts", "video/mp4": ".mp4", "video/mpeg": ".mpeg", "video/ogg": ".ogv", "video/webm": ".webm", "video/x-msvideo": ".avi", } var extensionToType = map[string]string{} func init() { for typ, ext := range typeToExtension { extensionToType[ext] = typ } } // ExtensionByType returns the file extension associated with the media type typ. // When typ has no associated extension, ExtensionByType returns an empty string. func ExtensionByType(typ string) string { // Lookup extension from pre-defined map ext := typeToExtension[typ] // Fall back to mime.ExtensionsByType if ext == "" { extensions, _ := mime.ExtensionsByType(typ) if len(extensions) > 0 { ext = extensions[0] } } return ext } // TypeByExtension returns the media type associated with the file extension ext. // The extension ext should begin with a leading dot, as in ".json" // When ext has no associated type, TypeByExtension returns "application/octet-stream" func TypeByExtension(ext string) string { if !strings.HasPrefix(ext, ".") { ext = "." + ext } // Lookup type from pre-defined map typ := extensionToType[ext] // Fall back to mime.TypeByExtension if typ == "" { typ = mime.TypeByExtension(ext) } // Default to "application/octet-stream" if typ == "" { typ = "application/octet-stream" } return typ } ================================================ FILE: pkg/util/mime/mime_test.go ================================================ package mime import ( "testing" "github.com/stretchr/testify/require" ) func TestExtensionByType(t *testing.T) { require.Equal(t, ".txt", ExtensionByType("text/plain")) require.Equal(t, ".jpg", ExtensionByType("image/jpeg")) require.Equal(t, ".png", ExtensionByType("image/png")) require.Equal(t, ".obj", ExtensionByType("model/obj")) require.Equal(t, ".json", ExtensionByType("application/json")) require.Equal(t, "", ExtensionByType("asdfasdf")) } func TestTypeByExtension(t *testing.T) { require.Equal(t, "text/plain", TypeByExtension(".txt")) require.Equal(t, "image/jpeg", TypeByExtension(".jpg")) require.Equal(t, "image/png", TypeByExtension(".png")) require.Equal(t, "model/obj", TypeByExtension(".obj")) require.Equal(t, "application/json", TypeByExtension(".json")) require.Equal(t, "application/octet-stream", TypeByExtension(".asdfasdf")) } ================================================ FILE: pkg/util/net.go ================================================ package util import ( "fmt" "math/rand" "net" "time" ) // PickFreePort returns a TCP port in [min,max] that's not in use on the 127.0.0.1 interface. // Note that there's a small chance of a race condition when a port is considered free at the // time of the call, but not free when something tries to use it. This is good enough for dev // and test code though. func PickFreePort(minPort, maxPort int) (int, error) { if minPort < 1024 || maxPort > 99999 || minPort > maxPort { return 0, fmt.Errorf("invalid port range") } rng := rand.New(rand.NewSource(time.Now().UnixNano())) // #nosec G404 - using math/rand is fine for test port selection for range 20 { // avoid infinite loops p := rng.Intn(maxPort-minPort+1) + minPort l, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", p)) if err == nil { _ = l.Close() return p, nil // looks free } } return 0, fmt.Errorf("could not find free port in range %d-%d", minPort, maxPort) } ================================================ FILE: pkg/util/overwrite_yaml.go ================================================ package util import ( "fmt" "go.yaml.in/yaml/v4" ) func OverwriteYAML(sourceYaml []byte, destinationYaml []byte) ([]byte, error) { var sourceNode yaml.Node err := yaml.Unmarshal(sourceYaml, &sourceNode) if err != nil { return nil, err } var destinationNode yaml.Node err = yaml.Unmarshal(destinationYaml, &destinationNode) if err != nil { return nil, err } err = traverseAndCompare(sourceNode.Content[0], destinationNode.Content[0], "") if err != nil { return nil, err } return yaml.Marshal(&destinationNode) } func traverseAndCompare(sourceNode, destinationNode *yaml.Node, path string) error { if sourceNode.Kind != destinationNode.Kind { return fmt.Errorf("Type mismatch at %s: %s vs %s\n", path, nodeKindToString(sourceNode.Kind), nodeKindToString(destinationNode.Kind)) } sourceNode.LineComment = destinationNode.LineComment sourceNode.HeadComment = destinationNode.HeadComment sourceNode.FootComment = destinationNode.FootComment switch sourceNode.Kind { case yaml.ScalarNode: if sourceNode.Value != destinationNode.Value { destinationNode.Value = sourceNode.Value } case yaml.MappingNode: map1 := mapNodeToMap(sourceNode) map2 := mapNodeToMap(destinationNode) allKeys := getAllKeys(map1, map2) for _, key := range allKeys { var childPath string if path == "" { childPath = key } else { childPath = path + "." + key } sourceKVNodeChild, ok1 := map1[key] destinationKVNodeChild, ok2 := map2[key] switch { case !ok1: // We need to remove this node NewContent := []*yaml.Node{} for _, node := range destinationNode.Content { if node == destinationKVNodeChild[0] || node == destinationKVNodeChild[1] { continue } NewContent = append(NewContent, node) } destinationNode.Content = NewContent case !ok2: // We need to add this node destinationNode.Content = append(destinationNode.Content, sourceKVNodeChild...) default: err := traverseAndCompare(sourceKVNodeChild[1], destinationKVNodeChild[1], childPath) if err != nil { return err } } } case yaml.SequenceNode: sourceLen := len(sourceNode.Content) destinationLen := len(destinationNode.Content) maxLen := max(destinationLen, sourceLen) for i := range maxLen { childPath := fmt.Sprintf("%s[%d]", path, i) if i >= destinationLen { destinationNode.Content = append(destinationNode.Content, sourceNode.Content[i]) } else if i < sourceLen { err := traverseAndCompare(sourceNode.Content[i], destinationNode.Content[i], childPath) if err != nil { return err } } } } return nil } func mapNodeToMap(node *yaml.Node) map[string][]*yaml.Node { result := make(map[string][]*yaml.Node) for i := 0; i < len(node.Content); i += 2 { keyNode := node.Content[i] valueNode := node.Content[i+1] result[keyNode.Value] = []*yaml.Node{keyNode, valueNode} } return result } func getAllKeys(map1, map2 map[string][]*yaml.Node) []string { keys := make(map[string]bool) for key := range map1 { keys[key] = true } for key := range map2 { keys[key] = true } var keyList []string for key := range keys { keyList = append(keyList, key) } return keyList } func nodeKindToString(kind yaml.Kind) string { switch kind { case yaml.ScalarNode: return "Scalar" case yaml.MappingNode: return "Mapping" case yaml.SequenceNode: return "Sequence" default: return "Unknown" } } ================================================ FILE: pkg/util/overwrite_yaml_test.go ================================================ package util import ( "testing" "github.com/stretchr/testify/require" ) /* func TestOverwriteYAML(t *testing.T) { var yamlData1 = `build: command: "build.sh" image: "my-image" predict: "predict.py" train: "train.py" concurrency: max: 5 environment: - "VAR1=value1" - "VAR2=value2" ` var yamlData2 = `build: command: "build_new.sh" image: "new-image" predict: "new_predict.py" concurrency: max: 10 environment: - "VAR1=new_value1" - "VAR3=value3" ` content, err := OverwriteYAML([]byte(yamlData1), []byte(yamlData2)) require.NoError(t, err) require.Equal(t, yamlData1, string(content)) } */ func TestOverwriteYAMLWithComments(t *testing.T) { var sourceYaml = `build: command: "build_new.sh" image: "new-image" predict: "new_predict.py" concurrency: max: 10 environment: - "VAR1=new_value1" - "VAR3=value3" ` var destinationYaml = `# This here is a YAML Comment build: command: "build.sh" image: "my-image" predict: "predict.py" train: "train.py" concurrency: max: 5 environment: - "VAR1=value1" - "VAR2=value2" ` expected := `# This here is a YAML Comment build: command: "build_new.sh" image: "new-image" predict: "new_predict.py" concurrency: max: 10 environment: - "VAR1=new_value1" - "VAR3=value3" ` content, err := OverwriteYAML([]byte(sourceYaml), []byte(destinationYaml)) require.NoError(t, err) require.Equal(t, expected, string(content)) } func TestOverwriteYAMLWithLineComments(t *testing.T) { var sourceYaml = `build: command: "build_new.sh" image: "new-image" predict: "new_predict.py" concurrency: max: 10 environment: - "VAR1=new_value1" - "VAR3=value3" ` var destinationYaml = `# This here is a YAML Comment build: # And we put this comment here for good measure command: "build.sh" image: "my-image" predict: "predict.py" train: "train.py" concurrency: max: 5 environment: - "VAR1=value1" - "VAR2=value2" ` expected := `# This here is a YAML Comment build: # And we put this comment here for good measure command: "build_new.sh" image: "new-image" predict: "new_predict.py" concurrency: max: 10 environment: - "VAR1=new_value1" - "VAR3=value3" ` content, err := OverwriteYAML([]byte(sourceYaml), []byte(destinationYaml)) require.NoError(t, err) require.Equal(t, expected, string(content)) } func TestStep1XYaml(t *testing.T) { var sourceYaml = `build: gpu: true system_packages: - "libgl1-mesa-glx" - "libglib2.0-0" python_version: "3.11" python_requirements: requirements.txt predict: "predict.py:Predictor" ` var destinationYaml = `# Configuration for Cog ⚙️ # Reference: https://cog.run/yaml build: # set to true if your model requires a GPU gpu: true # a list of ubuntu apt packages to install system_packages: - "libgl1-mesa-glx" - "libglib2.0-0" # python version in the form '3.11' or '3.11.4' python_version: "3.11" # path to a Python requirements.txt file python_requirements: requirements.txt # commands run after the environment is setup run: - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/latest/download/pget_$(uname -s)_$(uname -m)" - chmod +x /usr/local/bin/pget # predict.py defines how predictions are run on your model predict: "predict.py:Predictor"` expected := `# Configuration for Cog ⚙️ # Reference: https://cog.run/yaml build: # set to true if your model requires a GPU gpu: true # a list of ubuntu apt packages to install system_packages: - "libgl1-mesa-glx" - "libglib2.0-0" # python version in the form '3.11' or '3.11.4' python_version: "3.11" # path to a Python requirements.txt file python_requirements: requirements.txt # predict.py defines how predictions are run on your model predict: "predict.py:Predictor" ` content, err := OverwriteYAML([]byte(sourceYaml), []byte(destinationYaml)) require.NoError(t, err) require.Equal(t, expected, string(content)) } ================================================ FILE: pkg/util/platform.go ================================================ package util // IsAppleSiliconMac returns whether the current machine is an Apple silicon computer, such as the MacBook Air with M1. func IsAppleSiliconMac(goos string, goarch string) bool { return goos == "darwin" && goarch == "arm64" } ================================================ FILE: pkg/util/ringbuffer.go ================================================ package util import ( "io" "sync" ) // RingBufferWriter is a writer that writes to an underlying writer and also maintains // a ring buffer of the last N bytes written. type RingBufferWriter struct { writer io.Writer buffer []byte size int pos int mu sync.Mutex } // NewRingBufferWriter creates a new RingBufferWriter that writes to w and maintains // a buffer of the last size bytes. func NewRingBufferWriter(w io.Writer, size int) *RingBufferWriter { return &RingBufferWriter{ writer: w, buffer: make([]byte, size), size: size, } } // Write implements io.Writer interface func (w *RingBufferWriter) Write(p []byte) (n int, err error) { w.mu.Lock() defer w.mu.Unlock() // Write to underlying writer n, err = w.writer.Write(p) if err != nil { return n, err } // Update ring buffer for _, b := range p { w.buffer[w.pos] = b w.pos = (w.pos + 1) % w.size } return n, nil } // String returns the contents of the ring buffer as a string func (w *RingBufferWriter) String() string { w.mu.Lock() defer w.mu.Unlock() // If buffer is not full, return what we have if w.pos < w.size { return string(w.buffer[:w.pos]) } // Otherwise, return the last size bytes return string(w.buffer[w.pos:]) + string(w.buffer[:w.pos]) } ================================================ FILE: pkg/util/shell/net.go ================================================ package shell import ( "fmt" "net" "net/http" "strconv" "time" "github.com/replicate/cog/pkg/util/console" ) func WaitForPort(port int, timeout time.Duration) error { start := time.Now() for { if PortIsOpen(port) { return nil } now := time.Now() if now.Sub(start) > timeout { return fmt.Errorf("Timed out") } time.Sleep(100 * time.Millisecond) } } func WaitForHTTPOK(url string, timeout time.Duration) error { start := time.Now() console.Debugf("Waiting for %s to become accessible", url) for { now := time.Now() if now.Sub(start) > timeout { return fmt.Errorf("Timed out") } time.Sleep(100 * time.Millisecond) resp, err := http.Get(url) //#nosec G107 if err != nil { continue } if resp.StatusCode != http.StatusOK { continue } console.Debugf("Got successful response from %s", url) return nil } } func PortIsOpen(port int) bool { conn, err := net.DialTimeout("tcp", net.JoinHostPort("", strconv.Itoa(port)), 100*time.Millisecond) if conn != nil { _ = conn.Close() } return err == nil } ================================================ FILE: pkg/util/shell/pipes.go ================================================ package shell import ( "bufio" "io" ) type PipeFunc func() (io.ReadCloser, error) type LogFunc func(args ...any) func PipeTo(pf PipeFunc, lf LogFunc) (done chan struct{}, err error) { done = make(chan struct{}) pipe, err := pf() if err != nil { return nil, err } scanner := bufio.NewScanner(pipe) go func() { for scanner.Scan() { line := scanner.Text() lf(line) } done <- struct{}{} }() return done, nil } ================================================ FILE: pkg/util/version/version.go ================================================ package version import ( "fmt" "strconv" "strings" ) type Version struct { Major int Minor int Patch *int Metadata string } func NewVersion(s string) (version *Version, err error) { plusParts := strings.SplitN(s, "+", 2) number := plusParts[0] parts := strings.Split(number, ".") if len(parts) > 3 { return nil, fmt.Errorf("Version must not have more than 3 parts: %s", s) } version = new(Version) version.Major, err = strconv.Atoi(parts[0]) if err != nil { return nil, fmt.Errorf("Invalid major version %s: %w", parts[0], err) } if len(parts) >= 2 { version.Minor, err = strconv.Atoi(parts[1]) if err != nil { return nil, fmt.Errorf("Invalid minor version %s: %w", parts[1], err) } } if len(parts) >= 3 { patch, err := strconv.Atoi(parts[2]) if err != nil { return nil, fmt.Errorf("Invalid patch version %s: %w", parts[2], err) } // We assign a pointer here to handle cases where the patch version is not // explicitly assigned and we need to compare versions without patches to // versions with patches. version.Patch = new(int) *version.Patch = patch } if len(plusParts) == 2 { version.Metadata = plusParts[1] } return version, nil } func MustVersion(s string) *Version { version, err := NewVersion(s) if err != nil { panic(fmt.Sprintf("%s", err)) } return version } func (v *Version) Greater(other *Version) bool { switch { case v.Major > other.Major: return true case v.Major == other.Major && v.Minor > other.Minor: return true case v.Major == other.Major && v.Minor == other.Minor && v.PatchVersion() > other.PatchVersion(): return true default: return false } } func (v *Version) Equal(other *Version) bool { return v.Major == other.Major && v.Minor == other.Minor && v.PatchVersion() == other.PatchVersion() && v.Metadata == other.Metadata } func (v *Version) GreaterOrEqual(other *Version) bool { return v.Greater(other) || v.Equal(other) } func (v *Version) EqualMinor(other *Version) bool { return v.Major == other.Major && v.Minor == other.Minor } func (v *Version) HasPatch() bool { return v.Patch != nil } func (v *Version) PatchVersion() int { if v.Patch == nil { return 0 } return *v.Patch } func Equal(v1 string, v2 string) bool { return MustVersion(v1).Equal(MustVersion(v2)) } func EqualMinor(v1 string, v2 string) bool { return MustVersion(v1).EqualMinor(MustVersion(v2)) } func Greater(v1 string, v2 string) bool { return MustVersion(v1).Greater(MustVersion(v2)) } func GreaterOrEqual(v1 string, v2 string) bool { leftVersion, err := NewVersion(v1) if err != nil { return v1 == v2 } rightVersion, err := NewVersion(v2) if err != nil { return v1 == v2 } return leftVersion.GreaterOrEqual(rightVersion) } func (v *Version) Matches(other *Version) bool { switch { case v.Major != other.Major: return false case v.Minor != other.Minor: return false case v.HasPatch() && other.HasPatch() && *v.Patch != *other.Patch: return false default: return true } } func Matches(v1 string, v2 string) bool { return MustVersion(v1).Matches(MustVersion(v2)) } func StripPatch(v string) string { ver := MustVersion(v) return fmt.Sprintf("%d.%d", ver.Major, ver.Minor) } func StripModifier(v string) string { modifierSplit := strings.Split(v, "+") return modifierSplit[0] } ================================================ FILE: pkg/util/version/version_test.go ================================================ package version import ( "testing" "github.com/stretchr/testify/require" ) func TestVersionEqual(t *testing.T) { for _, tt := range []struct { v1 string v2 string equal bool }{ {"1", "1", true}, {"1.0", "1", true}, {"1", "1.0", true}, {"1.0.0", "1", true}, {"1.0.0", "1.0", true}, {"1.0.0", "1.0.0", true}, {"1.0.0+foo", "1.0.0", false}, {"11.2", "11.2.0", true}, {"1", "2", false}, {"1", "0", false}, {"1.1", "1", false}, {"1.0.1", "1", false}, {"1.1.0", "1", false}, } { not := "" if tt.equal { not = "not " } require.Equal(t, tt.equal, Equal(tt.v1, tt.v2), "%s is %sequal to %s", tt.v1, not, tt.v2) } } func TestVersionGreater(t *testing.T) { for _, tt := range []struct { v1 string v2 string greater bool }{ {"1", "1", false}, {"1.0", "1", false}, {"1", "1.0", false}, {"1.0.0", "1", false}, {"1.0.0", "1.0", false}, {"11.2", "11.2.0", false}, {"1", "2", false}, {"1", "0", true}, {"1.1", "1", true}, {"1.0.1", "1", true}, {"1.1.0", "1", true}, {"1.0.0+foo", "1", false}, } { not := "" if tt.greater { not = "not " } require.Equal(t, tt.greater, Greater(tt.v1, tt.v2), "%s is %sgreater than %s", tt.v1, not, tt.v2) } } func TestVersionStripModifier(t *testing.T) { version := "2.3.1" versionWithModifier := version + "+cu118" versionWithoutModifier := StripModifier(versionWithModifier) require.Equal(t, versionWithoutModifier, version) } func TestVersionMatches(t *testing.T) { version := "2.3" matchVersion := "2.3.2" require.True(t, Matches(version, matchVersion)) } func TestVersionMatchesModifier(t *testing.T) { version := "2.3" matchVersion := "2.3.2+cu118" require.True(t, Matches(version, matchVersion)) } func TestGreaterThanOrEqualToWithInvalidPatch(t *testing.T) { leftVersion := "1.1.0b2" rightVersion := "1.1.0b2" require.True(t, GreaterOrEqual(leftVersion, rightVersion)) } ================================================ FILE: pkg/web/client.go ================================================ package web import ( "bytes" "context" "encoding/json" "errors" "fmt" "net/http" "net/url" "os" "strconv" "strings" "sync" "time" "github.com/docker/docker/api/types/image" "github.com/replicate/go/types" "golang.org/x/sync/errgroup" "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/docker/command" "github.com/replicate/cog/pkg/env" r8_errors "github.com/replicate/cog/pkg/errors" "github.com/replicate/cog/pkg/global" "github.com/replicate/cog/pkg/util" "github.com/replicate/cog/pkg/util/console" ) const ( pushStartURLPath = "/api/models/push-start" startChallengeURLPath = "/api/models/file-challenge" ) var ( ErrorBadResponseNewVersionEndpoint = errors.New("Bad response from new version endpoint") ErrorBadResponsePushStartEndpoint = errors.New("Bad response from push start endpoint") ErrorBadResponseInitiateChallengeEndpoint = errors.New("Bad response from start file challenge endpoint") ErrorNoSuchDigest = errors.New("No digest submitted matches the digest requested") ) type Client struct { dockerCommand command.Command client *http.Client } type File struct { Path string `json:"path"` Digest string `json:"digest"` Size int64 `json:"size"` } type Env struct { CogGpu string `json:"COG_GPU"` CogPredictTypeStub string `json:"COG_PREDICT_TYPE_STUB"` CogTrainTypeStub string `json:"COG_TRAIN_TYPE_STUB"` CogPredictCodeStrip string `json:"COG_PREDICT_CODE_STRIP"` CogTrainCodeStrip string `json:"COG_TRAIN_CODE_STRIP"` R8CogVersion string `json:"R8_COG_VERSION"` R8CudaVersion string `json:"R8_CUDA_VERSION"` R8CudnnVersion string `json:"R8_CUDNN_VERSION"` R8PythonVersion string `json:"R8_PYTHON_VERSION"` R8TorchVersion string `json:"R8_TORCH_VERSION"` } type RuntimeConfig struct { Weights []File `json:"weights"` Files []File `json:"files"` Env Env `json:"env"` } type Version struct { Annotations map[string]string `json:"annotations"` CogConfig config.Config `json:"cog_config"` CogVersion string `json:"cog_version"` OpenAPISchema map[string]any `json:"openapi_schema"` RuntimeConfig RuntimeConfig `json:"runtime_config"` Virtual bool `json:"virtual"` PushID string `json:"push_id"` Challenges []FileChallengeAnswer `json:"file_challenges"` } type FileChallengeRequest struct { Digest string `json:"digest"` FileType string `json:"file_type"` } type FileChallenge struct { Salt string `json:"salt"` Start int `json:"byte_start"` End int `json:"byte_end"` Digest string `json:"digest"` ID string `json:"challenge_id"` } type FileChallengeAnswer struct { Digest string `json:"digest"` Hash string `json:"hash"` ChallengeID string `json:"challenge_id"` } type VersionError struct { Detail string `json:"detail"` Pointer string `json:"pointer"` } type VersionErrors struct { Detail string `json:"detail"` Errors []VersionError `json:"errors"` Status int `json:"status"` Title string `json:"title"` } type VersionCreate struct { Version string `json:"version"` } type CogKey struct { Key string `json:"key"` ExpiresAt string `json:"expires_at"` } type Keys struct { Cog CogKey `json:"cog"` } type TokenData struct { Keys Keys `json:"keys"` } func NewClient(dockerCommand command.Command, client *http.Client) *Client { return &Client{ dockerCommand: dockerCommand, client: client, } } func (c *Client) PostPushStart(ctx context.Context, pushID string, buildTime time.Duration) error { jsonBody := map[string]any{ "push_id": pushID, "build_duration": types.Duration(buildTime).String(), "push_start_time": time.Now().UTC(), } jsonData, err := json.Marshal(jsonBody) if err != nil { return util.WrapError(err, "failed to marshal JSON for build start") } url := webBaseURL() url.Path = pushStartURLPath req, err := http.NewRequestWithContext(ctx, http.MethodPost, url.String(), bytes.NewReader(jsonData)) if err != nil { return err } resp, err := c.client.Do(req) //nolint:gosec // G704: URL from configured endpoint if err != nil { return err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return util.WrapError(ErrorBadResponsePushStartEndpoint, strconv.Itoa(resp.StatusCode)) } return nil } func (c *Client) PostNewVersion(ctx context.Context, image string, weights []File, files []File, fileChallenges []FileChallengeAnswer) error { version, err := c.versionFromManifest(ctx, image, weights, files, fileChallenges) if err != nil { return util.WrapError(err, "failed to build new version from manifest") } jsonData, err := json.Marshal(version) if err != nil { return util.WrapError(err, "failed to marshal JSON for new version") } versionUrl, err := newVersionURL(image) if err != nil { return err } req, err := http.NewRequestWithContext(ctx, http.MethodPost, versionUrl.String(), bytes.NewReader(jsonData)) if err != nil { return err } resp, err := c.client.Do(req) //nolint:gosec // G704: URL from configured endpoint if err != nil { return err } defer resp.Body.Close() decoder := json.NewDecoder(resp.Body) if resp.StatusCode != http.StatusCreated { if resp.StatusCode == http.StatusBadRequest { var versionErrors VersionErrors err = decoder.Decode(&versionErrors) if err != nil { return err } return errors.New(versionErrors.Detail) } return util.WrapError(ErrorBadResponseNewVersionEndpoint, strconv.Itoa(resp.StatusCode)) } var versionCreate VersionCreate err = decoder.Decode(&versionCreate) if err != nil { return err } console.Infof("New Version: %s", versionCreate.Version) return nil } func (c *Client) FetchAPIToken(ctx context.Context, entity string) (string, error) { tokenUrl := tokenURL(entity) req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenUrl.String(), nil) if err != nil { return "", err } tokenResp, err := c.client.Do(req) //nolint:gosec // G704: URL from configured endpoint if err != nil { return "", err } defer tokenResp.Body.Close() if tokenResp.StatusCode != http.StatusOK { return "", fmt.Errorf("Bad response: %s attempting to exchange tokens", strconv.Itoa(tokenResp.StatusCode)) } var tokenData TokenData err = json.NewDecoder(tokenResp.Body).Decode(&tokenData) if err != nil { return "", err } return tokenData.Keys.Cog.Key, nil } func (c *Client) versionFromManifest(ctx context.Context, image string, weights []File, files []File, fileChallenges []FileChallengeAnswer) (*Version, error) { manifest, err := c.dockerCommand.Inspect(ctx, image) if err != nil { return nil, util.WrapError(err, "failed to inspect docker image") } cogConfig, err := readCogConfig(manifest) if err != nil { return nil, err } var openAPISchema map[string]any err = json.Unmarshal([]byte(manifest.Config.Labels[command.CogOpenAPISchemaLabelKey]), &openAPISchema) if err != nil { return nil, util.WrapError(err, "failed to get OpenAPI schema from docker image") } predictCode, err := stripCodeFromStub(cogConfig, true) if err != nil { return nil, err } trainCode, err := stripCodeFromStub(cogConfig, false) if err != nil { return nil, err } var cogGPU int if cogConfig.Build.GPU { cogGPU = 1 } cogVersion := "" torchVersion := "" cudaVersion := "" cudnnVersion := "" pythonVersion := "" for _, env := range manifest.Config.Env { envName, envValue, found := strings.Cut(env, "=") if !found { continue } switch envName { case command.R8CogVersionEnvVarName: cogVersion = envValue case command.R8TorchVersionEnvVarName: torchVersion = envValue case command.R8CudaVersionEnvVarName: cudaVersion = envValue case command.R8CudnnVersionEnvVarName: cudnnVersion = envValue case command.R8PythonVersionEnvVarName: pythonVersion = envValue } } env := Env{ CogGpu: strconv.Itoa(cogGPU), CogPredictTypeStub: cogConfig.Predict, CogTrainTypeStub: cogConfig.Train, CogPredictCodeStrip: predictCode, CogTrainCodeStrip: trainCode, R8CogVersion: cogVersion, R8CudaVersion: cudaVersion, R8CudnnVersion: cudnnVersion, R8PythonVersion: pythonVersion, R8TorchVersion: torchVersion, } prefixedFiles := make([]File, len(files)) for i, file := range files { prefixedFiles[i] = File{ Path: file.Path, Digest: "sha256:" + file.Digest, Size: file.Size, } } prefixedWeights := make([]File, len(weights)) for i, file := range weights { prefixedWeights[i] = File{ Path: file.Path, Digest: "sha256:" + file.Digest, Size: file.Size, } } // Digests should match whatever digest we are sending in as the // runtime config digests for i, challenge := range fileChallenges { fileChallenges[i] = FileChallengeAnswer{ Digest: fmt.Sprintf("sha256:%s", challenge.Digest), Hash: challenge.Hash, ChallengeID: challenge.ChallengeID, } } runtimeConfig := RuntimeConfig{ Weights: prefixedWeights, Files: prefixedFiles, Env: env, } version := Version{ Annotations: manifest.Config.Labels, CogConfig: *cogConfig, CogVersion: manifest.Config.Labels[command.CogVersionLabelKey], OpenAPISchema: openAPISchema, RuntimeConfig: runtimeConfig, Virtual: true, Challenges: fileChallenges, } if pushID, ok := manifest.Config.Labels["run.cog.push_id"]; ok { version.PushID = pushID } return &version, nil } func (c *Client) InitiateAndDoFileChallenge(ctx context.Context, weights []File, files []File) ([]FileChallengeAnswer, error) { var challengeAnswers []FileChallengeAnswer var mu sync.Mutex var wg errgroup.Group for _, item := range files { wg.Go(func() error { answer, err := c.doSingleFileChallenge(ctx, item, "files") if err != nil { return util.WrapError(err, fmt.Sprintf("do file challenge for digest %s", item.Digest)) } mu.Lock() challengeAnswers = append(challengeAnswers, answer) mu.Unlock() return nil }) } for _, item := range weights { wg.Go(func() error { answer, err := c.doSingleFileChallenge(ctx, item, "weights") if err != nil { return util.WrapError(err, fmt.Sprintf("do file challenge for digest %s", item.Digest)) } mu.Lock() challengeAnswers = append(challengeAnswers, answer) mu.Unlock() return nil }) } if err := wg.Wait(); err != nil { return nil, util.WrapError(err, "do file challenges") } return challengeAnswers, nil } // doSingleFileChallenge does a single file challenge. This is expected to be called in a goroutine. func (c *Client) doSingleFileChallenge(ctx context.Context, file File, fileType string) (FileChallengeAnswer, error) { initiateChallengePath := webBaseURL() initiateChallengePath.Path = startChallengeURLPath answer := FileChallengeAnswer{} jsonData, err := json.Marshal(FileChallengeRequest{ Digest: file.Digest, FileType: fileType, }) if err != nil { return answer, util.WrapError(err, "encode request JSON") } req, err := http.NewRequestWithContext(ctx, http.MethodPost, initiateChallengePath.String(), bytes.NewReader(jsonData)) if err != nil { return answer, util.WrapError(err, "build HTTP request") } resp, err := c.client.Do(req) //nolint:gosec // G704: URL from configured endpoint if err != nil { return answer, util.WrapError(err, "do HTTP request") } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return answer, util.WrapError(ErrorBadResponseInitiateChallengeEndpoint, strconv.Itoa(resp.StatusCode)) } var challenge FileChallenge err = json.NewDecoder(resp.Body).Decode(&challenge) if err != nil { return answer, util.WrapError(err, "decode response body") } ans, err := util.SHA256HashFileWithSaltAndRange(file.Path, challenge.Start, challenge.End, challenge.Salt) if err != nil { return answer, util.WrapError(err, "hash file") } return FileChallengeAnswer{ Digest: file.Digest, Hash: ans, ChallengeID: challenge.ID, }, nil } func newVersionURL(image string) (url.URL, error) { imageComponents := strings.Split(image, "/") newVersionUrl := webBaseURL() if len(imageComponents) != 3 || imageComponents[0] != global.ReplicateRegistryHost { return newVersionUrl, r8_errors.ErrorBadRegistryURL } newVersionUrl.Path = strings.Join([]string{"", "api", "models", imageComponents[1], imageComponents[2], "versions"}, "/") return newVersionUrl, nil } func tokenURL(entity string) url.URL { newVersionUrl := webBaseURL() newVersionUrl.Path = strings.Join([]string{"", "api", "token", entity}, "/") return newVersionUrl } func webBaseURL() url.URL { return url.URL{ Scheme: env.SchemeFromEnvironment(), Host: env.WebHostFromEnvironment(), } } func codeFileName(cogConfig *config.Config, isPredict bool) (string, error) { var stubComponents []string if isPredict { if cogConfig.Predict == "" { return "", nil } stubComponents = strings.Split(cogConfig.Predict, ":") } else { if cogConfig.Train == "" { return "", nil } stubComponents = strings.Split(cogConfig.Train, ":") } if len(stubComponents) < 2 { return "", errors.New("Code stub components has less than 2 entries.") } return stubComponents[0], nil } func readCode(cogConfig *config.Config, isPredict bool) (string, string, error) { codeFile, err := codeFileName(cogConfig, isPredict) if err != nil { return "", codeFile, err } if codeFile == "" { return "", "", nil } b, err := os.ReadFile(codeFile) if err != nil { return "", codeFile, err } return string(b), codeFile, nil } func stripCodeFromStub(cogConfig *config.Config, isPredict bool) (string, error) { // TODO: We should attempt to strip the code here, in python this is done like so: // from cog.code_xforms import strip_model_source_code // code = strip_model_source_code( // util.read_file(os.path.join(fs, 'src', base_file)), // [base_class], // ['predict', 'train'], // ) // Currently the behavior of the code strip attempts to strip, and if it can't it // loads the whole file in. Here we just load the whole file in. // We should figure out a way to call cog python from here to fulfill this. // It could be a good idea to do this in the layer functions where we do pip freeze // et al. code, _, err := readCode(cogConfig, isPredict) return code, err } func readCogConfig(manifest *image.InspectResponse) (*config.Config, error) { var cogConfig config.Config err := json.Unmarshal([]byte(manifest.Config.Labels[command.CogConfigLabelKey]), &cogConfig) if err != nil { return nil, util.WrapError(err, "failed to get cog config from docker image") } return &cogConfig, nil } ================================================ FILE: pkg/web/client_test.go ================================================ package web import ( "encoding/json" "net/http" "net/http/httptest" "net/url" "os" "path/filepath" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/docker/dockertest" "github.com/replicate/cog/pkg/env" ) func TestPostNewVersion(t *testing.T) { // Setup mock http server server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { output := "{\"version\":\"user/test:53c740f17ce88a61c3da5b0c20e48fd48e2da537c3a1276dec63ab11fbad6bcb\"}" w.WriteHeader(http.StatusCreated) w.Write([]byte(output)) })) defer server.Close() url, err := url.Parse(server.URL) require.NoError(t, err) t.Setenv(env.SchemeEnvVarName, url.Scheme) t.Setenv(env.WebHostEnvVarName, url.Host) dir := t.TempDir() // Create mock predict predictPyPath := filepath.Join(dir, "predict.py") handle, err := os.Create(predictPyPath) require.NoError(t, err) handle.WriteString("import cog") dockertest.MockCogConfig = "{\"build\":{\"python_version\":\"3.12\",\"python_packages\":[\"torch==2.5.0\",\"beautifulsoup4==4.12.3\"],\"system_packages\":[\"git\"]},\"image\":\"test\",\"predict\":\"" + predictPyPath + ":Predictor\"}" // Setup mock command command := dockertest.NewMockCommand() client := NewClient(command, http.DefaultClient) err = client.PostNewVersion(t.Context(), "r8.im/user/test", []File{}, []File{}, nil) require.NoError(t, err) } func TestVersionFromManifest(t *testing.T) { // Setup mock command command := dockertest.NewMockCommand() // Create mock predict dir := t.TempDir() predictPyPath := filepath.Join(dir, "predict.py") handle, err := os.Create(predictPyPath) require.NoError(t, err) handle.WriteString("import cog") dockertest.MockCogConfig = "{\"build\":{\"python_version\":\"3.12\",\"python_packages\":[\"torch==2.5.0\",\"beautifulsoup4==4.12.3\"],\"system_packages\":[\"git\"]},\"image\":\"test\",\"predict\":\"" + predictPyPath + ":Predictor\"}" dockertest.MockOpenAPISchema = "{\"test\": true}" client := NewClient(command, http.DefaultClient) version, err := client.versionFromManifest(t.Context(), "r8.im/user/test", []File{}, []File{}, nil) require.NoError(t, err) var openAPISchema map[string]any err = json.Unmarshal([]byte(dockertest.MockOpenAPISchema), &openAPISchema) require.NoError(t, err) var cogConfig config.Config err = json.Unmarshal([]byte(dockertest.MockCogConfig), &cogConfig) require.NoError(t, err) require.Equal(t, openAPISchema, version.OpenAPISchema) require.Equal(t, cogConfig, version.CogConfig) } func TestVersionURLErrorWithoutR8IMPrefix(t *testing.T) { _, err := newVersionURL("docker.com/thing/thing") require.Error(t, err) } func TestVersionURLErrorWithout3Components(t *testing.T) { _, err := newVersionURL("username/test") require.Error(t, err) } func TestDoFileChallenge(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "test.tmp") d1 := []byte("hello\nreplicate\nhello\n") err := os.WriteFile(path, d1, 0o644) require.NoError(t, err) path2 := filepath.Join(dir, "test2.tmp") d2 := []byte("hello\nreplicate\nhello\n") err = os.WriteFile(path2, d2, 0o644) require.NoError(t, err) files := []File{ { Path: path, Digest: "abc", Size: 22, }, } weights := []File{ { Path: path, Digest: "def", Size: 22, }, } abcChallenge := FileChallenge{ ID: "abc", Digest: "abc", Start: 0, End: 6, Salt: "go\n", } defChallenge := FileChallenge{ ID: "def", Digest: "def", Start: 16, End: 22, Salt: "go\n", } // Setup mock http server server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) var challengeRequest FileChallengeRequest // Ignore errors - make sure the test is set up correctly json.NewDecoder(r.Body).Decode(&challengeRequest) if challengeRequest.Digest == "abc" { body, _ := json.Marshal(abcChallenge) w.Write(body) } else { body, _ := json.Marshal(defChallenge) w.Write(body) } })) defer server.Close() url, err := url.Parse(server.URL) require.NoError(t, err) t.Setenv(env.SchemeEnvVarName, url.Scheme) t.Setenv(env.WebHostEnvVarName, url.Host) // Setup mock command command := dockertest.NewMockCommand() client := NewClient(command, http.DefaultClient) response, err := client.InitiateAndDoFileChallenge(t.Context(), weights, files) require.NoError(t, err) assert.ElementsMatch(t, response, []FileChallengeAnswer{ { ChallengeID: "abc", Digest: "abc", Hash: "43d250d92b5dbb47f75208de8e9a9a321d23e85eed0dc3d5dfa83bc3cc5aa68c", }, { ChallengeID: "def", Digest: "def", Hash: "43d250d92b5dbb47f75208de8e9a9a321d23e85eed0dc3d5dfa83bc3cc5aa68c", }, }) } func TestFetchToken(t *testing.T) { // Setup mock http server server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/api/token/user": // Mock token exchange response //nolint:gosec tokenResponse := `{ "keys": { "cog": { "key": "test-api-token", "expires_at": "2024-12-31T23:59:59Z" } } }` w.WriteHeader(http.StatusOK) w.Write([]byte(tokenResponse)) default: w.WriteHeader(http.StatusNotFound) } })) defer server.Close() url, err := url.Parse(server.URL) require.NoError(t, err) t.Setenv(env.SchemeEnvVarName, url.Scheme) t.Setenv(env.WebHostEnvVarName, url.Host) // Setup mock command command := dockertest.NewMockCommand() client := NewClient(command, http.DefaultClient) token, err := client.FetchAPIToken(t.Context(), "user") require.NoError(t, err) require.Equal(t, "test-api-token", token) } ================================================ FILE: pkg/weights/manifest.go ================================================ package weights import ( "encoding/binary" "encoding/hex" "encoding/json" "fmt" "hash/crc32" "io" "os" "path" ) // Manifest contains metadata about weights files in a model type Manifest struct { Files map[string]Metadata `json:"files"` } // Metadata contains information about a file type Metadata struct { // CRC32 is the CRC32 checksum of the file encoded as a hexadecimal string CRC32 string `json:"crc32"` } // NewManifest creates a new manifest func NewManifest() *Manifest { return &Manifest{} } // LoadManifest loads a manifest from a file func LoadManifest(filename string) (*Manifest, error) { if _, err := os.Stat(filename); err != nil { return nil, err } file, err := os.Open(filename) if err != nil { return nil, err } defer file.Close() m := &Manifest{} decoder := json.NewDecoder(file) if err := decoder.Decode(m); err != nil { return nil, err } return m, nil } // Save saves a manifest to a file func (m *Manifest) Save(filename string) error { if err := os.MkdirAll(path.Dir(filename), 0o755); err != nil { return err } file, err := os.Create(filename) if err != nil { return err } defer file.Close() encoder := json.NewEncoder(file) return encoder.Encode(m) } // Equal compares the files in two manifests for strict equality func (m *Manifest) Equal(other *Manifest) bool { if len(m.Files) != len(other.Files) { return false } for path, crc32 := range m.Files { if otherCrc32, ok := other.Files[path]; !ok || otherCrc32 != crc32 { return false } } return true } // AddFile adds a file to the manifest, calculating its CRC32 checksum func (m *Manifest) AddFile(path string) error { crc32Algo := crc32.NewIEEE() // generate checksum of file file, err := os.Open(path) if err != nil { return fmt.Errorf("failed to open file %s: %w", path, err) } defer file.Close() _, err = io.Copy(crc32Algo, file) if err != nil { return fmt.Errorf("failed to generate checksum of file %s: %w", path, err) } checksum := crc32Algo.Sum32() // encode checksum as hexadecimal string bytes := make([]byte, 4) binary.LittleEndian.PutUint32(bytes, checksum) encoded := hex.EncodeToString(bytes) if m.Files == nil { m.Files = make(map[string]Metadata) } m.Files[path] = Metadata{ CRC32: encoded, } return nil } ================================================ FILE: pkg/weights/weights.go ================================================ package weights import ( "os" "path/filepath" "slices" "sort" "strings" ) var prefixesToIgnore = []string{".cog", ".git", "__pycache__"} var suffixesToIgnore = []string{ ".py", ".ipynb", ".whl", // Python projects ".jpg", ".jpeg", ".png", ".webp", ".svg", ".gif", ".avif", ".heic", // images ".mp4", ".mov", ".avi", ".wmv", ".mkv", ".webm", // videos ".mp3", ".wav", ".ogg", ".flac", ".aac", ".m4a", // audio files ".log", // logs } // FileWalker is a function type that walks the file tree rooted at root, calling walkFn for each file or directory in the tree, including root. type FileWalker func(root string, walkFn filepath.WalkFunc) error func FindWeights(fw FileWalker) ([]string, []string, error) { var files []string var codeFiles []string err := fw(".", func(path string, info os.FileInfo, err error) error { if err != nil { return err } if info.IsDir() { return nil } if isGitFile(path) { return nil } if isCodeFile(path) { codeFiles = append(codeFiles, path) return nil } if info.Size() < sizeThreshold { return nil } if isNonModelFiles(path) { return nil } files = append(files, path) return nil }) if err != nil { return nil, nil, err } // by sorting the files by levels, we can filter out directories that are prefixes of other directories // e.g. /a/b/ is a prefix of /a/b/c/, so we can filter out /a/b/c/ sortFilesByLevels(files) dirs, rootFiles := getDirsAndRootfiles(files) dirs = filterDirsContainingCode(dirs, codeFiles) return dirs, rootFiles, nil } func isNonModelFiles(path string) bool { for _, prefix := range prefixesToIgnore { if strings.HasPrefix(path, prefix) { return true } } for _, suffix := range suffixesToIgnore { if strings.HasSuffix(path, suffix) { return true } } return false } const sizeThreshold = 10 * 1024 * 1024 // 10MB func sortFilesByLevels(files []string) { sort.Slice(files, func(i, j int) bool { list1 := strings.Split(files[i], "/") list2 := strings.Split(files[j], "/") if len(list1) != len(list2) { return len(list1) < len(list2) } for k := range list1 { if list1[k] != list2[k] { return list1[k] < list2[k] } } return false }) } // isCodeFile detects if a given path is a code file based on whether the file path ends with ".py" or ".ipynb" func isCodeFile(path string) bool { ext := filepath.Ext(path) return ext == ".py" || ext == ".ipynb" } func isGitFile(path string) bool { dir, _ := filepath.Split(path) folders := strings.Split(filepath.Clean(dir), string(filepath.Separator)) return slices.Contains(folders, ".git") } // filterDirsContainingCode filters out directories that contain code files. // If a dir is a prefix for any given codeFiles, it will be filtered out. func filterDirsContainingCode(dirs []string, codeFiles []string) []string { filteredDirs := make([]string, 0, len(dirs)) // Filter out directories that are prefixes of code directories for _, dir := range dirs { isPrefix := false for _, codeFile := range codeFiles { if strings.HasPrefix(codeFile, dir) { isPrefix = true break } } if !isPrefix { filteredDirs = append(filteredDirs, dir) } } return filteredDirs } func getDirsAndRootfiles(files []string) ([]string, []string) { // get all the directories that contain model weights files // remove sub-directories if their parent directory is already in the list var dirs []string // for large model files in root directory, we should not add the "." to dirs var rootFiles []string for _, f := range files { dir := filepath.Dir(f) if dir == "." || dir == "/" { rootFiles = append(rootFiles, f) continue } if hasParent(dir, dirs) { continue } dirs = append(dirs, dir) } return dirs, rootFiles } func hasParent(dir string, dirs []string) bool { for _, d := range dirs { parent := d + string(filepath.Separator) child := dir + string(filepath.Separator) if strings.HasPrefix(child, parent) { return true } } return false } ================================================ FILE: pkg/weights/weights_test.go ================================================ package weights import ( "os" "path/filepath" "testing" "time" "github.com/stretchr/testify/require" ) // mockFileInfo is a test type to mock os.FileInfo type mockFileInfo struct { size int64 } func (mfi mockFileInfo) Size() int64 { return mfi.size } func (mfi mockFileInfo) Name() string { return "" } func (mfi mockFileInfo) Mode() os.FileMode { return 0 } func (mfi mockFileInfo) ModTime() time.Time { return time.Time{} } func (mfi mockFileInfo) IsDir() bool { return false } func (mfi mockFileInfo) Sys() any { return nil } // Test case for root directory with large and small model files func TestRootDirModelFiles(t *testing.T) { mockFileWalker := func(root string, walkFn filepath.WalkFunc) error { sizes := []int64{sizeThreshold, sizeThreshold, sizeThreshold - 1} for i, path := range []string{"large-a", "large-b", "small"} { walkFn(path, mockFileInfo{size: sizes[i]}, nil) } return nil } dirs, rootFiles, err := FindWeights(mockFileWalker) require.NoError(t, err) require.Equal(t, []string{"large-a", "large-b"}, rootFiles) require.Empty(t, dirs) } // Test case for sub directory with large and small model files func TestSubDirModelFiles(t *testing.T) { mockFileWalker := func(root string, walkFn filepath.WalkFunc) error { sizes := []int64{sizeThreshold, sizeThreshold, sizeThreshold - 1} for i, path := range []string{"models/large-a", "models/large-b", "models/small"} { walkFn(path, mockFileInfo{size: sizes[i]}, nil) } return nil } dirs, rootFiles, err := FindWeights(mockFileWalker) require.NoError(t, err) require.Empty(t, rootFiles) require.Equal(t, []string{"models"}, dirs) } // Test case for both root and sub directory with large model files func TestRootAndSubDirModelFiles(t *testing.T) { mockFileWalker := func(root string, walkFn filepath.WalkFunc) error { sizes := []int64{sizeThreshold, sizeThreshold} for i, path := range []string{"root-large", "models/large-a"} { walkFn(path, mockFileInfo{size: sizes[i]}, nil) } return nil } dirs, rootFiles, err := FindWeights(mockFileWalker) require.NoError(t, err) require.Equal(t, []string{"root-large"}, rootFiles) require.Equal(t, []string{"models"}, dirs) } // Test case for root directory with both large model and code files func TestRootDirLargeModelAndCodeFiles(t *testing.T) { mockFileWalker := func(root string, walkFn filepath.WalkFunc) error { sizes := []int64{sizeThreshold, 1024} for i, path := range []string{"root-large", "predict.py"} { walkFn(path, mockFileInfo{size: sizes[i]}, nil) } return nil } dirs, rootFiles, err := FindWeights(mockFileWalker) require.NoError(t, err) require.Equal(t, []string{"root-large"}, rootFiles) require.Empty(t, dirs) } // Test case for sub directory with both large model and code files func TestSubDirLargeModelAndCodeFiles(t *testing.T) { mockFileWalker := func(root string, walkFn filepath.WalkFunc) error { sizes := []int64{sizeThreshold, 1024} for i, path := range []string{"models/root-large", "models/predict.py"} { walkFn(path, mockFileInfo{size: sizes[i]}, nil) } return nil } dirs, rootFiles, err := FindWeights(mockFileWalker) require.NoError(t, err) require.Empty(t, rootFiles) require.Empty(t, dirs) } // Test case for sub-directory with code files under large model directory func TestSubDirLargeModelDirWithCodeFiles(t *testing.T) { mockFileWalker := func(root string, walkFn filepath.WalkFunc) error { sizes := []int64{sizeThreshold, 1024} for i, path := range []string{"models/root-large", "models/code/predict.py"} { walkFn(path, mockFileInfo{size: sizes[i]}, nil) } return nil } dirs, rootFiles, err := FindWeights(mockFileWalker) require.NoError(t, err) require.Empty(t, rootFiles) require.Empty(t, dirs) } // Test case for sorting for model directories func TestDirSorting(t *testing.T) { mockFileWalker := func(root string, walkFn filepath.WalkFunc) error { sizes := []int64{sizeThreshold, sizeThreshold, sizeThreshold} for i, path := range []string{"models2/b/large", "models2/a/large", "models/large"} { walkFn(path, mockFileInfo{size: sizes[i]}, nil) } return nil } dirs, rootFiles, err := FindWeights(mockFileWalker) require.NoError(t, err) require.Empty(t, rootFiles) require.Equal(t, []string{"models", "models2/a", "models2/b"}, dirs) } // Test case for merging sub-directories with large models func TestSubDirMerge(t *testing.T) { mockFileWalker := func(root string, walkFn filepath.WalkFunc) error { sizes := []int64{sizeThreshold, sizeThreshold, sizeThreshold} for i, path := range []string{"models/b/large", "models/a/large", "models/large"} { walkFn(path, mockFileInfo{size: sizes[i]}, nil) } return nil } dirs, rootFiles, err := FindWeights(mockFileWalker) require.NoError(t, err) require.Empty(t, rootFiles) require.Equal(t, []string{"models"}, dirs) } // Test case for ignoring files within a .git directory func TestIgnoreGitFiles(t *testing.T) { mockFileWalker := func(root string, walkFn filepath.WalkFunc) error { sizes := []int64{sizeThreshold, sizeThreshold, 1024} for i, path := range []string{".git/root-large", "root-large", "predict.py"} { walkFn(path, mockFileInfo{size: sizes[i]}, nil) } return nil } dirs, rootFiles, err := FindWeights(mockFileWalker) require.NoError(t, err) require.Equal(t, []string{"root-large"}, rootFiles) require.Empty(t, dirs) } ================================================ FILE: pkg/wheels/wheels.go ================================================ // Package wheels provides configuration for sourcing cog and coglet wheels. package wheels import ( "fmt" "os" "path/filepath" "regexp" "sort" "strings" "github.com/replicate/cog/pkg/global" cogversion "github.com/replicate/cog/pkg/util/version" ) var semverPreReleaseRe = regexp.MustCompile(`-alpha(\d+)|-beta(\d+)|-rc(\d+)|-dev(\d*)`) // pep440PreReleaseRe matches PEP 440 pre-release identifiers (a1, b2, rc1, .dev1) var pep440PreReleaseRe = regexp.MustCompile(`\d(a|b|rc|\.dev)\d`) // IsPreRelease returns true if the version string contains a pre-release identifier // in either semver (-alpha1, -beta2, -rc1, -dev1) or PEP 440 (a1, b2, rc1, .dev1) format. func IsPreRelease(version string) bool { return semverPreReleaseRe.MatchString(version) || pep440PreReleaseRe.MatchString(version) } // MinimumSDKVersion is the minimum cog SDK version that can be explicitly requested. // Versions older than this lack features required by the current CLI. const MinimumSDKVersion = "0.16.0" // BaseVersionRe extracts the MAJOR.MINOR.PATCH prefix, ignoring pre-release suffixes. var BaseVersionRe = regexp.MustCompile(`^(\d+\.\d+\.\d+)`) // ValidateSDKVersion checks that a PyPI WheelConfig does not request a version // older than MinimumSDKVersion. Non-PyPI sources, unpinned versions, and nil // configs are always valid. func ValidateSDKVersion(config *WheelConfig, label string) error { if config == nil || config.Source != WheelSourcePyPI || config.Version == "" { return nil } base := config.Version if m := BaseVersionRe.FindString(base); m != "" { base = m } reqVer, err := cogversion.NewVersion(base) if err != nil { return nil // unparseable — let pip catch real problems } minVer := cogversion.MustVersion(MinimumSDKVersion) if reqVer.GreaterOrEqual(minVer) { return nil } return fmt.Errorf("%s version %s is below the minimum required version %s", label, config.Version, MinimumSDKVersion) } // WheelSource represents the source type for the wheel to install type WheelSource int const ( // WheelSourcePyPI installs from PyPI (default for released builds) WheelSourcePyPI WheelSource = iota // WheelSourceURL uses a custom URL WheelSourceURL // WheelSourceFile uses a local file path WheelSourceFile ) // String returns the string representation of the WheelSource func (s WheelSource) String() string { switch s { case WheelSourcePyPI: return "pypi" case WheelSourceURL: return "url" case WheelSourceFile: return "file" default: return "unknown" } } // WheelConfig represents the configuration for which wheel to install type WheelConfig struct { // Source indicates where the wheel comes from Source WheelSource // URL is set when Source is WheelSourceURL URL string // Path is set when Source is WheelSourceFile (absolute path) Path string // Version is set when Source is WheelSourcePyPI (optional, empty = latest) Version string } // CogSDKWheelEnvVar is the environment variable name for cog SDK wheel selection const CogSDKWheelEnvVar = "COG_SDK_WHEEL" // CogletWheelEnvVar is the environment variable name for coglet wheel selection const CogletWheelEnvVar = "COGLET_WHEEL" // ParseWheelValue parses a wheel env var value and returns the appropriate WheelConfig. // Supported values: // - "pypi" - Install from PyPI (latest version) // - "pypi:0.12.0" - Install specific version from PyPI // - "https://..." or "http://..." - Direct wheel URL // - "/path/to/file.whl" or "relative/path" - Local file or directory (resolved to abspath) // // Paths that point to directories are resolved later by the Resolve functions, // which glob for the appropriate wheel inside the directory. // // Returns nil if the value is empty (caller should use auto-detection). func ParseWheelValue(value string) *WheelConfig { value = strings.TrimSpace(value) if value == "" { return nil } // "pypi" or "pypi:version" requests PyPI if strings.EqualFold(value, "pypi") { return &WheelConfig{Source: WheelSourcePyPI} } if strings.HasPrefix(strings.ToLower(value), "pypi:") { // Extract version after "pypi:" prefix, preserving original case return &WheelConfig{Source: WheelSourcePyPI, Version: value[5:]} } // Check for URL (http:// or https://) if strings.HasPrefix(value, "https://") || strings.HasPrefix(value, "http://") { return &WheelConfig{Source: WheelSourceURL, URL: value} } // Treat everything else as a file/directory path - resolve to absolute absPath, err := filepath.Abs(value) if err != nil { absPath = value } return &WheelConfig{Source: WheelSourceFile, Path: absPath} } var executablePath = os.Executable var evalSymlinks = filepath.EvalSymlinks // goarchToWheelPlatform maps GOARCH values to wheel filename platform substrings. func goarchToWheelPlatform(goarch string) string { switch goarch { case "amd64": return "x86_64" case "arm64": return "aarch64" default: return "" } } func bestWheelMatch(matches []string, platform string) string { if len(matches) == 0 { return "" } if platform != "" { platStr := goarchToWheelPlatform(platform) if platStr != "" { var filtered []string for _, match := range matches { base := filepath.Base(match) if strings.Contains(base, platStr) || strings.Contains(base, "-none-any") { filtered = append(filtered, match) } } matches = filtered } } if len(matches) == 0 { return "" } sort.Strings(matches) return matches[len(matches)-1] } // distFromExecutable returns the dist/ directory relative to the running cog // binary, if it appears to be in a goreleaser output layout (dist/go//cog). // Returns empty string if the path cannot be determined. func distFromExecutable() string { exePath, err := executablePath() if err != nil { return "" } exePath, err = evalSymlinks(exePath) if err != nil { return "" } distDir := filepath.Clean(filepath.Join(filepath.Dir(exePath), "..", "..")) if info, err := os.Stat(distDir); err == nil && info.IsDir() { return distDir } return "" } // findWheelInAutoDetectDist checks ./dist and dist relative to the cog executable. // Returns the absolute path if found, empty string otherwise. func findWheelInAutoDetectDist(pattern string, platform string) string { matches, _ := filepath.Glob(filepath.Join("dist", pattern)) if best := bestWheelMatch(matches, platform); best != "" { absPath, _ := filepath.Abs(best) if absPath != "" { return absPath } return best } if distDir := distFromExecutable(); distDir != "" { matches, _ = filepath.Glob(filepath.Join(distDir, pattern)) if best := bestWheelMatch(matches, platform); best != "" { return best } } return "" } // DetectLocalSDKVersion checks dist/ (CWD and executable-relative) for a cog // SDK wheel and extracts the version from its filename. Returns empty string if // no local wheel is found. func DetectLocalSDKVersion() string { path := findWheelInAutoDetectDist("cog-*.whl", "") if path == "" { return "" } // Wheel filename format: cog----.whl base := filepath.Base(path) if !strings.HasPrefix(base, "cog-") { return "" } rest := strings.TrimPrefix(base, "cog-") if idx := strings.Index(rest, "-"); idx > 0 { return rest[:idx] } return "" } // resolveWheelPath resolves a wheel path that may be a file or directory. // If path is a directory, globs for pattern inside it, filtering by platform if non-empty. // If path is a file, returns it directly. func resolveWheelPath(path string, pattern string, platform string, envVar string) (string, error) { info, err := os.Stat(path) //nolint:gosec // G703: path from build config, not user input if err != nil { return "", fmt.Errorf("%s: path not found: %s", envVar, path) } if !info.IsDir() { return path, nil } matches, _ := filepath.Glob(filepath.Join(path, pattern)) if len(matches) == 0 { return "", fmt.Errorf("%s: no wheel matching '%s' found in %s\n\nTo build the wheel, run: mise run build:sdk (for cog) or mise run build:coglet (for coglet)", envVar, pattern, path) } // Filter by platform if specified platStr := goarchToWheelPlatform(platform) if platStr != "" { var filtered []string for _, m := range matches { base := filepath.Base(m) if strings.Contains(base, platStr) || strings.Contains(base, "-none-any") { filtered = append(filtered, m) } } if len(filtered) == 0 { return "", fmt.Errorf("%s: no wheel for platform %s found in %s (found %d for other platforms)", envVar, platform, path, len(matches)) } matches = filtered } if len(matches) > 1 { return "", fmt.Errorf("%s: multiple wheels matching '%s' in %s — specify the exact file path", envVar, pattern, path) } return matches[0], nil } // ResolveCogWheel resolves the WheelConfig for the cog SDK. // // Parameters: // - envValue: value of COG_SDK_WHEEL env var (empty string if not set) // - version: the CLI version (e.g. "dev", "0.17.0", "0.17.0-alpha1") // // Resolution order: // 1. envValue (if non-empty, explicit override) // 2. Auto-detect: check dist/cog-*.whl (for development builds only) // 3. Default: PyPI latest (use build.sdk_version in cog.yaml to pin) func ResolveCogWheel(envValue string, version string) (*WheelConfig, error) { // Check explicit env var first if config := ParseWheelValue(envValue); config != nil { if config.Source == WheelSourceFile { // cog SDK is pure Python (py3-none-any), no platform filtering needed resolved, err := resolveWheelPath(config.Path, "cog-*.whl", "", CogSDKWheelEnvVar) if err != nil { return nil, err } config.Path = resolved } return config, nil } isDev := version == "dev" || strings.Contains(version, "-dev") || strings.Contains(version, "+") // Auto-detect for dev builds: check ./dist or executable-relative dist if isDev { if path := findWheelInAutoDetectDist("cog-*.whl", ""); path != "" { return &WheelConfig{Source: WheelSourceFile, Path: path}, nil } } // Default: PyPI (always latest; use sdk_version in cog.yaml to pin) return &WheelConfig{Source: WheelSourcePyPI}, nil } // GetCogWheelConfig is a convenience wrapper that reads COG_SDK_WHEEL from the environment // and version from global.Version. func GetCogWheelConfig() (*WheelConfig, error) { return ResolveCogWheel(os.Getenv(CogSDKWheelEnvVar), global.Version) } // ResolveCogletWheel resolves the WheelConfig for coglet. // // targetArch is the GOARCH of the Docker build target (e.g. "amd64", "arm64"). // It is used to select the correct platform-specific wheel from dist/. // // Resolution order: // 1. envValue (COGLET_WHEEL) if non-empty — explicit override // 2. Auto-detect: check ./dist for coglet-*.whl (development builds only) // 3. Default: PyPI latest (use COGLET_WHEEL=pypi:x.y.z to pin) // // Coglet is always required. Returns a valid config or an error. // The platform parameter is a GOARCH value (e.g. "amd64", "arm64") used to select // the correct platform-specific wheel from a directory. Pass "" to skip filtering. func ResolveCogletWheel(envValue string, version string, platform string) (*WheelConfig, error) { // Check explicit env var first if config := ParseWheelValue(envValue); config != nil { if config.Source == WheelSourceFile { resolved, err := resolveWheelPath(config.Path, "coglet-*.whl", platform, CogletWheelEnvVar) if err != nil { return nil, err } config.Path = resolved } return config, nil } isDev := version == "dev" || strings.Contains(version, "-dev") || strings.Contains(version, "+") // Auto-detect for dev builds: check ./dist or executable-relative dist if isDev { if path := findWheelInAutoDetectDist("coglet-*.whl", platform); path != "" { return &WheelConfig{Source: WheelSourceFile, Path: path}, nil } } // Default: PyPI (always latest; use COGLET_WHEEL=pypi:x.y.z to pin) return &WheelConfig{Source: WheelSourcePyPI}, nil } // GetCogletWheelConfig is a convenience wrapper that reads COGLET_WHEEL from the environment // and version from global.Version. targetArch is the GOARCH of the Docker build target // (e.g. "amd64", "arm64") used to select the correct platform-specific wheel. func GetCogletWheelConfig(targetArch string) (*WheelConfig, error) { return ResolveCogletWheel(os.Getenv(CogletWheelEnvVar), global.Version, targetArch) } // SemverToPEP440 converts a semver pre-release version to PEP 440 format. // e.g. "0.17.0-alpha1" -> "0.17.0a1", "0.17.0-beta2" -> "0.17.0b2", // "0.17.0-rc1" -> "0.17.0rc1", "0.17.0-dev1" -> "0.17.0.dev1" // Stable versions pass through unchanged: "0.17.0" -> "0.17.0" func SemverToPEP440(version string) string { return semverPreReleaseRe.ReplaceAllStringFunc(version, func(match string) string { match = strings.TrimPrefix(match, "-") match = strings.Replace(match, "alpha", "a", 1) match = strings.Replace(match, "beta", "b", 1) // rc stays as rc in PEP 440 // dev -> .dev (PEP 440 uses dot separator) if strings.HasPrefix(match, "dev") { return "." + match } return match }) } // PyPIPackageURL returns the pip install specifier for a PyPI package. // If version is empty, returns just the package name (latest). // Otherwise returns "package==version" with the version converted to PEP 440. func (c *WheelConfig) PyPIPackageURL(packageName string) string { if c.Version == "" { return packageName } return packageName + "==" + SemverToPEP440(c.Version) } ================================================ FILE: pkg/wheels/wheels_test.go ================================================ package wheels import ( "os" "path/filepath" "testing" "github.com/stretchr/testify/require" ) func TestCogSDKWheelEnvVarName(t *testing.T) { require.Equal(t, "COG_SDK_WHEEL", CogSDKWheelEnvVar) } func TestWheelSourceString(t *testing.T) { tests := []struct { source WheelSource expected string }{ {WheelSourcePyPI, "pypi"}, {WheelSourceURL, "url"}, {WheelSourceFile, "file"}, {WheelSource(99), "unknown"}, } for _, tt := range tests { t.Run(tt.expected, func(t *testing.T) { require.Equal(t, tt.expected, tt.source.String()) }) } } func TestParseWheelValue(t *testing.T) { tests := []struct { name string input string expected *WheelConfig }{ // Empty/nil cases { name: "empty string returns nil", input: "", expected: nil, }, { name: "whitespace only returns nil", input: " ", expected: nil, }, // PyPI values { name: "pypi keyword", input: "pypi", expected: &WheelConfig{Source: WheelSourcePyPI}, }, { name: "pypi uppercase", input: "PYPI", expected: &WheelConfig{Source: WheelSourcePyPI}, }, { name: "pypi with version", input: "pypi:0.12.0", expected: &WheelConfig{Source: WheelSourcePyPI, Version: "0.12.0"}, }, { name: "pypi with version uppercase", input: "PYPI:1.0.0", expected: &WheelConfig{Source: WheelSourcePyPI, Version: "1.0.0"}, }, // relative directory paths (e.g. "dist") are resolved to absolute { name: "dist as relative path", input: "dist", expected: &WheelConfig{ Source: WheelSourceFile, // Path will be converted to absolute }, }, { name: "dist uppercase as relative path", input: "DIST", expected: &WheelConfig{ Source: WheelSourceFile, // Path will be converted to absolute }, }, // URLs { name: "https URL", input: "https://example.com/wheel.whl", expected: &WheelConfig{ Source: WheelSourceURL, URL: "https://example.com/wheel.whl", }, }, { name: "http URL", input: "http://example.com/wheel.whl", expected: &WheelConfig{ Source: WheelSourceURL, URL: "http://example.com/wheel.whl", }, }, { name: "github release URL", input: "https://github.com/replicate/cog/releases/download/v0.1.0/cog-0.1.0-py3-none-any.whl", expected: &WheelConfig{ Source: WheelSourceURL, URL: "https://github.com/replicate/cog/releases/download/v0.1.0/cog-0.1.0-py3-none-any.whl", }, }, // File paths { name: "absolute path", input: "/path/to/wheel.whl", expected: &WheelConfig{ Source: WheelSourceFile, Path: "/path/to/wheel.whl", }, }, { name: "relative path with ./", input: "./dist/wheel.whl", expected: &WheelConfig{ Source: WheelSourceFile, // Path will be converted to absolute }, }, { name: "relative path without ./", input: "path/to/wheel.whl", expected: &WheelConfig{ Source: WheelSourceFile, // Path will be converted to absolute }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := ParseWheelValue(tt.input) if tt.expected == nil { require.Nil(t, result) } else { require.NotNil(t, result) require.Equal(t, tt.expected.Source, result.Source) require.Equal(t, tt.expected.URL, result.URL) // For relative paths, just verify they're converted to absolute if tt.expected.Path == "" && result.Source == WheelSourceFile { require.True(t, filepath.IsAbs(result.Path), "path should be absolute: %s", result.Path) } else { require.Equal(t, tt.expected.Path, result.Path) } require.Equal(t, tt.expected.Version, result.Version) } }) } } func TestGetCogWheelConfig(t *testing.T) { // Create temp dir for file path tests and to avoid auto-detect from repo root tmpDir := t.TempDir() wheelFile := filepath.Join(tmpDir, "custom.whl") require.NoError(t, os.WriteFile(wheelFile, []byte("fake wheel"), 0o600)) // Change to temp dir and clear REPO_ROOT to prevent auto-detection from repo dist/ origDir, err := os.Getwd() require.NoError(t, err) require.NoError(t, os.Chdir(tmpDir)) defer func() { require.NoError(t, os.Chdir(origDir)) }() t.Setenv("REPO_ROOT", "") tests := []struct { name string envValue string globalVersion string expectedSource WheelSource expectedPath string expectedURL string expectedVer string }{ // Release build defaults to PyPI latest (no version pin) { name: "release build defaults to PyPI latest", envValue: "", globalVersion: "0.12.0", expectedSource: WheelSourcePyPI, expectedVer: "", }, // Dev build with explicit pypi (auto-detection tested separately in TestGetCogWheelConfigAutoDetect) { name: "dev build defaults to PyPI without version", envValue: "pypi", globalVersion: "dev", expectedSource: WheelSourcePyPI, expectedVer: "", }, // Snapshot build with explicit pypi { name: "snapshot build defaults to PyPI without version", envValue: "pypi", globalVersion: "0.16.12-dev+g6793b492", expectedSource: WheelSourcePyPI, expectedVer: "", }, // Explicit pypi override { name: "explicit pypi", envValue: "pypi", globalVersion: "dev", expectedSource: WheelSourcePyPI, expectedVer: "", }, { name: "explicit pypi with version", envValue: "pypi:0.11.0", globalVersion: "0.12.0", expectedSource: WheelSourcePyPI, expectedVer: "0.11.0", }, // URL override { name: "URL override", envValue: "https://example.com/custom.whl", globalVersion: "0.12.0", expectedSource: WheelSourceURL, expectedURL: "https://example.com/custom.whl", }, // File path override (use the real temp file) { name: "file path override", envValue: wheelFile, globalVersion: "0.12.0", expectedSource: WheelSourceFile, expectedPath: wheelFile, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result, err := ResolveCogWheel(tt.envValue, tt.globalVersion) require.NoError(t, err) require.NotNil(t, result) require.Equal(t, tt.expectedSource, result.Source) require.Equal(t, tt.expectedURL, result.URL) require.Equal(t, tt.expectedPath, result.Path) require.Equal(t, tt.expectedVer, result.Version) }) } } func TestGetCogWheelConfigErrors(t *testing.T) { // Test error cases for wheel config t.Run("file not found", func(t *testing.T) { t.Setenv(CogSDKWheelEnvVar, "/nonexistent/path/wheel.whl") _, err := GetCogWheelConfig() require.Error(t, err) require.Contains(t, err.Error(), "path not found") }) } func TestGetCogWheelConfigAutoDetect(t *testing.T) { // Create a temp directory with a wheel file tmpDir := t.TempDir() distDir := filepath.Join(tmpDir, "dist") require.NoError(t, os.MkdirAll(distDir, 0o750)) wheelPath := filepath.Join(distDir, "cog-0.1.0-py3-none-any.whl") require.NoError(t, os.WriteFile(wheelPath, []byte("fake wheel content"), 0o600)) // Change to temp dir origDir, err := os.Getwd() require.NoError(t, err) require.NoError(t, os.Chdir(tmpDir)) defer func() { require.NoError(t, os.Chdir(origDir)) }() // Test auto-detection in dev mode result, err := ResolveCogWheel("", "dev") require.NoError(t, err) require.NotNil(t, result) require.Equal(t, WheelSourceFile, result.Source) require.Contains(t, result.Path, "cog-0.1.0-py3-none-any.whl") // Test that release mode does NOT auto-detect (and has no version pin) result, err = ResolveCogWheel("", "0.12.0") require.NoError(t, err) require.NotNil(t, result) require.Equal(t, WheelSourcePyPI, result.Source) require.Equal(t, "", result.Version) } func TestResolveCogWheelUsesExecutableDist(t *testing.T) { rootDir := t.TempDir() distDir := filepath.Join(rootDir, "dist") require.NoError(t, os.MkdirAll(distDir, 0o750)) wheelPath := filepath.Join(distDir, "cog-0.1.0-py3-none-any.whl") require.NoError(t, os.WriteFile(wheelPath, []byte("fake wheel content"), 0o600)) fakeExe := filepath.Join(distDir, "go", "linux_amd64", "cog") require.NoError(t, os.MkdirAll(filepath.Dir(fakeExe), 0o750)) require.NoError(t, os.WriteFile(fakeExe, []byte("binary"), 0o600)) origExecutablePath := executablePath origEvalSymlinks := evalSymlinks defer func() { executablePath = origExecutablePath evalSymlinks = origEvalSymlinks }() executablePath = func() (string, error) { return fakeExe, nil } evalSymlinks = func(path string) (string, error) { return path, nil } cwd := t.TempDir() origDir, err := os.Getwd() require.NoError(t, err) require.NoError(t, os.Chdir(cwd)) defer func() { require.NoError(t, os.Chdir(origDir)) }() result, err := ResolveCogWheel("", "dev") require.NoError(t, err) require.NotNil(t, result) require.Equal(t, WheelSourceFile, result.Source) require.Contains(t, result.Path, "cog-0.1.0-py3-none-any.whl") } func TestGetCogletWheelConfig(t *testing.T) { // Change to temp dir to prevent auto-detection from repo dist/ tmpDir := t.TempDir() origDir, err := os.Getwd() require.NoError(t, err) require.NoError(t, os.Chdir(tmpDir)) defer func() { require.NoError(t, os.Chdir(origDir)) }() tests := []struct { name string envValue string globalVersion string expectedSource WheelSource expectedPath string expectedURL string expectedVer string }{ // Default: coglet from PyPI latest (release build, no version pin) { name: "release default uses PyPI latest", envValue: "", globalVersion: "0.12.0", expectedSource: WheelSourcePyPI, expectedVer: "", }, { name: "dev default falls back to PyPI without version", envValue: "", globalVersion: "dev", expectedSource: WheelSourcePyPI, expectedVer: "", }, // Explicit pypi { name: "explicit pypi", envValue: "pypi", globalVersion: "0.12.0", expectedSource: WheelSourcePyPI, expectedVer: "", }, { name: "explicit pypi with version", envValue: "pypi:0.11.0", globalVersion: "0.12.0", expectedSource: WheelSourcePyPI, expectedVer: "0.11.0", }, // URL override { name: "URL override", envValue: "https://example.com/coglet.whl", globalVersion: "0.12.0", expectedSource: WheelSourceURL, expectedURL: "https://example.com/coglet.whl", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result, err := ResolveCogletWheel(tt.envValue, tt.globalVersion, "amd64") require.NoError(t, err) require.NotNil(t, result) require.Equal(t, tt.expectedSource, result.Source) require.Equal(t, tt.expectedURL, result.URL) require.Equal(t, tt.expectedPath, result.Path) require.Equal(t, tt.expectedVer, result.Version) }) } } func TestPyPIPackageURL(t *testing.T) { tests := []struct { name string config *WheelConfig packageName string expected string }{ { name: "no version", config: &WheelConfig{Source: WheelSourcePyPI}, packageName: "cog", expected: "cog", }, { name: "with version", config: &WheelConfig{Source: WheelSourcePyPI, Version: "0.12.0"}, packageName: "cog", expected: "cog==0.12.0", }, { name: "coglet with version", config: &WheelConfig{Source: WheelSourcePyPI, Version: "0.1.0"}, packageName: "coglet", expected: "coglet==0.1.0", }, { name: "alpha pre-release converted to PEP 440", config: &WheelConfig{Source: WheelSourcePyPI, Version: "0.17.0-alpha1"}, packageName: "cog", expected: "cog==0.17.0a1", }, { name: "beta pre-release converted to PEP 440", config: &WheelConfig{Source: WheelSourcePyPI, Version: "0.17.0-beta2"}, packageName: "cog", expected: "cog==0.17.0b2", }, { name: "rc pre-release converted to PEP 440", config: &WheelConfig{Source: WheelSourcePyPI, Version: "1.0.0-rc1"}, packageName: "cog", expected: "cog==1.0.0rc1", }, { name: "dev pre-release converted to PEP 440", config: &WheelConfig{Source: WheelSourcePyPI, Version: "0.17.0-dev1"}, packageName: "cog", expected: "cog==0.17.0.dev1", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := tt.config.PyPIPackageURL(tt.packageName) require.Equal(t, tt.expected, result) }) } } func TestIsPreRelease(t *testing.T) { tests := []struct { version string expected bool }{ // semver format {"0.17.0-alpha1", true}, {"0.17.0-beta2", true}, {"0.17.0-rc1", true}, {"0.17.0-dev1", true}, // PEP 440 format {"0.17.0a1", true}, {"0.17.0b2", true}, {"0.17.0rc1", true}, {"0.17.0.dev1", true}, // stable {"0.17.0", false}, {"1.0.0", false}, } for _, tt := range tests { t.Run(tt.version, func(t *testing.T) { require.Equal(t, tt.expected, IsPreRelease(tt.version)) }) } } func TestValidateSDKVersion(t *testing.T) { tests := []struct { name string config *WheelConfig label string expectErr bool errMsg string }{ {name: "exact minimum valid", config: &WheelConfig{Source: WheelSourcePyPI, Version: "0.16.0"}, label: "cog"}, {name: "above minimum valid", config: &WheelConfig{Source: WheelSourcePyPI, Version: "0.17.0"}, label: "cog"}, {name: "nil config valid", config: nil, label: "cog"}, {name: "no version pin valid", config: &WheelConfig{Source: WheelSourcePyPI, Version: ""}, label: "cog"}, {name: "URL source not checked", config: &WheelConfig{Source: WheelSourceURL, URL: "https://example.com/old.whl"}, label: "cog"}, {name: "file source not checked", config: &WheelConfig{Source: WheelSourceFile, Path: "/tmp/old.whl"}, label: "cog"}, { name: "below minimum errors", config: &WheelConfig{Source: WheelSourcePyPI, Version: "0.15.0"}, label: "cog", expectErr: true, errMsg: "cog version 0.15.0 is below the minimum required version 0.16.0", }, { name: "pre-release of old version errors", config: &WheelConfig{Source: WheelSourcePyPI, Version: "0.15.0-rc1"}, label: "cog", expectErr: true, errMsg: "cog version 0.15.0-rc1 is below the minimum required version 0.16.0", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := ValidateSDKVersion(tt.config, tt.label) if tt.expectErr { require.Error(t, err) require.Equal(t, tt.errMsg, err.Error()) } else { require.NoError(t, err) } }) } } func TestPyPIPackageURLPreRelease(t *testing.T) { cfg := &WheelConfig{Source: WheelSourcePyPI, Version: "0.17.0-alpha1"} require.Equal(t, "cog==0.17.0a1", cfg.PyPIPackageURL("cog")) } ================================================ FILE: pyproject.toml ================================================ [build-system] requires = ["setuptools", "setuptools_scm[toml]"] build-backend = "setuptools.build_meta" [project] name = "cog" description = "Containers for machine learning" readme = "README.md" authors = [{ name = "Replicate", email = "team@replicate.com" }] license.file = "LICENSE" urls."Source" = "https://github.com/replicate/cog" classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", ] requires-python = ">=3.10" dependencies = [ "typing_extensions>=4.0", "pyyaml>=6.0", "structlog>=21.0.0", "requests>=2.25.0", "coglet>=0.1.0,<1.0", ] dynamic = ["version"] [dependency-groups] dev = [ "build>=1.2.2.post1", "ruff", "setuptools-scm>=8.2.0", ] test = [ "pytest", "pytest-timeout", "pytest-xdist", "pytest-cov", ] [tool.setuptools_scm] write_to = "python/cog/_version.py" [tool.pyright] # TODO: remove this and bring the codebase inline with the current default strictParameterNoneValue = false # legacy behavior, fixed in PEP688 disableBytesTypePromotions = true include = ["python"] exclude = ["python/tests"] reportMissingParameterType = "error" reportUnknownLambdaType = "error" reportUnnecessaryIsInstance = "warning" reportUnnecessaryComparison = "warning" reportUnnecessaryContains = "warning" reportMissingTypeArgument = "error" reportUnusedExpression = "warning" [tool.pytest.ini_options] asyncio_default_fixture_loop_scope = "function" [tool.setuptools] include-package-data = false [tool.setuptools.packages.find] where = ["python"] include = ["cog*"] exclude = ["tests*"] [tool.pylint.main] disable = [ "C0114", # Missing module docstring "C0115", # Missing class docstring "C0116", # Missing function or method docstring "C0301", # Line too long "C0413", # Import should be placed at the top of the module "R0903", # Too few public methods "W0622", # Redefining built-in ] good-names = ["id", "input"] ignore-paths = ["python/cog/_version.py", "python/tests"] [tool.ruff] exclude = ["python/cog/_version.py"] lint.select = [ "E", # pycodestyle error "F", # Pyflakes "I", # isort "W", # pycodestyle warning "S", # flake8-bandit "B", # flake8-bugbear "ANN", # flake8-annotations ] lint.ignore = [ "E501", # Line too long "S101", # Use of `assert` detected" "S113", # Probable use of requests call without timeout "B008", # Do not perform function call in argument defaults "ANN001", # Missing type annotation for function argument "ANN002", # Missing type annotation for `*args` "ANN003", # Missing type annotation for `**kwargs` "ANN401", # Dynamically typed expressions are disallowed ] extend-exclude = [ "python/tests/server/fixtures/*", "crates/coglet-python/**/*.pyi", "crates/coglet-python/scripts/*", ] src = ["python"] [tool.ruff.lint.per-file-ignores] "python/cog/server/http.py" = [ "S104", # Possible binding to all interfaces ] "python/tests/*" = [ "S101", # Use of assert "S104", # Possible binding to all interfaces "S105", # Possible hardcoded password "S106", # Possible hardcoded password in argument "S108", # Probable insecure usage of temp file "S110", # try-except-pass "S301", # pickle can be unsafe when used to deserialize untrusted data "S603", # subprocess call — tests use subprocess for isolation "ANN", # Type annotations not required in tests "B011", # Do not assert False ] "crates/coglet-python/tests/*" = [ "S101", # Use of assert "S104", # Possible binding to all interfaces "S110", # try-except-pass "S603", # subprocess call "ANN", # Type annotations not required in tests "B904", # raise from ] "tools/test-harness/*" = [ "S310", # URL open — URLs are hardcoded to https:// "S603", # subprocess call — harness invokes cog/docker/git by design "S607", # partial executable path — cog/docker/git resolved via PATH "ANN", # Type annotations not required in tooling ] ================================================ FILE: python/cog/.gitignore ================================================ /_version.py ================================================ FILE: python/cog/__init__.py ================================================ """ Cog SDK: Define machine learning models with standard Python. This package provides the core types and classes for building Cog predictors. Example: from cog import BasePredictor, Input, Path class Predictor(BasePredictor): def setup(self): # Load model weights self.model = load_model() def predict( self, prompt: str = Input(description="Input prompt"), image: Path = Input(description="Input image"), ) -> str: return self.model.generate(prompt, image) """ import sys as _sys from coglet import CancelationException as CancelationException from ._version import __version__ from .input import FieldInfo, Input from .model import BaseModel from .predictor import BasePredictor from .types import ( AsyncConcatenateIterator, ConcatenateIterator, File, Path, Secret, URLFile, URLPath, ) # --------------------------------------------------------------------------- # Backwards-compatibility shim: ExperimentalFeatureWarning # # This class was removed when the Python HTTP server was replaced by coglet. # Existing models import it to suppress warnings, e.g.: # # from cog import ExperimentalFeatureWarning # warnings.filterwarnings("ignore", category=ExperimentalFeatureWarning) # # The shim keeps those models working. The stderr message is printed # directly so it cannot be swallowed by warnings.filterwarnings("ignore"). # --------------------------------------------------------------------------- class _ExperimentalFeatureWarning(FutureWarning): """Deprecated: ExperimentalFeatureWarning is no longer used by Cog. This class exists only for backwards compatibility. Remove the import and any associated ``warnings.filterwarnings(...)`` calls from your code. """ pass def __getattr__(name: str) -> object: if name == "ExperimentalFeatureWarning": print( "cog: ExperimentalFeatureWarning is deprecated and will be removed in a " "future release. Remove `ExperimentalFeatureWarning` from your imports " "and any associated `warnings.filterwarnings(...)` calls.", file=_sys.stderr, ) # Cache in module namespace so __getattr__ is not called again and # the deprecation message prints at most once. globals()["ExperimentalFeatureWarning"] = _ExperimentalFeatureWarning return _ExperimentalFeatureWarning if name == "emit_metric": print( "cog: emit_metric() is deprecated and will be removed in a future release. " "Use current_scope().record_metric(name, value) instead.", file=_sys.stderr, ) def emit_metric(name: str, value: float) -> None: # noqa: A002 — name is the metric name here, not the module attr current_scope().record_metric(name, value) # type: ignore[attr-defined] # Cache so __getattr__ is not called again — the deprecation message # prints at most once (on first import), not on every call. globals()["emit_metric"] = emit_metric return emit_metric raise AttributeError(f"module 'cog' has no attribute {name!r}") def current_scope() -> object: """Get the current prediction scope for recording metrics. Returns a Scope object with a ``metrics`` attribute for recording prediction metrics. Outside a prediction context, returns a no-op scope that silently ignores all operations (never ``None``). Example:: from cog import current_scope scope = current_scope() scope.record_metric("temperature", 0.7) scope.metrics["token_count"] = 42 scope.metrics.record("logprobs", -1.2, mode="append") """ import coglet return coglet._sdk.current_scope() # type: ignore[attr-defined] # PyO3 native submodule __all__ = [ # Version "__version__", # Core classes "BasePredictor", "BaseModel", # Input "Input", "FieldInfo", # Types "Path", "Secret", "File", "URLFile", "URLPath", "ConcatenateIterator", "AsyncConcatenateIterator", # Exceptions "CancelationException", # Metrics "current_scope", # Deprecated compat shims "ExperimentalFeatureWarning", "emit_metric", ] ================================================ FILE: python/cog/_adt.py ================================================ """ Internal ADT (Abstract Data Types) for predictor introspection. This module defines the type system used internally for introspecting predictor inputs and outputs, generating OpenAPI schemas, and validating input values. """ import dataclasses import os import typing from dataclasses import dataclass from enum import Enum, auto from typing import Any, Callable, Dict, List, Optional, Set, Union from .coder import Coder from .types import File, Path, Secret def _type_name(tpe: Any) -> str: """Get a human-readable name for a type.""" try: return tpe.__name__ except AttributeError: return str(tpe) def _is_union(tpe: type) -> bool: """Check if a type is a Union type.""" if typing.get_origin(tpe) is Union: return True # Python 3.10+ has UnionType for X | Y syntax from types import UnionType if typing.get_origin(tpe) is UnionType: return True return False class PrimitiveType(Enum): """Primitive types supported by Cog.""" BOOL = auto() FLOAT = auto() INTEGER = auto() STRING = auto() PATH = auto() FILE = auto() # Deprecated, use PATH SECRET = auto() ANY = auto() CUSTOM = auto() @staticmethod def _python_type() -> Dict["PrimitiveType", type | Any]: return { PrimitiveType.BOOL: bool, PrimitiveType.FLOAT: float, PrimitiveType.INTEGER: int, PrimitiveType.STRING: str, PrimitiveType.PATH: Path, PrimitiveType.FILE: File, PrimitiveType.SECRET: Secret, PrimitiveType.ANY: Any, PrimitiveType.CUSTOM: Any, } @staticmethod def _json_type() -> Dict["PrimitiveType", str]: return { PrimitiveType.BOOL: "boolean", PrimitiveType.FLOAT: "number", PrimitiveType.INTEGER: "integer", PrimitiveType.STRING: "string", PrimitiveType.PATH: "string", PrimitiveType.FILE: "string", PrimitiveType.SECRET: "string", PrimitiveType.ANY: "object", PrimitiveType.CUSTOM: "object", } @staticmethod def _adt_type() -> Dict[type | Any, "PrimitiveType"]: return { bool: PrimitiveType.BOOL, float: PrimitiveType.FLOAT, int: PrimitiveType.INTEGER, str: PrimitiveType.STRING, Path: PrimitiveType.PATH, File: PrimitiveType.FILE, Secret: PrimitiveType.SECRET, Any: PrimitiveType.ANY, } @staticmethod def from_type(tpe: type | Any) -> "PrimitiveType": """Determine the PrimitiveType for a given Python type.""" if match := PrimitiveType._adt_type().get(tpe): return match try: if tpe is os.PathLike or ( isinstance(tpe, type) and issubclass(tpe, os.PathLike) # type: ignore[arg-type] ): return PrimitiveType.PATH except TypeError: # issubclass raises TypeError for non-class types pass return PrimitiveType.CUSTOM def normalize(self, value: Any) -> Any: """Normalize a value to this primitive type.""" pt = PrimitiveType._python_type()[self] tpe = type(value) if self is PrimitiveType.CUSTOM: return value elif self is PrimitiveType.ANY: return value elif self is PrimitiveType.FILE: # For File inputs, convert URL strings to file-like objects immediately # using File.validate() - the worker won't need to do any conversion import io if isinstance(value, io.IOBase): return value # URL string or data URI - validate to file-like object return File.validate(value) elif self is PrimitiveType.PATH: # Convert strings/URLs to Path or URLPath objects if isinstance(value, Path): return value return Path.validate(value) elif self is PrimitiveType.SECRET: # Convert strings to Secret objects if isinstance(value, Secret): return value return Secret(value) else: # Handle enums by extracting their value if issubclass(tpe, Enum): if not issubclass(tpe, pt): raise ValueError( f"enum {_type_name(tpe)} is used as {_type_name(pt)} " "but does not extend it" ) value = value.value v = pt(value) # For numeric types, allow string coercion (e.g., "3" -> 3) # but verify the conversion is valid (not lossy for floats) if v != value: # Allow string to numeric conversion if isinstance(value, str) and pt in (int, float): return v # Allow int to float conversion if isinstance(value, int) and pt is float: return v raise ValueError(f"failed to normalize value as {_type_name(pt)}") return v def python_type_name(self) -> str: """Get the Python type name for this primitive.""" return _type_name(PrimitiveType._python_type()[self]) def json_type(self) -> Dict[str, Any]: """Get the JSON Schema type for this primitive.""" jt: Dict[str, Any] = {"type": self._json_type()[self]} if self in {PrimitiveType.PATH, PrimitiveType.FILE}: jt["format"] = "uri" elif self is PrimitiveType.SECRET: jt["format"] = "password" jt["writeOnly"] = True jt["x-cog-secret"] = True return jt def json_encode(self, value: Any) -> Any: """Encode a value for JSON serialization.""" if self is PrimitiveType.FLOAT: return float(value) elif self in {PrimitiveType.PATH, PrimitiveType.FILE}: return value elif self is PrimitiveType.SECRET: # Secret objects need to be unwrapped for JSON serialization if isinstance(value, Secret): return value.get_secret_value() return value elif self is PrimitiveType.ANY: return value return value class Repetition(Enum): """Field repetition/optionality.""" REQUIRED = 1 OPTIONAL = 2 REPEATED = 3 @dataclass(frozen=True) class FieldType: """Type information for an input/output field.""" primitive: PrimitiveType repetition: Repetition coder: Optional[Coder] @staticmethod def from_type(tpe: type) -> "FieldType": """Create a FieldType from a Python type annotation.""" origin = typing.get_origin(tpe) # Handle bare collection types if tpe is list: tpe = List[Any] origin = typing.get_origin(tpe) elif tpe is dict: tpe = Dict[str, Any] origin = typing.get_origin(tpe) elif tpe is set: tpe = Set[Any] origin = typing.get_origin(tpe) if origin is dict: # dict / Dict[K, V] → opaque JSON object, consistent with the # static Go schema generator's SchemaAnyType(). return FieldType( primitive=PrimitiveType.ANY, repetition=Repetition.REQUIRED, coder=None, ) if origin in (list, List): t_args = typing.get_args(tpe) if t_args: if len(t_args) != 1: raise ValueError("List must have one type argument") elem_t = t_args[0] nested_t = typing.get_origin(elem_t) if nested_t is not None: raise ValueError( f"List cannot have nested type {_type_name(nested_t)}" ) else: elem_t = Any repetition = Repetition.REPEATED elif _is_union(tpe): t_args = typing.get_args(tpe) if not (len(t_args) == 2 and type(None) in t_args): raise ValueError(f"unsupported union type {tpe}") elem_t = t_args[0] if t_args[1] is type(None) else t_args[1] nested_t = typing.get_origin(elem_t) if nested_t is not None: raise ValueError( f"Optional cannot have nested type {_type_name(nested_t)}" ) repetition = Repetition.OPTIONAL else: elem_t = tpe repetition = Repetition.REQUIRED cog_t = PrimitiveType.from_type(elem_t) coder = None if cog_t is PrimitiveType.CUSTOM: coder = Coder.lookup(elem_t) if coder is None: raise ValueError(f"unsupported Cog type {_type_name(elem_t)}") return FieldType(primitive=cog_t, repetition=repetition, coder=coder) def normalize(self, value: Any) -> Any: """Normalize a value according to this field type.""" if self.repetition is Repetition.REQUIRED: return self.primitive.normalize(value) elif self.repetition is Repetition.OPTIONAL: return None if value is None else self.primitive.normalize(value) elif self.repetition is Repetition.REPEATED: return [self.primitive.normalize(v) for v in value] return value def json_type(self) -> Dict[str, Any]: """Get the JSON Schema type for this field.""" if self.repetition is Repetition.REPEATED: return {"type": "array", "items": self.primitive.json_type()} return self.primitive.json_type() def json_encode(self, value: Any) -> Any: """Encode a value for JSON serialization.""" f: Callable[[Any], Any] = self.primitive.json_encode if self.primitive is PrimitiveType.CUSTOM: assert self.coder is not None f = self.coder.encode if self.repetition is Repetition.REPEATED: return [f(x) for x in value] return f(value) def json_decode(self, value: Any) -> Any: """Decode a value from JSON.""" if self.primitive is not PrimitiveType.CUSTOM: return value assert self.coder is not None f = self.coder.decode if self.repetition is Repetition.REPEATED: return [f(x) for x in value] return f(value) def python_type_name(self) -> str: """Get the Python type name for this field.""" if self.repetition is Repetition.REQUIRED: return self.primitive.python_type_name() elif self.repetition is Repetition.OPTIONAL: return f"Optional[{self.primitive.python_type_name()}]" elif self.repetition is Repetition.REPEATED: return f"List[{self.primitive.python_type_name()}]" return self.primitive.python_type_name() @dataclass(frozen=True) class InputField: """Metadata for a predictor input field.""" name: str order: int type: FieldType default: Any = None description: Optional[str] = None ge: Optional[Union[int, float]] = None le: Optional[Union[int, float]] = None min_length: Optional[int] = None max_length: Optional[int] = None regex: Optional[str] = None choices: Optional[List[Union[str, int]]] = None deprecated: Optional[bool] = None class OutputKind(Enum): """Kind of output a predictor produces.""" SINGLE = 1 LIST = 2 ITERATOR = 3 CONCAT_ITERATOR = 4 OBJECT = 5 @dataclass(frozen=True) class OutputType: """Type information for predictor output.""" kind: OutputKind type: Optional[PrimitiveType] = None fields: Optional[Dict[str, FieldType]] = None coder: Optional[Coder] = None def json_type(self) -> Dict[str, Any]: """Get the JSON Schema type for this output.""" jt: Dict[str, Any] = {"title": "Output"} if self.kind is OutputKind.SINGLE: assert self.type is not None jt.update(self.type.json_type()) elif self.kind is OutputKind.LIST: assert self.type is not None jt.update({"type": "array", "items": self.type.json_type()}) elif self.kind is OutputKind.ITERATOR: assert self.type is not None jt.update( { "type": "array", "items": self.type.json_type(), "x-cog-array-type": "iterator", } ) elif self.kind is OutputKind.CONCAT_ITERATOR: assert self.type is not None jt.update( { "type": "array", "items": self.type.json_type(), "x-cog-array-type": "iterator", "x-cog-array-display": "concatenate", } ) elif self.kind is OutputKind.OBJECT: assert self.fields is not None props = {} for name, field_type in self.fields.items(): props[name] = field_type.primitive.json_type() props[name]["title"] = name.replace("_", " ").title() jt.update( { "type": "object", "properties": props, "required": list(self.fields.keys()), } ) return jt def normalize(self, value: Any) -> Any: """Normalize an output value.""" return self._transform(value, json=False) def json_encode(self, value: Any) -> Any: """Encode an output value for JSON serialization.""" if self.coder is not None: if self.kind is OutputKind.LIST: return [self.coder.encode(x) for x in value] return self.coder.encode(value) o = self._transform(value, json=True) if self.kind is OutputKind.OBJECT: # Expand dataclass to dict tpe = type(o) if not dataclasses.is_dataclass(tpe): raise ValueError(f"{tpe} is not a dataclass") return {f.name: getattr(o, f.name) for f in dataclasses.fields(o)} return o def _transform(self, value: Any, json: bool) -> Any: """Transform an output value (normalize or encode).""" if self.kind in { OutputKind.SINGLE, OutputKind.ITERATOR, OutputKind.CONCAT_ITERATOR, }: assert self.type is not None f: Callable[[Any], Any] = ( self.type.json_encode if json else self.type.normalize ) return f(value) elif self.kind is OutputKind.LIST: assert self.type is not None f = self.type.json_encode if json else self.type.normalize return [f(x) for x in value] elif self.kind is OutputKind.OBJECT: assert self.fields is not None for name, ft in self.fields.items(): f = ft.json_encode if json else ft.normalize if not hasattr(value, name): raise ValueError(f"missing output field: {name}") v = getattr(value, name) if v is None: if ft.repetition is not Repetition.OPTIONAL: raise ValueError(f"missing value for output field: {name}") setattr(value, name, f(v)) return value raise RuntimeError(f"unsupported output kind {self.kind}") @dataclass(frozen=True) class PredictorInfo: """Complete type information for a predictor.""" module_name: str predictor_name: str inputs: Dict[str, InputField] output: OutputType ================================================ FILE: python/cog/_inspector.py ================================================ """ Internal inspector for predictor introspection. This module provides functions to inspect predictor classes and functions, extract input/output type information, and validate inputs. """ import importlib import inspect import re import sys import typing from dataclasses import MISSING, Field from enum import Enum from types import ModuleType from typing import Any, AsyncIterator, Callable, Dict, Iterator, Type from . import _adt as adt from .coder import Coder from .input import FieldInfo from .model import BaseModel from .types import AsyncConcatenateIterator, ConcatenateIterator try: from pydantic import ( # pyright: ignore[reportMissingImports] BaseModel as PydanticBaseModel, ) except ImportError: PydanticBaseModel = None # type: ignore[assignment,misc] def _check_parent(child: type, parent: type) -> bool: """Check if a type has a parent in its MRO.""" return any(c is parent for c in inspect.getmro(child)) def _type_name(tpe: Any) -> str: """Get a human-readable name for a type.""" try: return tpe.__name__ except AttributeError: return str(tpe) def _validate_setup(f: Callable[..., Any]) -> None: """Validate a predictor's setup method.""" if not inspect.isfunction(f): raise ValueError("setup is not a function") spec = inspect.getfullargspec(f) if spec.args[:1] != ["self"]: raise ValueError("setup() must have 'self' as first argument") non_default_args = spec.args if spec.defaults is not None: non_default_args = non_default_args[: -len(spec.defaults)] extra_args = [a for a in non_default_args if a not in {"self", "weights"}] if extra_args: raise ValueError(f"unexpected setup() arguments: {', '.join(extra_args)}") if spec.varargs is not None: raise ValueError("setup() must not have *args") if spec.varkw is not None: raise ValueError("setup() must not have **kwargs") if spec.kwonlyargs: raise ValueError("setup() must not have keyword-only args") if spec.kwonlydefaults: raise ValueError("setup() must not have keyword-only defaults") if spec.annotations.get("return") is not None: raise ValueError("setup() must return None") def _validate_predict(f: Callable[..., Any], f_name: str, is_class_fn: bool) -> None: """Validate a predictor's predict method.""" if not inspect.isfunction(f): raise ValueError(f"{f_name} is not a function") spec = inspect.getfullargspec(f) if is_class_fn and spec.args[:1] != ["self"]: raise ValueError(f"{f_name}() must have 'self' as first argument") if spec.varargs is not None: raise ValueError(f"{f_name}() must not have *args") if spec.varkw is not None: raise ValueError(f"{f_name}() must not have **kwargs") if spec.kwonlyargs: raise ValueError(f"{f_name}() must not have keyword-only args") if spec.kwonlydefaults: raise ValueError(f"{f_name}() must not have keyword-only defaults") if spec.annotations.get("return") is None: raise ValueError(f"{f_name}() must have a return type annotation") def _validate_input_constraints( name: str, ft: adt.FieldType, field_info: FieldInfo ) -> None: """Validate that FieldInfo constraints are compatible with the field type.""" cog_t = ft.primitive in_repr = f"{name}: {ft.python_type_name()}" # Extract actual default for validation defaults = [] if field_info.default is not None: if isinstance(field_info.default, Field): if field_info.default.default_factory is not MISSING: actual_default = field_info.default.default_factory() elif field_info.default.default is not MISSING: actual_default = field_info.default.default else: actual_default = None else: actual_default = field_info.default if actual_default is not None: if ft.repetition is adt.Repetition.REPEATED: defaults = ft.normalize(actual_default) else: defaults = [ft.normalize(actual_default)] numeric_types = {adt.PrimitiveType.FLOAT, adt.PrimitiveType.INTEGER} # Validate ge/le constraints if field_info.ge is not None or field_info.le is not None: if cog_t not in numeric_types: raise ValueError(f"incompatible input type for ge/le: {in_repr}") if defaults: if field_info.ge is not None and not all( x >= field_info.ge for x in defaults ): raise ValueError( f"invalid default for {in_repr}: must be at minimum {field_info.ge}" ) if field_info.le is not None and not all( x <= field_info.le for x in defaults ): raise ValueError( f"invalid default for {in_repr}: must be at maximum {field_info.le}" ) # Validate min_length/max_length constraints if field_info.min_length is not None or field_info.max_length is not None: if cog_t is not adt.PrimitiveType.STRING: raise ValueError( f"incompatible input type for min_length/max_length: {in_repr}" ) if defaults: if field_info.min_length is not None and not all( len(x) >= field_info.min_length for x in defaults ): raise ValueError( f"default conflicts with min_length={field_info.min_length} for input: {in_repr}" ) if field_info.max_length is not None and not all( len(x) <= field_info.max_length for x in defaults ): raise ValueError( f"default conflicts with max_length={field_info.max_length} for input: {in_repr}" ) # Validate regex constraint if field_info.regex is not None: if cog_t is not adt.PrimitiveType.STRING: raise ValueError(f"incompatible input type for regex: {in_repr}") if defaults: regex = re.compile(field_info.regex) if not all(regex.match(x) for x in defaults): raise ValueError(f"default not a regex match for input: {in_repr}") # Validate choices constraint if field_info.choices is not None: choice_types = {adt.PrimitiveType.INTEGER, adt.PrimitiveType.STRING} if cog_t not in choice_types: raise ValueError(f"incompatible input type for choices: {in_repr}") if len(field_info.choices) < 2: raise ValueError( f"choices={field_info.choices!r} must have >= 2 elements: {in_repr}" ) if field_info.ge is not None or field_info.le is not None: raise ValueError(f"choices and ge/le are mutually exclusive: {in_repr}") if field_info.min_length is not None or field_info.max_length is not None: raise ValueError( f"choices and min_length/max_length are mutually exclusive: {in_repr}" ) # Normalize enum values in choices choices = [ cog_t.normalize(c) if isinstance(c, Enum) else c for c in field_info.choices ] if not all(adt.PrimitiveType.from_type(type(x)) is cog_t for x in choices): raise ValueError(f"not all choices have the same type as input: {in_repr}") def _create_input_field( order: int, name: str, tpe: type, field_info: Any ) -> adt.InputField: """Create an InputField from type annotation and optional FieldInfo or raw default.""" try: ft = adt.FieldType.from_type(tpe) except (ValueError, AssertionError) as e: raise ValueError(f"invalid input field {name}: {e}") from e if field_info is None: return adt.InputField(name=name, order=order, type=ft) # Handle raw default values (not FieldInfo) if not isinstance(field_info, FieldInfo): # It's a raw default value like "world" or 42 default = ft.normalize(field_info) if field_info is not None else None return adt.InputField(name=name, order=order, type=ft, default=default) _validate_input_constraints(name, ft, field_info) # Extract default value if isinstance(field_info.default, Field): if field_info.default.default_factory is not MISSING: default = field_info.default elif field_info.default.default is not MISSING: default = ft.normalize(field_info.default.default) else: default = None else: default = ( None if field_info.default is None else ft.normalize(field_info.default) ) # Normalize choices choices = ( None if field_info.choices is None else [ft.primitive.normalize(c) for c in field_info.choices] ) return adt.InputField( name=name, order=order, type=ft, default=default, description=field_info.description, ge=float(field_info.ge) if field_info.ge is not None else None, le=float(field_info.le) if field_info.le is not None else None, min_length=field_info.min_length, max_length=field_info.max_length, regex=field_info.regex, choices=choices, deprecated=field_info.deprecated, ) class _AnyType: """Placeholder type for Any output (for compatibility).""" @staticmethod def normalize(value: Any) -> Any: return value @staticmethod def json_type() -> Dict[str, Any]: return {} @staticmethod def json_encode(value: Any) -> Any: return value _any_type = _AnyType() def _create_output_type(tpe: type) -> adt.OutputType: """Create an OutputType from a return type annotation.""" if tpe is Any: print( "Warning: use of Any as output type is error-prone and highly discouraged" ) return adt.OutputType(kind=adt.OutputKind.SINGLE, type=_any_type) # type: ignore[arg-type] if inspect.isclass(tpe) and _check_parent(tpe, BaseModel): fields = {} for name, t in tpe.__annotations__.items(): ft = adt.FieldType.from_type(t) fields[name] = ft return adt.OutputType(kind=adt.OutputKind.OBJECT, fields=fields) if ( PydanticBaseModel is not None and inspect.isclass(tpe) and _check_parent(tpe, PydanticBaseModel) ): fields = {} for name, field_info in tpe.model_fields.items(): ft = adt.FieldType.from_type(field_info.annotation) fields[name] = ft return adt.OutputType(kind=adt.OutputKind.OBJECT, fields=fields) origin = typing.get_origin(tpe) concat_iters = {ConcatenateIterator, AsyncConcatenateIterator} if origin in {typing.get_origin(Iterator), typing.get_origin(AsyncIterator)}: kind = adt.OutputKind.ITERATOR t_args = typing.get_args(tpe) if len(t_args) != 1: raise ValueError("iterator type must have a type argument") ft = adt.FieldType.from_type(t_args[0]) if ft.repetition is not adt.Repetition.REQUIRED: raise ValueError("iterator element type must not be Optional or List") elif origin in concat_iters or tpe in concat_iters: kind = adt.OutputKind.CONCAT_ITERATOR t_args = typing.get_args(tpe) if len(t_args) != 1: raise ValueError("iterator type must have a type argument") ft = adt.FieldType.from_type(t_args[0]) if ft.repetition is not adt.Repetition.REQUIRED: raise ValueError("iterator element type must not be Optional or List") if ft.primitive is not adt.PrimitiveType.STRING: raise ValueError(f"{_type_name(tpe)} must have str element") else: ft = adt.FieldType.from_type(tpe) if ft.repetition is adt.Repetition.OPTIONAL: raise ValueError("output must not be Optional") if ft.repetition == adt.Repetition.REQUIRED: kind = adt.OutputKind.SINGLE elif ft.repetition == adt.Repetition.REPEATED: kind = adt.OutputKind.LIST else: raise RuntimeError(f"unexpected repetition: {ft.repetition}") return adt.OutputType(kind=kind, type=ft.primitive, coder=ft.coder) def _create_predictor_info( module_name: str, predictor_name: str, f: Callable[..., Any], f_name: str, is_class_fn: bool, ) -> adt.PredictorInfo: """Create PredictorInfo from a predict function.""" _validate_predict(f, f_name, is_class_fn) spec = inspect.getfullargspec(f) # Use get_type_hints to resolve string annotations (from __future__ import annotations) try: type_hints = typing.get_type_hints(f) except Exception: # Fall back to raw annotations if get_type_hints fails type_hints = spec.annotations # Skip 'self' for class methods names = spec.args[1:] if is_class_fn else spec.args defaults = list(spec.defaults) if spec.defaults else [] field_infos = [None] * (len(names) - len(defaults)) + defaults inputs: Dict[str, adt.InputField] = {} for i, (name, field_info) in enumerate(zip(names, field_infos, strict=False)): tpe = type_hints.get(name) if tpe is None: raise ValueError(f"missing type annotation for input: {name}") inputs[name] = _create_input_field(i, name, tpe, field_info) return_type = type_hints.get("return", spec.annotations.get("return")) if return_type is None: raise ValueError("missing return type annotation for predict method") output = _create_output_type(return_type) return adt.PredictorInfo(module_name, predictor_name, inputs, output) def _unwrap(f: Callable[..., Any]) -> Callable[..., Any]: """Unwrap decorated functions to get the original function.""" g = f while hasattr(g, "__closure__") and g.__closure__ is not None: cs = [ c.cell_contents for c in g.__closure__ if inspect.isfunction(c.cell_contents) ] if len(cs) > 1: raise ValueError(f"unable to inspect function decorator: {f}") if len(cs) == 0: return g g = cs[0] return g def _is_coder(cls: Type[Any]) -> bool: """Check if a class is a Coder subclass.""" return inspect.isclass(cls) and cls is not Coder and _check_parent(cls, Coder) def _find_coders(module: ModuleType) -> None: """Find and register coders defined in a module.""" # Direct imports: from cog.coders.some_coder import SomeCoder for _, c in inspect.getmembers(module, _is_coder): Coder.register(c) # Module imports: from cog.coders import some_coders for _, m in inspect.getmembers(module, inspect.ismodule): for _, c in inspect.getmembers(m, _is_coder): Coder.register(c) def create_predictor(module_name: str, predictor_name: str) -> adt.PredictorInfo: """ Create PredictorInfo by inspecting a predictor class or function. Args: module_name: The module containing the predictor predictor_name: The name of the predictor class or function Returns: PredictorInfo with input/output type information """ try: module = importlib.import_module(module_name) except (ImportError, ModuleNotFoundError) as e: raise ImportError(f"failed to import predictor module: {e}") from e fullname = f"{module_name}.{predictor_name}" if not hasattr(module, predictor_name): # Check if module is partially loaded (common with import errors) if module_name in sys.modules: raise ImportError( f"predictor {predictor_name} not found in {module_name} " "(module may have import errors)" ) raise ValueError(f"predictor not found: {fullname}") p = getattr(module, predictor_name) if inspect.isclass(p): if not hasattr(p, "predict"): raise ValueError(f"predict method not found: {fullname}") if hasattr(p, "setup"): _validate_setup(_unwrap(p.setup)) predict_fn_name = "predict" predict_fn = _unwrap(getattr(p, predict_fn_name)) is_class_fn = True elif inspect.isfunction(p): predict_fn_name = predictor_name predict_fn = _unwrap(p) is_class_fn = False else: raise ValueError(f"invalid predictor {fullname}") # Find coders before validating predict function _find_coders(module) return _create_predictor_info( module_name, predictor_name, predict_fn, predict_fn_name, is_class_fn ) def check_input( inputs: Dict[str, adt.InputField], values: Dict[str, Any] ) -> Dict[str, Any]: """ Validate and normalize input values against InputField definitions. Args: inputs: Dictionary of InputField definitions values: Dictionary of input values to validate Returns: Dictionary of normalized input values """ kwargs: Dict[str, Any] = {} # Process provided values for name, value in values.items(): input_field = inputs.get(name) if input_field is None: print(f"WARNING unknown input field ignored: {name}") else: try: kwargs[name] = input_field.type.normalize(value) except ValueError as e: # Reformat normalize errors to use "field: message" format # and avoid leaking user input values msg = str(e) if "failed to normalize value" in msg: # Extract just the type name without the value if " as " in msg: type_name = msg.split(" as ", 1)[1] raise ValueError( f"{name}: Invalid value for type {type_name}" ) from None raise ValueError(f"{name}: Invalid value") from None # For other normalize errors, prepend field name raise ValueError(f"{name}: {msg}") from None # Apply defaults for missing values for name, input_field in inputs.items(): if name not in kwargs: if isinstance(input_field.default, Field): if input_field.default.default_factory is not MISSING: kwargs[name] = input_field.default.default_factory() elif input_field.default.default is not MISSING: kwargs[name] = input_field.default.default else: if input_field.type.repetition is not adt.Repetition.OPTIONAL: raise ValueError(f"{name}: Field required") kwargs[name] = None elif input_field.default is not None: kwargs[name] = input_field.default else: if input_field.type.repetition is not adt.Repetition.OPTIONAL: raise ValueError(f"{name}: Field required") kwargs[name] = None # Validate constraints v = kwargs[name] values_to_check = [] if input_field.type.repetition is adt.Repetition.REQUIRED: values_to_check = [v] elif input_field.type.repetition is adt.Repetition.OPTIONAL: values_to_check = [] if v is None else [v] elif input_field.type.repetition is adt.Repetition.REPEATED: values_to_check = v if input_field.ge is not None: if not all(x >= input_field.ge for x in values_to_check): raise ValueError( f"{name} fails constraint >= {int(input_field.ge) if input_field.ge == int(input_field.ge) else input_field.ge}" ) if input_field.le is not None: if not all(x <= input_field.le for x in values_to_check): raise ValueError( f"{name} fails constraint <= {int(input_field.le) if input_field.le == int(input_field.le) else input_field.le}" ) if input_field.min_length is not None: if not all(len(x) >= input_field.min_length for x in values_to_check): raise ValueError( f"{name} fails constraint len() >= {input_field.min_length}" ) if input_field.max_length is not None: if not all(len(x) <= input_field.max_length for x in values_to_check): raise ValueError( f"{name} fails constraint len() <= {input_field.max_length}" ) if input_field.regex is not None: p = re.compile(input_field.regex) if not all(p.match(x) is not None for x in values_to_check): raise ValueError(f"{name} does not match regex {input_field.regex!r}") if input_field.choices is not None: if not all(x in input_field.choices for x in values_to_check): raise ValueError( f"{name} does not match choices {input_field.choices!r}" ) return kwargs ================================================ FILE: python/cog/_schemas.py ================================================ """ Internal schema generation for OpenAPI. This module provides functions to generate OpenAPI JSON schemas from PredictorInfo. """ from dataclasses import MISSING, Field from typing import Any, Dict from . import _adt as adt from .mode import Mode def to_json_input(predictor: adt.PredictorInfo) -> Dict[str, Any]: """Generate OpenAPI schema for predictor inputs.""" schema: Dict[str, Any] = { "properties": {}, "type": "object", "title": "Input", } required = [] for name, input_field in predictor.inputs.items(): prop: Dict[str, Any] = {"x-order": input_field.order} if input_field.choices is not None: prop["allOf"] = [{"$ref": f"#/components/schemas/{name}"}] else: prop["title"] = name.replace("_", " ").title() prop.update(input_field.type.json_type()) # Determine required status and default value: # - name: type = Input() -> required # - name: type = Input(default=value) -> not required, has default # - name: Optional[type] = Input() -> not required, default None # - name: Optional[type] = Input(default=value) -> not required, has default # - name: list[type] = Input() -> required # - name: list[type] = Input(default=[...]) -> not required, has default if input_field.default is None: if input_field.type.repetition in { adt.Repetition.REQUIRED, adt.Repetition.REPEATED, }: required.append(name) else: # Extract actual default for schema if isinstance(input_field.default, Field): if input_field.default.default_factory is not MISSING: actual_default = input_field.default.default_factory() elif input_field.default.default is not MISSING: actual_default = input_field.default.default else: actual_default = None else: actual_default = input_field.default if actual_default is not None: normalized = input_field.type.normalize(actual_default) prop["default"] = input_field.type.json_encode(normalized) # Optional types are nullable if input_field.type.repetition is adt.Repetition.OPTIONAL: prop["nullable"] = True # Add constraints if input_field.description is not None: prop["description"] = input_field.description if input_field.ge is not None: prop["minimum"] = input_field.ge if input_field.le is not None: prop["maximum"] = input_field.le if input_field.min_length is not None: prop["minLength"] = input_field.min_length if input_field.max_length is not None: prop["maxLength"] = input_field.max_length if input_field.regex is not None: prop["pattern"] = input_field.regex if input_field.deprecated is not None: prop["deprecated"] = input_field.deprecated schema["properties"][name] = prop if required: schema["required"] = required return schema def to_json_enums(predictor: adt.PredictorInfo) -> Dict[str, Any]: """Generate OpenAPI schema for enum inputs (choices).""" enums = {} for name, input_field in predictor.inputs.items(): if input_field.choices is None: continue enum_schema = { "title": name, "description": "An enumeration.", "enum": input_field.choices, } enum_schema.update(input_field.type.primitive.json_type()) enums[name] = enum_schema return enums def to_json_output(predictor: adt.PredictorInfo) -> Dict[str, Any]: """Generate OpenAPI schema for predictor output.""" return predictor.output.json_type() def to_json_schema( predictor: adt.PredictorInfo, mode: Mode = Mode.PREDICT ) -> Dict[str, Any]: """ Generate a complete OpenAPI schema for a predictor. This creates the full OpenAPI specification with Input, Output, and enum schemas populated from the predictor info. Args: predictor: The predictor info to generate schema from mode: The prediction mode (Mode.PREDICT or Mode.TRAIN) """ # Determine routes and schema names based on mode if mode == Mode.TRAIN: main_route = "/trainings" cancel_route = "/trainings/{training_id}/cancel" request_schema = "TrainingRequest" response_schema = "TrainingResponse" id_param_name = "training_id" id_param_title = "Training Id" summary = "Train" description = "Run a training session" operation_id = "train_trainings_post" cancel_operation_id = "cancel_trainings__training_id__cancel_post" else: main_route = "/predictions" cancel_route = "/predictions/{prediction_id}/cancel" request_schema = "PredictionRequest" response_schema = "PredictionResponse" id_param_name = "prediction_id" id_param_title = "Prediction Id" summary = "Predict" description = "Run a single prediction on the model" operation_id = "predict_predictions_post" cancel_operation_id = "cancel_predictions__prediction_id__cancel_post" # Base OpenAPI schema structure schema: Dict[str, Any] = { "openapi": "3.0.2", "info": {"title": "Cog", "version": "0.1.0"}, "paths": { "/": { "get": { "summary": "Root", "operationId": "root__get", "responses": { "200": { "description": "Successful Response", "content": {"application/json": {"schema": {}}}, } }, } }, "/health-check": { "get": { "summary": "Healthcheck", "operationId": "healthcheck_health_check_get", "responses": { "200": { "description": "Successful Response", "content": {"application/json": {"schema": {}}}, } }, } }, main_route: { "post": { "summary": summary, "description": description, "operationId": operation_id, "requestBody": { "content": { "application/json": { "schema": { "$ref": f"#/components/schemas/{request_schema}" } } } }, "responses": { "200": { "description": "Successful Response", "content": { "application/json": { "schema": { "$ref": f"#/components/schemas/{response_schema}" } } }, }, "422": { "description": "Validation Error", "content": { "application/json": { "schema": { "$ref": "#/components/schemas/HTTPValidationError" } } }, }, }, } }, cancel_route: { "post": { "summary": "Cancel", "operationId": cancel_operation_id, "parameters": [ { "required": True, "schema": {"title": id_param_title, "type": "string"}, "name": id_param_name, "in": "path", } ], "responses": { "200": { "description": "Successful Response", "content": {"application/json": {"schema": {}}}, }, "422": { "description": "Validation Error", "content": { "application/json": { "schema": { "$ref": "#/components/schemas/HTTPValidationError" } } }, }, }, } }, }, "components": { "schemas": { "HTTPValidationError": { "title": "HTTPValidationError", "type": "object", "properties": { "detail": { "title": "Detail", "type": "array", "items": {"$ref": "#/components/schemas/ValidationError"}, } }, }, request_schema: { "title": request_schema, "type": "object", "properties": { "id": {"title": "Id", "type": "string"}, "input": {"$ref": "#/components/schemas/Input"}, }, }, response_schema: { "title": response_schema, "type": "object", "properties": { "input": {"$ref": "#/components/schemas/Input"}, "output": {"$ref": "#/components/schemas/Output"}, "id": {"title": "Id", "type": "string"}, "version": {"title": "Version", "type": "string"}, "created_at": { "title": "Created At", "type": "string", "format": "date-time", }, "started_at": { "title": "Started At", "type": "string", "format": "date-time", }, "completed_at": { "title": "Completed At", "type": "string", "format": "date-time", }, "status": {"title": "Status", "type": "string"}, "error": {"title": "Error", "type": "string"}, "logs": {"title": "Logs", "type": "string"}, "metrics": {"title": "Metrics", "type": "object"}, }, }, "Status": { "title": "Status", "description": "An enumeration.", "enum": [ "starting", "processing", "succeeded", "canceled", "failed", ], "type": "string", }, "ValidationError": { "title": "ValidationError", "required": ["loc", "msg", "type"], "type": "object", "properties": { "loc": { "title": "Location", "type": "array", "items": { "anyOf": [{"type": "string"}, {"type": "integer"}] }, }, "msg": {"title": "Message", "type": "string"}, "type": {"title": "Error Type", "type": "string"}, }, }, } }, } # Add Input, Output, and enum schemas schema["components"]["schemas"]["Input"] = to_json_input(predictor) schema["components"]["schemas"]["Output"] = to_json_output(predictor) schema["components"]["schemas"].update(to_json_enums(predictor)) return schema ================================================ FILE: python/cog/coder.py ================================================ """ Cog SDK Coder system for custom type encoding/decoding. This module provides the Coder base class for defining custom type serialization between Python types and JSON. """ from abc import abstractmethod from typing import Any, Dict, Optional, Set, Type class Coder: """ Base class for custom type encoders/decoders. Implement this to add support for custom types in predictor inputs/outputs. Register your coder with Coder.register() to make it available. Example: from cog import Coder from myapp import MyCustomType class MyCustomCoder(Coder): @staticmethod def factory(tpe: Type) -> Optional["MyCustomCoder"]: if tpe is MyCustomType: return MyCustomCoder() return None def encode(self, value: MyCustomType) -> dict: return {"data": value.to_dict()} def decode(self, value: dict) -> MyCustomType: return MyCustomType.from_dict(value["data"]) # Register the coder Coder.register(MyCustomCoder) """ _coders: Set[Type["Coder"]] = set() @staticmethod def register(coder: Type["Coder"]) -> None: """ Register a coder class for custom type handling. Args: coder: A Coder subclass to register. """ Coder._coders.add(coder) @staticmethod def lookup(tpe: type | Any) -> Optional["Coder"]: """ Find a coder that can handle the given type. Args: tpe: The type to find a coder for. Returns: A Coder instance if one is found, None otherwise. """ for cls in Coder._coders: c = cls.factory(tpe) if c is not None: return c return None @staticmethod @abstractmethod def factory(tpe: Type[Any]) -> Optional["Coder"]: """ Factory method to create a coder for a given type. Override this to check if your coder can handle the type and return an instance if so. Args: tpe: The type to potentially handle. Returns: A Coder instance if this coder can handle the type, None otherwise. """ pass @abstractmethod def encode(self, x: Any) -> Dict[str, Any]: """ Encode a value to a JSON-serializable dictionary. Args: x: The value to encode. Returns: A dictionary representation of the value. """ pass @abstractmethod def decode(self, x: Dict[str, Any]) -> Any: """ Decode a dictionary back to the original type. Args: x: The dictionary to decode. Returns: The decoded value. """ pass ================================================ FILE: python/cog/command/__init__.py ================================================ """Cog CLI command modules.""" ================================================ FILE: python/cog/command/openapi_schema.py ================================================ """ python -m cog.command.openapi_schema This prints a JSON object describing the OpenAPI schema of the model. Schema is generated by introspecting the predictor's type annotations without starting the HTTP server. """ import importlib.util import json import os import sys from typing import Any, Dict from .._inspector import create_predictor from .._schemas import to_json_schema from ..config import Config from ..errors import ConfigDoesNotExist from ..mode import Mode from ..suppress_output import suppress_output def _load_module_from_ref(ref: str) -> tuple[str, str]: """Load a predictor module from a ref like 'predict.py:Predictor' or 'my-subdir/predict.py:Predictor'. Uses spec_from_file_location to load the module by file path, which handles subdirectory predictors correctly (unlike import_module which requires the module to be on sys.path). Returns (module_name, class_name) with the module pre-loaded in sys.modules. """ module_path, class_name = ref.rsplit(":", 1) if ":" in ref else (ref, "Predictor") module_name = os.path.basename(module_path).removesuffix(".py") # Load module from file path so subdirectory predictors work spec = importlib.util.spec_from_file_location(module_name, module_path) if spec is not None and spec.loader is not None: module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module spec.loader.exec_module(module) return module_name, class_name def remove_title_next_to_ref( schema_node: Any, ) -> Any: """ Recursively remove 'title' from schema components that have a '$ref'. This function addresses a non-compliance issue in FastAPI's OpenAPI schema generation. """ if isinstance(schema_node, dict): if "$ref" in schema_node and "title" in schema_node: del schema_node["title"] for _key, value in schema_node.items(): remove_title_next_to_ref(value) elif isinstance(schema_node, list): for i, item in enumerate(schema_node): schema_node[i] = remove_title_next_to_ref(item) return schema_node def fix_nullable_anyof(schema_node: Any) -> None: """ Convert anyOf with null type to nullable: true for OpenAPI 3.0 compatibility. FastAPI generates: {"anyOf": [{"type": "string"}, {"type": "null"}]} OpenAPI 3.0 wants: {"type": "string", "nullable": true} """ if isinstance(schema_node, dict): if "anyOf" in schema_node: anyof = schema_node["anyOf"] if isinstance(anyof, list) and len(anyof) == 2: # Check if one is {"type": "null"} null_idx = None other_idx = None for i, item in enumerate(anyof): if isinstance(item, dict) and item.get("type") == "null": null_idx = i else: other_idx = i if null_idx is not None and other_idx is not None: other = anyof[other_idx] if isinstance(other, dict): # Replace anyOf with the non-null type + nullable del schema_node["anyOf"] schema_node.update(other) schema_node["nullable"] = True for value in schema_node.values(): fix_nullable_anyof(value) elif isinstance(schema_node, list): for item in schema_node: fix_nullable_anyof(item) if __name__ == "__main__": schema: Dict[str, Any] = {} try: config = Config() # Determine mode: prefer predict, fall back to train try: ref = config.get_predictor_ref(Mode.PREDICT) mode = Mode.PREDICT except ValueError: ref = config.get_predictor_ref(Mode.TRAIN) mode = Mode.TRAIN module_name, class_name = _load_module_from_ref(ref) with suppress_output(): predictor_info = create_predictor(module_name, class_name) schema = to_json_schema(predictor_info, mode) remove_title_next_to_ref(schema) fix_nullable_anyof(schema) except FileNotFoundError: raise ConfigDoesNotExist("cog.yaml not found") from None print(json.dumps(schema, indent=2)) ================================================ FILE: python/cog/config.py ================================================ """ Configuration from cog.yaml. This module is restored for the legacy runtime schema generation path (python -m cog.command.openapi_schema). It reads cog.yaml to determine the predictor reference. """ import os from typing import Any, Optional import yaml from .errors import ConfigDoesNotExist from .mode import Mode COG_YAML_FILE = "cog.yaml" COG_PREDICT_TYPE_STUB_ENV_VAR = "COG_PREDICT_TYPE_STUB" COG_TRAIN_TYPE_STUB_ENV_VAR = "COG_TRAIN_TYPE_STUB" class Config: """A class for reading the cog.yaml properties.""" def __init__(self, config: Optional[dict[str, Any]] = None) -> None: self._config = config @property def _cog_config(self) -> dict[str, Any]: config = self._config if config is None: config_path = os.path.abspath(COG_YAML_FILE) try: with open(config_path, encoding="utf-8") as handle: config = yaml.safe_load(handle) except FileNotFoundError as e: raise ConfigDoesNotExist( f"Could not find {config_path}", ) from e self._config = config return config @property def predictor_predict_ref(self) -> Optional[str]: env_val = os.environ.get(COG_PREDICT_TYPE_STUB_ENV_VAR) if env_val: return env_val return self._cog_config.get(str(Mode.PREDICT)) @property def predictor_train_ref(self) -> Optional[str]: env_val = os.environ.get(COG_TRAIN_TYPE_STUB_ENV_VAR) if env_val: return env_val return self._cog_config.get(str(Mode.TRAIN)) def get_predictor_ref(self, mode: Mode) -> str: predictor_ref = None if mode == Mode.PREDICT: predictor_ref = self.predictor_predict_ref elif mode == Mode.TRAIN: predictor_ref = self.predictor_train_ref if predictor_ref is None: raise ValueError( f"Can't run predictions: '{mode}' option not found in cog.yaml" ) return predictor_ref ================================================ FILE: python/cog/errors.py ================================================ class CogError(Exception): """Base class for all Cog errors.""" class ConfigDoesNotExist(CogError): """Exception raised when a cog.yaml does not exist.""" class PredictorNotSet(CogError): """Exception raised when 'predict' is not set in cog.yaml when it needs to be.""" ================================================ FILE: python/cog/input.py ================================================ """ Cog SDK Input definition. This module provides the Input() function and FieldInfo class for defining predictor input parameters with constraints and metadata. """ from dataclasses import dataclass from typing import Any, Callable, List, Optional, Union @dataclass(frozen=True) class FieldInfo: """ Internal dataclass to hold Input metadata. This stores the constraints and metadata for a predictor input parameter. Users don't typically create this directly - use Input() instead. """ default: Any = None description: Optional[str] = None ge: Optional[Union[int, float]] = None le: Optional[Union[int, float]] = None min_length: Optional[int] = None max_length: Optional[int] = None regex: Optional[str] = None choices: Optional[List[Union[str, int]]] = None deprecated: Optional[bool] = None def Input( default: Any = None, *, default_factory: Optional[Callable[..., Any]] = None, description: Optional[str] = None, ge: Optional[Union[int, float]] = None, le: Optional[Union[int, float]] = None, min_length: Optional[int] = None, max_length: Optional[int] = None, regex: Optional[str] = None, choices: Optional[List[Union[str, int]]] = None, deprecated: Optional[bool] = None, ) -> Any: """ Create an input field specification for a predictor parameter. Use this to add metadata and constraints to predictor inputs. Example:: from cog import BasePredictor, Input class Predictor(BasePredictor): def predict( self, prompt: str = Input(description="The input prompt"), temperature: float = Input(default=0.7, ge=0.0, le=2.0), max_tokens: int = Input(default=100, ge=1, le=4096), ) -> str: ... Args: default: Default value for the field. Must be an immutable literal value. description: Human-readable description of the input. ge: Minimum value (greater than or equal) for numeric inputs. le: Maximum value (less than or equal) for numeric inputs. min_length: Minimum length for string inputs. max_length: Maximum length for string inputs. regex: Regular expression pattern for string inputs. choices: List of allowed values. deprecated: Whether the input is deprecated. Returns: A FieldInfo instance containing the field metadata. """ if default_factory is not None: raise TypeError( "default_factory is not supported in Input(). " "Use a literal default value instead: Input(default=...). " "Mutable defaults like lists should use immutable alternatives " "(e.g. a comma-separated string) or be constructed in predict()." ) return FieldInfo( default=default, description=description, ge=ge, le=le, min_length=min_length, max_length=max_length, regex=regex, choices=choices, deprecated=deprecated, ) ================================================ FILE: python/cog/mode.py ================================================ from enum import Enum class Mode(Enum): """Enumeration over the different prediction modes.""" PREDICT = "predict" TRAIN = "train" def __str__(self) -> str: return str(self.value) ================================================ FILE: python/cog/model.py ================================================ """ Cog SDK BaseModel definition. This module provides the BaseModel class that users can subclass to define structured output types. BaseModel automatically converts subclasses into dataclasses. """ from dataclasses import dataclass, is_dataclass class BaseModel: """ Base class for structured output types. Subclasses are automatically converted to dataclasses. This provides a clean way to define output schemas without explicit dataclass decorators. Example: from cog import BaseModel class Output(BaseModel): text: str confidence: float # Use as return type def predict(self, prompt: str) -> Output: return Output(text="hello", confidence=0.9) By default, auto_dataclass=True, which means all subclasses are automatically wrapped with @dataclass. You can disable this with auto_dataclass=False if you need manual control: class ManualModel(BaseModel, auto_dataclass=False): # You must apply @dataclass yourself or handle initialization pass """ def __init_subclass__( cls, *, auto_dataclass: bool = True, init: bool = True, **kwargs: object, ) -> None: """ Hook called when BaseModel is subclassed. This automatically wraps subclasses with @dataclass unless auto_dataclass=False is specified. Args: auto_dataclass: If True, automatically apply @dataclass to the class. init: If True (and auto_dataclass=True), generate __init__. **kwargs: Additional arguments passed to @dataclass. """ # BaseModel is parented to `object` so we have nothing to pass up to it, # we pass the kwargs to dataclass() only. super().__init_subclass__() # For sanity, the primary base class must inherit from BaseModel if not issubclass(cls.__bases__[0], BaseModel): raise TypeError( f'Primary base class of "{cls.__name__}" must inherit from BaseModel' ) elif not auto_dataclass: try: if ( cls.__bases__[0] != BaseModel and cls.__bases__[0].__auto_dataclass is True # type: ignore[attr-defined] ): raise ValueError( f'Primary base class of "{cls.__name__}" ' f'("{cls.__bases__[0].__name__}") has auto_dataclass=True, ' f'but "{cls.__name__}" has auto_dataclass=False. ' "This creates broken field inheritance." ) except AttributeError: raise RuntimeError( f'Primary base class of "{cls.__name__}" is a child of a child ' "of `BaseModel`, but `auto_dataclass` tracking does not exist. " "This is likely a bug or other programming error." ) from None for base in cls.__bases__[1:]: if is_dataclass(base): raise TypeError( f'Cannot mixin dataclass "{base.__name__}" while inheriting ' "from `BaseModel`" ) # Once manual dataclass handling is enabled, we never apply the auto # dataclass logic again. It becomes the responsibility of the user to # ensure that all dataclass semantics are handled. if not auto_dataclass: cls.__auto_dataclass = False # type: ignore[attr-defined] return # All children should be dataclass'd. This is the only way to ensure # that the dataclass inheritance is handled properly. dataclass(init=init, **kwargs)(cls) # type: ignore[call-overload] cls.__auto_dataclass = True # type: ignore[attr-defined] ================================================ FILE: python/cog/predictor.py ================================================ """ Cog SDK BasePredictor definition. This module provides the BasePredictor class that users subclass to define their model's prediction interface. """ import importlib import importlib.util import inspect import os import sys from typing import Any, Optional, Union from .types import Path class BasePredictor: """ Base class for Cog predictors. Subclass this to define your model's prediction interface. Override the `setup` method to load your model, and the `predict` method to run predictions. Example: from cog import BasePredictor, Input, Path class Predictor(BasePredictor): def setup(self): self.model = load_model() def predict(self, prompt: str = Input(description="Input text")) -> str: self.record_metric("temperature", 0.7) return self.model.generate(prompt) """ def setup( self, weights: Optional[Union[Path, str]] = None, ) -> None: """ Prepare the model for predictions. This method is called once when the predictor is initialized. Use it to load model weights and do any other one-time setup. Args: weights: Optional path to model weights. Can be a local path or URL. """ pass def predict(self, **kwargs: Any) -> Any: """ Run a single prediction. Override this method to implement your model's prediction logic. Input parameters should be annotated with types and optionally use Input() for additional metadata. Args: **kwargs: Prediction inputs as defined by the method signature. Returns: The prediction output. Raises: NotImplementedError: If predict is not implemented. """ raise NotImplementedError("predict has not been implemented by parent class.") @property def scope(self) -> Any: """The current prediction scope. Provides access to the full scope API for advanced metric operations like dict-style access and deletion:: self.scope.metrics["token_count"] = 42 del self.scope.metrics["token_count"] Outside an active prediction this returns a no-op scope. """ import coglet return coglet._sdk.current_scope() # type: ignore[attr-defined] def record_metric(self, key: str, value: Any, mode: str = "replace") -> None: """Record a prediction metric. Convenience method for recording metrics on the current prediction scope. Outside an active prediction this is a silent no-op. Args: key: Metric name. Use dot-separated keys (e.g. ``"timing.inference"``) to create nested objects in the metrics output. value: Metric value. Supported types: bool, int, float, str, list, dict. Setting a value to ``None`` deletes the metric. mode: Accumulation mode. One of: - ``"replace"`` (default): overwrite any previous value. - ``"incr"``: add to the existing numeric value. - ``"append"``: append to an array. Example:: class Predictor(BasePredictor): def predict(self, prompt: str) -> str: self.record_metric("temperature", 0.7) self.record_metric("token_count", 1, mode="incr") return self.model.generate(prompt) """ self.scope.record_metric(key, value, mode=mode) def load_predictor_from_ref(ref: str) -> BasePredictor: """Load a predictor from a module:class reference (e.g. 'predict.py:Predictor').""" module_path, class_name = ref.split(":", 1) if ":" in ref else (ref, "Predictor") module_name = os.path.basename(module_path).replace(".py", "") # Use spec_from_file_location to load from file path spec = importlib.util.spec_from_file_location(module_name, module_path) if spec is None or spec.loader is None: raise ImportError(f"Cannot load module from {module_path}") module = importlib.util.module_from_spec(spec) # Add module to sys.modules so pickle can find it sys.modules[module_name] = module spec.loader.exec_module(module) predictor = getattr(module, class_name) # It could be a class or a function (for training) if inspect.isclass(predictor): return predictor() return predictor def has_setup_weights(predictor: BasePredictor) -> bool: """Check if predictor's setup accepts a weights parameter.""" if not hasattr(predictor, "setup"): return False sig = inspect.signature(predictor.setup) return "weights" in sig.parameters def extract_setup_weights(predictor: BasePredictor) -> Optional[Union[Path, str]]: """Extract weights from environment for setup.""" weights = os.environ.get("COG_WEIGHTS") if weights: return weights return None ================================================ FILE: python/cog/server/__init__.py ================================================ ================================================ FILE: python/cog/server/http.py ================================================ import argparse import os import sys from enum import Enum import coglet class Mode(Enum): PREDICT = "predict" TRAIN = "train" def __str__(self) -> str: return str(self.value) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Cog HTTP server") parser.add_argument( "-v", "--version", action="store_true", help="Show version and exit" ) parser.add_argument( "--host", dest="host", type=str, default="0.0.0.0", help="Host to bind to", ) parser.add_argument( "--await-explicit-shutdown", dest="await_explicit_shutdown", type=bool, default=False, help="Ignore SIGTERM and wait for a request to /shutdown (or a SIGINT) before exiting", ) parser.add_argument( "--x-mode", dest="mode", type=Mode, default=Mode.PREDICT, choices=list(Mode), help="Experimental: Run in 'predict' or 'train' mode", ) # Accept but ignore other args for compatibility parser.add_argument("--threads", dest="threads", type=int, default=None) parser.add_argument("--upload-url", dest="upload_url", type=str, default=None) args = parser.parse_args() if args.version: print(f"coglet (Rust) {coglet.__version__}") # type: ignore[attr-defined] sys.exit(0) port = int(os.getenv("PORT", "5000")) is_train = args.mode == Mode.TRAIN # Resolve predictor ref from env vars (set by Dockerfile at build time) if is_train: predictor_ref = os.environ.get("COG_TRAIN_TYPE_STUB") else: predictor_ref = os.environ.get("COG_PREDICT_TYPE_STUB") if not predictor_ref: env_var = "COG_TRAIN_TYPE_STUB" if is_train else "COG_PREDICT_TYPE_STUB" print( f"ERROR: {env_var} environment variable is not set.\n" f"This should be set automatically by 'cog build'. If running manually,\n" f"set it to your predictor reference (e.g. {env_var}=predict.py:Predictor).", file=sys.stderr, ) sys.exit(1) coglet.server.serve( # type: ignore[attr-defined] predictor_ref=predictor_ref, host=args.host, port=port, await_explicit_shutdown=args.await_explicit_shutdown, is_train=is_train, upload_url=args.upload_url, ) sys.exit(0) ================================================ FILE: python/cog/suppress_output.py ================================================ import os import sys from contextlib import contextmanager from typing import Iterator @contextmanager def suppress_output() -> Iterator[None]: out_fd = sys.stdout.fileno() err_fd = sys.stderr.fileno() out_dup_fd = os.dup(out_fd) err_dup_fd = os.dup(err_fd) try: with ( open(os.devnull, "w", encoding="utf-8") as null_out, open(os.devnull, "w", encoding="utf-8") as null_err, ): os.dup2(null_out.fileno(), out_fd) os.dup2(null_err.fileno(), err_fd) try: yield finally: os.dup2(out_dup_fd, out_fd) os.dup2(err_dup_fd, err_fd) finally: os.close(out_dup_fd) os.close(err_dup_fd) ================================================ FILE: python/cog/types.py ================================================ """ Cog SDK type definitions. This module provides core types for defining predictor inputs and outputs: - Path: File path type that supports URL inputs - Secret: Secure string type that masks its value - File: Deprecated file type (use Path instead) - ConcatenateIterator: Streaming output iterator - AsyncConcatenateIterator: Async streaming output iterator """ import io import mimetypes import os import pathlib import shutil import tempfile import urllib.parse import urllib.request from abc import abstractmethod from dataclasses import dataclass from typing import ( Any, AsyncIterator, Dict, Iterator, Optional, TypeVar, ) import requests # Constants for filename handling FILENAME_ILLEGAL_CHARS = set("\u0000/") FILENAME_MAX_LENGTH = 200 def _len_bytes(s: str) -> int: """Return the length of a string in bytes (UTF-8).""" return len(s.encode("utf-8")) def _truncate_filename_bytes(filename: str, length: int) -> str: """Truncate a filename to a maximum byte length, preserving extension.""" if _len_bytes(filename) <= length: return filename # Split filename and extension name, ext = os.path.splitext(filename) # Reserve space for tilde and extension max_name_length = length - _len_bytes(ext) - 1 # Truncate name encoded = name.encode("utf-8") truncated = encoded[:max_name_length].decode("utf-8", errors="ignore") return f"{truncated}~{ext}" def get_filename(url: str) -> str: """Extract a filename from a URL.""" parsed_url = urllib.parse.urlparse(url) if parsed_url.scheme == "data": # Safe: scheme is validated to be 'data:' before urlopen with urllib.request.urlopen(url) as resp: # noqa: S310 mime_type = resp.headers.get_content_type() extension = mimetypes.guess_extension(mime_type) if extension is None: return "file" return "file" + extension basename = os.path.basename(parsed_url.path) basename = urllib.parse.unquote_plus(basename) # Truncate if too long if _len_bytes(basename) > FILENAME_MAX_LENGTH: basename = _truncate_filename_bytes(basename, length=FILENAME_MAX_LENGTH) # Replace illegal characters for c in FILENAME_ILLEGAL_CHARS: basename = basename.replace(c, "_") return basename ######################################## # Secret ######################################## @dataclass(frozen=True) class Secret: """ A secret string value that masks itself in string representations. Use this type for sensitive data like API keys or passwords that should not be logged or displayed. Example: def predict(self, api_key: Secret) -> str: key = api_key.get_secret_value() # Use key... """ secret_value: Optional[str] = None def __repr__(self) -> str: return f"Secret({str(self)})" def __str__(self) -> str: return "**********" if self.secret_value is not None else "" def get_secret_value(self) -> Optional[str]: """Return the actual secret value.""" return self.secret_value ######################################## # URLFile ######################################## class URLFile(io.IOBase): """ URLFile is a proxy object for a :class:`urllib3.response.HTTPResponse` object that is created lazily. It's a file-like object constructed from a URL that can survive pickling/unpickling. """ __slots__ = ("__target__", "__url__", "name") def __init__(self, url: str, filename: Optional[str] = None) -> None: parsed = urllib.parse.urlparse(url) if parsed.scheme not in {"http", "https"}: raise ValueError( "URLFile requires URL to conform to HTTP or HTTPS protocol" ) if not filename: filename = os.path.basename(parsed.path) object.__setattr__(self, "name", filename) object.__setattr__(self, "__url__", url) def __del__(self) -> None: try: object.__getattribute__(self, "__target__") except AttributeError: # Do nothing when tearing down the object if the response object # hasn't been created yet. return super().__del__() # We provide __getstate__ and __setstate__ explicitly to ensure that the # object is always picklable. def __getstate__(self) -> Dict[str, Any]: return { "name": object.__getattribute__(self, "name"), "url": object.__getattribute__(self, "__url__"), } def __setstate__(self, state: Dict[str, Any]) -> None: object.__setattr__(self, "name", state["name"]) object.__setattr__(self, "__url__", state["url"]) # Proxy getattr/setattr/delattr through to the response object. def __setattr__(self, name: str, value: Any) -> None: if hasattr(type(self), name): object.__setattr__(self, name, value) else: setattr(self.__wrapped__, name, value) def __getattr__(self, name: str) -> Any: if name in ("__target__", "__wrapped__", "__url__"): raise AttributeError(name) elif name == "name": return object.__getattribute__(self, "name") return getattr(self.__wrapped__, name) def __delattr__(self, name: str) -> None: if hasattr(type(self), name): object.__delattr__(self, name) else: delattr(self.__wrapped__, name) # Luckily the only dunder method on HTTPResponse is __iter__ def __iter__(self) -> Iterator[bytes]: return iter(self.__wrapped__) @property def __wrapped__(self) -> Any: try: return object.__getattribute__(self, "__target__") except AttributeError: pass url = object.__getattribute__(self, "__url__") headers = {} ua = os.getenv("COG_USER_AGENT") if ua: headers["User-Agent"] = ua resp = requests.get(url, stream=True, timeout=10, headers=headers) resp.raise_for_status() resp.raw.decode_content = True object.__setattr__(self, "__target__", resp.raw) return resp.raw def __repr__(self) -> str: try: target = object.__getattribute__(self, "__target__") except AttributeError: return f"<{type(self).__name__} at 0x{id(self):x} for {object.__getattribute__(self, '__url__')!r}>" return f"<{type(self).__name__} at 0x{id(self):x} wrapping {target!r}>" ######################################## # File (Deprecated) ######################################## class File(io.IOBase): """ Deprecated: use Path instead. A file-like object that can be constructed from a URL or data URI. """ @classmethod def validate(cls, value: Any) -> io.IOBase: """Validate and convert a value to a file-like object.""" if isinstance(value, io.IOBase): return value parsed_url = urllib.parse.urlparse(value) if parsed_url.scheme == "data": # Safe: scheme is validated to be 'data:' before urlopen with urllib.request.urlopen(value) as res: # noqa: S310 return io.BytesIO(res.read()) if parsed_url.scheme in ("http", "https"): return URLFile(value) raise ValueError( f"'{parsed_url.scheme}' is not a valid URL scheme. " "'data', 'http', or 'https' is supported." ) ######################################## # URLPath ######################################## class URLPath(pathlib.PosixPath): """ URLPath is a nasty hack to ensure that we can defer the downloading of a URL passed as a path until later in prediction dispatch. It subclasses pathlib.PosixPath only so that it can pass isinstance(_, pathlib.Path) checks. """ _path: Optional["Path"] # pylint: disable=super-init-not-called def __init__(self, *, source: str, filename: str, fileobj: io.IOBase) -> None: if len(filename) > FILENAME_MAX_LENGTH: filename = _truncate_filename_bytes(filename, FILENAME_MAX_LENGTH) self.source = source self.filename = filename self.fileobj = fileobj self._path = None def __new__(cls, *, source: str, filename: str, fileobj: io.IOBase) -> "URLPath": # PosixPath.__new__ requires path segments, but we don't have a real path # Use a placeholder that will be replaced obj = super().__new__(cls, filename) return obj def convert(self) -> "Path": """Download the URL content to a temporary file and return its Path.""" if self._path is None: # pylint: disable=consider-using-with dest = tempfile.NamedTemporaryFile(suffix=self.filename, delete=False) shutil.copyfileobj(self.fileobj, dest) dest.close() self._path = Path(dest.name) return self._path def unlink(self, missing_ok: bool = False) -> None: """Remove the temporary file if it exists.""" if self._path: self._path.unlink(missing_ok=missing_ok) def __str__(self) -> str: # FastAPI's jsonable_encoder will encode subclasses of pathlib.Path by # calling str() on them return self.source ######################################## # Path ######################################## class Path(pathlib.PosixPath): """ A path type that can be constructed from URLs. When a URL is passed, it creates a URLPath that defers downloading until the file is actually needed. Example: def predict(self, image: Path) -> Path: # image could be a local path or downloaded from URL return process(image) """ @classmethod def validate(cls, value: Any) -> pathlib.Path: """Validate and convert a value to a Path.""" if isinstance(value, pathlib.Path): return value parsed_url = urllib.parse.urlparse(value) if parsed_url.scheme in ("data", "http", "https"): return URLPath( source=value, filename=get_filename(value), fileobj=File.validate(value), ) return Path(value) ######################################## # Iterators ######################################## Item = TypeVar("Item") class ConcatenateIterator(Iterator[Item]): """ An iterator that yields items which should be concatenated for display. Use this as a return type hint for streaming text output where the individual chunks should be joined together. Example: def predict(self, prompt: str) -> ConcatenateIterator[str]: for token in generate_tokens(prompt): yield token """ @abstractmethod def __next__(self) -> Item: ... class AsyncConcatenateIterator(AsyncIterator[Item]): """ An async iterator that yields items which should be concatenated for display. Use this as a return type hint for async streaming text output where the individual chunks should be joined together. Example: async def predict(self, prompt: str) -> AsyncConcatenateIterator[str]: async for token in generate_tokens_async(prompt): yield token """ @abstractmethod async def __anext__(self) -> Item: ... ================================================ FILE: python/tests/__init__.py ================================================ """Tests for cog-dataclass SDK.""" ================================================ FILE: python/tests/test_emit_metric.py ================================================ """Tests for the emit_metric backwards-compatibility shim.""" import subprocess import sys class TestEmitMetric: """Tests for the deprecated emit_metric import.""" def test_import_succeeds(self) -> None: """Importing emit_metric should not raise.""" result = subprocess.run( [sys.executable, "-c", "from cog import emit_metric"], capture_output=True, text=True, ) assert result.returncode == 0, f"Import failed: {result.stderr}" def test_attribute_access_succeeds(self) -> None: """Accessing cog.emit_metric as a module attribute should not raise.""" result = subprocess.run( [sys.executable, "-c", "import cog; cog.emit_metric"], capture_output=True, text=True, ) assert result.returncode == 0, f"Attribute access failed: {result.stderr}" def test_prints_deprecation_to_stderr(self) -> None: """First import should print a deprecation message to stderr.""" result = subprocess.run( [sys.executable, "-c", "from cog import emit_metric"], capture_output=True, text=True, ) assert "emit_metric() is deprecated" in result.stderr def test_message_prints_once(self) -> None: """The deprecation message should print only once per process, not on every call.""" code = "from cog import emit_metric\nfrom cog import emit_metric\n" result = subprocess.run( [sys.executable, "-c", code], capture_output=True, text=True, ) assert result.stderr.count("emit_metric() is deprecated") == 1 def test_callable(self) -> None: """emit_metric should be callable and not raise outside a prediction context.""" result = subprocess.run( [ sys.executable, "-c", "from cog import emit_metric; emit_metric('output_tokens', 42)", ], capture_output=True, text=True, ) assert result.returncode == 0, f"Call failed: {result.stderr}" def test_module_attribute_callable(self) -> None: """cog.emit_metric(...) style (used in cog-triton, cog-arctic, etc.) should work.""" result = subprocess.run( [ sys.executable, "-c", "import cog; cog.emit_metric('input_token_count', 100)", ], capture_output=True, text=True, ) assert result.returncode == 0, ( f"Call via module attribute failed: {result.stderr}" ) def test_unknown_attr_still_raises(self) -> None: """Adding emit_metric shim should not break AttributeError for unknown attrs.""" result = subprocess.run( [sys.executable, "-c", "import cog; cog.NoSuchAttribute"], capture_output=True, text=True, ) assert result.returncode != 0 assert "AttributeError" in result.stderr ================================================ FILE: python/tests/test_experimental_feature_warning.py ================================================ """Tests for the ExperimentalFeatureWarning backwards-compatibility shim.""" import subprocess import sys class TestExperimentalFeatureWarning: """Tests for the deprecated ExperimentalFeatureWarning import.""" def test_import_succeeds(self) -> None: """Importing ExperimentalFeatureWarning should not raise.""" result = subprocess.run( [sys.executable, "-c", "from cog import ExperimentalFeatureWarning"], capture_output=True, text=True, ) assert result.returncode == 0, f"Import failed: {result.stderr}" def test_prints_deprecation_to_stderr(self) -> None: """First import should print a deprecation message to stderr.""" result = subprocess.run( [sys.executable, "-c", "from cog import ExperimentalFeatureWarning"], capture_output=True, text=True, ) assert "ExperimentalFeatureWarning is deprecated" in result.stderr def test_message_prints_once(self) -> None: """The deprecation message should print only once per process.""" code = ( "from cog import ExperimentalFeatureWarning\n" "from cog import ExperimentalFeatureWarning\n" ) result = subprocess.run( [sys.executable, "-c", code], capture_output=True, text=True, ) assert result.stderr.count("ExperimentalFeatureWarning is deprecated") == 1 def test_is_future_warning_subclass(self) -> None: """The shim class should be a subclass of FutureWarning.""" code = ( "from cog import ExperimentalFeatureWarning\n" "assert issubclass(ExperimentalFeatureWarning, FutureWarning)\n" ) result = subprocess.run( [sys.executable, "-c", code], capture_output=True, text=True, ) assert result.returncode == 0, f"Assertion failed: {result.stderr}" def test_filterwarnings_compat(self) -> None: """The real use case: warnings.filterwarnings('ignore', ...) should work.""" code = ( "import warnings\n" "from cog import ExperimentalFeatureWarning\n" "warnings.filterwarnings('ignore', category=ExperimentalFeatureWarning)\n" ) result = subprocess.run( [sys.executable, "-c", code], capture_output=True, text=True, ) assert result.returncode == 0, f"filterwarnings failed: {result.stderr}" def test_unknown_attr_raises(self) -> None: """Accessing a non-existent attribute should raise AttributeError.""" code = "import cog\ncog.NoSuchAttribute\n" result = subprocess.run( [sys.executable, "-c", code], capture_output=True, text=True, ) assert result.returncode != 0 assert "AttributeError" in result.stderr ================================================ FILE: python/tests/test_input.py ================================================ """Tests for cog.input module (Input, FieldInfo).""" import pytest from cog import Input from cog.input import FieldInfo class TestInput: """Tests for Input() function.""" def test_input_returns_fieldinfo(self) -> None: result = Input(description="Test input") assert isinstance(result, FieldInfo) def test_input_with_default(self) -> None: result = Input(default="hello", description="A string") assert result.default == "hello" assert result.description == "A string" def test_input_with_numeric_constraints(self) -> None: result = Input(default=5, ge=0, le=10) assert result.default == 5 assert result.ge == 0 assert result.le == 10 def test_input_with_string_constraints(self) -> None: result = Input(min_length=1, max_length=100, regex=r"^\w+$") assert result.min_length == 1 assert result.max_length == 100 assert result.regex == r"^\w+$" def test_input_with_choices(self) -> None: result = Input(default="a", choices=["a", "b", "c"]) assert result.default == "a" assert result.choices == ["a", "b", "c"] def test_input_with_deprecated(self) -> None: result = Input(deprecated=True) assert result.deprecated is True def test_input_default_factory_raises_error(self) -> None: with pytest.raises(TypeError, match="default_factory is not supported"): Input(default_factory=list) def test_input_immutable_defaults_stored_directly(self) -> None: for default in ["string", 42, 3.14, True, None, (1, 2), frozenset([1, 2])]: result = Input(default=default) assert result.default == default def test_input_no_default(self) -> None: # No default means the parameter is required result = Input(description="Required input") assert result.default is None assert result.description == "Required input" class TestFieldInfo: """Tests for FieldInfo dataclass.""" def test_fieldinfo_is_frozen(self) -> None: info = FieldInfo(default="test") with pytest.raises(AttributeError): info.default = "new" # type: ignore[misc] def test_fieldinfo_defaults(self) -> None: info = FieldInfo(default=5, ge=0, le=10, description="A number") assert info.default == 5 assert info.ge == 0 assert info.le == 10 assert info.description == "A number" def test_fieldinfo_none_defaults(self) -> None: info = FieldInfo(description="Just a description") assert info.default is None assert info.ge is None assert info.le is None ================================================ FILE: python/tests/test_model.py ================================================ """Tests for cog.model module (BaseModel).""" from dataclasses import is_dataclass from typing import Optional from cog import BaseModel, Path class TestBaseModel: """Tests for BaseModel auto-dataclass behavior.""" def test_subclass_becomes_dataclass(self) -> None: class Output(BaseModel): text: str score: float assert is_dataclass(Output) def test_subclass_can_be_instantiated(self) -> None: class Output(BaseModel): text: str score: float output = Output(text="hello", score=0.9) assert output.text == "hello" assert output.score == 0.9 def test_subclass_with_defaults(self) -> None: class Output(BaseModel): text: str score: float = 0.5 output = Output(text="hello") assert output.text == "hello" assert output.score == 0.5 def test_subclass_with_optional(self) -> None: class Output(BaseModel): text: str metadata: Optional[str] = None output = Output(text="hello") assert output.text == "hello" assert output.metadata is None def test_nested_models(self) -> None: class Inner(BaseModel): value: int class Outer(BaseModel): inner: Inner name: str inner = Inner(value=42) outer = Outer(inner=inner, name="test") assert outer.inner.value == 42 assert outer.name == "test" def test_inheritance(self) -> None: class Base(BaseModel): x: int class Derived(Base): y: str derived = Derived(x=1, y="two") assert derived.x == 1 assert derived.y == "two" def test_auto_dataclass_false(self) -> None: class Manual(BaseModel, auto_dataclass=False): x: int def __init__(self, x: int) -> None: self.x = x # Should not be auto-dataclassed assert not is_dataclass(Manual) # But should still be usable m = Manual(x=5) assert m.x == 5 def test_primary_base_must_be_basemodel(self) -> None: class NotBaseModel: pass try: class Bad(NotBaseModel, BaseModel): # type: ignore[misc] x: int assert False, "Should have raised TypeError" except TypeError as e: assert "must inherit from BaseModel" in str(e) def test_cannot_mixin_dataclass(self) -> None: from dataclasses import dataclass @dataclass class SomeDataclass: y: int try: class Bad(BaseModel, SomeDataclass): # type: ignore[misc] x: int assert False, "Should have raised TypeError" except TypeError as e: assert "Cannot mixin dataclass" in str(e) def test_auto_dataclass_inheritance_mismatch(self) -> None: class Parent(BaseModel): x: int try: class Child(Parent, auto_dataclass=False): y: str assert False, "Should have raised ValueError" except ValueError as e: assert "auto_dataclass=True" in str(e) assert "auto_dataclass=False" in str(e) def test_basemodel_asdict(self) -> None: from dataclasses import asdict class Output(BaseModel): weights: Path output = Output(weights=Path("weights.bin")) assert asdict(output) == {"weights": Path("weights.bin")} ================================================ FILE: python/tests/test_predictor.py ================================================ """Tests for cog.predictor module (BasePredictor).""" from typing import Optional from cog import BasePredictor, Path class TestBasePredictor: """Tests for BasePredictor class.""" def test_subclass_can_override_predict(self) -> None: class MyPredictor(BasePredictor): def predict(self, text: str) -> str: return text.upper() predictor = MyPredictor() result = predictor.predict(text="hello") assert result == "HELLO" def test_default_predict_raises(self) -> None: predictor = BasePredictor() try: predictor.predict() assert False, "Should have raised NotImplementedError" except NotImplementedError as e: assert "predict has not been implemented" in str(e) def test_setup_is_optional(self) -> None: class MyPredictor(BasePredictor): def predict(self, x: int) -> int: return x * 2 predictor = MyPredictor() # setup() should not raise predictor.setup() assert predictor.predict(x=5) == 10 def test_setup_with_weights(self) -> None: class MyPredictor(BasePredictor): weights_path: Optional[str] = None def setup(self, weights: Optional[str] = None) -> None: self.weights_path = weights def predict(self, x: int) -> int: return x predictor = MyPredictor() predictor.setup(weights="/path/to/weights") assert predictor.weights_path == "/path/to/weights" def test_setup_with_path_weights(self) -> None: class MyPredictor(BasePredictor): weights_path: Optional[Path] = None def setup(self, weights: Optional[Path] = None) -> None: self.weights_path = weights def predict(self, x: int) -> int: return x predictor = MyPredictor() predictor.setup(weights=Path("/path/to/weights")) assert str(predictor.weights_path) == "/path/to/weights" def test_predictor_with_multiple_inputs(self) -> None: class MyPredictor(BasePredictor): def predict(self, a: int, b: int, c: str = "default") -> str: return f"{a + b}: {c}" predictor = MyPredictor() result = predictor.predict(a=1, b=2, c="test") assert result == "3: test" result_default = predictor.predict(a=1, b=2) assert result_default == "3: default" def test_predictor_with_state(self) -> None: class StatefulPredictor(BasePredictor): count: int = 0 def setup(self, weights: Optional[str] = None) -> None: self.count = 0 def predict(self, x: int) -> int: self.count += 1 return x * self.count predictor = StatefulPredictor() predictor.setup() assert predictor.predict(x=10) == 10 assert predictor.predict(x=10) == 20 assert predictor.predict(x=10) == 30 ================================================ FILE: python/tests/test_types.py ================================================ """Tests for cog.types module.""" import io from dataclasses import is_dataclass from cog import ( AsyncConcatenateIterator, ConcatenateIterator, File, Path, Secret, URLFile, ) class TestSecret: """Tests for Secret type.""" def test_secret_creation(self) -> None: secret = Secret(secret_value="my-api-key") assert secret.get_secret_value() == "my-api-key" def test_secret_masks_in_str(self) -> None: secret = Secret(secret_value="my-api-key") assert str(secret) == "**********" assert "my-api-key" not in str(secret) def test_secret_masks_in_repr(self) -> None: secret = Secret(secret_value="my-api-key") assert "my-api-key" not in repr(secret) assert "**********" in repr(secret) def test_secret_none_value(self) -> None: secret = Secret(secret_value=None) assert secret.get_secret_value() is None assert str(secret) == "" def test_secret_default_none(self) -> None: secret = Secret() assert secret.get_secret_value() is None def test_secret_is_dataclass(self) -> None: assert is_dataclass(Secret) def test_secret_is_frozen(self) -> None: secret = Secret(secret_value="test") try: secret.secret_value = "new" # type: ignore[misc] assert False, "Should have raised FrozenInstanceError" except Exception: pass # Expected - frozen dataclass class TestPath: """Tests for Path type.""" def test_path_from_string(self) -> None: p = Path("/tmp/test.txt") assert str(p) == "/tmp/test.txt" def test_path_is_pathlib_subclass(self) -> None: import pathlib p = Path("/tmp/test.txt") assert isinstance(p, pathlib.PosixPath) class TestFile: """Tests for File type (deprecated).""" def test_file_validate_iobase(self) -> None: buf = io.BytesIO(b"test data") result = File.validate(buf) assert result is buf def test_file_validate_data_uri(self) -> None: # data URI with plain text data_uri = "data:text/plain;base64,SGVsbG8gV29ybGQ=" result = File.validate(data_uri) assert isinstance(result, io.BytesIO) assert result.read() == b"Hello World" def test_file_validate_invalid_scheme(self) -> None: try: File.validate("ftp://example.com/file.txt") assert False, "Should have raised ValueError" except ValueError as e: assert "not a valid URL scheme" in str(e) class TestURLFile: """Tests for URLFile type.""" def test_urlfile_creation(self) -> None: url = "https://example.com/image.jpg" uf = URLFile(url) assert uf.name == "image.jpg" def test_urlfile_invalid_scheme(self) -> None: try: URLFile("ftp://example.com/file.txt") assert False, "Should have raised ValueError" except ValueError as e: assert "HTTP or HTTPS" in str(e) def test_urlfile_picklable(self) -> None: import pickle url = "https://example.com/image.jpg" uf = URLFile(url) pickled = pickle.dumps(uf) restored = pickle.loads(pickled) assert restored.name == "image.jpg" def test_urlfile_custom_filename(self) -> None: url = "https://example.com/image.jpg" uf = URLFile(url, filename="custom.png") assert uf.name == "custom.png" class TestIterators: """Tests for iterator types.""" def test_concatenate_iterator_is_abstract(self) -> None: # ConcatenateIterator should be usable as a type hint from typing import Iterator assert issubclass(ConcatenateIterator, Iterator) def test_async_concatenate_iterator_is_abstract(self) -> None: # AsyncConcatenateIterator should be usable as a type hint from typing import AsyncIterator assert issubclass(AsyncConcatenateIterator, AsyncIterator) ================================================ FILE: script/generate-compat ================================================ #!/usr/bin/env bash # # Regenerate CUDA/PyTorch/TensorFlow compatibility matrices. # # Usage: # script/generate-compat # regenerate all matrices # script/generate-compat cuda # regenerate CUDA base images only # script/generate-compat torch # regenerate PyTorch compatibility only # script/generate-compat tensorflow # regenerate TensorFlow compatibility only # # The generated JSON files are checked into source control and only need # to be regenerated when adding support for new framework versions. set -euo pipefail SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" ROOT_DIR="$(dirname "$SCRIPT_DIR")" CONFIG_DIR="${ROOT_DIR}/pkg/config" generate_cuda() { echo "Generating CUDA base images..." go run "${ROOT_DIR}/tools/compatgen/main.go" cuda -o "${CONFIG_DIR}/cuda_base_images.json" } generate_torch() { echo "Generating PyTorch compatibility matrix..." go run "${ROOT_DIR}/tools/compatgen/main.go" torch -o "${CONFIG_DIR}/torch_compatibility_matrix.json" } generate_tensorflow() { echo "Generating TensorFlow compatibility matrix..." go run "${ROOT_DIR}/tools/compatgen/main.go" tensorflow -o "${CONFIG_DIR}/tf_compatibility_matrix.json" } target="${1:-all}" case "$target" in cuda) generate_cuda ;; torch) generate_torch ;; tensorflow|tf) generate_tensorflow ;; all) generate_cuda generate_torch generate_tensorflow ;; *) echo "Unknown target: $target" echo "Usage: $0 [cuda|torch|tensorflow|all]" exit 1 ;; esac echo "Done." ================================================ FILE: test-helpers/https-server/go.mod ================================================ module github.com/replicate/cog/test-helpers/https-server go 1.21 ================================================ FILE: test-helpers/https-server/main.go ================================================ // Package main provides a simple HTTPS server for testing CA certificate injection. // Usage: go run ./test-helpers/https-server --cert=server.crt --key=server.key --addr=:8443 package main import ( "flag" "fmt" "log" "net/http" "os" "os/signal" "syscall" ) func main() { cert := flag.String("cert", "", "Path to TLS certificate file") key := flag.String("key", "", "Path to TLS key file") addr := flag.String("addr", ":8443", "Address to listen on") flag.Parse() if *cert == "" || *key == "" { fmt.Fprintln(os.Stderr, "Usage: https-server --cert=server.crt --key=server.key [--addr=:8443]") os.Exit(1) } // Simple handler that returns OK http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) fmt.Fprintln(w, "OK") }) // Start server in a goroutine server := &http.Server{Addr: *addr} go func() { log.Printf("Starting HTTPS server on %s", *addr) if err := server.ListenAndServeTLS(*cert, *key); err != http.ErrServerClosed { log.Fatalf("HTTPS server error: %v", err) } }() // Print ready message for test synchronization fmt.Println("READY") // Wait for interrupt signal sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) <-sigChan log.Println("Shutting down HTTPS server") } ================================================ FILE: test-integration/test_integration/fixtures/hello-image/cog.yaml ================================================ build: python_version: "3.11" predict: "predict.py:Predictor" image: "r8.im/replicate/hello-image" ================================================ FILE: test-integration/test_integration/fixtures/hello-image/predict.py ================================================ from cog import BasePredictor, Path class Predictor(BasePredictor): def predict(self, word: str) -> Path: return Path("hello.webp") ================================================ FILE: tools/compatgen/internal/cuda.go ================================================ package internal import ( "cmp" "context" "encoding/json" "fmt" "slices" "sort" "strings" "github.com/anaskhan96/soup" "github.com/google/go-containerregistry/pkg/authn" "github.com/google/go-containerregistry/pkg/name" "github.com/google/go-containerregistry/pkg/v1/remote" "github.com/replicate/cog/pkg/config" ) func FetchCUDABaseImages(ctx context.Context) ([]config.CUDABaseImage, error) { url := "https://hub.docker.com/v2/repositories/nvidia/cuda/tags/?page_size=1000&name=devel-ubuntu&ordering=last_updated" tags, err := fetchCUDABaseImageTags(url) if err != nil { return nil, err } var images []config.CUDABaseImage for _, tag := range tags { image, err := parseCUDABaseImage(ctx, tag) if err != nil { return nil, err } images = append(images, *image) } // stable sort for deterministic output slices.SortFunc(images, func(a, b config.CUDABaseImage) int { return cmp.Or( cmp.Compare(a.CUDA, b.CUDA), cmp.Compare(a.CuDNN, b.CuDNN), cmp.Compare(a.Ubuntu, b.Ubuntu), cmp.Compare(a.Tag, b.Tag), ) }) return images, nil } func fetchCUDABaseImageTags(url string) ([]string, error) { tags := []string{} resp, err := soup.Get(url) if err != nil { return tags, fmt.Errorf("Failed to download %s: %w", url, err) } var results struct { Next *string Results []struct { Name string `json:"name"` } `json:"results"` } if err := json.Unmarshal([]byte(resp), &results); err != nil { return tags, fmt.Errorf("Failed parse CUDA images json: %w", err) } for _, result := range results.Results { tag := result.Name if strings.Contains(tag, "-cudnn") && !strings.HasSuffix(tag, "-rc") { tags = append(tags, tag) } } // recursive case for pagination if results.Next != nil { nextURL := *results.Next nextTags, err := fetchCUDABaseImageTags(nextURL) if err != nil { return tags, err } tags = append(tags, nextTags...) } sort.Sort(sort.Reverse(sort.StringSlice(tags))) return tags, nil } // parseCUDABaseImage fetches the Docker image config for an nvidia/cuda tag // and extracts CUDA and CuDNN versions from environment variables. This is // necessary because newer nvidia/cuda tags no longer include the CuDNN version // in the tag itself (e.g. "12.9.1-cudnn-devel-ubuntu24.04" instead of // "12.6.3-cudnn9-devel-ubuntu22.04"). func parseCUDABaseImage(ctx context.Context, tag string) (*config.CUDABaseImage, error) { fmt.Println("parsing", tag) baseImg := &config.CUDABaseImage{ Tag: tag, IsDevel: strings.Contains(tag, "-devel"), } if parts := strings.Split(tag, "ubuntu"); len(parts) == 2 { baseImg.Ubuntu = parts[1] } else { return nil, fmt.Errorf("invalid tag, must end in ubuntu: %q", tag) } ref, err := name.ParseReference(fmt.Sprintf("nvidia/cuda:%s", tag)) if err != nil { return nil, fmt.Errorf("failed to parse reference %s: %w", tag, err) } img, err := remote.Image(ref, remote.WithContext(ctx), remote.WithAuthFromKeychain(authn.DefaultKeychain)) if err != nil { return nil, fmt.Errorf("failed to get image %s: %w", tag, err) } cfg, err := img.ConfigFile() if err != nil { return nil, fmt.Errorf("failed to get config file %s: %w", tag, err) } for _, envVal := range cfg.Config.Env { parts := strings.SplitN(envVal, "=", 2) if len(parts) != 2 { continue } switch parts[0] { case "CUDA_VERSION": baseImg.CUDA = parts[1] case "NV_CUDNN_VERSION": // downstream code expects only the major version component baseImg.CuDNN = strings.Split(parts[1], ".")[0] } } if baseImg.CuDNN == "" { return nil, fmt.Errorf("no CuDNN version found in image config for tag %s", tag) } return baseImg, nil } ================================================ FILE: tools/compatgen/internal/tensorflow.go ================================================ package internal import ( "fmt" "strconv" "strings" "github.com/replicate/cog/pkg/util/version" "github.com/anaskhan96/soup" "github.com/replicate/cog/pkg/config" ) func FetchTensorFlowCompatibilityMatrix() ([]config.TFCompatibility, error) { url := "https://www.tensorflow.org/install/source" minCudaVersion := strconv.Itoa(config.MinimumMajorCudaVersion) resp, err := soup.Get(url) if err != nil { return nil, fmt.Errorf("Failed to download %s: %w", url, err) } doc := soup.HTMLParse(resp) gpuHeading := doc.Find("h4", "id", "gpu") table := gpuHeading.FindNextElementSibling() rows := table.FindAll("tr") compats := []config.TFCompatibility{} for _, row := range rows[1:] { cells := row.FindAll("td") gpuPackage, packageVersion := split2(cells[0].Text(), "-") pythonVersions, err := parsePythonVersionsCell(cells[1].Text()) if err != nil { return nil, err } cuDNN := cells[4].Text() cuda := cells[5].Text() if !version.Greater(cuda, minCudaVersion) && !version.Equal(cuda, minCudaVersion) { continue } compat := config.TFCompatibility{ TF: packageVersion, TFCPUPackage: "tensorflow==" + packageVersion, TFGPUPackage: gpuPackage + "==" + packageVersion, CUDA: cuda, CuDNN: cuDNN, Pythons: pythonVersions, } compats = append(compats, compat) } // sanity check if len(compats) < 12 { return nil, fmt.Errorf("Tensorflow compatibility matrix only had %d rows, has the html changed?", len(compats)) } return compats, nil } func parsePythonVersionsCell(val string) ([]string, error) { versions := []string{} parts := strings.SplitSeq(val, ",") for part := range parts { part = strings.TrimSpace(part) if strings.Contains(part, "-") { start, end := split2(part, "-") startMajor, startMinor, err := splitPythonVersion(start) if err != nil { return nil, err } endMajor, endMinor, err := splitPythonVersion(end) if err != nil { return nil, err } if startMajor != endMajor { return nil, fmt.Errorf("Invalid start and end minor versions: %d, %d", startMajor, endMajor) } for minor := startMinor; minor <= endMinor; minor++ { versions = append(versions, newVersion(startMajor, minor)) } } else { versions = append(versions, part) } } return versions, nil } func newVersion(major int, minor int) string { return fmt.Sprintf("%d.%d", major, minor) } func splitPythonVersion(version string) (major int, minor int, err error) { version = strings.TrimSpace(version) majorStr, minorStr := split2(version, ".") major, err = strconv.Atoi(majorStr) if err != nil { return 0, 0, err } minor, err = strconv.Atoi(minorStr) if err != nil { return 0, 0, err } return major, minor, nil } ================================================ FILE: tools/compatgen/internal/torch.go ================================================ package internal import ( "errors" "fmt" "net/url" "regexp" "slices" "sort" "strconv" "strings" "github.com/anaskhan96/soup" "github.com/hashicorp/go-version" "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/env" "github.com/replicate/cog/pkg/util/console" ) var ErrorBadPytorchFormat = errors.New("The pytorch version format could not be parsed.") func FetchTorchCompatibilityMatrix() ([]config.TorchCompatibility, error) { compats := []config.TorchCompatibility{} var err error compats, err = fetchCurrentTorchVersions(compats) if err != nil { return nil, err } compats, err = fetchPreviousTorchVersions(compats) if err != nil { return nil, err } // Remove entries with no supported Python versions filtered := make([]config.TorchCompatibility, 0, len(compats)) for _, c := range compats { if len(c.Pythons) > 0 { filtered = append(filtered, c) } else { console.Warnf("Dropping %s: no supported Python versions", c.Torch) } } compats = filtered // sanity check if len(compats) < 21 { return nil, fmt.Errorf("PyTorch compatibility matrix only had %d rows, has the html changed?", len(compats)) } return compats, nil } func FetchTorchPackages(name string) ([]TorchPackage, error) { url := pytorchURL(name) return fetchTorchPackagesFromURL(url) } func getLatestVersion(packages []TorchPackage) string { latestVersion, _ := version.NewVersion("0.0.0") for _, pkg := range packages { v, err := version.NewVersion(pkg.Version) if err != nil { fmt.Println("error parsing version:", pkg.Version) continue } if v.GreaterThan(latestVersion) { latestVersion = v } } return latestVersion.String() } func fetchCurrentTorchVersions(compats []config.TorchCompatibility) ([]config.TorchCompatibility, error) { // For the latest PyTorch version, we can just grab the latest of each packages from the repository. // We then install the packages in the same way as we do for 1.12.x: // https://pytorch.org/get-started/previous-versions/#v1121 torchPackages, err := FetchTorchPackages("torch") if err != nil { return nil, fmt.Errorf("Error fetching PyTorch packages: %w", err) } torchVisionPackages, err := FetchTorchPackages("torchvision") if err != nil { return nil, fmt.Errorf("Error fetching PyTorch packages: %w", err) } torchAudioPackages, err := FetchTorchPackages("torchaudio") if err != nil { return nil, fmt.Errorf("Error fetching PyTorch packages: %w", err) } latestTorchVersion := getLatestVersion(torchPackages) latestTorchvisionVersion := getLatestVersion(torchVisionPackages) latestTorchaudioVersion := getLatestVersion(torchAudioPackages) torchCompats := map[string]config.TorchCompatibility{} for _, pkg := range torchPackages { if pkg.Version != latestTorchVersion { continue } if val, ok := torchCompats[pkg.Name]; ok { if !slices.Contains(val.Pythons, pkg.PythonVersion) { val.Pythons = append(val.Pythons, pkg.PythonVersion) } torchCompats[pkg.Name] = val } else { torchCompats[pkg.Name] = config.TorchCompatibility{ Torch: pkg.Name, Torchvision: latestTorchvisionVersion, Torchaudio: latestTorchaudioVersion, CUDA: pkg.CUDA, ExtraIndexURL: pytorchURL(pkg.Variant), Pythons: []string{pkg.PythonVersion}, } } } for _, compat := range torchCompats { compats = append(compats, compat) } return compats, nil } func parseTorchInstallString(s string, defaultVersions map[string]string, cuda *string) (*config.TorchCompatibility, error) { // for example: // pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 // pip install torch==1.8.0+cpu torchvision==0.9.0+cpu torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html libVersions := map[string]string{} findLinks := "" extraIndexURL := "" skipNext := false // Simple parser for pip install strings fields := strings.Fields(s) for i, item := range fields { // Ideally we want to be able to consume the next token, but golang has no simple way of doing that without constructing a channel if skipNext { skipNext = false continue } switch item { case "pip", "pip3", "install": continue case "-f": findLinks = fields[i+1] skipNext = true continue case "--index-url", "--extra-index-url": extraIndexURL = fields[i+1] skipNext = true continue } libParts := strings.Split(item, "==") libName := libParts[0] if _, ok := defaultVersions[libName]; !ok { return nil, fmt.Errorf("Unknown token when parsing torch string: %s", item) } if len(libParts) == 1 { libVersions[libName] = defaultVersions[libName] } else { libVersions[libName] = libParts[1] } } torch, ok := libVersions["torch"] if !ok { return nil, fmt.Errorf("Missing torch version") } torchvision, ok := libVersions["torchvision"] if !ok { return nil, fmt.Errorf("Missing torchvision version") } torchaudio := libVersions["torchaudio"] pythons, err := FindCompatiblePythonVersions(torch, torchvision, torchaudio, extraIndexURL, findLinks) if err != nil { return nil, err } return &config.TorchCompatibility{ Torch: torch, Torchvision: torchvision, Torchaudio: torchaudio, FindLinks: findLinks, ExtraIndexURL: extraIndexURL, CUDA: cuda, Pythons: pythons, }, nil } func fetchPreviousTorchVersions(compats []config.TorchCompatibility) ([]config.TorchCompatibility, error) { // For previous versions, we need to scrape the PyTorch website. // The reason we can't fetch it from the PyPI repository like the latest version is // because we don't know what versions of torch, torchvision, and torchaudio are compatible with each other. url := "https://pytorch.org/get-started/previous-versions/" resp, err := soup.Get(url) if err != nil { return nil, fmt.Errorf("Failed to download %s: %w", url, err) } doc := soup.HTMLParse(resp) for _, h5 := range doc.FindAll("h5") { if strings.TrimSpace(h5.Text()) == "Linux and Windows" { highlight := h5.FindNextElementSibling() code := highlight.Find("code") compats, err = parsePreviousTorchVersionsCode(code.Text(), compats) if err != nil { return nil, err } } } return compats, nil } func parsePreviousTorchVersionsCode(code string, compats []config.TorchCompatibility) ([]config.TorchCompatibility, error) { // e.g. // # CUDA 10.1 // pip install torch==1.5.0+cu101 torchvision==0.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html supportedLibrarySet := map[string]string{ "torch": "", "torchvision": "", "torchaudio": "", } var cuda *string skipSection := false for line := range strings.SplitSeq(code, "\n") { // Set section if strings.HasPrefix(line, "#") { skipSection = false rawArch := strings.ToLower(line[2:]) switch { case strings.HasPrefix(rawArch, "cuda"): _, c := split2(rawArch, " ") cuda = &c case rawArch == "cpu only": cuda = nil case strings.HasPrefix(rawArch, "rocm"): cuda = nil skipSection = true default: // Ignore additional heading lines (notes, etc) continue } } // In a ROCM section, so skip if skipSection { continue } // conda install etc if !strings.HasPrefix(line, "pip install ") { continue } compat, err := parseTorchInstallString(line, supportedLibrarySet, cuda) if err != nil { return nil, err } fixTorchCompatibility(compat) compats = append(compats, *compat) } return compats, nil } // torchvision==0.8.0 should actually be 0.8.1, this is a bug on the website func fixTorchCompatibility(compat *config.TorchCompatibility) { if strings.HasPrefix(compat.Torchvision, "0.8.0") { compat.Torchvision = strings.ReplaceAll(compat.Torchvision, "0.8.0", "0.8.1") } } func basePytorchURL() string { return env.SchemeFromEnvironment() + "://" + env.PytorchHostFromEnvironment() + "/whl" } func pytorchURL(name string) string { url := fmt.Sprintf(basePytorchURL()+"/%s/", name) return url } func ExtractSubFeaturesFromPytorchVersion(pytorchVersion string) (string, string, string, string, string, error) { decoded, err := url.PathUnescape(pytorchVersion) if err != nil { return "", "", "", "", "", fmt.Errorf("failed to decode filename: %w", err) } pkgRegexp := regexp.MustCompile( `.+?-(?P\d+(?:\.\d+)*)(?P(?:[._]?(?:post|dev|rc)\d+)*)?(?:\+(?P[a-z0-9_.]+))?-(?P[a-z0-9_.]+)-[a-z0-9_.]+-(?P.+?)\.whl`, ) matches := pkgRegexp.FindStringSubmatch(decoded) if len(matches) == 0 { return "", "", "", "", "", fmt.Errorf("invalid PyTorch wheel filename: %s", decoded) } groupMap := make(map[string]string) for i, name := range pkgRegexp.SubexpNames() { if i != 0 && name != "" { groupMap[name] = matches[i] } } base := groupMap["basever"] suffix := groupMap["suffix"] variant := groupMap["variant"] pyverRaw := groupMap["pyver"] platform := groupMap["platform"] name := base + suffix if variant != "" { name += "+" + variant } version := base pyver := pyverRaw if strings.HasPrefix(pyverRaw, "cp") { pyver = pyverRaw[len("cp"):] } return name, version, variant, pyver, platform, nil } func FindCompatiblePythonVersions(torchVersion string, torchVisionVersion string, torchAudioVersion string, extraIndexUrl string, findLinksUrl string) ([]string, error) { if extraIndexUrl == "" && findLinksUrl == "" { extraIndexUrl = basePytorchURL() } url := extraIndexUrl if url == "" { url = findLinksUrl } // Correct 0.8.0 torchvision to 0.8.1, this is a bug on pytorch.org if strings.HasPrefix(torchVisionVersion, "0.8.0") { torchVisionVersion = strings.ReplaceAll(torchVisionVersion, "0.8.0", "0.8.1") } torchPkgs, err := findTorchPackagesWithVersion("torch", url, torchVersion, url != findLinksUrl) if err != nil { return nil, err } torchVisionPkgs, err := findTorchPackagesWithVersion("torchvision", url, torchVisionVersion, url != findLinksUrl) if err != nil { return nil, err } torchAudioPkgs, err := findTorchPackagesWithVersion("torchaudio", url, torchAudioVersion, url != findLinksUrl) if err != nil { return nil, err } // Get initial list of valid python versions from torch pythonVersions := map[string]bool{} for _, pkg := range torchPkgs { pythonVersions[pkg.PythonVersion] = true } // Check that torchaudio/torchvision shares these python versions extraPkgs := [][]TorchPackage{} if torchVisionVersion != "" { extraPkgs = append(extraPkgs, torchVisionPkgs) } if torchAudioVersion != "" { extraPkgs = append(extraPkgs, torchAudioPkgs) } for _, pkgs := range extraPkgs { pkgPythonVersions := map[string]bool{} for _, pkg := range pkgs { pkgPythonVersions[pkg.PythonVersion] = true } for pythonVersion := range pythonVersions { _, ok := pkgPythonVersions[pythonVersion] if !ok { delete(pythonVersions, pythonVersion) } } } validPythonVersions := make([]string, 0, len(pythonVersions)) for k := range pythonVersions { validPythonVersions = append(validPythonVersions, k) } sort.Strings(validPythonVersions) return validPythonVersions, nil } func fetchTorchPackagesFromURL(url string) ([]TorchPackage, error) { resp, err := soup.Get(url) if err != nil { return nil, fmt.Errorf("Failed to download %s: %w", url, err) } doc := soup.HTMLParse(resp) links := doc.FindAll("a") packages := []TorchPackage{} for _, link := range links { name, version, variant, pythonVersion, platform, err := ExtractSubFeaturesFromPytorchVersion(link.Text()) if err != nil { console.Warnf("Failed to parse pytorch version: %v", err) continue } if (platform != "linux_x86_64" && platform != "manylinux_2_28_x86_64" && platform != "manylinux1_x86_64") || strings.Contains(name, ".cxx") { continue } var cuda *string switch { case variant == "cpu": cuda = nil case variant == "": cuda = nil case strings.HasPrefix(variant, "cu"): // cu92 -> 9.2 c := strings.TrimPrefix(variant, "cu") c = c[:len(c)-1] + "." + c[len(c)-1:] cuda = &c default: // rocm etc continue } // 310 -> 3.10 pythonVersion = pythonVersion[:1] + "." + pythonVersion[1:] if minor, ok := strings.CutPrefix(pythonVersion, "3."); ok { minorInt, err := strconv.Atoi(minor) if err != nil { return nil, fmt.Errorf("invalid python version %q: %w", pythonVersion, err) } if minorInt < config.MinimumMinorPythonVersion { continue } } pkg := TorchPackage{ Name: name, Version: version, Variant: variant, CUDA: cuda, PythonVersion: pythonVersion, } found := false for _, currentPkg := range packages { if currentPkg.Equals(pkg) { found = true break } } if found { continue } packages = append(packages, pkg) } return packages, nil } func findTorchPackagesWithVersion(pkgName string, url string, version string, appendPkg bool) ([]TorchPackage, error) { if appendPkg { url = url + "/" + pkgName } pkgs, err := fetchTorchPackagesFromURL(url) if err != nil { return nil, err } validPkgs := []TorchPackage{} for _, pkg := range pkgs { if pkg.Version != version && pkg.Name != version { continue } validPkgs = append(validPkgs, pkg) } return validPkgs, nil } ================================================ FILE: tools/compatgen/internal/torch_package.go ================================================ package internal type TorchPackage struct { Name string Version string Variant string CUDA *string PythonVersion string } func (c *TorchPackage) Equals(other TorchPackage) bool { if c.CUDA != other.CUDA { if c.CUDA != nil && other.CUDA != nil && *c.CUDA != *other.CUDA { return false } } return c.Name == other.Name && c.Version == other.Version && c.Variant == other.Variant && c.PythonVersion == other.PythonVersion } ================================================ FILE: tools/compatgen/internal/torch_test.go ================================================ package internal import ( "log" "net/http" "net/http/httptest" "net/url" "os" "strings" "testing" "github.com/stretchr/testify/require" "github.com/replicate/cog/pkg/env" ) func TestFetchTorchPackages(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { content, err := os.ReadFile("torch_test.html") if err != nil { log.Fatalf("Error reading file: %v", err) } w.WriteHeader(http.StatusOK) w.Write(content) })) defer server.Close() url, err := url.Parse(server.URL) require.NoError(t, err) t.Setenv(env.SchemeEnvVarName, url.Scheme) t.Setenv(env.PytorchHostEnvVarName, url.Host) torchPackages, err := FetchTorchPackages("torch") require.NoError(t, err) torch271Packages := []TorchPackage{} for _, pkg := range torchPackages { if strings.Contains(pkg.Name, "2.7.1+cu128") { torch271Packages = append(torch271Packages, pkg) } } cuda128 := "12.8" require.Equal(t, []TorchPackage{ { Name: "2.7.1+cu128", Version: "2.7.1", Variant: "cu128", CUDA: &cuda128, PythonVersion: "3.10", }, { Name: "2.7.1+cu128", Version: "2.7.1", Variant: "cu128", CUDA: &cuda128, PythonVersion: "3.11", }, { Name: "2.7.1+cu128", Version: "2.7.1", Variant: "cu128", CUDA: &cuda128, PythonVersion: "3.12", }, { Name: "2.7.1+cu128", Version: "2.7.1", Variant: "cu128", CUDA: &cuda128, PythonVersion: "3.13", }, }, torch271Packages) } func TestIsValidPytorchVersionFormat(t *testing.T) { name, version, variant, pythonVersion, platform, err := ExtractSubFeaturesFromPytorchVersion("torch-2.7.1+cpu.cxx11.abi-cp312-cp312-linux_x86_64.whl") require.NoError(t, err) require.Equal(t, "2.7.1+cpu.cxx11.abi", name) require.Equal(t, "2.7.1", version) require.Equal(t, "cpu.cxx11.abi", variant) require.Equal(t, "312", pythonVersion) require.Equal(t, "linux_x86_64", platform) } func TestIsValidPytorchVersionFormatWithOldVersion(t *testing.T) { name, version, variant, pythonVersion, platform, err := ExtractSubFeaturesFromPytorchVersion("torch-1.10.0+cpu-cp310-cp310-linux_x86_64.whl") require.NoError(t, err) require.Equal(t, "1.10.0+cpu", name) require.Equal(t, "1.10.0", version) require.Equal(t, "cpu", variant) require.Equal(t, "310", pythonVersion) require.Equal(t, "linux_x86_64", platform) } func TestIsValidPytorchAudioVersionFormat(t *testing.T) { name, version, variant, pythonVersion, platform, err := ExtractSubFeaturesFromPytorchVersion("torchaudio-2.7.1+xpu-cp313-cp313t-win_amd64.whl") require.NoError(t, err) require.Equal(t, "2.7.1+xpu", name) require.Equal(t, "2.7.1", version) require.Equal(t, "xpu", variant) require.Equal(t, "313", pythonVersion) require.Equal(t, "win_amd64", platform) } func TestIsValidPytorchAudioVersionFormatBasic(t *testing.T) { name, version, variant, pythonVersion, platform, err := ExtractSubFeaturesFromPytorchVersion("torchaudio-0.8.1-cp39-none-win_amd64.whl") require.NoError(t, err) require.Equal(t, "0.8.1", name) require.Equal(t, "0.8.1", version) require.Equal(t, "", variant) require.Equal(t, "39", pythonVersion) require.Equal(t, "win_amd64", platform) } func TestIsValidPytorchVisionVersionFormatPostRelease(t *testing.T) { name, version, variant, pythonVersion, platform, err := ExtractSubFeaturesFromPytorchVersion("torchvision-0.4.1.post2-cp37-cp37m-macosx_10_9_x86_64.whl") require.NoError(t, err) require.Equal(t, "0.4.1.post2", name) require.Equal(t, "0.4.1", version) require.Equal(t, "", variant) require.Equal(t, "37", pythonVersion) require.Equal(t, "macosx_10_9_x86_64", platform) } func TestIsValidPytorchVisionEarlyVersion(t *testing.T) { name, version, variant, pythonVersion, platform, err := ExtractSubFeaturesFromPytorchVersion("torchvision-0.14.1+cu116-cp310-cp310-linux_x86_64.whl") require.NoError(t, err) require.Equal(t, "0.14.1+cu116", name) require.Equal(t, "0.14.1", version) require.Equal(t, "cu116", variant) require.Equal(t, "310", pythonVersion) require.Equal(t, "linux_x86_64", platform) } func TestIsValidPytorchAudioEarlyVersion(t *testing.T) { name, version, variant, pythonVersion, platform, err := ExtractSubFeaturesFromPytorchVersion("torchaudio-0.9.1-cp39-cp39-linux_x86_64.whl") require.NoError(t, err) require.Equal(t, "0.9.1", name) require.Equal(t, "0.9.1", version) require.Equal(t, "", variant) require.Equal(t, "39", pythonVersion) require.Equal(t, "linux_x86_64", platform) } func TestURLEncodedVersion(t *testing.T) { name, version, variant, pythonVersion, platform, err := ExtractSubFeaturesFromPytorchVersion("torchtext-0.17.0%2Bcpu-cp39-cp39-win_amd64.whl") require.NoError(t, err) require.Equal(t, "0.17.0+cpu", name) require.Equal(t, "0.17.0", version) require.Equal(t, "cpu", variant) require.Equal(t, "39", pythonVersion) require.Equal(t, "win_amd64", platform) } func TestVersionUnderFolder(t *testing.T) { name, version, variant, pythonVersion, platform, err := ExtractSubFeaturesFromPytorchVersion("cu111/torch-1.8.0%2Bcu111-cp36-cp36m-linux_x86_64.whl") require.NoError(t, err) require.Equal(t, "1.8.0+cu111", name) require.Equal(t, "1.8.0", version) require.Equal(t, "cu111", variant) require.Equal(t, "36", pythonVersion) require.Equal(t, "linux_x86_64", platform) } func TestPythonMVersion(t *testing.T) { name, version, variant, pythonVersion, platform, err := ExtractSubFeaturesFromPytorchVersion("torchaudio-0.7.2-cp36-cp36m-linux_x86_64.whl") require.NoError(t, err) require.Equal(t, "0.7.2", name) require.Equal(t, "0.7.2", version) require.Equal(t, "", variant) require.Equal(t, "36", pythonVersion) require.Equal(t, "linux_x86_64", platform) } ================================================ FILE: tools/compatgen/internal/torch_test.html ================================================

Links for torch

torch-2.0.0+cpu.cxx11.abi-cp310-cp310-linux_x86_64.whl
torch-2.0.0+cpu.cxx11.abi-cp311-cp311-linux_x86_64.whl
torch-2.0.0+cpu.cxx11.abi-cp38-cp38-linux_x86_64.whl
torch-2.0.0+cpu.cxx11.abi-cp39-cp39-linux_x86_64.whl
torch-2.0.1+cpu.cxx11.abi-cp310-cp310-linux_x86_64.whl
torch-2.0.1+cpu.cxx11.abi-cp311-cp311-linux_x86_64.whl
torch-2.0.1+cpu.cxx11.abi-cp38-cp38-linux_x86_64.whl
torch-2.0.1+cpu.cxx11.abi-cp39-cp39-linux_x86_64.whl
torch-2.1.0+cpu.cxx11.abi-cp310-cp310-linux_x86_64.whl
torch-2.1.0+cpu.cxx11.abi-cp311-cp311-linux_x86_64.whl
torch-2.1.0+cpu.cxx11.abi-cp38-cp38-linux_x86_64.whl
torch-2.1.0+cpu.cxx11.abi-cp39-cp39-linux_x86_64.whl
torch-2.1.1+cpu.cxx11.abi-cp310-cp310-linux_x86_64.whl
torch-2.1.1+cpu.cxx11.abi-cp311-cp311-linux_x86_64.whl
torch-2.1.1+cpu.cxx11.abi-cp38-cp38-linux_x86_64.whl
torch-2.1.1+cpu.cxx11.abi-cp39-cp39-linux_x86_64.whl
torch-2.1.2+cpu.cxx11.abi-cp310-cp310-linux_x86_64.whl
torch-2.1.2+cpu.cxx11.abi-cp311-cp311-linux_x86_64.whl
torch-2.1.2+cpu.cxx11.abi-cp38-cp38-linux_x86_64.whl
torch-2.1.2+cpu.cxx11.abi-cp39-cp39-linux_x86_64.whl
torch-2.2.0+cpu.cxx11.abi-cp310-cp310-linux_x86_64.whl
torch-2.2.0+cpu.cxx11.abi-cp311-cp311-linux_x86_64.whl
torch-2.2.0+cpu.cxx11.abi-cp312-cp312-linux_x86_64.whl
torch-2.2.0+cpu.cxx11.abi-cp38-cp38-linux_x86_64.whl
torch-2.2.0+cpu.cxx11.abi-cp39-cp39-linux_x86_64.whl
torch-2.2.1+cpu.cxx11.abi-cp310-cp310-linux_x86_64.whl
torch-2.2.1+cpu.cxx11.abi-cp311-cp311-linux_x86_64.whl
torch-2.2.1+cpu.cxx11.abi-cp312-cp312-linux_x86_64.whl
torch-2.2.1+cpu.cxx11.abi-cp38-cp38-linux_x86_64.whl
torch-2.2.1+cpu.cxx11.abi-cp39-cp39-linux_x86_64.whl
torch-2.2.2+cpu.cxx11.abi-cp310-cp310-linux_x86_64.whl
torch-2.2.2+cpu.cxx11.abi-cp311-cp311-linux_x86_64.whl
torch-2.2.2+cpu.cxx11.abi-cp312-cp312-linux_x86_64.whl
torch-2.2.2+cpu.cxx11.abi-cp38-cp38-linux_x86_64.whl
torch-2.2.2+cpu.cxx11.abi-cp39-cp39-linux_x86_64.whl
torch-2.3.0+cpu.cxx11.abi-cp310-cp310-linux_x86_64.whl
torch-2.3.0+cpu.cxx11.abi-cp311-cp311-linux_x86_64.whl
torch-2.3.0+cpu.cxx11.abi-cp312-cp312-linux_x86_64.whl
torch-2.3.0+cpu.cxx11.abi-cp38-cp38-linux_x86_64.whl
torch-2.3.0+cpu.cxx11.abi-cp39-cp39-linux_x86_64.whl
torch-2.3.1+cpu.cxx11.abi-cp310-cp310-linux_x86_64.whl
torch-2.3.1+cpu.cxx11.abi-cp311-cp311-linux_x86_64.whl
torch-2.3.1+cpu.cxx11.abi-cp312-cp312-linux_x86_64.whl
torch-2.3.1+cpu.cxx11.abi-cp38-cp38-linux_x86_64.whl
torch-2.3.1+cpu.cxx11.abi-cp39-cp39-linux_x86_64.whl
torch-2.4.0+cpu.cxx11.abi-cp310-cp310-linux_x86_64.whl
torch-2.4.0+cpu.cxx11.abi-cp311-cp311-linux_x86_64.whl
torch-2.4.0+cpu.cxx11.abi-cp312-cp312-linux_x86_64.whl
torch-2.4.0+cpu.cxx11.abi-cp38-cp38-linux_x86_64.whl
torch-2.4.0+cpu.cxx11.abi-cp39-cp39-linux_x86_64.whl
torch-2.4.1+cpu.cxx11.abi-cp310-cp310-linux_x86_64.whl
torch-2.4.1+cpu.cxx11.abi-cp311-cp311-linux_x86_64.whl
torch-2.4.1+cpu.cxx11.abi-cp312-cp312-linux_x86_64.whl
torch-2.4.1+cpu.cxx11.abi-cp38-cp38-linux_x86_64.whl
torch-2.4.1+cpu.cxx11.abi-cp39-cp39-linux_x86_64.whl
torch-2.5.0+cpu.cxx11.abi-cp310-cp310-linux_x86_64.whl
torch-2.5.0+cpu.cxx11.abi-cp311-cp311-linux_x86_64.whl
torch-2.5.0+cpu.cxx11.abi-cp312-cp312-linux_x86_64.whl
torch-2.5.0+cpu.cxx11.abi-cp313-cp313-linux_x86_64.whl
torch-2.5.0+cpu.cxx11.abi-cp39-cp39-linux_x86_64.whl
torch-2.5.1+cpu.cxx11.abi-cp310-cp310-linux_x86_64.whl
torch-2.5.1+cpu.cxx11.abi-cp311-cp311-linux_x86_64.whl
torch-2.5.1+cpu.cxx11.abi-cp312-cp312-linux_x86_64.whl
torch-2.5.1+cpu.cxx11.abi-cp313-cp313-linux_x86_64.whl
torch-2.5.1+cpu.cxx11.abi-cp39-cp39-linux_x86_64.whl
torch-2.6.0+cpu.cxx11.abi-cp310-cp310-linux_x86_64.whl
torch-2.6.0+cpu.cxx11.abi-cp311-cp311-linux_x86_64.whl
torch-2.6.0+cpu.cxx11.abi-cp312-cp312-linux_x86_64.whl
torch-2.6.0+cpu.cxx11.abi-cp313-cp313-linux_x86_64.whl
torch-2.6.0+cpu.cxx11.abi-cp313-cp313t-linux_x86_64.whl
torch-2.6.0+cpu.cxx11.abi-cp39-cp39-linux_x86_64.whl
torch-2.7.0+cpu.cxx11.abi-cp310-cp310-linux_x86_64.whl
torch-2.7.0+cpu.cxx11.abi-cp311-cp311-linux_x86_64.whl
torch-2.7.0+cpu.cxx11.abi-cp312-cp312-linux_x86_64.whl
torch-2.7.0+cpu.cxx11.abi-cp313-cp313-linux_x86_64.whl
torch-2.7.0+cpu.cxx11.abi-cp313-cp313t-linux_x86_64.whl
torch-2.7.0+cpu.cxx11.abi-cp39-cp39-linux_x86_64.whl
torch-2.7.1+cpu.cxx11.abi-cp310-cp310-linux_x86_64.whl
torch-2.7.1+cpu.cxx11.abi-cp311-cp311-linux_x86_64.whl
torch-2.7.1+cpu.cxx11.abi-cp312-cp312-linux_x86_64.whl
torch-2.7.1+cpu.cxx11.abi-cp313-cp313-linux_x86_64.whl
torch-2.7.1+cpu.cxx11.abi-cp313-cp313t-linux_x86_64.whl
torch-2.7.1+cpu.cxx11.abi-cp39-cp39-linux_x86_64.whl
torch-0.3.0.post4-cp27-cp27m-linux_x86_64.whl
torch-0.3.0.post4-cp27-cp27mu-linux_x86_64.whl
torch-0.3.0.post4-cp35-cp35m-linux_x86_64.whl
torch-0.3.0.post4-cp36-cp36m-linux_x86_64.whl
torch-0.3.1-cp27-cp27m-linux_x86_64.whl
torch-0.3.1-cp27-cp27mu-linux_x86_64.whl
torch-0.3.1-cp35-cp35m-linux_x86_64.whl
torch-0.3.1-cp36-cp36m-linux_x86_64.whl
torch-0.4.0-cp27-cp27m-linux_x86_64.whl
torch-0.4.0-cp27-cp27mu-linux_x86_64.whl
torch-0.4.0-cp35-cp35m-linux_x86_64.whl
torch-0.4.0-cp35-cp35m-win_amd64.whl
torch-0.4.0-cp36-cp36m-linux_x86_64.whl
torch-0.4.0-cp36-cp36m-win_amd64.whl
torch-0.4.1-cp27-cp27m-linux_x86_64.whl
torch-0.4.1-cp27-cp27mu-linux_x86_64.whl
torch-0.4.1-cp35-cp35m-linux_x86_64.whl
torch-0.4.1-cp35-cp35m-win_amd64.whl
torch-0.4.1-cp36-cp36m-linux_x86_64.whl
torch-0.4.1-cp36-cp36m-win_amd64.whl
torch-0.4.1-cp37-cp37m-linux_x86_64.whl
torch-0.4.1-cp37-cp37m-win_amd64.whl
torch-0.4.1.post2-cp37-cp37m-linux_x86_64.whl
torch-1.0.0-cp27-cp27m-linux_x86_64.whl
torch-1.0.0-cp27-cp27mu-linux_x86_64.whl
torch-1.0.0-cp27-none-macosx_10_6_x86_64.whl
torch-1.0.0-cp35-cp35m-linux_x86_64.whl
torch-1.0.0-cp35-cp35m-win_amd64.whl
torch-1.0.0-cp35-none-macosx_10_6_x86_64.whl
torch-1.0.0-cp36-cp36m-linux_x86_64.whl
torch-1.0.0-cp36-cp36m-win_amd64.whl
torch-1.0.0-cp36-none-macosx_10_7_x86_64.whl
torch-1.0.0-cp37-cp37m-linux_x86_64.whl
torch-1.0.0-cp37-cp37m-win_amd64.whl
torch-1.0.0-cp37-none-macosx_10_7_x86_64.whl
torch-1.0.1-cp27-cp27m-linux_x86_64.whl
torch-1.0.1-cp27-cp27mu-linux_x86_64.whl
torch-1.0.1-cp27-none-macosx_10_6_x86_64.whl
torch-1.0.1-cp35-cp35m-linux_x86_64.whl
torch-1.0.1-cp35-cp35m-win_amd64.whl
torch-1.0.1-cp35-none-macosx_10_6_x86_64.whl
torch-1.0.1-cp36-cp36m-linux_x86_64.whl
torch-1.0.1-cp36-cp36m-win_amd64.whl
torch-1.0.1-cp36-none-macosx_10_7_x86_64.whl
torch-1.0.1-cp37-cp37m-linux_x86_64.whl
torch-1.0.1-cp37-cp37m-win_amd64.whl
torch-1.0.1-cp37-none-macosx_10_7_x86_64.whl
torch-1.0.1.post2-cp27-cp27m-linux_x86_64.whl
torch-1.0.1.post2-cp27-cp27mu-linux_x86_64.whl
torch-1.0.1.post2-cp35-cp35m-linux_x86_64.whl
torch-1.0.1.post2-cp36-cp36m-linux_x86_64.whl
torch-1.0.1.post2-cp37-cp37m-linux_x86_64.whl
torch-1.1.0-cp27-cp27m-linux_x86_64.whl
torch-1.1.0-cp27-cp27mu-linux_x86_64.whl
torch-1.1.0-cp27-none-macosx_10_6_x86_64.whl
torch-1.1.0-cp35-cp35m-linux_x86_64.whl
torch-1.1.0-cp35-cp35m-win_amd64.whl
torch-1.1.0-cp35-none-macosx_10_6_x86_64.whl
torch-1.1.0-cp36-cp36m-linux_x86_64.whl
torch-1.1.0-cp36-cp36m-win_amd64.whl
torch-1.1.0-cp36-none-macosx_10_7_x86_64.whl
torch-1.1.0-cp37-cp37m-linux_x86_64.whl
torch-1.1.0-cp37-cp37m-win_amd64.whl
torch-1.1.0-cp37-none-macosx_10_7_x86_64.whl
torch-1.1.0.post2-cp27-none-macosx_10_6_x86_64.whl
torch-1.1.0.post2-cp35-none-macosx_10_6_x86_64.whl
torch-1.1.0.post2-cp36-none-macosx_10_7_x86_64.whl
torch-1.1.0.post2-cp37-none-macosx_10_7_x86_64.whl
torch-1.10.0+cpu-cp36-cp36m-linux_x86_64.whl
torch-1.10.0+cpu-cp36-cp36m-win_amd64.whl
torch-1.10.0+cpu-cp37-cp37m-linux_x86_64.whl
torch-1.10.0+cpu-cp37-cp37m-win_amd64.whl
torch-1.10.0+cpu-cp38-cp38-linux_x86_64.whl
torch-1.10.0+cpu-cp38-cp38-win_amd64.whl
torch-1.10.0+cpu-cp39-cp39-linux_x86_64.whl
torch-1.10.0+cpu-cp39-cp39-win_amd64.whl
torch-1.10.0-cp36-cp36m-manylinux2014_aarch64.whl
torch-1.10.0-cp36-none-macosx_10_9_x86_64.whl
torch-1.10.0-cp37-cp37m-manylinux2014_aarch64.whl
torch-1.10.0-cp37-none-macosx_10_9_x86_64.whl
torch-1.10.0-cp38-cp38-manylinux2014_aarch64.whl
torch-1.10.0-cp38-none-macosx_10_9_x86_64.whl
torch-1.10.0-cp38-none-macosx_11_0_arm64.whl
torch-1.10.0-cp39-cp39-manylinux2014_aarch64.whl
torch-1.10.0-cp39-none-macosx_10_9_x86_64.whl
torch-1.10.0-cp39-none-macosx_11_0_arm64.whl
torch-1.10.1+cpu-cp36-cp36m-linux_x86_64.whl
torch-1.10.1+cpu-cp36-cp36m-win_amd64.whl
torch-1.10.1+cpu-cp37-cp37m-linux_x86_64.whl
torch-1.10.1+cpu-cp37-cp37m-win_amd64.whl
torch-1.10.1+cpu-cp38-cp38-linux_x86_64.whl
torch-1.10.1+cpu-cp38-cp38-win_amd64.whl
torch-1.10.1+cpu-cp39-cp39-linux_x86_64.whl
torch-1.10.1+cpu-cp39-cp39-win_amd64.whl
torch-1.10.1-cp36-cp36m-manylinux2014_aarch64.whl
torch-1.10.1-cp36-none-macosx_10_9_x86_64.whl
torch-1.10.1-cp37-cp37m-manylinux2014_aarch64.whl
torch-1.10.1-cp37-none-macosx_10_9_x86_64.whl
torch-1.10.1-cp38-cp38-manylinux2014_aarch64.whl
torch-1.10.1-cp38-none-macosx_10_9_x86_64.whl
torch-1.10.1-cp38-none-macosx_11_0_arm64.whl
torch-1.10.1-cp39-cp39-manylinux2014_aarch64.whl
torch-1.10.1-cp39-none-macosx_10_9_x86_64.whl
torch-1.10.1-cp39-none-macosx_11_0_arm64.whl
torch-1.10.2+cpu-cp36-cp36m-linux_x86_64.whl
torch-1.10.2+cpu-cp36-cp36m-win_amd64.whl
torch-1.10.2+cpu-cp37-cp37m-linux_x86_64.whl
torch-1.10.2+cpu-cp37-cp37m-win_amd64.whl
torch-1.10.2+cpu-cp38-cp38-linux_x86_64.whl
torch-1.10.2+cpu-cp38-cp38-win_amd64.whl
torch-1.10.2+cpu-cp39-cp39-linux_x86_64.whl
torch-1.10.2+cpu-cp39-cp39-win_amd64.whl
torch-1.10.2-cp310-cp310-manylinux2014_aarch64.whl
torch-1.10.2-cp36-cp36m-manylinux2014_aarch64.whl
torch-1.10.2-cp36-none-macosx_10_9_x86_64.whl
torch-1.10.2-cp37-cp37m-manylinux2014_aarch64.whl
torch-1.10.2-cp37-none-macosx_10_9_x86_64.whl
torch-1.10.2-cp38-cp38-manylinux2014_aarch64.whl
torch-1.10.2-cp38-none-macosx_10_9_x86_64.whl
torch-1.10.2-cp38-none-macosx_11_0_arm64.whl
torch-1.10.2-cp39-cp39-manylinux2014_aarch64.whl
torch-1.10.2-cp39-none-macosx_10_9_x86_64.whl
torch-1.10.2-cp39-none-macosx_11_0_arm64.whl
torch-1.11.0+cpu-cp310-cp310-linux_x86_64.whl
torch-1.11.0+cpu-cp310-cp310-win_amd64.whl
torch-1.11.0+cpu-cp37-cp37m-linux_x86_64.whl
torch-1.11.0+cpu-cp37-cp37m-win_amd64.whl
torch-1.11.0+cpu-cp38-cp38-linux_x86_64.whl
torch-1.11.0+cpu-cp38-cp38-win_amd64.whl
torch-1.11.0+cpu-cp39-cp39-linux_x86_64.whl
torch-1.11.0+cpu-cp39-cp39-win_amd64.whl
torch-1.11.0-cp310-none-macosx_10_9_x86_64.whl
torch-1.11.0-cp310-none-macosx_11_0_arm64.whl
torch-1.11.0-cp37-none-macosx_10_9_x86_64.whl
torch-1.11.0-cp38-none-macosx_10_9_x86_64.whl
torch-1.11.0-cp38-none-macosx_11_0_arm64.whl
torch-1.11.0-cp39-none-macosx_10_9_x86_64.whl
torch-1.11.0-cp39-none-macosx_11_0_arm64.whl
torch-1.12.0+cpu-cp310-cp310-linux_x86_64.whl
torch-1.12.0+cpu-cp310-cp310-win_amd64.whl
torch-1.12.0+cpu-cp37-cp37m-linux_x86_64.whl
torch-1.12.0+cpu-cp37-cp37m-win_amd64.whl
torch-1.12.0+cpu-cp38-cp38-linux_x86_64.whl
torch-1.12.0+cpu-cp38-cp38-win_amd64.whl
torch-1.12.0+cpu-cp39-cp39-linux_x86_64.whl
torch-1.12.0+cpu-cp39-cp39-win_amd64.whl
torch-1.12.0-cp310-cp310-manylinux2014_aarch64.whl
torch-1.12.0-cp310-none-macosx_10_9_x86_64.whl
torch-1.12.0-cp310-none-macosx_11_0_arm64.whl
torch-1.12.0-cp37-cp37m-manylinux2014_aarch64.whl
torch-1.12.0-cp37-none-macosx_10_9_x86_64.whl
torch-1.12.0-cp37-none-macosx_11_0_arm64.whl
torch-1.12.0-cp38-cp38-manylinux2014_aarch64.whl
torch-1.12.0-cp38-none-macosx_10_9_x86_64.whl
torch-1.12.0-cp38-none-macosx_11_0_arm64.whl
torch-1.12.0-cp39-cp39-manylinux2014_aarch64.whl
torch-1.12.0-cp39-none-macosx_10_9_x86_64.whl
torch-1.12.0-cp39-none-macosx_11_0_arm64.whl
torch-1.12.1+cpu-cp310-cp310-linux_x86_64.whl
torch-1.12.1+cpu-cp310-cp310-win_amd64.whl
torch-1.12.1+cpu-cp37-cp37m-linux_x86_64.whl
torch-1.12.1+cpu-cp37-cp37m-win_amd64.whl
torch-1.12.1+cpu-cp38-cp38-linux_x86_64.whl
torch-1.12.1+cpu-cp38-cp38-win_amd64.whl
torch-1.12.1+cpu-cp39-cp39-linux_x86_64.whl
torch-1.12.1+cpu-cp39-cp39-win_amd64.whl
torch-1.12.1-cp310-none-macosx_10_9_x86_64.whl
torch-1.12.1-cp310-none-macosx_11_0_arm64.whl
torch-1.12.1-cp37-none-macosx_10_9_x86_64.whl
torch-1.12.1-cp37-none-macosx_11_0_arm64.whl
torch-1.12.1-cp38-none-macosx_10_9_x86_64.whl
torch-1.12.1-cp38-none-macosx_11_0_arm64.whl
torch-1.12.1-cp39-none-macosx_10_9_x86_64.whl
torch-1.12.1-cp39-none-macosx_11_0_arm64.whl
torch-1.13.0+cpu-cp310-cp310-linux_x86_64.whl
torch-1.13.0+cpu-cp310-cp310-win_amd64.whl
torch-1.13.0+cpu-cp311-cp311-linux_x86_64.whl
torch-1.13.0+cpu-cp37-cp37m-linux_x86_64.whl
torch-1.13.0+cpu-cp37-cp37m-win_amd64.whl
torch-1.13.0+cpu-cp38-cp38-linux_x86_64.whl
torch-1.13.0+cpu-cp38-cp38-win_amd64.whl
torch-1.13.0+cpu-cp39-cp39-linux_x86_64.whl
torch-1.13.0+cpu-cp39-cp39-win_amd64.whl
torch-1.13.0-cp310-none-macosx_10_9_x86_64.whl
torch-1.13.0-cp310-none-macosx_11_0_arm64.whl
torch-1.13.0-cp37-none-macosx_10_9_x86_64.whl
torch-1.13.0-cp37-none-macosx_11_0_arm64.whl
torch-1.13.0-cp38-none-macosx_10_9_x86_64.whl
torch-1.13.0-cp38-none-macosx_11_0_arm64.whl
torch-1.13.0-cp39-none-macosx_10_9_x86_64.whl
torch-1.13.0-cp39-none-macosx_11_0_arm64.whl
torch-1.13.1+cpu-cp310-cp310-linux_x86_64.whl
torch-1.13.1+cpu-cp310-cp310-win_amd64.whl
torch-1.13.1+cpu-cp311-cp311-linux_x86_64.whl
torch-1.13.1+cpu-cp37-cp37m-linux_x86_64.whl
torch-1.13.1+cpu-cp37-cp37m-win_amd64.whl
torch-1.13.1+cpu-cp38-cp38-linux_x86_64.whl
torch-1.13.1+cpu-cp38-cp38-win_amd64.whl
torch-1.13.1+cpu-cp39-cp39-linux_x86_64.whl
torch-1.13.1+cpu-cp39-cp39-win_amd64.whl
torch-1.13.1-cp310-none-macosx_10_9_x86_64.whl
torch-1.13.1-cp310-none-macosx_11_0_arm64.whl
torch-1.13.1-cp37-none-macosx_10_9_x86_64.whl
torch-1.13.1-cp37-none-macosx_11_0_arm64.whl
torch-1.13.1-cp38-none-macosx_10_9_x86_64.whl
torch-1.13.1-cp38-none-macosx_11_0_arm64.whl
torch-1.13.1-cp39-none-macosx_10_9_x86_64.whl
torch-1.13.1-cp39-none-macosx_11_0_arm64.whl
torch-1.2.0+cpu-cp27-cp27m-manylinux1_x86_64.whl
torch-1.2.0+cpu-cp27-cp27mu-manylinux1_x86_64.whl
torch-1.2.0+cpu-cp35-cp35m-manylinux1_x86_64.whl
torch-1.2.0+cpu-cp35-cp35m-win_amd64.whl
torch-1.2.0+cpu-cp36-cp36m-manylinux1_x86_64.whl
torch-1.2.0+cpu-cp36-cp36m-win_amd64.whl
torch-1.2.0+cpu-cp37-cp37m-manylinux1_x86_64.whl
torch-1.2.0+cpu-cp37-cp37m-win_amd64.whl
torch-1.2.0-cp27-none-macosx_10_6_x86_64.whl
torch-1.2.0-cp35-none-macosx_10_6_x86_64.whl
torch-1.2.0-cp36-none-macosx_10_7_x86_64.whl
torch-1.2.0-cp37-none-macosx_10_7_x86_64.whl
torch-1.3.0+cpu-cp27-cp27m-linux_x86_64.whl
torch-1.3.0+cpu-cp27-cp27mu-linux_x86_64.whl
torch-1.3.0+cpu-cp35-cp35m-linux_x86_64.whl
torch-1.3.0+cpu-cp35-cp35m-win_amd64.whl
torch-1.3.0+cpu-cp36-cp36m-linux_x86_64.whl
torch-1.3.0+cpu-cp36-cp36m-win_amd64.whl
torch-1.3.0+cpu-cp37-cp37m-linux_x86_64.whl
torch-1.3.0+cpu-cp37-cp37m-win_amd64.whl
torch-1.3.0-cp35-none-macosx_10_6_x86_64.whl
torch-1.3.0-cp36-none-macosx_10_7_x86_64.whl
torch-1.3.0-cp37-none-macosx_10_9_x86_64.whl
torch-1.3.0.post2-cp27-none-macosx_10_7_x86_64.whl
torch-1.3.0.post2-cp35-none-macosx_10_6_x86_64.whl
torch-1.3.0.post2-cp36-none-macosx_10_7_x86_64.whl
torch-1.3.0.post2-cp37-none-macosx_10_9_x86_64.whl
torch-1.3.1+cpu-cp27-cp27m-linux_x86_64.whl
torch-1.3.1+cpu-cp27-cp27mu-linux_x86_64.whl
torch-1.3.1+cpu-cp35-cp35m-linux_x86_64.whl
torch-1.3.1+cpu-cp35-cp35m-win_amd64.whl
torch-1.3.1+cpu-cp36-cp36m-linux_x86_64.whl
torch-1.3.1+cpu-cp36-cp36m-win_amd64.whl
torch-1.3.1+cpu-cp37-cp37m-linux_x86_64.whl
torch-1.3.1+cpu-cp37-cp37m-win_amd64.whl
torch-1.3.1-cp27-none-macosx_10_7_x86_64.whl
torch-1.3.1-cp35-none-macosx_10_6_x86_64.whl
torch-1.3.1-cp36-none-macosx_10_7_x86_64.whl
torch-1.3.1-cp37-none-macosx_10_7_x86_64.whl
torch-1.4.0+cpu-cp27-cp27m-linux_x86_64.whl
torch-1.4.0+cpu-cp27-cp27mu-linux_x86_64.whl
torch-1.4.0+cpu-cp35-cp35m-linux_x86_64.whl
torch-1.4.0+cpu-cp35-cp35m-win_amd64.whl
torch-1.4.0+cpu-cp36-cp36m-linux_x86_64.whl
torch-1.4.0+cpu-cp36-cp36m-win_amd64.whl
torch-1.4.0+cpu-cp37-cp37m-linux_x86_64.whl
torch-1.4.0+cpu-cp37-cp37m-win_amd64.whl
torch-1.4.0+cpu-cp38-cp38-linux_x86_64.whl
torch-1.4.0+cpu-cp38-cp38-win_amd64.whl
torch-1.4.0-cp27-none-macosx_10_7_x86_64.whl
torch-1.4.0-cp35-none-macosx_10_6_x86_64.whl
torch-1.4.0-cp36-none-macosx_10_7_x86_64.whl
torch-1.4.0-cp36-none-macosx_10_9_x86_64.whl
torch-1.4.0-cp37-none-macosx_10_7_x86_64.whl
torch-1.4.0-cp37-none-macosx_10_9_x86_64.whl
torch-1.4.0-cp38-none-macosx_10_9_x86_64.whl
torch-1.5.0+cpu-cp27-cp27m-linux_x86_64.whl
torch-1.5.0+cpu-cp27-cp27mu-linux_x86_64.whl
torch-1.5.0+cpu-cp35-cp35m-linux_x86_64.whl
torch-1.5.0+cpu-cp35-cp35m-win_amd64.whl
torch-1.5.0+cpu-cp36-cp36m-linux_x86_64.whl
torch-1.5.0+cpu-cp36-cp36m-win_amd64.whl
torch-1.5.0+cpu-cp37-cp37m-linux_x86_64.whl
torch-1.5.0+cpu-cp37-cp37m-win_amd64.whl
torch-1.5.0+cpu-cp38-cp38-linux_x86_64.whl
torch-1.5.0+cpu-cp38-cp38-win_amd64.whl
torch-1.5.0-cp27-none-macosx_10_7_x86_64.whl
torch-1.5.0-cp35-none-macosx_10_6_x86_64.whl
torch-1.5.0-cp36-none-macosx_10_9_x86_64.whl
torch-1.5.0-cp37-none-macosx_10_9_x86_64.whl
torch-1.5.0-cp38-none-macosx_10_9_x86_64.whl
torch-1.5.1+cpu-cp35-cp35m-linux_x86_64.whl
torch-1.5.1+cpu-cp35-cp35m-win_amd64.whl
torch-1.5.1+cpu-cp36-cp36m-linux_x86_64.whl
torch-1.5.1+cpu-cp36-cp36m-win_amd64.whl
torch-1.5.1+cpu-cp37-cp37m-linux_x86_64.whl
torch-1.5.1+cpu-cp37-cp37m-win_amd64.whl
torch-1.5.1+cpu-cp38-cp38-linux_x86_64.whl
torch-1.5.1+cpu-cp38-cp38-win_amd64.whl
torch-1.5.1-cp35-none-macosx_10_6_x86_64.whl
torch-1.5.1-cp36-none-macosx_10_9_x86_64.whl
torch-1.5.1-cp37-none-macosx_10_9_x86_64.whl
torch-1.5.1-cp38-none-macosx_10_9_x86_64.whl
torch-1.6.0+cpu-cp36-cp36m-linux_x86_64.whl
torch-1.6.0+cpu-cp36-cp36m-win_amd64.whl
torch-1.6.0+cpu-cp37-cp37m-linux_x86_64.whl
torch-1.6.0+cpu-cp37-cp37m-win_amd64.whl
torch-1.6.0+cpu-cp38-cp38-linux_x86_64.whl
torch-1.6.0+cpu-cp38-cp38-win_amd64.whl
torch-1.6.0-cp36-none-macosx_10_9_x86_64.whl
torch-1.6.0-cp37-none-macosx_10_9_x86_64.whl
torch-1.6.0-cp38-none-macosx_10_9_x86_64.whl
torch-1.7.0+cpu-cp36-cp36m-linux_x86_64.whl
torch-1.7.0+cpu-cp36-cp36m-win_amd64.whl
torch-1.7.0+cpu-cp37-cp37m-linux_x86_64.whl
torch-1.7.0+cpu-cp37-cp37m-win_amd64.whl
torch-1.7.0+cpu-cp38-cp38-linux_x86_64.whl
torch-1.7.0+cpu-cp38-cp38-win_amd64.whl
torch-1.7.0-cp36-none-macosx_10_9_x86_64.whl
torch-1.7.0-cp36-none-macosx_11_0_x86_64.whl
torch-1.7.0-cp37-none-macosx_10_9_x86_64.whl
torch-1.7.0-cp37-none-macosx_11_0_x86_64.whl
torch-1.7.0-cp38-none-macosx_10_9_x86_64.whl
torch-1.7.0-cp38-none-macosx_11_0_x86_64.whl
torch-1.7.1+cpu-cp36-cp36m-linux_x86_64.whl
torch-1.7.1+cpu-cp36-cp36m-win_amd64.whl
torch-1.7.1+cpu-cp37-cp37m-linux_x86_64.whl
torch-1.7.1+cpu-cp37-cp37m-win_amd64.whl
torch-1.7.1+cpu-cp38-cp38-linux_x86_64.whl
torch-1.7.1+cpu-cp38-cp38-win_amd64.whl
torch-1.7.1+cpu-cp39-cp39-linux_x86_64.whl
torch-1.7.1+cpu-cp39-cp39-win_amd64.whl
torch-1.7.1-cp36-cp36m-linux_aarch64.whl
torch-1.7.1-cp36-none-macosx_10_9_x86_64.whl
torch-1.7.1-cp37-cp37m-linux_aarch64.whl
torch-1.7.1-cp37-none-macosx_10_9_x86_64.whl
torch-1.7.1-cp38-cp38-linux_aarch64.whl
torch-1.7.1-cp38-none-macosx_10_9_x86_64.whl
torch-1.7.1-cp39-cp39-linux_aarch64.whl
torch-1.7.1-cp39-none-macosx_10_9_x86_64.whl
torch-1.8.0+cpu-cp36-cp36m-linux_x86_64.whl
torch-1.8.0+cpu-cp36-cp36m-win_amd64.whl
torch-1.8.0+cpu-cp37-cp37m-linux_x86_64.whl
torch-1.8.0+cpu-cp37-cp37m-win_amd64.whl
torch-1.8.0+cpu-cp38-cp38-linux_x86_64.whl
torch-1.8.0+cpu-cp38-cp38-win_amd64.whl
torch-1.8.0+cpu-cp39-cp39-linux_x86_64.whl
torch-1.8.0+cpu-cp39-cp39-win_amd64.whl
torch-1.8.0-cp36-cp36m-manylinux2014_aarch64.whl
torch-1.8.0-cp36-none-macosx_10_9_x86_64.whl
torch-1.8.0-cp37-cp37m-manylinux2014_aarch64.whl
torch-1.8.0-cp37-none-macosx_10_9_x86_64.whl
torch-1.8.0-cp38-cp38-manylinux2014_aarch64.whl
torch-1.8.0-cp38-none-macosx_10_9_x86_64.whl
torch-1.8.0-cp38-none-macosx_11_1_arm64.whl
torch-1.8.0-cp39-cp39-manylinux2014_aarch64.whl
torch-1.8.0-cp39-none-macosx_10_9_x86_64.whl
torch-1.8.1+cpu-cp36-cp36m-linux_x86_64.whl
torch-1.8.1+cpu-cp36-cp36m-win_amd64.whl
torch-1.8.1+cpu-cp37-cp37m-linux_x86_64.whl
torch-1.8.1+cpu-cp37-cp37m-win_amd64.whl
torch-1.8.1+cpu-cp38-cp38-linux_x86_64.whl
torch-1.8.1+cpu-cp38-cp38-win_amd64.whl
torch-1.8.1+cpu-cp39-cp39-linux_x86_64.whl
torch-1.8.1+cpu-cp39-cp39-win_amd64.whl
torch-1.8.1-cp36-cp36m-manylinux2014_aarch64.whl
torch-1.8.1-cp36-none-macosx_10_9_x86_64.whl
torch-1.8.1-cp37-cp37m-manylinux2014_aarch64.whl
torch-1.8.1-cp37-none-macosx_10_9_x86_64.whl
torch-1.8.1-cp38-cp38-manylinux2014_aarch64.whl
torch-1.8.1-cp38-none-macosx_10_9_x86_64.whl
torch-1.8.1-cp38-none-macosx_11_0_arm64.whl
torch-1.8.1-cp39-cp39-manylinux2014_aarch64.whl
torch-1.8.1-cp39-none-macosx_10_9_x86_64.whl
torch-1.9.0+cpu-cp36-cp36m-linux_x86_64.whl
torch-1.9.0+cpu-cp36-cp36m-win_amd64.whl
torch-1.9.0+cpu-cp37-cp37m-linux_x86_64.whl
torch-1.9.0+cpu-cp37-cp37m-win_amd64.whl
torch-1.9.0+cpu-cp38-cp38-linux_x86_64.whl
torch-1.9.0+cpu-cp38-cp38-win_amd64.whl
torch-1.9.0+cpu-cp39-cp39-linux_x86_64.whl
torch-1.9.0+cpu-cp39-cp39-win_amd64.whl
torch-1.9.0-cp36-cp36m-linux_aarch64.whl
torch-1.9.0-cp36-cp36m-manylinux2014_aarch64.whl
torch-1.9.0-cp36-none-macosx_10_9_x86_64.whl
torch-1.9.0-cp37-cp37m-linux_aarch64.whl
torch-1.9.0-cp37-cp37m-manylinux2014_aarch64.whl
torch-1.9.0-cp37-none-macosx_10_9_x86_64.whl
torch-1.9.0-cp38-cp38-linux_aarch64.whl
torch-1.9.0-cp38-cp38-manylinux2014_aarch64.whl
torch-1.9.0-cp38-none-macosx_10_9_x86_64.whl
torch-1.9.0-cp38-none-macosx_11_0_arm64.whl
torch-1.9.0-cp39-cp39-linux_aarch64.whl
torch-1.9.0-cp39-cp39-manylinux2014_aarch64.whl
torch-1.9.0-cp39-none-macosx_10_9_x86_64.whl
torch-1.9.0-cp39-none-macosx_11_0_arm64.whl
torch-1.9.1+cpu-cp36-cp36m-linux_x86_64.whl
torch-1.9.1+cpu-cp36-cp36m-win_amd64.whl
torch-1.9.1+cpu-cp37-cp37m-linux_x86_64.whl
torch-1.9.1+cpu-cp37-cp37m-win_amd64.whl
torch-1.9.1+cpu-cp38-cp38-linux_x86_64.whl
torch-1.9.1+cpu-cp38-cp38-win_amd64.whl
torch-1.9.1+cpu-cp39-cp39-linux_x86_64.whl
torch-1.9.1+cpu-cp39-cp39-win_amd64.whl
torch-1.9.1-cp36-none-macosx_10_9_x86_64.whl
torch-1.9.1-cp37-none-macosx_10_9_x86_64.whl
torch-1.9.1-cp38-none-macosx_10_9_x86_64.whl
torch-1.9.1-cp38-none-macosx_11_0_arm64.whl
torch-1.9.1-cp39-none-macosx_10_9_x86_64.whl
torch-1.9.1-cp39-none-macosx_11_0_arm64.whl
torch-2.0.0+cpu-cp310-cp310-linux_x86_64.whl
torch-2.0.0+cpu-cp310-cp310-win_amd64.whl
torch-2.0.0+cpu-cp311-cp311-linux_x86_64.whl
torch-2.0.0+cpu-cp311-cp311-win_amd64.whl
torch-2.0.0+cpu-cp38-cp38-linux_x86_64.whl
torch-2.0.0+cpu-cp38-cp38-win_amd64.whl
torch-2.0.0+cpu-cp39-cp39-linux_x86_64.whl
torch-2.0.0+cpu-cp39-cp39-win_amd64.whl
torch-2.0.0-cp310-none-macosx_10_9_x86_64.whl
torch-2.0.0-cp310-none-macosx_11_0_arm64.whl
torch-2.0.0-cp311-none-macosx_10_9_x86_64.whl
torch-2.0.0-cp311-none-macosx_11_0_arm64.whl
torch-2.0.0-cp38-none-macosx_10_9_x86_64.whl
torch-2.0.0-cp38-none-macosx_11_0_arm64.whl
torch-2.0.0-cp39-none-macosx_10_9_x86_64.whl
torch-2.0.0-cp39-none-macosx_11_0_arm64.whl
torch-2.0.1+cpu-cp310-cp310-linux_x86_64.whl
torch-2.0.1+cpu-cp310-cp310-win_amd64.whl
torch-2.0.1+cpu-cp311-cp311-linux_x86_64.whl
torch-2.0.1+cpu-cp311-cp311-win_amd64.whl
torch-2.0.1+cpu-cp38-cp38-linux_x86_64.whl
torch-2.0.1+cpu-cp38-cp38-win_amd64.whl
torch-2.0.1+cpu-cp39-cp39-linux_x86_64.whl
torch-2.0.1+cpu-cp39-cp39-win_amd64.whl
torch-2.0.1-cp310-none-macosx_10_9_x86_64.whl
torch-2.0.1-cp310-none-macosx_11_0_arm64.whl
torch-2.0.1-cp311-none-macosx_10_9_x86_64.whl
torch-2.0.1-cp311-none-macosx_11_0_arm64.whl
torch-2.0.1-cp38-none-macosx_10_9_x86_64.whl
torch-2.0.1-cp38-none-macosx_11_0_arm64.whl
torch-2.0.1-cp39-none-macosx_10_9_x86_64.whl
torch-2.0.1-cp39-none-macosx_11_0_arm64.whl
torch-2.1.0+cpu-cp310-cp310-linux_x86_64.whl
torch-2.1.0+cpu-cp310-cp310-win_amd64.whl
torch-2.1.0+cpu-cp311-cp311-linux_x86_64.whl
torch-2.1.0+cpu-cp311-cp311-win_amd64.whl
torch-2.1.0+cpu-cp38-cp38-linux_x86_64.whl
torch-2.1.0+cpu-cp38-cp38-win_amd64.whl
torch-2.1.0+cpu-cp39-cp39-linux_x86_64.whl
torch-2.1.0+cpu-cp39-cp39-win_amd64.whl
torch-2.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.1.0-cp310-none-macosx_10_9_x86_64.whl
torch-2.1.0-cp310-none-macosx_11_0_arm64.whl
torch-2.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.1.0-cp311-none-macosx_10_9_x86_64.whl
torch-2.1.0-cp311-none-macosx_11_0_arm64.whl
torch-2.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.1.0-cp38-none-macosx_10_9_x86_64.whl
torch-2.1.0-cp38-none-macosx_11_0_arm64.whl
torch-2.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.1.0-cp39-none-macosx_10_9_x86_64.whl
torch-2.1.0-cp39-none-macosx_11_0_arm64.whl
torch-2.1.1+cpu-cp310-cp310-linux_x86_64.whl
torch-2.1.1+cpu-cp310-cp310-win_amd64.whl
torch-2.1.1+cpu-cp311-cp311-linux_x86_64.whl
torch-2.1.1+cpu-cp311-cp311-win_amd64.whl
torch-2.1.1+cpu-cp38-cp38-linux_x86_64.whl
torch-2.1.1+cpu-cp38-cp38-win_amd64.whl
torch-2.1.1+cpu-cp39-cp39-linux_x86_64.whl
torch-2.1.1+cpu-cp39-cp39-win_amd64.whl
torch-2.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.1.1-cp310-none-macosx_10_9_x86_64.whl
torch-2.1.1-cp310-none-macosx_11_0_arm64.whl
torch-2.1.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.1.1-cp311-none-macosx_10_9_x86_64.whl
torch-2.1.1-cp311-none-macosx_11_0_arm64.whl
torch-2.1.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.1.1-cp38-none-macosx_10_9_x86_64.whl
torch-2.1.1-cp38-none-macosx_11_0_arm64.whl
torch-2.1.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.1.1-cp39-none-macosx_10_9_x86_64.whl
torch-2.1.1-cp39-none-macosx_11_0_arm64.whl
torch-2.1.2+cpu-cp310-cp310-linux_x86_64.whl
torch-2.1.2+cpu-cp310-cp310-win_amd64.whl
torch-2.1.2+cpu-cp311-cp311-linux_x86_64.whl
torch-2.1.2+cpu-cp311-cp311-win_amd64.whl
torch-2.1.2+cpu-cp38-cp38-linux_x86_64.whl
torch-2.1.2+cpu-cp38-cp38-win_amd64.whl
torch-2.1.2+cpu-cp39-cp39-linux_x86_64.whl
torch-2.1.2+cpu-cp39-cp39-win_amd64.whl
torch-2.1.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.1.2-cp310-none-macosx_10_9_x86_64.whl
torch-2.1.2-cp310-none-macosx_11_0_arm64.whl
torch-2.1.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.1.2-cp311-none-macosx_10_9_x86_64.whl
torch-2.1.2-cp311-none-macosx_11_0_arm64.whl
torch-2.1.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.1.2-cp38-none-macosx_10_9_x86_64.whl
torch-2.1.2-cp38-none-macosx_11_0_arm64.whl
torch-2.1.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.1.2-cp39-none-macosx_10_9_x86_64.whl
torch-2.1.2-cp39-none-macosx_11_0_arm64.whl
torch-2.2.0+cpu-cp310-cp310-linux_x86_64.whl
torch-2.2.0+cpu-cp310-cp310-win_amd64.whl
torch-2.2.0+cpu-cp311-cp311-linux_x86_64.whl
torch-2.2.0+cpu-cp311-cp311-win_amd64.whl
torch-2.2.0+cpu-cp312-cp312-linux_x86_64.whl
torch-2.2.0+cpu-cp312-cp312-win_amd64.whl
torch-2.2.0+cpu-cp38-cp38-linux_x86_64.whl
torch-2.2.0+cpu-cp38-cp38-win_amd64.whl
torch-2.2.0+cpu-cp39-cp39-linux_x86_64.whl
torch-2.2.0+cpu-cp39-cp39-win_amd64.whl
torch-2.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.2.0-cp310-none-macosx_10_9_x86_64.whl
torch-2.2.0-cp310-none-macosx_11_0_arm64.whl
torch-2.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.2.0-cp311-none-macosx_10_9_x86_64.whl
torch-2.2.0-cp311-none-macosx_11_0_arm64.whl
torch-2.2.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.2.0-cp312-none-macosx_10_9_x86_64.whl
torch-2.2.0-cp312-none-macosx_11_0_arm64.whl
torch-2.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.2.0-cp38-none-macosx_10_9_x86_64.whl
torch-2.2.0-cp38-none-macosx_11_0_arm64.whl
torch-2.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.2.0-cp39-none-macosx_10_9_x86_64.whl
torch-2.2.0-cp39-none-macosx_11_0_arm64.whl
torch-2.2.1+cpu-cp310-cp310-linux_x86_64.whl
torch-2.2.1+cpu-cp310-cp310-win_amd64.whl
torch-2.2.1+cpu-cp311-cp311-linux_x86_64.whl
torch-2.2.1+cpu-cp311-cp311-win_amd64.whl
torch-2.2.1+cpu-cp312-cp312-linux_x86_64.whl
torch-2.2.1+cpu-cp312-cp312-win_amd64.whl
torch-2.2.1+cpu-cp38-cp38-linux_x86_64.whl
torch-2.2.1+cpu-cp38-cp38-win_amd64.whl
torch-2.2.1+cpu-cp39-cp39-linux_x86_64.whl
torch-2.2.1+cpu-cp39-cp39-win_amd64.whl
torch-2.2.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.2.1-cp310-none-macosx_10_9_x86_64.whl
torch-2.2.1-cp310-none-macosx_11_0_arm64.whl
torch-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.2.1-cp311-none-macosx_10_9_x86_64.whl
torch-2.2.1-cp311-none-macosx_11_0_arm64.whl
torch-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.2.1-cp312-none-macosx_10_9_x86_64.whl
torch-2.2.1-cp312-none-macosx_11_0_arm64.whl
torch-2.2.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.2.1-cp38-none-macosx_10_9_x86_64.whl
torch-2.2.1-cp38-none-macosx_11_0_arm64.whl
torch-2.2.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.2.1-cp39-none-macosx_10_9_x86_64.whl
torch-2.2.1-cp39-none-macosx_11_0_arm64.whl
torch-2.2.2+cpu-cp310-cp310-linux_x86_64.whl
torch-2.2.2+cpu-cp310-cp310-win_amd64.whl
torch-2.2.2+cpu-cp311-cp311-linux_x86_64.whl
torch-2.2.2+cpu-cp311-cp311-win_amd64.whl
torch-2.2.2+cpu-cp312-cp312-linux_x86_64.whl
torch-2.2.2+cpu-cp312-cp312-win_amd64.whl
torch-2.2.2+cpu-cp38-cp38-linux_x86_64.whl
torch-2.2.2+cpu-cp38-cp38-win_amd64.whl
torch-2.2.2+cpu-cp39-cp39-linux_x86_64.whl
torch-2.2.2+cpu-cp39-cp39-win_amd64.whl
torch-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.2.2-cp310-none-macosx_10_9_x86_64.whl
torch-2.2.2-cp310-none-macosx_11_0_arm64.whl
torch-2.2.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.2.2-cp311-none-macosx_10_9_x86_64.whl
torch-2.2.2-cp311-none-macosx_11_0_arm64.whl
torch-2.2.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.2.2-cp312-none-macosx_10_9_x86_64.whl
torch-2.2.2-cp312-none-macosx_11_0_arm64.whl
torch-2.2.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.2.2-cp38-none-macosx_10_9_x86_64.whl
torch-2.2.2-cp38-none-macosx_11_0_arm64.whl
torch-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.2.2-cp39-none-macosx_10_9_x86_64.whl
torch-2.2.2-cp39-none-macosx_11_0_arm64.whl
torch-2.3.0+cpu-cp310-cp310-linux_x86_64.whl
torch-2.3.0+cpu-cp310-cp310-win_amd64.whl
torch-2.3.0+cpu-cp311-cp311-linux_x86_64.whl
torch-2.3.0+cpu-cp311-cp311-win_amd64.whl
torch-2.3.0+cpu-cp312-cp312-linux_x86_64.whl
torch-2.3.0+cpu-cp312-cp312-win_amd64.whl
torch-2.3.0+cpu-cp38-cp38-linux_x86_64.whl
torch-2.3.0+cpu-cp38-cp38-win_amd64.whl
torch-2.3.0+cpu-cp39-cp39-linux_x86_64.whl
torch-2.3.0+cpu-cp39-cp39-win_amd64.whl
torch-2.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.3.0-cp310-none-macosx_11_0_arm64.whl
torch-2.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.3.0-cp311-none-macosx_11_0_arm64.whl
torch-2.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.3.0-cp312-none-macosx_11_0_arm64.whl
torch-2.3.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.3.0-cp38-none-macosx_11_0_arm64.whl
torch-2.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.3.0-cp39-none-macosx_11_0_arm64.whl
torch-2.3.1+cpu-cp310-cp310-linux_x86_64.whl
torch-2.3.1+cpu-cp310-cp310-win_amd64.whl
torch-2.3.1+cpu-cp311-cp311-linux_x86_64.whl
torch-2.3.1+cpu-cp311-cp311-win_amd64.whl
torch-2.3.1+cpu-cp312-cp312-linux_x86_64.whl
torch-2.3.1+cpu-cp312-cp312-win_amd64.whl
torch-2.3.1+cpu-cp38-cp38-linux_x86_64.whl
torch-2.3.1+cpu-cp38-cp38-win_amd64.whl
torch-2.3.1+cpu-cp39-cp39-linux_x86_64.whl
torch-2.3.1+cpu-cp39-cp39-win_amd64.whl
torch-2.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.3.1-cp310-none-macosx_11_0_arm64.whl
torch-2.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.3.1-cp311-none-macosx_11_0_arm64.whl
torch-2.3.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.3.1-cp312-none-macosx_11_0_arm64.whl
torch-2.3.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.3.1-cp38-none-macosx_11_0_arm64.whl
torch-2.3.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.3.1-cp39-none-macosx_11_0_arm64.whl
torch-2.4.0+cpu-cp310-cp310-linux_x86_64.whl
torch-2.4.0+cpu-cp310-cp310-win_amd64.whl
torch-2.4.0+cpu-cp311-cp311-linux_x86_64.whl
torch-2.4.0+cpu-cp311-cp311-win_amd64.whl
torch-2.4.0+cpu-cp312-cp312-linux_x86_64.whl
torch-2.4.0+cpu-cp312-cp312-win_amd64.whl
torch-2.4.0+cpu-cp38-cp38-linux_x86_64.whl
torch-2.4.0+cpu-cp38-cp38-win_amd64.whl
torch-2.4.0+cpu-cp39-cp39-linux_x86_64.whl
torch-2.4.0+cpu-cp39-cp39-win_amd64.whl
torch-2.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.4.0-cp310-none-macosx_11_0_arm64.whl
torch-2.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.4.0-cp311-none-macosx_11_0_arm64.whl
torch-2.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.4.0-cp312-none-macosx_11_0_arm64.whl
torch-2.4.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.4.0-cp38-none-macosx_11_0_arm64.whl
torch-2.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.4.0-cp39-none-macosx_11_0_arm64.whl
torch-2.4.1+cpu-cp310-cp310-linux_x86_64.whl
torch-2.4.1+cpu-cp310-cp310-win_amd64.whl
torch-2.4.1+cpu-cp311-cp311-linux_x86_64.whl
torch-2.4.1+cpu-cp311-cp311-win_amd64.whl
torch-2.4.1+cpu-cp312-cp312-linux_x86_64.whl
torch-2.4.1+cpu-cp312-cp312-win_amd64.whl
torch-2.4.1+cpu-cp38-cp38-linux_x86_64.whl
torch-2.4.1+cpu-cp38-cp38-win_amd64.whl
torch-2.4.1+cpu-cp39-cp39-linux_x86_64.whl
torch-2.4.1+cpu-cp39-cp39-win_amd64.whl
torch-2.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.4.1-cp310-none-macosx_11_0_arm64.whl
torch-2.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.4.1-cp311-none-macosx_11_0_arm64.whl
torch-2.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.4.1-cp312-none-macosx_11_0_arm64.whl
torch-2.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.4.1-cp38-none-macosx_11_0_arm64.whl
torch-2.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.4.1-cp39-none-macosx_11_0_arm64.whl
torch-2.5.0+cpu-cp310-cp310-linux_x86_64.whl
torch-2.5.0+cpu-cp310-cp310-win_amd64.whl
torch-2.5.0+cpu-cp311-cp311-linux_x86_64.whl
torch-2.5.0+cpu-cp311-cp311-win_amd64.whl
torch-2.5.0+cpu-cp312-cp312-linux_x86_64.whl
torch-2.5.0+cpu-cp312-cp312-win_amd64.whl
torch-2.5.0+cpu-cp313-cp313-linux_x86_64.whl
torch-2.5.0+cpu-cp39-cp39-linux_x86_64.whl
torch-2.5.0+cpu-cp39-cp39-win_amd64.whl
torch-2.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.5.0-cp310-none-macosx_11_0_arm64.whl
torch-2.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.5.0-cp311-none-macosx_11_0_arm64.whl
torch-2.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.5.0-cp312-none-macosx_11_0_arm64.whl
torch-2.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.5.0-cp39-none-macosx_11_0_arm64.whl
torch-2.5.1+cpu-cp310-cp310-linux_x86_64.whl
torch-2.5.1+cpu-cp310-cp310-win_amd64.whl
torch-2.5.1+cpu-cp311-cp311-linux_x86_64.whl
torch-2.5.1+cpu-cp311-cp311-win_amd64.whl
torch-2.5.1+cpu-cp312-cp312-linux_x86_64.whl
torch-2.5.1+cpu-cp312-cp312-win_amd64.whl
torch-2.5.1+cpu-cp313-cp313-linux_x86_64.whl
torch-2.5.1+cpu-cp39-cp39-linux_x86_64.whl
torch-2.5.1+cpu-cp39-cp39-win_amd64.whl
torch-2.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.5.1-cp310-none-macosx_11_0_arm64.whl
torch-2.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.5.1-cp311-none-macosx_11_0_arm64.whl
torch-2.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.5.1-cp312-none-macosx_11_0_arm64.whl
torch-2.5.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
torch-2.5.1-cp39-none-macosx_11_0_arm64.whl
torch-2.6.0+cpu-cp310-cp310-linux_x86_64.whl
torch-2.6.0+cpu-cp310-cp310-manylinux_2_28_aarch64.whl
torch-2.6.0+cpu-cp310-cp310-win_amd64.whl
torch-2.6.0+cpu-cp311-cp311-linux_x86_64.whl
torch-2.6.0+cpu-cp311-cp311-manylinux_2_28_aarch64.whl
torch-2.6.0+cpu-cp311-cp311-win_amd64.whl
torch-2.6.0+cpu-cp312-cp312-linux_x86_64.whl
torch-2.6.0+cpu-cp312-cp312-manylinux_2_28_aarch64.whl
torch-2.6.0+cpu-cp312-cp312-win_amd64.whl
torch-2.6.0+cpu-cp313-cp313-linux_x86_64.whl
torch-2.6.0+cpu-cp313-cp313-manylinux_2_28_aarch64.whl
torch-2.6.0+cpu-cp313-cp313-win_amd64.whl
torch-2.6.0+cpu-cp313-cp313t-linux_x86_64.whl
torch-2.6.0+cpu-cp313-cp313t-manylinux_2_28_aarch64.whl
torch-2.6.0+cpu-cp39-cp39-linux_x86_64.whl
torch-2.6.0+cpu-cp39-cp39-manylinux_2_28_aarch64.whl
torch-2.6.0+cpu-cp39-cp39-win_amd64.whl
torch-2.6.0-cp310-none-macosx_11_0_arm64.whl
torch-2.6.0-cp311-none-macosx_11_0_arm64.whl
torch-2.6.0-cp312-none-macosx_11_0_arm64.whl
torch-2.6.0-cp313-none-macosx_11_0_arm64.whl
torch-2.6.0-cp39-none-macosx_11_0_arm64.whl
torch-2.7.0+cpu-cp310-cp310-manylinux_2_28_aarch64.whl
torch-2.7.0+cpu-cp310-cp310-manylinux_2_28_x86_64.whl
torch-2.7.0+cpu-cp310-cp310-win_amd64.whl
torch-2.7.0+cpu-cp311-cp311-manylinux_2_28_aarch64.whl
torch-2.7.0+cpu-cp311-cp311-manylinux_2_28_x86_64.whl
torch-2.7.0+cpu-cp311-cp311-win_amd64.whl
torch-2.7.0+cpu-cp312-cp312-manylinux_2_28_aarch64.whl
torch-2.7.0+cpu-cp312-cp312-manylinux_2_28_x86_64.whl
torch-2.7.0+cpu-cp312-cp312-win_amd64.whl
torch-2.7.0+cpu-cp313-cp313-manylinux_2_28_aarch64.whl
torch-2.7.0+cpu-cp313-cp313-manylinux_2_28_x86_64.whl
torch-2.7.0+cpu-cp313-cp313-win_amd64.whl
torch-2.7.0+cpu-cp313-cp313t-manylinux_2_28_aarch64.whl
torch-2.7.0+cpu-cp313-cp313t-manylinux_2_28_x86_64.whl
torch-2.7.0+cpu-cp313-cp313t-win_amd64.whl
torch-2.7.0+cpu-cp39-cp39-manylinux_2_28_aarch64.whl
torch-2.7.0+cpu-cp39-cp39-manylinux_2_28_x86_64.whl
torch-2.7.0+cpu-cp39-cp39-win_amd64.whl
torch-2.7.0-cp310-none-macosx_11_0_arm64.whl
torch-2.7.0-cp311-none-macosx_11_0_arm64.whl
torch-2.7.0-cp312-none-macosx_11_0_arm64.whl
torch-2.7.0-cp313-cp313t-macosx_14_0_arm64.whl
torch-2.7.0-cp313-none-macosx_11_0_arm64.whl
torch-2.7.0-cp39-none-macosx_11_0_arm64.whl
torch-2.7.1+cpu-cp310-cp310-manylinux_2_28_aarch64.whl
torch-2.7.1+cpu-cp310-cp310-manylinux_2_28_x86_64.whl
torch-2.7.1+cpu-cp310-cp310-win_amd64.whl
torch-2.7.1+cpu-cp311-cp311-manylinux_2_28_aarch64.whl
torch-2.7.1+cpu-cp311-cp311-manylinux_2_28_x86_64.whl
torch-2.7.1+cpu-cp311-cp311-win_amd64.whl
torch-2.7.1+cpu-cp312-cp312-manylinux_2_28_aarch64.whl
torch-2.7.1+cpu-cp312-cp312-manylinux_2_28_x86_64.whl
torch-2.7.1+cpu-cp312-cp312-win_amd64.whl
torch-2.7.1+cpu-cp313-cp313-manylinux_2_28_aarch64.whl
torch-2.7.1+cpu-cp313-cp313-manylinux_2_28_x86_64.whl
torch-2.7.1+cpu-cp313-cp313-win_amd64.whl
torch-2.7.1+cpu-cp313-cp313t-manylinux_2_28_aarch64.whl
torch-2.7.1+cpu-cp313-cp313t-manylinux_2_28_x86_64.whl
torch-2.7.1+cpu-cp313-cp313t-win_amd64.whl
torch-2.7.1+cpu-cp39-cp39-manylinux_2_28_aarch64.whl
torch-2.7.1+cpu-cp39-cp39-manylinux_2_28_x86_64.whl
torch-2.7.1+cpu-cp39-cp39-win_amd64.whl
torch-2.7.1-cp310-none-macosx_11_0_arm64.whl
torch-2.7.1-cp311-none-macosx_11_0_arm64.whl
torch-2.7.1-cp312-none-macosx_11_0_arm64.whl
torch-2.7.1-cp313-cp313t-macosx_14_0_arm64.whl
torch-2.7.1-cp313-none-macosx_11_0_arm64.whl
torch-2.7.1-cp39-none-macosx_11_0_arm64.whl
torch-2.6.0.dev20240914+cpu-cp310-cp310-linux_x86_64.whl
torch-2.6.0.dev20240914+cpu-cp311-cp311-linux_x86_64.whl
torch-2.6.0.dev20240914+cpu-cp312-cp312-linux_x86_64.whl
torch-2.6.0.dev20240914+cpu-cp39-cp39-linux_x86_64.whl
torch-1.0.0-cp27-cp27m-linux_x86_64.whl
torch-1.0.0-cp27-cp27mu-linux_x86_64.whl
torch-1.0.0-cp35-cp35m-linux_x86_64.whl
torch-1.0.0-cp35-cp35m-win_amd64.whl
torch-1.0.0-cp36-cp36m-linux_x86_64.whl
torch-1.0.0-cp36-cp36m-win_amd64.whl
torch-1.0.0-cp37-cp37m-linux_x86_64.whl
torch-1.0.0-cp37-cp37m-win_amd64.whl
torch-1.0.1-cp27-cp27m-linux_x86_64.whl
torch-1.0.1-cp27-cp27mu-linux_x86_64.whl
torch-1.0.1-cp35-cp35m-linux_x86_64.whl
torch-1.0.1-cp35-cp35m-win_amd64.whl
torch-1.0.1-cp36-cp36m-linux_x86_64.whl
torch-1.0.1-cp36-cp36m-win_amd64.whl
torch-1.0.1-cp37-cp37m-linux_x86_64.whl
torch-1.0.1-cp37-cp37m-win_amd64.whl
torch-1.0.1.post2-cp27-cp27m-linux_x86_64.whl
torch-1.0.1.post2-cp27-cp27mu-linux_x86_64.whl
torch-1.0.1.post2-cp35-cp35m-linux_x86_64.whl
torch-1.0.1.post2-cp36-cp36m-linux_x86_64.whl
torch-1.0.1.post2-cp37-cp37m-linux_x86_64.whl
torch-1.1.0-cp27-cp27m-linux_x86_64.whl
torch-1.1.0-cp27-cp27mu-linux_x86_64.whl
torch-1.1.0-cp35-cp35m-linux_x86_64.whl
torch-1.1.0-cp35-cp35m-win_amd64.whl
torch-1.1.0-cp36-cp36m-linux_x86_64.whl
torch-1.1.0-cp36-cp36m-win_amd64.whl
torch-1.1.0-cp37-cp37m-linux_x86_64.whl
torch-1.1.0-cp37-cp37m-win_amd64.whl
torch-1.2.0-cp27-cp27m-manylinux1_x86_64.whl
torch-1.2.0-cp27-cp27mu-manylinux1_x86_64.whl
torch-1.2.0-cp35-cp35m-manylinux1_x86_64.whl
torch-1.2.0-cp35-cp35m-win_amd64.whl
torch-1.2.0-cp36-cp36m-manylinux1_x86_64.whl
torch-1.2.0-cp36-cp36m-win_amd64.whl
torch-1.2.0-cp37-cp37m-manylinux1_x86_64.whl
torch-1.2.0-cp37-cp37m-win_amd64.whl
torch-1.3.0+cu100-cp27-cp27m-linux_x86_64.whl
torch-1.3.0+cu100-cp27-cp27mu-linux_x86_64.whl
torch-1.3.0+cu100-cp35-cp35m-linux_x86_64.whl
torch-1.3.0+cu100-cp36-cp36m-linux_x86_64.whl
torch-1.3.0+cu100-cp37-cp37m-linux_x86_64.whl
torch-1.3.1+cu100-cp27-cp27m-linux_x86_64.whl
torch-1.3.1+cu100-cp27-cp27mu-linux_x86_64.whl
torch-1.3.1+cu100-cp35-cp35m-linux_x86_64.whl
torch-1.3.1+cu100-cp36-cp36m-linux_x86_64.whl
torch-1.3.1+cu100-cp37-cp37m-linux_x86_64.whl
torch-1.4.0+cu100-cp27-cp27m-linux_x86_64.whl
torch-1.4.0+cu100-cp27-cp27mu-linux_x86_64.whl
torch-1.4.0+cu100-cp35-cp35m-linux_x86_64.whl
torch-1.4.0+cu100-cp36-cp36m-linux_x86_64.whl
torch-1.4.0+cu100-cp37-cp37m-linux_x86_64.whl
torch-1.4.0+cu100-cp38-cp38-linux_x86_64.whl
torch-1.3.0-cp27-cp27m-manylinux1_x86_64.whl
torch-1.3.0-cp27-cp27mu-manylinux1_x86_64.whl
torch-1.3.0-cp35-cp35m-manylinux1_x86_64.whl
torch-1.3.0-cp35-cp35m-win_amd64.whl
torch-1.3.0-cp36-cp36m-manylinux1_x86_64.whl
torch-1.3.0-cp36-cp36m-win_amd64.whl
torch-1.3.0-cp37-cp37m-manylinux1_x86_64.whl
torch-1.3.0-cp37-cp37m-win_amd64.whl
torch-1.3.1-cp27-cp27m-linux_x86_64.whl
torch-1.3.1-cp27-cp27mu-linux_x86_64.whl
torch-1.3.1-cp35-cp35m-linux_x86_64.whl
torch-1.3.1-cp35-cp35m-win_amd64.whl
torch-1.3.1-cp36-cp36m-linux_x86_64.whl
torch-1.3.1-cp36-cp36m-win_amd64.whl
torch-1.3.1-cp37-cp37m-linux_x86_64.whl
torch-1.3.1-cp37-cp37m-win_amd64.whl
torch-1.4.0-cp27-cp27m-linux_x86_64.whl
torch-1.4.0-cp27-cp27mu-linux_x86_64.whl
torch-1.4.0-cp35-cp35m-linux_x86_64.whl
torch-1.4.0-cp35-cp35m-win_amd64.whl
torch-1.4.0-cp36-cp36m-linux_x86_64.whl
torch-1.4.0-cp36-cp36m-win_amd64.whl
torch-1.4.0-cp37-cp37m-linux_x86_64.whl
torch-1.4.0-cp37-cp37m-win_amd64.whl
torch-1.4.0-cp38-cp38-linux_x86_64.whl
torch-1.4.0-cp38-cp38-win_amd64.whl
torch-1.5.0+cu101-cp35-cp35m-linux_x86_64.whl
torch-1.5.0+cu101-cp35-cp35m-win_amd64.whl
torch-1.5.0+cu101-cp36-cp36m-linux_x86_64.whl
torch-1.5.0+cu101-cp36-cp36m-win_amd64.whl
torch-1.5.0+cu101-cp37-cp37m-linux_x86_64.whl
torch-1.5.0+cu101-cp37-cp37m-win_amd64.whl
torch-1.5.0+cu101-cp38-cp38-linux_x86_64.whl
torch-1.5.0+cu101-cp38-cp38-win_amd64.whl
torch-1.5.1+cu101-cp35-cp35m-linux_x86_64.whl
torch-1.5.1+cu101-cp35-cp35m-win_amd64.whl
torch-1.5.1+cu101-cp36-cp36m-linux_x86_64.whl
torch-1.5.1+cu101-cp36-cp36m-win_amd64.whl
torch-1.5.1+cu101-cp37-cp37m-linux_x86_64.whl
torch-1.5.1+cu101-cp37-cp37m-win_amd64.whl
torch-1.5.1+cu101-cp38-cp38-linux_x86_64.whl
torch-1.5.1+cu101-cp38-cp38-win_amd64.whl
torch-1.6.0+cu101-cp36-cp36m-linux_x86_64.whl
torch-1.6.0+cu101-cp36-cp36m-win_amd64.whl
torch-1.6.0+cu101-cp37-cp37m-linux_x86_64.whl
torch-1.6.0+cu101-cp37-cp37m-win_amd64.whl
torch-1.6.0+cu101-cp38-cp38-linux_x86_64.whl
torch-1.6.0+cu101-cp38-cp38-win_amd64.whl
torch-1.7.0+cu101-cp36-cp36m-linux_x86_64.whl
torch-1.7.0+cu101-cp36-cp36m-win_amd64.whl
torch-1.7.0+cu101-cp37-cp37m-linux_x86_64.whl
torch-1.7.0+cu101-cp37-cp37m-win_amd64.whl
torch-1.7.0+cu101-cp38-cp38-linux_x86_64.whl
torch-1.7.0+cu101-cp38-cp38-win_amd64.whl
torch-1.7.1+cu101-cp36-cp36m-linux_x86_64.whl
torch-1.7.1+cu101-cp36-cp36m-win_amd64.whl
torch-1.7.1+cu101-cp37-cp37m-linux_x86_64.whl
torch-1.7.1+cu101-cp37-cp37m-win_amd64.whl
torch-1.7.1+cu101-cp38-cp38-linux_x86_64.whl
torch-1.7.1+cu101-cp38-cp38-win_amd64.whl
torch-1.7.1+cu101-cp39-cp39-linux_x86_64.whl
torch-1.7.1+cu101-cp39-cp39-win_amd64.whl
torch-1.8.0+cu101-cp36-cp36m-linux_x86_64.whl
torch-1.8.0+cu101-cp36-cp36m-win_amd64.whl
torch-1.8.0+cu101-cp37-cp37m-linux_x86_64.whl
torch-1.8.0+cu101-cp37-cp37m-win_amd64.whl
torch-1.8.0+cu101-cp38-cp38-linux_x86_64.whl
torch-1.8.0+cu101-cp38-cp38-win_amd64.whl
torch-1.8.0+cu101-cp39-cp39-linux_x86_64.whl
torch-1.8.0+cu101-cp39-cp39-win_amd64.whl
torch-1.8.1+cu101-cp36-cp36m-linux_x86_64.whl
torch-1.8.1+cu101-cp36-cp36m-win_amd64.whl
torch-1.8.1+cu101-cp37-cp37m-linux_x86_64.whl
torch-1.8.1+cu101-cp37-cp37m-win_amd64.whl
torch-1.8.1+cu101-cp38-cp38-linux_x86_64.whl
torch-1.8.1+cu101-cp38-cp38-win_amd64.whl
torch-1.8.1+cu101-cp39-cp39-linux_x86_64.whl
torch-1.8.1+cu101-cp39-cp39-win_amd64.whl
torch-1.10.0+cu102-cp36-cp36m-linux_x86_64.whl
torch-1.10.0+cu102-cp36-cp36m-win_amd64.whl
torch-1.10.0+cu102-cp37-cp37m-linux_x86_64.whl
torch-1.10.0+cu102-cp37-cp37m-win_amd64.whl
torch-1.10.0+cu102-cp38-cp38-linux_x86_64.whl
torch-1.10.0+cu102-cp38-cp38-win_amd64.whl
torch-1.10.0+cu102-cp39-cp39-linux_x86_64.whl
torch-1.10.0+cu102-cp39-cp39-win_amd64.whl
torch-1.10.1+cu102-cp36-cp36m-linux_x86_64.whl
torch-1.10.1+cu102-cp36-cp36m-win_amd64.whl
torch-1.10.1+cu102-cp37-cp37m-linux_x86_64.whl
torch-1.10.1+cu102-cp37-cp37m-win_amd64.whl
torch-1.10.1+cu102-cp38-cp38-linux_x86_64.whl
torch-1.10.1+cu102-cp38-cp38-win_amd64.whl
torch-1.10.1+cu102-cp39-cp39-linux_x86_64.whl
torch-1.10.1+cu102-cp39-cp39-win_amd64.whl
torch-1.10.2+cu102-cp36-cp36m-linux_x86_64.whl
torch-1.10.2+cu102-cp36-cp36m-win_amd64.whl
torch-1.10.2+cu102-cp37-cp37m-linux_x86_64.whl
torch-1.10.2+cu102-cp37-cp37m-win_amd64.whl
torch-1.10.2+cu102-cp38-cp38-linux_x86_64.whl
torch-1.10.2+cu102-cp38-cp38-win_amd64.whl
torch-1.10.2+cu102-cp39-cp39-linux_x86_64.whl
torch-1.10.2+cu102-cp39-cp39-win_amd64.whl
torch-1.11.0+cu102-cp310-cp310-linux_x86_64.whl
torch-1.11.0+cu102-cp37-cp37m-linux_x86_64.whl
torch-1.11.0+cu102-cp38-cp38-linux_x86_64.whl
torch-1.11.0+cu102-cp39-cp39-linux_x86_64.whl
torch-1.12.0+cu102-cp310-cp310-linux_x86_64.whl
torch-1.12.0+cu102-cp37-cp37m-linux_x86_64.whl
torch-1.12.0+cu102-cp38-cp38-linux_x86_64.whl
torch-1.12.0+cu102-cp39-cp39-linux_x86_64.whl
torch-1.12.1+cu102-cp310-cp310-linux_x86_64.whl
torch-1.12.1+cu102-cp37-cp37m-linux_x86_64.whl
torch-1.12.1+cu102-cp38-cp38-linux_x86_64.whl
torch-1.12.1+cu102-cp39-cp39-linux_x86_64.whl
torch-1.5.0-cp35-cp35m-linux_x86_64.whl
torch-1.5.0-cp35-cp35m-win_amd64.whl
torch-1.5.0-cp36-cp36m-linux_x86_64.whl
torch-1.5.0-cp36-cp36m-win_amd64.whl
torch-1.5.0-cp37-cp37m-linux_x86_64.whl
torch-1.5.0-cp37-cp37m-win_amd64.whl
torch-1.5.0-cp38-cp38-linux_x86_64.whl
torch-1.5.0-cp38-cp38-win_amd64.whl
torch-1.5.1-cp35-cp35m-linux_x86_64.whl
torch-1.5.1-cp35-cp35m-win_amd64.whl
torch-1.5.1-cp36-cp36m-linux_x86_64.whl
torch-1.5.1-cp36-cp36m-win_amd64.whl
torch-1.5.1-cp37-cp37m-linux_x86_64.whl
torch-1.5.1-cp37-cp37m-win_amd64.whl
torch-1.5.1-cp38-cp38-linux_x86_64.whl
torch-1.5.1-cp38-cp38-win_amd64.whl
torch-1.6.0-cp36-cp36m-linux_x86_64.whl
torch-1.6.0-cp36-cp36m-win_amd64.whl
torch-1.6.0-cp37-cp37m-linux_x86_64.whl
torch-1.6.0-cp37-cp37m-win_amd64.whl
torch-1.6.0-cp38-cp38-linux_x86_64.whl
torch-1.6.0-cp38-cp38-win_amd64.whl
torch-1.7.0-cp36-cp36m-linux_x86_64.whl
torch-1.7.0-cp36-cp36m-win_amd64.whl
torch-1.7.0-cp37-cp37m-linux_x86_64.whl
torch-1.7.0-cp37-cp37m-win_amd64.whl
torch-1.7.0-cp38-cp38-linux_x86_64.whl
torch-1.7.0-cp38-cp38-win_amd64.whl
torch-1.7.1-cp36-cp36m-linux_x86_64.whl
torch-1.7.1-cp36-cp36m-win_amd64.whl
torch-1.7.1-cp37-cp37m-linux_x86_64.whl
torch-1.7.1-cp37-cp37m-win_amd64.whl
torch-1.7.1-cp38-cp38-linux_x86_64.whl
torch-1.7.1-cp38-cp38-win_amd64.whl
torch-1.7.1-cp39-cp39-linux_x86_64.whl
torch-1.7.1-cp39-cp39-win_amd64.whl
torch-1.8.0-cp36-cp36m-linux_x86_64.whl
torch-1.8.0-cp36-cp36m-win_amd64.whl
torch-1.8.0-cp37-cp37m-linux_x86_64.whl
torch-1.8.0-cp37-cp37m-win_amd64.whl
torch-1.8.0-cp38-cp38-linux_x86_64.whl
torch-1.8.0-cp38-cp38-win_amd64.whl
torch-1.8.0-cp39-cp39-linux_x86_64.whl
torch-1.8.0-cp39-cp39-win_amd64.whl
torch-1.8.1+cu102-cp36-cp36m-linux_x86_64.whl
torch-1.8.1+cu102-cp36-cp36m-win_amd64.whl
torch-1.8.1+cu102-cp37-cp37m-linux_x86_64.whl
torch-1.8.1+cu102-cp37-cp37m-win_amd64.whl
torch-1.8.1+cu102-cp38-cp38-linux_x86_64.whl
torch-1.8.1+cu102-cp38-cp38-win_amd64.whl
torch-1.8.1+cu102-cp39-cp39-linux_x86_64.whl
torch-1.8.1+cu102-cp39-cp39-win_amd64.whl
torch-1.9.0+cu102-cp36-cp36m-linux_x86_64.whl
torch-1.9.0+cu102-cp36-cp36m-win_amd64.whl
torch-1.9.0+cu102-cp37-cp37m-linux_x86_64.whl
torch-1.9.0+cu102-cp37-cp37m-win_amd64.whl
torch-1.9.0+cu102-cp38-cp38-linux_x86_64.whl
torch-1.9.0+cu102-cp38-cp38-win_amd64.whl
torch-1.9.0+cu102-cp39-cp39-linux_x86_64.whl
torch-1.9.0+cu102-cp39-cp39-win_amd64.whl
torch-1.9.1+cu102-cp36-cp36m-linux_x86_64.whl
torch-1.9.1+cu102-cp36-cp36m-win_amd64.whl
torch-1.9.1+cu102-cp37-cp37m-linux_x86_64.whl
torch-1.9.1+cu102-cp37-cp37m-win_amd64.whl
torch-1.9.1+cu102-cp38-cp38-linux_x86_64.whl
torch-1.9.1+cu102-cp38-cp38-win_amd64.whl
torch-1.9.1+cu102-cp39-cp39-linux_x86_64.whl
torch-1.9.1+cu102-cp39-cp39-win_amd64.whl
torch-1.7.0+cu110-cp36-cp36m-linux_x86_64.whl
torch-1.7.0+cu110-cp36-cp36m-win_amd64.whl
torch-1.7.0+cu110-cp37-cp37m-linux_x86_64.whl
torch-1.7.0+cu110-cp37-cp37m-win_amd64.whl
torch-1.7.0+cu110-cp38-cp38-linux_x86_64.whl
torch-1.7.0+cu110-cp38-cp38-win_amd64.whl
torch-1.7.1+cu110-cp36-cp36m-linux_x86_64.whl
torch-1.7.1+cu110-cp36-cp36m-win_amd64.whl
torch-1.7.1+cu110-cp37-cp37m-linux_x86_64.whl
torch-1.7.1+cu110-cp37-cp37m-win_amd64.whl
torch-1.7.1+cu110-cp38-cp38-linux_x86_64.whl
torch-1.7.1+cu110-cp38-cp38-win_amd64.whl
torch-1.7.1+cu110-cp39-cp39-linux_x86_64.whl
torch-1.7.1+cu110-cp39-cp39-win_amd64.whl
torch-1.10.0+cu111-cp36-cp36m-linux_x86_64.whl
torch-1.10.0+cu111-cp36-cp36m-win_amd64.whl
torch-1.10.0+cu111-cp37-cp37m-linux_x86_64.whl
torch-1.10.0+cu111-cp37-cp37m-win_amd64.whl
torch-1.10.0+cu111-cp38-cp38-linux_x86_64.whl
torch-1.10.0+cu111-cp38-cp38-win_amd64.whl
torch-1.10.0+cu111-cp39-cp39-linux_x86_64.whl
torch-1.10.0+cu111-cp39-cp39-win_amd64.whl
torch-1.10.1+cu111-cp36-cp36m-linux_x86_64.whl
torch-1.10.1+cu111-cp36-cp36m-win_amd64.whl
torch-1.10.1+cu111-cp37-cp37m-linux_x86_64.whl
torch-1.10.1+cu111-cp37-cp37m-win_amd64.whl
torch-1.10.1+cu111-cp38-cp38-linux_x86_64.whl
torch-1.10.1+cu111-cp38-cp38-win_amd64.whl
torch-1.10.1+cu111-cp39-cp39-linux_x86_64.whl
torch-1.10.1+cu111-cp39-cp39-win_amd64.whl
torch-1.10.2+cu111-cp36-cp36m-linux_x86_64.whl
torch-1.10.2+cu111-cp36-cp36m-win_amd64.whl
torch-1.10.2+cu111-cp37-cp37m-linux_x86_64.whl
torch-1.10.2+cu111-cp37-cp37m-win_amd64.whl
torch-1.10.2+cu111-cp38-cp38-linux_x86_64.whl
torch-1.10.2+cu111-cp38-cp38-win_amd64.whl
torch-1.10.2+cu111-cp39-cp39-linux_x86_64.whl
torch-1.10.2+cu111-cp39-cp39-win_amd64.whl
torch-1.8.0+cu111-cp36-cp36m-linux_x86_64.whl
torch-1.8.0+cu111-cp36-cp36m-win_amd64.whl
torch-1.8.0+cu111-cp37-cp37m-linux_x86_64.whl
torch-1.8.0+cu111-cp37-cp37m-win_amd64.whl
torch-1.8.0+cu111-cp38-cp38-linux_x86_64.whl
torch-1.8.0+cu111-cp38-cp38-win_amd64.whl
torch-1.8.0+cu111-cp39-cp39-linux_x86_64.whl
torch-1.8.0+cu111-cp39-cp39-win_amd64.whl
torch-1.8.1+cu111-cp36-cp36m-linux_x86_64.whl
torch-1.8.1+cu111-cp36-cp36m-win_amd64.whl
torch-1.8.1+cu111-cp37-cp37m-linux_x86_64.whl
torch-1.8.1+cu111-cp37-cp37m-win_amd64.whl
torch-1.8.1+cu111-cp38-cp38-linux_x86_64.whl
torch-1.8.1+cu111-cp38-cp38-win_amd64.whl
torch-1.8.1+cu111-cp39-cp39-linux_x86_64.whl
torch-1.8.1+cu111-cp39-cp39-win_amd64.whl
torch-1.9.0+cu111-cp36-cp36m-linux_x86_64.whl
torch-1.9.0+cu111-cp36-cp36m-win_amd64.whl
torch-1.9.0+cu111-cp37-cp37m-linux_x86_64.whl
torch-1.9.0+cu111-cp37-cp37m-win_amd64.whl
torch-1.9.0+cu111-cp38-cp38-linux_x86_64.whl
torch-1.9.0+cu111-cp38-cp38-win_amd64.whl
torch-1.9.0+cu111-cp39-cp39-linux_x86_64.whl
torch-1.9.0+cu111-cp39-cp39-win_amd64.whl
torch-1.9.1+cu111-cp36-cp36m-linux_x86_64.whl
torch-1.9.1+cu111-cp36-cp36m-win_amd64.whl
torch-1.9.1+cu111-cp37-cp37m-linux_x86_64.whl
torch-1.9.1+cu111-cp37-cp37m-win_amd64.whl
torch-1.9.1+cu111-cp38-cp38-linux_x86_64.whl
torch-1.9.1+cu111-cp38-cp38-win_amd64.whl
torch-1.9.1+cu111-cp39-cp39-linux_x86_64.whl
torch-1.9.1+cu111-cp39-cp39-win_amd64.whl
torch-1.10.0+cu113-cp36-cp36m-linux_x86_64.whl
torch-1.10.0+cu113-cp36-cp36m-win_amd64.whl
torch-1.10.0+cu113-cp37-cp37m-linux_x86_64.whl
torch-1.10.0+cu113-cp37-cp37m-win_amd64.whl
torch-1.10.0+cu113-cp38-cp38-linux_x86_64.whl
torch-1.10.0+cu113-cp38-cp38-win_amd64.whl
torch-1.10.0+cu113-cp39-cp39-linux_x86_64.whl
torch-1.10.0+cu113-cp39-cp39-win_amd64.whl
torch-1.10.1+cu113-cp36-cp36m-linux_x86_64.whl
torch-1.10.1+cu113-cp36-cp36m-win_amd64.whl
torch-1.10.1+cu113-cp37-cp37m-linux_x86_64.whl
torch-1.10.1+cu113-cp37-cp37m-win_amd64.whl
torch-1.10.1+cu113-cp38-cp38-linux_x86_64.whl
torch-1.10.1+cu113-cp38-cp38-win_amd64.whl
torch-1.10.1+cu113-cp39-cp39-linux_x86_64.whl
torch-1.10.1+cu113-cp39-cp39-win_amd64.whl
torch-1.10.2+cu113-cp36-cp36m-linux_x86_64.whl
torch-1.10.2+cu113-cp36-cp36m-win_amd64.whl
torch-1.10.2+cu113-cp37-cp37m-linux_x86_64.whl
torch-1.10.2+cu113-cp37-cp37m-win_amd64.whl
torch-1.10.2+cu113-cp38-cp38-linux_x86_64.whl
torch-1.10.2+cu113-cp38-cp38-win_amd64.whl
torch-1.10.2+cu113-cp39-cp39-linux_x86_64.whl
torch-1.10.2+cu113-cp39-cp39-win_amd64.whl
torch-1.11.0+cu113-cp310-cp310-linux_x86_64.whl
torch-1.11.0+cu113-cp310-cp310-win_amd64.whl
torch-1.11.0+cu113-cp37-cp37m-linux_x86_64.whl
torch-1.11.0+cu113-cp37-cp37m-win_amd64.whl
torch-1.11.0+cu113-cp38-cp38-linux_x86_64.whl
torch-1.11.0+cu113-cp38-cp38-win_amd64.whl
torch-1.11.0+cu113-cp39-cp39-linux_x86_64.whl
torch-1.11.0+cu113-cp39-cp39-win_amd64.whl
torch-1.12.0+cu113-cp310-cp310-linux_x86_64.whl
torch-1.12.0+cu113-cp310-cp310-win_amd64.whl
torch-1.12.0+cu113-cp37-cp37m-linux_x86_64.whl
torch-1.12.0+cu113-cp37-cp37m-win_amd64.whl
torch-1.12.0+cu113-cp38-cp38-linux_x86_64.whl
torch-1.12.0+cu113-cp38-cp38-win_amd64.whl
torch-1.12.0+cu113-cp39-cp39-linux_x86_64.whl
torch-1.12.0+cu113-cp39-cp39-win_amd64.whl
torch-1.12.1+cu113-cp310-cp310-linux_x86_64.whl
torch-1.12.1+cu113-cp310-cp310-win_amd64.whl
torch-1.12.1+cu113-cp37-cp37m-linux_x86_64.whl
torch-1.12.1+cu113-cp37-cp37m-win_amd64.whl
torch-1.12.1+cu113-cp38-cp38-linux_x86_64.whl
torch-1.12.1+cu113-cp38-cp38-win_amd64.whl
torch-1.12.1+cu113-cp39-cp39-linux_x86_64.whl
torch-1.12.1+cu113-cp39-cp39-win_amd64.whl
torch-1.11.0+cu115-cp310-cp310-linux_x86_64.whl
torch-1.11.0+cu115-cp310-cp310-win_amd64.whl
torch-1.11.0+cu115-cp37-cp37m-linux_x86_64.whl
torch-1.11.0+cu115-cp37-cp37m-win_amd64.whl
torch-1.11.0+cu115-cp38-cp38-linux_x86_64.whl
torch-1.11.0+cu115-cp38-cp38-win_amd64.whl
torch-1.11.0+cu115-cp39-cp39-linux_x86_64.whl
torch-1.11.0+cu115-cp39-cp39-win_amd64.whl
torch-1.12.0+cu116-cp310-cp310-linux_x86_64.whl
torch-1.12.0+cu116-cp310-cp310-win_amd64.whl
torch-1.12.0+cu116-cp37-cp37m-linux_x86_64.whl
torch-1.12.0+cu116-cp37-cp37m-win_amd64.whl
torch-1.12.0+cu116-cp38-cp38-linux_x86_64.whl
torch-1.12.0+cu116-cp38-cp38-win_amd64.whl
torch-1.12.0+cu116-cp39-cp39-linux_x86_64.whl
torch-1.12.0+cu116-cp39-cp39-win_amd64.whl
torch-1.12.1+cu116-cp310-cp310-linux_x86_64.whl
torch-1.12.1+cu116-cp310-cp310-win_amd64.whl
torch-1.12.1+cu116-cp37-cp37m-linux_x86_64.whl
torch-1.12.1+cu116-cp37-cp37m-win_amd64.whl
torch-1.12.1+cu116-cp38-cp38-linux_x86_64.whl
torch-1.12.1+cu116-cp38-cp38-win_amd64.whl
torch-1.12.1+cu116-cp39-cp39-linux_x86_64.whl
torch-1.12.1+cu116-cp39-cp39-win_amd64.whl
torch-1.13.0+cu116-cp310-cp310-linux_x86_64.whl
torch-1.13.0+cu116-cp310-cp310-win_amd64.whl
torch-1.13.0+cu116-cp311-cp311-linux_x86_64.whl
torch-1.13.0+cu116-cp37-cp37m-linux_x86_64.whl
torch-1.13.0+cu116-cp37-cp37m-win_amd64.whl
torch-1.13.0+cu116-cp38-cp38-linux_x86_64.whl
torch-1.13.0+cu116-cp38-cp38-win_amd64.whl
torch-1.13.0+cu116-cp39-cp39-linux_x86_64.whl
torch-1.13.0+cu116-cp39-cp39-win_amd64.whl
torch-1.13.1+cu116-cp310-cp310-linux_x86_64.whl
torch-1.13.1+cu116-cp310-cp310-win_amd64.whl
torch-1.13.1+cu116-cp311-cp311-linux_x86_64.whl
torch-1.13.1+cu116-cp37-cp37m-linux_x86_64.whl
torch-1.13.1+cu116-cp37-cp37m-win_amd64.whl
torch-1.13.1+cu116-cp38-cp38-linux_x86_64.whl
torch-1.13.1+cu116-cp38-cp38-win_amd64.whl
torch-1.13.1+cu116-cp39-cp39-linux_x86_64.whl
torch-1.13.1+cu116-cp39-cp39-win_amd64.whl
torch-1.13.0+cu117-cp310-cp310-linux_x86_64.whl
torch-1.13.0+cu117-cp310-cp310-win_amd64.whl
torch-1.13.0+cu117-cp311-cp311-linux_x86_64.whl
torch-1.13.0+cu117-cp37-cp37m-linux_x86_64.whl
torch-1.13.0+cu117-cp37-cp37m-win_amd64.whl
torch-1.13.0+cu117-cp38-cp38-linux_x86_64.whl
torch-1.13.0+cu117-cp38-cp38-win_amd64.whl
torch-1.13.0+cu117-cp39-cp39-linux_x86_64.whl
torch-1.13.0+cu117-cp39-cp39-win_amd64.whl
torch-1.13.1+cu117-cp310-cp310-linux_x86_64.whl
torch-1.13.1+cu117-cp310-cp310-win_amd64.whl
torch-1.13.1+cu117-cp311-cp311-linux_x86_64.whl
torch-1.13.1+cu117-cp37-cp37m-linux_x86_64.whl
torch-1.13.1+cu117-cp37-cp37m-win_amd64.whl
torch-1.13.1+cu117-cp38-cp38-linux_x86_64.whl
torch-1.13.1+cu117-cp38-cp38-win_amd64.whl
torch-1.13.1+cu117-cp39-cp39-linux_x86_64.whl
torch-1.13.1+cu117-cp39-cp39-win_amd64.whl
torch-2.0.0+cu117-cp310-cp310-linux_x86_64.whl
torch-2.0.0+cu117-cp310-cp310-win_amd64.whl
torch-2.0.0+cu117-cp311-cp311-linux_x86_64.whl
torch-2.0.0+cu117-cp311-cp311-win_amd64.whl
torch-2.0.0+cu117-cp38-cp38-linux_x86_64.whl
torch-2.0.0+cu117-cp38-cp38-win_amd64.whl
torch-2.0.0+cu117-cp39-cp39-linux_x86_64.whl
torch-2.0.0+cu117-cp39-cp39-win_amd64.whl
torch-2.0.1+cu117-cp310-cp310-linux_x86_64.whl
torch-2.0.1+cu117-cp310-cp310-win_amd64.whl
torch-2.0.1+cu117-cp311-cp311-linux_x86_64.whl
torch-2.0.1+cu117-cp311-cp311-win_amd64.whl
torch-2.0.1+cu117-cp38-cp38-linux_x86_64.whl
torch-2.0.1+cu117-cp38-cp38-win_amd64.whl
torch-2.0.1+cu117-cp39-cp39-linux_x86_64.whl
torch-2.0.1+cu117-cp39-cp39-win_amd64.whl
torch-1.13.0+cu117.with.pypi.cudnn-cp310-cp310-linux_x86_64.whl
torch-1.13.0+cu117.with.pypi.cudnn-cp311-cp311-linux_x86_64.whl
torch-1.13.0+cu117.with.pypi.cudnn-cp37-cp37m-linux_x86_64.whl
torch-1.13.0+cu117.with.pypi.cudnn-cp38-cp38-linux_x86_64.whl
torch-1.13.0+cu117.with.pypi.cudnn-cp39-cp39-linux_x86_64.whl
torch-1.13.1+cu117.with.pypi.cudnn-cp310-cp310-linux_x86_64.whl
torch-1.13.1+cu117.with.pypi.cudnn-cp311-cp311-linux_x86_64.whl
torch-1.13.1+cu117.with.pypi.cudnn-cp37-cp37m-linux_x86_64.whl
torch-1.13.1+cu117.with.pypi.cudnn-cp38-cp38-linux_x86_64.whl
torch-1.13.1+cu117.with.pypi.cudnn-cp39-cp39-linux_x86_64.whl
torch-2.0.0+cu117.with.pypi.cudnn-cp310-cp310-linux_x86_64.whl
torch-2.0.0+cu117.with.pypi.cudnn-cp311-cp311-linux_x86_64.whl
torch-2.0.0+cu117.with.pypi.cudnn-cp38-cp38-linux_x86_64.whl
torch-2.0.0+cu117.with.pypi.cudnn-cp39-cp39-linux_x86_64.whl
torch-2.0.1+cu117.with.pypi.cudnn-cp310-cp310-linux_x86_64.whl
torch-2.0.1+cu117.with.pypi.cudnn-cp311-cp311-linux_x86_64.whl
torch-2.0.1+cu117.with.pypi.cudnn-cp38-cp38-linux_x86_64.whl
torch-2.0.1+cu117.with.pypi.cudnn-cp39-cp39-linux_x86_64.whl
torch-2.0.0+cu118-cp310-cp310-linux_x86_64.whl
torch-2.0.0+cu118-cp310-cp310-win_amd64.whl
torch-2.0.0+cu118-cp311-cp311-linux_x86_64.whl
torch-2.0.0+cu118-cp311-cp311-win_amd64.whl
torch-2.0.0+cu118-cp38-cp38-linux_x86_64.whl
torch-2.0.0+cu118-cp38-cp38-win_amd64.whl
torch-2.0.0+cu118-cp39-cp39-linux_x86_64.whl
torch-2.0.0+cu118-cp39-cp39-win_amd64.whl
torch-2.0.1+cu118-cp310-cp310-linux_x86_64.whl
torch-2.0.1+cu118-cp310-cp310-win_amd64.whl
torch-2.0.1+cu118-cp311-cp311-linux_x86_64.whl
torch-2.0.1+cu118-cp311-cp311-win_amd64.whl
torch-2.0.1+cu118-cp38-cp38-linux_x86_64.whl
torch-2.0.1+cu118-cp38-cp38-win_amd64.whl
torch-2.0.1+cu118-cp39-cp39-linux_x86_64.whl
torch-2.0.1+cu118-cp39-cp39-win_amd64.whl
torch-2.1.0+cu118-cp310-cp310-linux_x86_64.whl
torch-2.1.0+cu118-cp310-cp310-win_amd64.whl
torch-2.1.0+cu118-cp311-cp311-linux_x86_64.whl
torch-2.1.0+cu118-cp311-cp311-win_amd64.whl
torch-2.1.0+cu118-cp38-cp38-linux_x86_64.whl
torch-2.1.0+cu118-cp38-cp38-win_amd64.whl
torch-2.1.0+cu118-cp39-cp39-linux_x86_64.whl
torch-2.1.0+cu118-cp39-cp39-win_amd64.whl
torch-2.1.1+cu118-cp310-cp310-linux_x86_64.whl
torch-2.1.1+cu118-cp310-cp310-win_amd64.whl
torch-2.1.1+cu118-cp311-cp311-linux_x86_64.whl
torch-2.1.1+cu118-cp311-cp311-win_amd64.whl
torch-2.1.1+cu118-cp38-cp38-linux_x86_64.whl
torch-2.1.1+cu118-cp38-cp38-win_amd64.whl
torch-2.1.1+cu118-cp39-cp39-linux_x86_64.whl
torch-2.1.1+cu118-cp39-cp39-win_amd64.whl
torch-2.1.2+cu118-cp310-cp310-linux_x86_64.whl
torch-2.1.2+cu118-cp310-cp310-win_amd64.whl
torch-2.1.2+cu118-cp311-cp311-linux_x86_64.whl
torch-2.1.2+cu118-cp311-cp311-win_amd64.whl
torch-2.1.2+cu118-cp38-cp38-linux_x86_64.whl
torch-2.1.2+cu118-cp38-cp38-win_amd64.whl
torch-2.1.2+cu118-cp39-cp39-linux_x86_64.whl
torch-2.1.2+cu118-cp39-cp39-win_amd64.whl
torch-2.2.0+cu118-cp310-cp310-linux_x86_64.whl
torch-2.2.0+cu118-cp310-cp310-win_amd64.whl
torch-2.2.0+cu118-cp311-cp311-linux_x86_64.whl
torch-2.2.0+cu118-cp311-cp311-win_amd64.whl
torch-2.2.0+cu118-cp312-cp312-linux_x86_64.whl
torch-2.2.0+cu118-cp312-cp312-win_amd64.whl
torch-2.2.0+cu118-cp38-cp38-linux_x86_64.whl
torch-2.2.0+cu118-cp38-cp38-win_amd64.whl
torch-2.2.0+cu118-cp39-cp39-linux_x86_64.whl
torch-2.2.0+cu118-cp39-cp39-win_amd64.whl
torch-2.2.1+cu118-cp310-cp310-linux_x86_64.whl
torch-2.2.1+cu118-cp310-cp310-win_amd64.whl
torch-2.2.1+cu118-cp311-cp311-linux_x86_64.whl
torch-2.2.1+cu118-cp311-cp311-win_amd64.whl
torch-2.2.1+cu118-cp312-cp312-linux_x86_64.whl
torch-2.2.1+cu118-cp312-cp312-win_amd64.whl
torch-2.2.1+cu118-cp38-cp38-linux_x86_64.whl
torch-2.2.1+cu118-cp38-cp38-win_amd64.whl
torch-2.2.1+cu118-cp39-cp39-linux_x86_64.whl
torch-2.2.1+cu118-cp39-cp39-win_amd64.whl
torch-2.2.2+cu118-cp310-cp310-linux_x86_64.whl
torch-2.2.2+cu118-cp310-cp310-win_amd64.whl
torch-2.2.2+cu118-cp311-cp311-linux_x86_64.whl
torch-2.2.2+cu118-cp311-cp311-win_amd64.whl
torch-2.2.2+cu118-cp312-cp312-linux_x86_64.whl
torch-2.2.2+cu118-cp312-cp312-win_amd64.whl
torch-2.2.2+cu118-cp38-cp38-linux_x86_64.whl
torch-2.2.2+cu118-cp38-cp38-win_amd64.whl
torch-2.2.2+cu118-cp39-cp39-linux_x86_64.whl
torch-2.2.2+cu118-cp39-cp39-win_amd64.whl
torch-2.3.0+cu118-cp310-cp310-linux_x86_64.whl
torch-2.3.0+cu118-cp310-cp310-win_amd64.whl
torch-2.3.0+cu118-cp311-cp311-linux_x86_64.whl
torch-2.3.0+cu118-cp311-cp311-win_amd64.whl
torch-2.3.0+cu118-cp312-cp312-linux_x86_64.whl
torch-2.3.0+cu118-cp312-cp312-win_amd64.whl
torch-2.3.0+cu118-cp38-cp38-linux_x86_64.whl
torch-2.3.0+cu118-cp38-cp38-win_amd64.whl
torch-2.3.0+cu118-cp39-cp39-linux_x86_64.whl
torch-2.3.0+cu118-cp39-cp39-win_amd64.whl
torch-2.3.1+cu118-cp310-cp310-linux_x86_64.whl
torch-2.3.1+cu118-cp310-cp310-win_amd64.whl
torch-2.3.1+cu118-cp311-cp311-linux_x86_64.whl
torch-2.3.1+cu118-cp311-cp311-win_amd64.whl
torch-2.3.1+cu118-cp312-cp312-linux_x86_64.whl
torch-2.3.1+cu118-cp312-cp312-win_amd64.whl
torch-2.3.1+cu118-cp38-cp38-linux_x86_64.whl
torch-2.3.1+cu118-cp38-cp38-win_amd64.whl
torch-2.3.1+cu118-cp39-cp39-linux_x86_64.whl
torch-2.3.1+cu118-cp39-cp39-win_amd64.whl
torch-2.4.0+cu118-cp310-cp310-linux_x86_64.whl
torch-2.4.0+cu118-cp310-cp310-win_amd64.whl
torch-2.4.0+cu118-cp311-cp311-linux_x86_64.whl
torch-2.4.0+cu118-cp311-cp311-win_amd64.whl
torch-2.4.0+cu118-cp312-cp312-linux_x86_64.whl
torch-2.4.0+cu118-cp312-cp312-win_amd64.whl
torch-2.4.0+cu118-cp38-cp38-linux_x86_64.whl
torch-2.4.0+cu118-cp38-cp38-win_amd64.whl
torch-2.4.0+cu118-cp39-cp39-linux_x86_64.whl
torch-2.4.0+cu118-cp39-cp39-win_amd64.whl
torch-2.4.1+cu118-cp310-cp310-linux_x86_64.whl
torch-2.4.1+cu118-cp310-cp310-win_amd64.whl
torch-2.4.1+cu118-cp311-cp311-linux_x86_64.whl
torch-2.4.1+cu118-cp311-cp311-win_amd64.whl
torch-2.4.1+cu118-cp312-cp312-linux_x86_64.whl
torch-2.4.1+cu118-cp312-cp312-win_amd64.whl
torch-2.4.1+cu118-cp38-cp38-linux_x86_64.whl
torch-2.4.1+cu118-cp38-cp38-win_amd64.whl
torch-2.4.1+cu118-cp39-cp39-linux_x86_64.whl
torch-2.4.1+cu118-cp39-cp39-win_amd64.whl
torch-2.5.0+cu118-cp310-cp310-linux_x86_64.whl
torch-2.5.0+cu118-cp310-cp310-win_amd64.whl
torch-2.5.0+cu118-cp311-cp311-linux_x86_64.whl
torch-2.5.0+cu118-cp311-cp311-win_amd64.whl
torch-2.5.0+cu118-cp312-cp312-linux_x86_64.whl
torch-2.5.0+cu118-cp312-cp312-win_amd64.whl
torch-2.5.0+cu118-cp313-cp313-linux_x86_64.whl
torch-2.5.0+cu118-cp39-cp39-linux_x86_64.whl
torch-2.5.0+cu118-cp39-cp39-win_amd64.whl
torch-2.5.1+cu118-cp310-cp310-linux_x86_64.whl
torch-2.5.1+cu118-cp310-cp310-win_amd64.whl
torch-2.5.1+cu118-cp311-cp311-linux_x86_64.whl
torch-2.5.1+cu118-cp311-cp311-win_amd64.whl
torch-2.5.1+cu118-cp312-cp312-linux_x86_64.whl
torch-2.5.1+cu118-cp312-cp312-win_amd64.whl
torch-2.5.1+cu118-cp313-cp313-linux_x86_64.whl
torch-2.5.1+cu118-cp39-cp39-linux_x86_64.whl
torch-2.5.1+cu118-cp39-cp39-win_amd64.whl
torch-2.6.0+cu118-cp310-cp310-linux_x86_64.whl
torch-2.6.0+cu118-cp310-cp310-win_amd64.whl
torch-2.6.0+cu118-cp311-cp311-linux_x86_64.whl
torch-2.6.0+cu118-cp311-cp311-win_amd64.whl
torch-2.6.0+cu118-cp312-cp312-linux_x86_64.whl
torch-2.6.0+cu118-cp312-cp312-win_amd64.whl
torch-2.6.0+cu118-cp313-cp313-linux_x86_64.whl
torch-2.6.0+cu118-cp313-cp313-win_amd64.whl
torch-2.6.0+cu118-cp313-cp313t-linux_x86_64.whl
torch-2.6.0+cu118-cp39-cp39-linux_x86_64.whl
torch-2.6.0+cu118-cp39-cp39-win_amd64.whl
torch-2.7.0+cu118-cp310-cp310-manylinux_2_28_x86_64.whl
torch-2.7.0+cu118-cp310-cp310-win_amd64.whl
torch-2.7.0+cu118-cp311-cp311-manylinux_2_28_x86_64.whl
torch-2.7.0+cu118-cp311-cp311-win_amd64.whl
torch-2.7.0+cu118-cp312-cp312-manylinux_2_28_x86_64.whl
torch-2.7.0+cu118-cp312-cp312-win_amd64.whl
torch-2.7.0+cu118-cp313-cp313-manylinux_2_28_x86_64.whl
torch-2.7.0+cu118-cp313-cp313-win_amd64.whl
torch-2.7.0+cu118-cp313-cp313t-manylinux_2_28_x86_64.whl
torch-2.7.0+cu118-cp313-cp313t-win_amd64.whl
torch-2.7.0+cu118-cp39-cp39-manylinux_2_28_x86_64.whl
torch-2.7.0+cu118-cp39-cp39-win_amd64.whl
torch-2.7.1+cu118-cp310-cp310-manylinux_2_28_x86_64.whl
torch-2.7.1+cu118-cp310-cp310-win_amd64.whl
torch-2.7.1+cu118-cp311-cp311-manylinux_2_28_x86_64.whl
torch-2.7.1+cu118-cp311-cp311-win_amd64.whl
torch-2.7.1+cu118-cp312-cp312-manylinux_2_28_x86_64.whl
torch-2.7.1+cu118-cp312-cp312-win_amd64.whl
torch-2.7.1+cu118-cp313-cp313-manylinux_2_28_x86_64.whl
torch-2.7.1+cu118-cp313-cp313-win_amd64.whl
torch-2.7.1+cu118-cp313-cp313t-manylinux_2_28_x86_64.whl
torch-2.7.1+cu118-cp313-cp313t-win_amd64.whl
torch-2.7.1+cu118-cp39-cp39-manylinux_2_28_x86_64.whl
torch-2.7.1+cu118-cp39-cp39-win_amd64.whl
torch-2.1.0+cu121-cp310-cp310-linux_x86_64.whl
torch-2.1.0+cu121-cp310-cp310-win_amd64.whl
torch-2.1.0+cu121-cp311-cp311-linux_x86_64.whl
torch-2.1.0+cu121-cp311-cp311-win_amd64.whl
torch-2.1.0+cu121-cp38-cp38-linux_x86_64.whl
torch-2.1.0+cu121-cp38-cp38-win_amd64.whl
torch-2.1.0+cu121-cp39-cp39-linux_x86_64.whl
torch-2.1.0+cu121-cp39-cp39-win_amd64.whl
torch-2.1.1+cu121-cp310-cp310-linux_x86_64.whl
torch-2.1.1+cu121-cp310-cp310-win_amd64.whl
torch-2.1.1+cu121-cp311-cp311-linux_x86_64.whl
torch-2.1.1+cu121-cp311-cp311-win_amd64.whl
torch-2.1.1+cu121-cp38-cp38-linux_x86_64.whl
torch-2.1.1+cu121-cp38-cp38-win_amd64.whl
torch-2.1.1+cu121-cp39-cp39-linux_x86_64.whl
torch-2.1.1+cu121-cp39-cp39-win_amd64.whl
torch-2.1.2+cu121-cp310-cp310-linux_x86_64.whl
torch-2.1.2+cu121-cp310-cp310-win_amd64.whl
torch-2.1.2+cu121-cp311-cp311-linux_x86_64.whl
torch-2.1.2+cu121-cp311-cp311-win_amd64.whl
torch-2.1.2+cu121-cp38-cp38-linux_x86_64.whl
torch-2.1.2+cu121-cp38-cp38-win_amd64.whl
torch-2.1.2+cu121-cp39-cp39-linux_x86_64.whl
torch-2.1.2+cu121-cp39-cp39-win_amd64.whl
torch-2.2.0+cu121-cp310-cp310-linux_x86_64.whl
torch-2.2.0+cu121-cp310-cp310-win_amd64.whl
torch-2.2.0+cu121-cp311-cp311-linux_x86_64.whl
torch-2.2.0+cu121-cp311-cp311-win_amd64.whl
torch-2.2.0+cu121-cp312-cp312-linux_x86_64.whl
torch-2.2.0+cu121-cp312-cp312-win_amd64.whl
torch-2.2.0+cu121-cp38-cp38-linux_x86_64.whl
torch-2.2.0+cu121-cp38-cp38-win_amd64.whl
torch-2.2.0+cu121-cp39-cp39-linux_x86_64.whl
torch-2.2.0+cu121-cp39-cp39-win_amd64.whl
torch-2.2.1+cu121-cp310-cp310-linux_x86_64.whl
torch-2.2.1+cu121-cp310-cp310-win_amd64.whl
torch-2.2.1+cu121-cp311-cp311-linux_x86_64.whl
torch-2.2.1+cu121-cp311-cp311-win_amd64.whl
torch-2.2.1+cu121-cp312-cp312-linux_x86_64.whl
torch-2.2.1+cu121-cp312-cp312-win_amd64.whl
torch-2.2.1+cu121-cp38-cp38-linux_x86_64.whl
torch-2.2.1+cu121-cp38-cp38-win_amd64.whl
torch-2.2.1+cu121-cp39-cp39-linux_x86_64.whl
torch-2.2.1+cu121-cp39-cp39-win_amd64.whl
torch-2.2.2+cu121-cp310-cp310-linux_x86_64.whl
torch-2.2.2+cu121-cp310-cp310-win_amd64.whl
torch-2.2.2+cu121-cp311-cp311-linux_x86_64.whl
torch-2.2.2+cu121-cp311-cp311-win_amd64.whl
torch-2.2.2+cu121-cp312-cp312-linux_x86_64.whl
torch-2.2.2+cu121-cp312-cp312-win_amd64.whl
torch-2.2.2+cu121-cp38-cp38-linux_x86_64.whl
torch-2.2.2+cu121-cp38-cp38-win_amd64.whl
torch-2.2.2+cu121-cp39-cp39-linux_x86_64.whl
torch-2.2.2+cu121-cp39-cp39-win_amd64.whl
torch-2.3.0+cu121-cp310-cp310-linux_x86_64.whl
torch-2.3.0+cu121-cp310-cp310-win_amd64.whl
torch-2.3.0+cu121-cp311-cp311-linux_x86_64.whl
torch-2.3.0+cu121-cp311-cp311-win_amd64.whl
torch-2.3.0+cu121-cp312-cp312-linux_x86_64.whl
torch-2.3.0+cu121-cp312-cp312-win_amd64.whl
torch-2.3.0+cu121-cp38-cp38-linux_x86_64.whl
torch-2.3.0+cu121-cp38-cp38-win_amd64.whl
torch-2.3.0+cu121-cp39-cp39-linux_x86_64.whl
torch-2.3.0+cu121-cp39-cp39-win_amd64.whl
torch-2.3.1+cu121-cp310-cp310-linux_x86_64.whl
torch-2.3.1+cu121-cp310-cp310-win_amd64.whl
torch-2.3.1+cu121-cp311-cp311-linux_x86_64.whl
torch-2.3.1+cu121-cp311-cp311-win_amd64.whl
torch-2.3.1+cu121-cp312-cp312-linux_x86_64.whl
torch-2.3.1+cu121-cp312-cp312-win_amd64.whl
torch-2.3.1+cu121-cp38-cp38-linux_x86_64.whl
torch-2.3.1+cu121-cp38-cp38-win_amd64.whl
torch-2.3.1+cu121-cp39-cp39-linux_x86_64.whl
torch-2.3.1+cu121-cp39-cp39-win_amd64.whl
torch-2.4.0+cu121-cp310-cp310-linux_x86_64.whl
torch-2.4.0+cu121-cp310-cp310-win_amd64.whl
torch-2.4.0+cu121-cp311-cp311-linux_x86_64.whl
torch-2.4.0+cu121-cp311-cp311-win_amd64.whl
torch-2.4.0+cu121-cp312-cp312-linux_x86_64.whl
torch-2.4.0+cu121-cp312-cp312-win_amd64.whl
torch-2.4.0+cu121-cp38-cp38-linux_x86_64.whl
torch-2.4.0+cu121-cp38-cp38-win_amd64.whl
torch-2.4.0+cu121-cp39-cp39-linux_x86_64.whl
torch-2.4.0+cu121-cp39-cp39-win_amd64.whl
torch-2.4.1+cu121-cp310-cp310-linux_x86_64.whl
torch-2.4.1+cu121-cp310-cp310-win_amd64.whl
torch-2.4.1+cu121-cp311-cp311-linux_x86_64.whl
torch-2.4.1+cu121-cp311-cp311-win_amd64.whl
torch-2.4.1+cu121-cp312-cp312-linux_x86_64.whl
torch-2.4.1+cu121-cp312-cp312-win_amd64.whl
torch-2.4.1+cu121-cp38-cp38-linux_x86_64.whl
torch-2.4.1+cu121-cp38-cp38-win_amd64.whl
torch-2.4.1+cu121-cp39-cp39-linux_x86_64.whl
torch-2.4.1+cu121-cp39-cp39-win_amd64.whl
torch-2.5.0+cu121-cp310-cp310-linux_x86_64.whl
torch-2.5.0+cu121-cp310-cp310-win_amd64.whl
torch-2.5.0+cu121-cp311-cp311-linux_x86_64.whl
torch-2.5.0+cu121-cp311-cp311-win_amd64.whl
torch-2.5.0+cu121-cp312-cp312-linux_x86_64.whl
torch-2.5.0+cu121-cp312-cp312-win_amd64.whl
torch-2.5.0+cu121-cp313-cp313-linux_x86_64.whl
torch-2.5.0+cu121-cp39-cp39-linux_x86_64.whl
torch-2.5.0+cu121-cp39-cp39-win_amd64.whl
torch-2.5.1+cu121-cp310-cp310-linux_x86_64.whl
torch-2.5.1+cu121-cp310-cp310-win_amd64.whl
torch-2.5.1+cu121-cp311-cp311-linux_x86_64.whl
torch-2.5.1+cu121-cp311-cp311-win_amd64.whl
torch-2.5.1+cu121-cp312-cp312-linux_x86_64.whl
torch-2.5.1+cu121-cp312-cp312-win_amd64.whl
torch-2.5.1+cu121-cp313-cp313-linux_x86_64.whl
torch-2.5.1+cu121-cp39-cp39-linux_x86_64.whl
torch-2.5.1+cu121-cp39-cp39-win_amd64.whl
torch-2.4.0+cu121-cp310-cp310-linux_x86_64.whl
torch-2.4.1+cu121-cp310-cp310-linux_x86_64.whl
torch-2.5.0+cu121-cp310-cp310-linux_x86_64.whl
torch-2.5.1+cu121-cp310-cp310-linux_x86_64.whl
torch-2.1.0+cu121.with.pypi.cudnn-cp310-cp310-linux_x86_64.whl
torch-2.1.0+cu121.with.pypi.cudnn-cp311-cp311-linux_x86_64.whl
torch-2.1.0+cu121.with.pypi.cudnn-cp38-cp38-linux_x86_64.whl
torch-2.1.0+cu121.with.pypi.cudnn-cp39-cp39-linux_x86_64.whl
torch-2.1.1+cu121.with.pypi.cudnn-cp310-cp310-linux_x86_64.whl
torch-2.1.1+cu121.with.pypi.cudnn-cp311-cp311-linux_x86_64.whl
torch-2.1.1+cu121.with.pypi.cudnn-cp38-cp38-linux_x86_64.whl
torch-2.1.1+cu121.with.pypi.cudnn-cp39-cp39-linux_x86_64.whl
torch-2.1.2+cu121.with.pypi.cudnn-cp310-cp310-linux_x86_64.whl
torch-2.1.2+cu121.with.pypi.cudnn-cp311-cp311-linux_x86_64.whl
torch-2.1.2+cu121.with.pypi.cudnn-cp38-cp38-linux_x86_64.whl
torch-2.1.2+cu121.with.pypi.cudnn-cp39-cp39-linux_x86_64.whl
torch-2.4.0+cu124-cp310-cp310-linux_x86_64.whl
torch-2.4.0+cu124-cp310-cp310-win_amd64.whl
torch-2.4.0+cu124-cp311-cp311-linux_x86_64.whl
torch-2.4.0+cu124-cp311-cp311-win_amd64.whl
torch-2.4.0+cu124-cp312-cp312-linux_x86_64.whl
torch-2.4.0+cu124-cp312-cp312-win_amd64.whl
torch-2.4.0+cu124-cp38-cp38-linux_x86_64.whl
torch-2.4.0+cu124-cp38-cp38-win_amd64.whl
torch-2.4.0+cu124-cp39-cp39-linux_x86_64.whl
torch-2.4.0+cu124-cp39-cp39-win_amd64.whl
torch-2.4.0-cp310-cp310-linux_aarch64.whl
torch-2.4.0-cp311-cp311-linux_aarch64.whl
torch-2.4.0-cp312-cp312-linux_aarch64.whl
torch-2.4.0-cp38-cp38-linux_aarch64.whl
torch-2.4.0-cp39-cp39-linux_aarch64.whl
torch-2.4.1+cu124-cp310-cp310-linux_x86_64.whl
torch-2.4.1+cu124-cp310-cp310-win_amd64.whl
torch-2.4.1+cu124-cp311-cp311-linux_x86_64.whl
torch-2.4.1+cu124-cp311-cp311-win_amd64.whl
torch-2.4.1+cu124-cp312-cp312-linux_x86_64.whl
torch-2.4.1+cu124-cp312-cp312-win_amd64.whl
torch-2.4.1+cu124-cp38-cp38-linux_x86_64.whl
torch-2.4.1+cu124-cp38-cp38-win_amd64.whl
torch-2.4.1+cu124-cp39-cp39-linux_x86_64.whl
torch-2.4.1+cu124-cp39-cp39-win_amd64.whl
torch-2.4.1-cp310-cp310-linux_aarch64.whl
torch-2.4.1-cp311-cp311-linux_aarch64.whl
torch-2.4.1-cp312-cp312-linux_aarch64.whl
torch-2.4.1-cp38-cp38-linux_aarch64.whl
torch-2.4.1-cp39-cp39-linux_aarch64.whl
torch-2.5.0+cu124-cp310-cp310-linux_x86_64.whl
torch-2.5.0+cu124-cp310-cp310-win_amd64.whl
torch-2.5.0+cu124-cp311-cp311-linux_x86_64.whl
torch-2.5.0+cu124-cp311-cp311-win_amd64.whl
torch-2.5.0+cu124-cp312-cp312-linux_x86_64.whl
torch-2.5.0+cu124-cp312-cp312-win_amd64.whl
torch-2.5.0+cu124-cp313-cp313-linux_x86_64.whl
torch-2.5.0+cu124-cp39-cp39-linux_x86_64.whl
torch-2.5.0+cu124-cp39-cp39-win_amd64.whl
torch-2.5.0-cp310-cp310-linux_aarch64.whl
torch-2.5.0-cp311-cp311-linux_aarch64.whl
torch-2.5.0-cp312-cp312-linux_aarch64.whl
torch-2.5.0-cp39-cp39-linux_aarch64.whl
torch-2.5.1+cu124-cp310-cp310-linux_x86_64.whl
torch-2.5.1+cu124-cp310-cp310-win_amd64.whl
torch-2.5.1+cu124-cp311-cp311-linux_x86_64.whl
torch-2.5.1+cu124-cp311-cp311-win_amd64.whl
torch-2.5.1+cu124-cp312-cp312-linux_x86_64.whl
torch-2.5.1+cu124-cp312-cp312-win_amd64.whl
torch-2.5.1+cu124-cp313-cp313-linux_x86_64.whl
torch-2.5.1+cu124-cp39-cp39-linux_x86_64.whl
torch-2.5.1+cu124-cp39-cp39-win_amd64.whl
torch-2.5.1-cp310-cp310-linux_aarch64.whl
torch-2.5.1-cp311-cp311-linux_aarch64.whl
torch-2.5.1-cp312-cp312-linux_aarch64.whl
torch-2.5.1-cp39-cp39-linux_aarch64.whl
torch-2.6.0+cu124-cp310-cp310-linux_x86_64.whl
torch-2.6.0+cu124-cp310-cp310-win_amd64.whl
torch-2.6.0+cu124-cp311-cp311-linux_x86_64.whl
torch-2.6.0+cu124-cp311-cp311-win_amd64.whl
torch-2.6.0+cu124-cp312-cp312-linux_x86_64.whl
torch-2.6.0+cu124-cp312-cp312-win_amd64.whl
torch-2.6.0+cu124-cp313-cp313-linux_x86_64.whl
torch-2.6.0+cu124-cp313-cp313-win_amd64.whl
torch-2.6.0+cu124-cp313-cp313t-linux_x86_64.whl
torch-2.6.0+cu124-cp39-cp39-linux_x86_64.whl
torch-2.6.0+cu124-cp39-cp39-win_amd64.whl
torch-2.6.0+cu124-cp311-cp311-linux_x86_64.whl
torch-2.6.0+cu126-cp310-cp310-linux_aarch64.whl
torch-2.6.0+cu126-cp310-cp310-manylinux_2_28_x86_64.whl
torch-2.6.0+cu126-cp310-cp310-win_amd64.whl
torch-2.6.0+cu126-cp311-cp311-linux_aarch64.whl
torch-2.6.0+cu126-cp311-cp311-manylinux_2_28_x86_64.whl
torch-2.6.0+cu126-cp311-cp311-win_amd64.whl
torch-2.6.0+cu126-cp312-cp312-linux_aarch64.whl
torch-2.6.0+cu126-cp312-cp312-manylinux_2_28_x86_64.whl
torch-2.6.0+cu126-cp312-cp312-win_amd64.whl
torch-2.6.0+cu126-cp313-cp313-linux_aarch64.whl
torch-2.6.0+cu126-cp313-cp313-manylinux_2_28_x86_64.whl
torch-2.6.0+cu126-cp313-cp313-win_amd64.whl
torch-2.6.0+cu126-cp313-cp313t-linux_aarch64.whl
torch-2.6.0+cu126-cp313-cp313t-manylinux_2_28_x86_64.whl
torch-2.6.0+cu126-cp39-cp39-linux_aarch64.whl
torch-2.6.0+cu126-cp39-cp39-manylinux_2_28_x86_64.whl
torch-2.6.0+cu126-cp39-cp39-win_amd64.whl
torch-2.7.0+cu126-cp310-cp310-manylinux_2_28_x86_64.whl
torch-2.7.0+cu126-cp310-cp310-win_amd64.whl
torch-2.7.0+cu126-cp311-cp311-manylinux_2_28_x86_64.whl
torch-2.7.0+cu126-cp311-cp311-win_amd64.whl
torch-2.7.0+cu126-cp312-cp312-manylinux_2_28_x86_64.whl
torch-2.7.0+cu126-cp312-cp312-win_amd64.whl
torch-2.7.0+cu126-cp313-cp313-manylinux_2_28_x86_64.whl
torch-2.7.0+cu126-cp313-cp313-win_amd64.whl
torch-2.7.0+cu126-cp313-cp313t-manylinux_2_28_x86_64.whl
torch-2.7.0+cu126-cp313-cp313t-win_amd64.whl
torch-2.7.0+cu126-cp39-cp39-manylinux_2_28_x86_64.whl
torch-2.7.0+cu126-cp39-cp39-win_amd64.whl
torch-2.7.1+cu126-cp310-cp310-manylinux_2_28_x86_64.whl
torch-2.7.1+cu126-cp310-cp310-win_amd64.whl
torch-2.7.1+cu126-cp311-cp311-manylinux_2_28_x86_64.whl
torch-2.7.1+cu126-cp311-cp311-win_amd64.whl
torch-2.7.1+cu126-cp312-cp312-manylinux_2_28_x86_64.whl
torch-2.7.1+cu126-cp312-cp312-win_amd64.whl
torch-2.7.1+cu126-cp313-cp313-manylinux_2_28_x86_64.whl
torch-2.7.1+cu126-cp313-cp313-win_amd64.whl
torch-2.7.1+cu126-cp313-cp313t-manylinux_2_28_x86_64.whl
torch-2.7.1+cu126-cp313-cp313t-win_amd64.whl
torch-2.7.1+cu126-cp39-cp39-manylinux_2_28_x86_64.whl
torch-2.7.1+cu126-cp39-cp39-win_amd64.whl
torch-2.7.0+cu126-cp311-cp311-manylinux_2_28_x86_64.whl
torch-2.7.1+cu126-cp311-cp311-manylinux_2_28_x86_64.whl
torch-2.7.0+cu128-cp310-cp310-manylinux_2_28_aarch64.whl
torch-2.7.0+cu128-cp310-cp310-manylinux_2_28_x86_64.whl
torch-2.7.0+cu128-cp310-cp310-win_amd64.whl
torch-2.7.0+cu128-cp311-cp311-manylinux_2_28_aarch64.whl
torch-2.7.0+cu128-cp311-cp311-manylinux_2_28_x86_64.whl
torch-2.7.0+cu128-cp311-cp311-win_amd64.whl
torch-2.7.0+cu128-cp312-cp312-manylinux_2_28_aarch64.whl
torch-2.7.0+cu128-cp312-cp312-manylinux_2_28_x86_64.whl
torch-2.7.0+cu128-cp312-cp312-win_amd64.whl
torch-2.7.0+cu128-cp313-cp313-manylinux_2_28_aarch64.whl
torch-2.7.0+cu128-cp313-cp313-manylinux_2_28_x86_64.whl
torch-2.7.0+cu128-cp313-cp313-win_amd64.whl
torch-2.7.0+cu128-cp313-cp313t-manylinux_2_28_aarch64.whl
torch-2.7.0+cu128-cp313-cp313t-manylinux_2_28_x86_64.whl
torch-2.7.0+cu128-cp313-cp313t-win_amd64.whl
torch-2.7.0+cu128-cp39-cp39-manylinux_2_28_aarch64.whl
torch-2.7.0+cu128-cp39-cp39-manylinux_2_28_x86_64.whl
torch-2.7.0+cu128-cp39-cp39-win_amd64.whl
torch-2.7.1+cu128-cp310-cp310-manylinux_2_28_aarch64.whl
torch-2.7.1+cu128-cp310-cp310-manylinux_2_28_x86_64.whl
torch-2.7.1+cu128-cp310-cp310-win_amd64.whl
torch-2.7.1+cu128-cp311-cp311-manylinux_2_28_aarch64.whl
torch-2.7.1+cu128-cp311-cp311-manylinux_2_28_x86_64.whl
torch-2.7.1+cu128-cp311-cp311-win_amd64.whl
torch-2.7.1+cu128-cp312-cp312-manylinux_2_28_aarch64.whl
torch-2.7.1+cu128-cp312-cp312-manylinux_2_28_x86_64.whl
torch-2.7.1+cu128-cp312-cp312-win_amd64.whl
torch-2.7.1+cu128-cp313-cp313-manylinux_2_28_aarch64.whl
torch-2.7.1+cu128-cp313-cp313-manylinux_2_28_x86_64.whl
torch-2.7.1+cu128-cp313-cp313-win_amd64.whl
torch-2.7.1+cu128-cp313-cp313t-manylinux_2_28_aarch64.whl
torch-2.7.1+cu128-cp313-cp313t-manylinux_2_28_x86_64.whl
torch-2.7.1+cu128-cp313-cp313t-win_amd64.whl
torch-2.7.1+cu128-cp39-cp39-manylinux_2_28_aarch64.whl
torch-2.7.1+cu128-cp39-cp39-manylinux_2_28_x86_64.whl
torch-2.7.1+cu128-cp39-cp39-win_amd64.whl
torch-0.1.10.post1-cp27-none-linux_x86_64.whl
torch-0.1.10.post1-cp35-cp35m-linux_x86_64.whl
torch-0.1.10.post1-cp36-cp36m-linux_x86_64.whl
torch-0.1.10.post2-cp27-none-linux_x86_64.whl
torch-0.1.10.post2-cp35-cp35m-linux_x86_64.whl
torch-0.1.10.post2-cp36-cp36m-linux_x86_64.whl
torch-0.1.11.post4-cp27-none-linux_x86_64.whl
torch-0.1.11.post4-cp35-cp35m-linux_x86_64.whl
torch-0.1.11.post4-cp36-cp36m-linux_x86_64.whl
torch-0.1.11.post5-cp27-none-linux_x86_64.whl
torch-0.1.11.post5-cp35-cp35m-linux_x86_64.whl
torch-0.1.11.post5-cp36-cp36m-linux_x86_64.whl
torch-0.1.12.post1-cp27-none-linux_x86_64.whl
torch-0.1.12.post1-cp35-cp35m-linux_x86_64.whl
torch-0.1.12.post1-cp36-cp36m-linux_x86_64.whl
torch-0.1.12.post2-cp27-none-linux_x86_64.whl
torch-0.1.12.post2-cp35-cp35m-linux_x86_64.whl
torch-0.1.12.post2-cp36-cp36m-linux_x86_64.whl
torch-0.1.6.post20-cp27-cp27mu-linux_x86_64.whl
torch-0.1.6.post20-cp35-cp35m-linux_x86_64.whl
torch-0.1.6.post22-cp27-none-linux_x86_64.whl
torch-0.1.6.post22-cp35-cp35m-linux_x86_64.whl
torch-0.1.7.post2-cp27-none-linux_x86_64.whl
torch-0.1.7.post2-cp35-cp35m-linux_x86_64.whl
torch-0.1.7.post2-cp36-cp36m-linux_x86_64.whl
torch-0.1.8.post1-cp27-none-linux_x86_64.whl
torch-0.1.8.post1-cp35-cp35m-linux_x86_64.whl
torch-0.1.8.post1-cp36-cp36m-linux_x86_64.whl
torch-0.1.9.post1-cp27-none-linux_x86_64.whl
torch-0.1.9.post1-cp35-cp35m-linux_x86_64.whl
torch-0.1.9.post1-cp36-cp36m-linux_x86_64.whl
torch-0.1.9.post2-cp27-none-linux_x86_64.whl
torch-0.1.9.post2-cp35-cp35m-linux_x86_64.whl
torch-0.1.9.post2-cp36-cp36m-linux_x86_64.whl
torch-0.2.0.post1-cp27-cp27m-manylinux1_x86_64.whl
torch-0.2.0.post1-cp27-cp27mu-manylinux1_x86_64.whl
torch-0.2.0.post1-cp35-cp35m-manylinux1_x86_64.whl
torch-0.2.0.post1-cp36-cp36m-manylinux1_x86_64.whl
torch-0.2.0.post2-cp27-cp27m-manylinux1_x86_64.whl
torch-0.2.0.post2-cp27-cp27mu-manylinux1_x86_64.whl
torch-0.2.0.post2-cp35-cp35m-manylinux1_x86_64.whl
torch-0.2.0.post2-cp36-cp36m-manylinux1_x86_64.whl
torch-0.2.0.post3-cp27-cp27m-manylinux1_x86_64.whl
torch-0.2.0.post3-cp27-cp27mu-manylinux1_x86_64.whl
torch-0.2.0.post3-cp35-cp35m-manylinux1_x86_64.whl
torch-0.2.0.post3-cp36-cp36m-manylinux1_x86_64.whl
torch-0.3.0-cp27-cp27m-linux_x86_64.whl
torch-0.3.0-cp27-cp27mu-linux_x86_64.whl
torch-0.3.0-cp35-cp35m-linux_x86_64.whl
torch-0.3.0-cp36-cp36m-linux_x86_64.whl
torch-0.3.0.post2-cp27-cp27m-linux_x86_64.whl
torch-0.3.0.post2-cp27-cp27mu-linux_x86_64.whl
torch-0.3.0.post2-cp35-cp35m-linux_x86_64.whl
torch-0.3.0.post2-cp36-cp36m-linux_x86_64.whl
torch-0.3.0.post3-cp27-cp27m-linux_x86_64.whl
torch-0.3.0.post3-cp27-cp27mu-linux_x86_64.whl
torch-0.3.0.post3-cp35-cp35m-linux_x86_64.whl
torch-0.3.0.post3-cp36-cp36m-linux_x86_64.whl
torch-0.3.0.post4-cp27-cp27m-linux_x86_64.whl
torch-0.3.0.post4-cp27-cp27mu-linux_x86_64.whl
torch-0.3.0.post4-cp35-cp35m-linux_x86_64.whl
torch-0.3.0.post4-cp36-cp36m-linux_x86_64.whl
torch-0.1.10.post1-cp27-none-linux_x86_64.whl
torch-0.1.10.post1-cp35-cp35m-linux_x86_64.whl
torch-0.1.10.post1-cp36-cp36m-linux_x86_64.whl
torch-0.1.10.post2-cp27-none-linux_x86_64.whl
torch-0.1.10.post2-cp35-cp35m-linux_x86_64.whl
torch-0.1.10.post2-cp36-cp36m-linux_x86_64.whl
torch-0.1.11.post4-cp27-none-linux_x86_64.whl
torch-0.1.11.post4-cp35-cp35m-linux_x86_64.whl
torch-0.1.11.post4-cp36-cp36m-linux_x86_64.whl
torch-0.1.11.post5-cp27-none-linux_x86_64.whl
torch-0.1.11.post5-cp35-cp35m-linux_x86_64.whl
torch-0.1.11.post5-cp36-cp36m-linux_x86_64.whl
torch-0.1.12.post1-cp27-none-linux_x86_64.whl
torch-0.1.12.post1-cp35-cp35m-linux_x86_64.whl
torch-0.1.12.post1-cp36-cp36m-linux_x86_64.whl
torch-0.1.12.post2-cp27-none-linux_x86_64.whl
torch-0.1.12.post2-cp35-cp35m-linux_x86_64.whl
torch-0.1.12.post2-cp36-cp36m-linux_x86_64.whl
torch-0.1.6.post20-cp27-cp27mu-linux_x86_64.whl
torch-0.1.6.post20-cp35-cp35m-linux_x86_64.whl
torch-0.1.6.post22-cp27-none-linux_x86_64.whl
torch-0.1.6.post22-cp35-cp35m-linux_x86_64.whl
torch-0.1.7.post2-cp27-none-linux_x86_64.whl
torch-0.1.7.post2-cp35-cp35m-linux_x86_64.whl
torch-0.1.7.post2-cp36-cp36m-linux_x86_64.whl
torch-0.1.8.post1-cp27-none-linux_x86_64.whl
torch-0.1.8.post1-cp35-cp35m-linux_x86_64.whl
torch-0.1.8.post1-cp36-cp36m-linux_x86_64.whl
torch-0.1.9.post1-cp27-none-linux_x86_64.whl
torch-0.1.9.post1-cp35-cp35m-linux_x86_64.whl
torch-0.1.9.post1-cp36-cp36m-linux_x86_64.whl
torch-0.1.9.post2-cp27-none-linux_x86_64.whl
torch-0.1.9.post2-cp35-cp35m-linux_x86_64.whl
torch-0.1.9.post2-cp36-cp36m-linux_x86_64.whl
torch-0.2.0.post2-cp27-cp27m-manylinux1_x86_64.whl
torch-0.2.0.post2-cp27-cp27mu-manylinux1_x86_64.whl
torch-0.2.0.post2-cp35-cp35m-manylinux1_x86_64.whl
torch-0.2.0.post2-cp36-cp36m-manylinux1_x86_64.whl
torch-0.2.0.post3-cp27-cp27m-manylinux1_x86_64.whl
torch-0.2.0.post3-cp27-cp27mu-manylinux1_x86_64.whl
torch-0.2.0.post3-cp35-cp35m-manylinux1_x86_64.whl
torch-0.2.0.post3-cp36-cp36m-manylinux1_x86_64.whl
torch-0.3.0-cp27-cp27m-linux_x86_64.whl
torch-0.3.0-cp27-cp27mu-linux_x86_64.whl
torch-0.3.0-cp35-cp35m-linux_x86_64.whl
torch-0.3.0-cp36-cp36m-linux_x86_64.whl
torch-0.3.0.post2-cp27-cp27m-linux_x86_64.whl
torch-0.3.0.post2-cp27-cp27mu-linux_x86_64.whl
torch-0.3.0.post2-cp35-cp35m-linux_x86_64.whl
torch-0.3.0.post2-cp36-cp36m-linux_x86_64.whl
torch-0.3.0.post3-cp27-cp27m-linux_x86_64.whl
torch-0.3.0.post3-cp27-cp27mu-linux_x86_64.whl
torch-0.3.0.post3-cp35-cp35m-linux_x86_64.whl
torch-0.3.0.post3-cp36-cp36m-linux_x86_64.whl
torch-0.3.0.post4-cp27-cp27m-linux_x86_64.whl
torch-0.3.0.post4-cp27-cp27mu-linux_x86_64.whl
torch-0.3.0.post4-cp35-cp35m-linux_x86_64.whl
torch-0.3.0.post4-cp36-cp36m-linux_x86_64.whl
torch-0.3.1-cp27-cp27m-linux_x86_64.whl
torch-0.3.1-cp27-cp27mu-linux_x86_64.whl
torch-0.3.1-cp35-cp35m-linux_x86_64.whl
torch-0.3.1-cp36-cp36m-linux_x86_64.whl
torch-0.4.0-cp27-cp27m-linux_x86_64.whl
torch-0.4.0-cp27-cp27mu-linux_x86_64.whl
torch-0.4.0-cp35-cp35m-linux_x86_64.whl
torch-0.4.0-cp35-cp35m-win_amd64.whl
torch-0.4.0-cp36-cp36m-linux_x86_64.whl
torch-0.4.0-cp36-cp36m-win_amd64.whl
torch-0.4.1-cp27-cp27m-linux_x86_64.whl
torch-0.4.1-cp27-cp27mu-linux_x86_64.whl
torch-0.4.1-cp35-cp35m-linux_x86_64.whl
torch-0.4.1-cp35-cp35m-win_amd64.whl
torch-0.4.1-cp36-cp36m-linux_x86_64.whl
torch-0.4.1-cp36-cp36m-win_amd64.whl
torch-0.4.1-cp37-cp37m-linux_x86_64.whl
torch-0.4.1-cp37-cp37m-win_amd64.whl
torch-0.4.1.post2-cp37-cp37m-linux_x86_64.whl
torch-1.0.0-cp27-cp27m-linux_x86_64.whl
torch-1.0.0-cp27-cp27mu-linux_x86_64.whl
torch-1.0.0-cp35-cp35m-linux_x86_64.whl
torch-1.0.0-cp35-cp35m-win_amd64.whl
torch-1.0.0-cp36-cp36m-linux_x86_64.whl
torch-1.0.0-cp36-cp36m-win_amd64.whl
torch-1.0.0-cp37-cp37m-linux_x86_64.whl
torch-1.0.0-cp37-cp37m-win_amd64.whl
torch-1.0.1-cp27-cp27m-linux_x86_64.whl
torch-1.0.1-cp27-cp27mu-linux_x86_64.whl
torch-1.0.1-cp35-cp35m-linux_x86_64.whl
torch-1.0.1-cp35-cp35m-win_amd64.whl
torch-1.0.1-cp36-cp36m-linux_x86_64.whl
torch-1.0.1-cp36-cp36m-win_amd64.whl
torch-1.0.1-cp37-cp37m-linux_x86_64.whl
torch-1.0.1-cp37-cp37m-win_amd64.whl
torch-1.0.1.post2-cp27-cp27m-linux_x86_64.whl
torch-1.0.1.post2-cp27-cp27mu-linux_x86_64.whl
torch-1.0.1.post2-cp35-cp35m-linux_x86_64.whl
torch-1.0.1.post2-cp36-cp36m-linux_x86_64.whl
torch-1.0.1.post2-cp37-cp37m-linux_x86_64.whl
torch-0.3.0-cp27-cp27m-linux_x86_64.whl
torch-0.3.0-cp27-cp27mu-linux_x86_64.whl
torch-0.3.0-cp35-cp35m-linux_x86_64.whl
torch-0.3.0-cp36-cp36m-linux_x86_64.whl
torch-0.3.0.post2-cp27-cp27m-linux_x86_64.whl
torch-0.3.0.post2-cp27-cp27mu-linux_x86_64.whl
torch-0.3.0.post2-cp35-cp35m-linux_x86_64.whl
torch-0.3.0.post2-cp36-cp36m-linux_x86_64.whl
torch-0.3.0.post3-cp27-cp27m-linux_x86_64.whl
torch-0.3.0.post3-cp27-cp27mu-linux_x86_64.whl
torch-0.3.0.post3-cp35-cp35m-linux_x86_64.whl
torch-0.3.0.post3-cp36-cp36m-linux_x86_64.whl
torch-0.3.0.post4-cp27-cp27m-linux_x86_64.whl
torch-0.3.0.post4-cp27-cp27mu-linux_x86_64.whl
torch-0.3.0.post4-cp35-cp35m-linux_x86_64.whl
torch-0.3.0.post4-cp36-cp36m-linux_x86_64.whl
torch-0.3.1-cp27-cp27m-linux_x86_64.whl
torch-0.3.1-cp27-cp27mu-linux_x86_64.whl
torch-0.3.1-cp35-cp35m-linux_x86_64.whl
torch-0.3.1-cp36-cp36m-linux_x86_64.whl
torch-0.4.0-cp27-cp27m-linux_x86_64.whl
torch-0.4.0-cp27-cp27mu-linux_x86_64.whl
torch-0.4.0-cp35-cp35m-linux_x86_64.whl
torch-0.4.0-cp35-cp35m-win_amd64.whl
torch-0.4.0-cp36-cp36m-linux_x86_64.whl
torch-0.4.0-cp36-cp36m-win_amd64.whl
torch-0.4.1-cp27-cp27m-linux_x86_64.whl
torch-0.4.1-cp27-cp27mu-linux_x86_64.whl
torch-0.4.1-cp35-cp35m-linux_x86_64.whl
torch-0.4.1-cp35-cp35m-win_amd64.whl
torch-0.4.1-cp36-cp36m-linux_x86_64.whl
torch-0.4.1-cp36-cp36m-win_amd64.whl
torch-0.4.1-cp37-cp37m-linux_x86_64.whl
torch-0.4.1-cp37-cp37m-win_amd64.whl
torch-0.4.1.post2-cp37-cp37m-linux_x86_64.whl
torch-1.0.0-cp27-cp27m-linux_x86_64.whl
torch-1.0.0-cp27-cp27mu-linux_x86_64.whl
torch-1.0.0-cp35-cp35m-linux_x86_64.whl
torch-1.0.0-cp35-cp35m-win_amd64.whl
torch-1.0.0-cp36-cp36m-linux_x86_64.whl
torch-1.0.0-cp36-cp36m-win_amd64.whl
torch-1.0.0-cp37-cp37m-linux_x86_64.whl
torch-1.0.0-cp37-cp37m-win_amd64.whl
torch-1.0.1-cp27-cp27m-linux_x86_64.whl
torch-1.0.1-cp27-cp27mu-linux_x86_64.whl
torch-1.0.1-cp35-cp35m-linux_x86_64.whl
torch-1.0.1-cp35-cp35m-win_amd64.whl
torch-1.0.1-cp36-cp36m-linux_x86_64.whl
torch-1.0.1-cp36-cp36m-win_amd64.whl
torch-1.0.1-cp37-cp37m-linux_x86_64.whl
torch-1.0.1-cp37-cp37m-win_amd64.whl
torch-1.0.1.post2-cp27-cp27m-linux_x86_64.whl
torch-1.0.1.post2-cp27-cp27mu-linux_x86_64.whl
torch-1.0.1.post2-cp35-cp35m-linux_x86_64.whl
torch-1.0.1.post2-cp36-cp36m-linux_x86_64.whl
torch-1.0.1.post2-cp37-cp37m-linux_x86_64.whl
torch-1.1.0-cp27-cp27m-linux_x86_64.whl
torch-1.1.0-cp27-cp27mu-linux_x86_64.whl
torch-1.1.0-cp35-cp35m-linux_x86_64.whl
torch-1.1.0-cp35-cp35m-win_amd64.whl
torch-1.1.0-cp36-cp36m-linux_x86_64.whl
torch-1.1.0-cp36-cp36m-win_amd64.whl
torch-1.1.0-cp37-cp37m-linux_x86_64.whl
torch-1.1.0-cp37-cp37m-win_amd64.whl
torch-0.3.1-cp27-cp27m-linux_x86_64.whl
torch-0.3.1-cp27-cp27mu-linux_x86_64.whl
torch-0.3.1-cp35-cp35m-linux_x86_64.whl
torch-0.3.1-cp36-cp36m-linux_x86_64.whl
torch-0.4.0-cp27-cp27m-linux_x86_64.whl
torch-0.4.0-cp27-cp27mu-linux_x86_64.whl
torch-0.4.0-cp35-cp35m-linux_x86_64.whl
torch-0.4.0-cp35-cp35m-win_amd64.whl
torch-0.4.0-cp36-cp36m-linux_x86_64.whl
torch-0.4.0-cp36-cp36m-win_amd64.whl
torch-0.4.1-cp27-cp27m-linux_x86_64.whl
torch-0.4.1-cp27-cp27mu-linux_x86_64.whl
torch-0.4.1-cp35-cp35m-linux_x86_64.whl
torch-0.4.1-cp35-cp35m-win_amd64.whl
torch-0.4.1-cp36-cp36m-linux_x86_64.whl
torch-0.4.1-cp36-cp36m-win_amd64.whl
torch-0.4.1-cp37-cp37m-linux_x86_64.whl
torch-0.4.1-cp37-cp37m-win_amd64.whl
torch-0.4.1.post2-cp37-cp37m-linux_x86_64.whl
torch-1.2.0+cu92-cp27-cp27m-manylinux1_x86_64.whl
torch-1.2.0+cu92-cp27-cp27mu-manylinux1_x86_64.whl
torch-1.2.0+cu92-cp35-cp35m-manylinux1_x86_64.whl
torch-1.2.0+cu92-cp35-cp35m-win_amd64.whl
torch-1.2.0+cu92-cp36-cp36m-manylinux1_x86_64.whl
torch-1.2.0+cu92-cp36-cp36m-win_amd64.whl
torch-1.2.0+cu92-cp37-cp37m-manylinux1_x86_64.whl
torch-1.2.0+cu92-cp37-cp37m-win_amd64.whl
torch-1.3.0+cu92-cp27-cp27m-linux_x86_64.whl
torch-1.3.0+cu92-cp27-cp27mu-linux_x86_64.whl
torch-1.3.0+cu92-cp35-cp35m-linux_x86_64.whl
torch-1.3.0+cu92-cp35-cp35m-win_amd64.whl
torch-1.3.0+cu92-cp36-cp36m-linux_x86_64.whl
torch-1.3.0+cu92-cp36-cp36m-win_amd64.whl
torch-1.3.0+cu92-cp37-cp37m-linux_x86_64.whl
torch-1.3.0+cu92-cp37-cp37m-win_amd64.whl
torch-1.3.1+cu92-cp27-cp27m-linux_x86_64.whl
torch-1.3.1+cu92-cp27-cp27mu-linux_x86_64.whl
torch-1.3.1+cu92-cp35-cp35m-linux_x86_64.whl
torch-1.3.1+cu92-cp35-cp35m-win_amd64.whl
torch-1.3.1+cu92-cp36-cp36m-linux_x86_64.whl
torch-1.3.1+cu92-cp36-cp36m-win_amd64.whl
torch-1.3.1+cu92-cp37-cp37m-linux_x86_64.whl
torch-1.3.1+cu92-cp37-cp37m-win_amd64.whl
torch-1.4.0+cu92-cp27-cp27m-linux_x86_64.whl
torch-1.4.0+cu92-cp27-cp27mu-linux_x86_64.whl
torch-1.4.0+cu92-cp35-cp35m-linux_x86_64.whl
torch-1.4.0+cu92-cp35-cp35m-win_amd64.whl
torch-1.4.0+cu92-cp36-cp36m-linux_x86_64.whl
torch-1.4.0+cu92-cp36-cp36m-win_amd64.whl
torch-1.4.0+cu92-cp37-cp37m-linux_x86_64.whl
torch-1.4.0+cu92-cp37-cp37m-win_amd64.whl
torch-1.4.0+cu92-cp38-cp38-linux_x86_64.whl
torch-1.4.0+cu92-cp38-cp38-win_amd64.whl
torch-1.5.0+cu92-cp27-cp27m-linux_x86_64.whl
torch-1.5.0+cu92-cp27-cp27mu-linux_x86_64.whl
torch-1.5.0+cu92-cp35-cp35m-linux_x86_64.whl
torch-1.5.0+cu92-cp35-cp35m-win_amd64.whl
torch-1.5.0+cu92-cp36-cp36m-linux_x86_64.whl
torch-1.5.0+cu92-cp36-cp36m-win_amd64.whl
torch-1.5.0+cu92-cp37-cp37m-linux_x86_64.whl
torch-1.5.0+cu92-cp37-cp37m-win_amd64.whl
torch-1.5.0+cu92-cp38-cp38-linux_x86_64.whl
torch-1.5.0+cu92-cp38-cp38-win_amd64.whl
torch-1.5.1+cu92-cp35-cp35m-linux_x86_64.whl
torch-1.5.1+cu92-cp35-cp35m-win_amd64.whl
torch-1.5.1+cu92-cp36-cp36m-linux_x86_64.whl
torch-1.5.1+cu92-cp36-cp36m-win_amd64.whl
torch-1.5.1+cu92-cp37-cp37m-linux_x86_64.whl
torch-1.5.1+cu92-cp37-cp37m-win_amd64.whl
torch-1.5.1+cu92-cp38-cp38-linux_x86_64.whl
torch-1.5.1+cu92-cp38-cp38-win_amd64.whl
torch-1.6.0+cu92-cp36-cp36m-linux_x86_64.whl
torch-1.6.0+cu92-cp37-cp37m-linux_x86_64.whl
torch-1.6.0+cu92-cp38-cp38-linux_x86_64.whl
torch-1.7.0+cu92-cp36-cp36m-linux_x86_64.whl
torch-1.7.0+cu92-cp37-cp37m-linux_x86_64.whl
torch-1.7.0+cu92-cp38-cp38-linux_x86_64.whl
torch-1.7.1+cu92-cp36-cp36m-linux_x86_64.whl
torch-1.7.1+cu92-cp37-cp37m-linux_x86_64.whl
torch-1.7.1+cu92-cp38-cp38-linux_x86_64.whl
torch-1.7.1+cu92-cp39-cp39-linux_x86_64.whl
torch-1.8.0+rocm3.10-cp36-cp36m-linux_x86_64.whl
torch-1.8.0+rocm3.10-cp37-cp37m-linux_x86_64.whl
torch-1.8.0+rocm3.10-cp38-cp38-linux_x86_64.whl
torch-1.8.0+rocm3.10-cp39-cp39-linux_x86_64.whl
torch-1.8.1+rocm3.10-cp36-cp36m-linux_x86_64.whl
torch-1.8.1+rocm3.10-cp37-cp37m-linux_x86_64.whl
torch-1.8.1+rocm3.10-cp38-cp38-linux_x86_64.whl
torch-1.8.1+rocm3.10-cp39-cp39-linux_x86_64.whl
torch-1.7.1+rocm3.7-cp36-cp36m-linux_x86_64.whl
torch-1.7.1+rocm3.7-cp37-cp37m-linux_x86_64.whl
torch-1.7.1+rocm3.7-cp38-cp38-linux_x86_64.whl
torch-1.7.1+rocm3.7-cp39-cp39-linux_x86_64.whl
torch-1.7.1+rocm3.8-cp36-cp36m-linux_x86_64.whl
torch-1.7.1+rocm3.8-cp37-cp37m-linux_x86_64.whl
torch-1.7.1+rocm3.8-cp38-cp38-linux_x86_64.whl
torch-1.7.1+rocm3.8-cp39-cp39-linux_x86_64.whl
torch-1.10.0+rocm4.0.1-cp36-cp36m-linux_x86_64.whl
torch-1.10.0+rocm4.0.1-cp37-cp37m-linux_x86_64.whl
torch-1.10.0+rocm4.0.1-cp38-cp38-linux_x86_64.whl
torch-1.10.0+rocm4.0.1-cp39-cp39-linux_x86_64.whl
torch-1.10.1+rocm4.0.1-cp36-cp36m-linux_x86_64.whl
torch-1.10.1+rocm4.0.1-cp37-cp37m-linux_x86_64.whl
torch-1.10.1+rocm4.0.1-cp38-cp38-linux_x86_64.whl
torch-1.10.1+rocm4.0.1-cp39-cp39-linux_x86_64.whl
torch-1.10.2+rocm4.0.1-cp36-cp36m-linux_x86_64.whl
torch-1.10.2+rocm4.0.1-cp37-cp37m-linux_x86_64.whl
torch-1.10.2+rocm4.0.1-cp38-cp38-linux_x86_64.whl
torch-1.10.2+rocm4.0.1-cp39-cp39-linux_x86_64.whl
torch-1.8.0+rocm4.0.1-cp36-cp36m-linux_x86_64.whl
torch-1.8.0+rocm4.0.1-cp37-cp37m-linux_x86_64.whl
torch-1.8.0+rocm4.0.1-cp38-cp38-linux_x86_64.whl
torch-1.8.0+rocm4.0.1-cp39-cp39-linux_x86_64.whl
torch-1.8.1+rocm4.0.1-cp36-cp36m-linux_x86_64.whl
torch-1.8.1+rocm4.0.1-cp37-cp37m-linux_x86_64.whl
torch-1.8.1+rocm4.0.1-cp38-cp38-linux_x86_64.whl
torch-1.8.1+rocm4.0.1-cp39-cp39-linux_x86_64.whl
torch-1.9.0+rocm4.0.1-cp36-cp36m-linux_x86_64.whl
torch-1.9.0+rocm4.0.1-cp37-cp37m-linux_x86_64.whl
torch-1.9.0+rocm4.0.1-cp38-cp38-linux_x86_64.whl
torch-1.9.0+rocm4.0.1-cp39-cp39-linux_x86_64.whl
torch-1.9.1+rocm4.0.1-cp36-cp36m-linux_x86_64.whl
torch-1.9.1+rocm4.0.1-cp37-cp37m-linux_x86_64.whl
torch-1.9.1+rocm4.0.1-cp38-cp38-linux_x86_64.whl
torch-1.9.1+rocm4.0.1-cp39-cp39-linux_x86_64.whl
torch-1.10.0+rocm4.1-cp36-cp36m-linux_x86_64.whl
torch-1.10.0+rocm4.1-cp37-cp37m-linux_x86_64.whl
torch-1.10.0+rocm4.1-cp38-cp38-linux_x86_64.whl
torch-1.10.0+rocm4.1-cp39-cp39-linux_x86_64.whl
torch-1.10.1+rocm4.1-cp36-cp36m-linux_x86_64.whl
torch-1.10.1+rocm4.1-cp37-cp37m-linux_x86_64.whl
torch-1.10.1+rocm4.1-cp38-cp38-linux_x86_64.whl
torch-1.10.1+rocm4.1-cp39-cp39-linux_x86_64.whl
torch-1.10.2+rocm4.1-cp36-cp36m-linux_x86_64.whl
torch-1.10.2+rocm4.1-cp37-cp37m-linux_x86_64.whl
torch-1.10.2+rocm4.1-cp38-cp38-linux_x86_64.whl
torch-1.10.2+rocm4.1-cp39-cp39-linux_x86_64.whl
torch-1.9.0+rocm4.1-cp36-cp36m-linux_x86_64.whl
torch-1.9.0+rocm4.1-cp37-cp37m-linux_x86_64.whl
torch-1.9.0+rocm4.1-cp38-cp38-linux_x86_64.whl
torch-1.9.0+rocm4.1-cp39-cp39-linux_x86_64.whl
torch-1.9.1+rocm4.1-cp36-cp36m-linux_x86_64.whl
torch-1.9.1+rocm4.1-cp37-cp37m-linux_x86_64.whl
torch-1.9.1+rocm4.1-cp38-cp38-linux_x86_64.whl
torch-1.9.1+rocm4.1-cp39-cp39-linux_x86_64.whl
torch-1.10.0+rocm4.2-cp36-cp36m-linux_x86_64.whl
torch-1.10.0+rocm4.2-cp37-cp37m-linux_x86_64.whl
torch-1.10.0+rocm4.2-cp38-cp38-linux_x86_64.whl
torch-1.10.0+rocm4.2-cp39-cp39-linux_x86_64.whl
torch-1.10.1+rocm4.2-cp36-cp36m-linux_x86_64.whl
torch-1.10.1+rocm4.2-cp37-cp37m-linux_x86_64.whl
torch-1.10.1+rocm4.2-cp38-cp38-linux_x86_64.whl
torch-1.10.1+rocm4.2-cp39-cp39-linux_x86_64.whl
torch-1.10.2+rocm4.2-cp36-cp36m-linux_x86_64.whl
torch-1.10.2+rocm4.2-cp37-cp37m-linux_x86_64.whl
torch-1.10.2+rocm4.2-cp38-cp38-linux_x86_64.whl
torch-1.10.2+rocm4.2-cp39-cp39-linux_x86_64.whl
torch-1.9.0+rocm4.2-cp36-cp36m-linux_x86_64.whl
torch-1.9.0+rocm4.2-cp37-cp37m-linux_x86_64.whl
torch-1.9.0+rocm4.2-cp38-cp38-linux_x86_64.whl
torch-1.9.0+rocm4.2-cp39-cp39-linux_x86_64.whl
torch-1.9.1+rocm4.2-cp36-cp36m-linux_x86_64.whl
torch-1.9.1+rocm4.2-cp37-cp37m-linux_x86_64.whl
torch-1.9.1+rocm4.2-cp38-cp38-linux_x86_64.whl
torch-1.9.1+rocm4.2-cp39-cp39-linux_x86_64.whl
torch-1.11.0+rocm4.3.1-cp310-cp310-linux_x86_64.whl
torch-1.11.0+rocm4.3.1-cp37-cp37m-linux_x86_64.whl
torch-1.11.0+rocm4.3.1-cp38-cp38-linux_x86_64.whl
torch-1.11.0+rocm4.3.1-cp39-cp39-linux_x86_64.whl
torch-1.11.0+rocm4.5.2-cp310-cp310-linux_x86_64.whl
torch-1.11.0+rocm4.5.2-cp37-cp37m-linux_x86_64.whl
torch-1.11.0+rocm4.5.2-cp38-cp38-linux_x86_64.whl
torch-1.11.0+rocm4.5.2-cp39-cp39-linux_x86_64.whl
torch-1.12.0+rocm5.0-cp310-cp310-linux_x86_64.whl
torch-1.12.0+rocm5.0-cp37-cp37m-linux_x86_64.whl
torch-1.12.0+rocm5.0-cp38-cp38-linux_x86_64.whl
torch-1.12.0+rocm5.0-cp39-cp39-linux_x86_64.whl
torch-1.12.1+rocm5.0-cp310-cp310-linux_x86_64.whl
torch-1.12.1+rocm5.0-cp37-cp37m-linux_x86_64.whl
torch-1.12.1+rocm5.0-cp38-cp38-linux_x86_64.whl
torch-1.12.1+rocm5.0-cp39-cp39-linux_x86_64.whl
torch-1.12.0+rocm5.1.1-cp310-cp310-linux_x86_64.whl
torch-1.12.0+rocm5.1.1-cp37-cp37m-linux_x86_64.whl
torch-1.12.0+rocm5.1.1-cp38-cp38-linux_x86_64.whl
torch-1.12.0+rocm5.1.1-cp39-cp39-linux_x86_64.whl
torch-1.12.1+rocm5.1.1-cp310-cp310-linux_x86_64.whl
torch-1.12.1+rocm5.1.1-cp37-cp37m-linux_x86_64.whl
torch-1.12.1+rocm5.1.1-cp38-cp38-linux_x86_64.whl
torch-1.12.1+rocm5.1.1-cp39-cp39-linux_x86_64.whl
torch-1.13.0+rocm5.1.1-cp310-cp310-linux_x86_64.whl
torch-1.13.0+rocm5.1.1-cp37-cp37m-linux_x86_64.whl
torch-1.13.0+rocm5.1.1-cp38-cp38-linux_x86_64.whl
torch-1.13.0+rocm5.1.1-cp39-cp39-linux_x86_64.whl
torch-1.13.1+rocm5.1.1-cp310-cp310-linux_x86_64.whl
torch-1.13.1+rocm5.1.1-cp37-cp37m-linux_x86_64.whl
torch-1.13.1+rocm5.1.1-cp38-cp38-linux_x86_64.whl
torch-1.13.1+rocm5.1.1-cp39-cp39-linux_x86_64.whl
torch-1.13.0+rocm5.2-cp310-cp310-linux_x86_64.whl
torch-1.13.0+rocm5.2-cp37-cp37m-linux_x86_64.whl
torch-1.13.0+rocm5.2-cp38-cp38-linux_x86_64.whl
torch-1.13.0+rocm5.2-cp39-cp39-linux_x86_64.whl
torch-1.13.1+rocm5.2-cp310-cp310-linux_x86_64.whl
torch-1.13.1+rocm5.2-cp37-cp37m-linux_x86_64.whl
torch-1.13.1+rocm5.2-cp38-cp38-linux_x86_64.whl
torch-1.13.1+rocm5.2-cp39-cp39-linux_x86_64.whl
torch-2.0.0+rocm5.3-cp310-cp310-linux_x86_64.whl
torch-2.0.0+rocm5.3-cp38-cp38-linux_x86_64.whl
torch-2.0.0+rocm5.3-cp39-cp39-linux_x86_64.whl
torch-2.0.1+rocm5.3-cp310-cp310-linux_x86_64.whl
torch-2.0.1+rocm5.3-cp311-cp311-linux_x86_64.whl
torch-2.0.1+rocm5.3-cp38-cp38-linux_x86_64.whl
torch-2.0.1+rocm5.3-cp39-cp39-linux_x86_64.whl
torch-2.0.0+rocm5.4.2-cp310-cp310-linux_x86_64.whl
torch-2.0.0+rocm5.4.2-cp38-cp38-linux_x86_64.whl
torch-2.0.0+rocm5.4.2-cp39-cp39-linux_x86_64.whl
torch-2.0.1+rocm5.4.2-cp310-cp310-linux_x86_64.whl
torch-2.0.1+rocm5.4.2-cp311-cp311-linux_x86_64.whl
torch-2.0.1+rocm5.4.2-cp38-cp38-linux_x86_64.whl
torch-2.0.1+rocm5.4.2-cp39-cp39-linux_x86_64.whl
torch-2.1.0+rocm5.5-cp310-cp310-linux_x86_64.whl
torch-2.1.0+rocm5.5-cp311-cp311-linux_x86_64.whl
torch-2.1.0+rocm5.5-cp38-cp38-linux_x86_64.whl
torch-2.1.0+rocm5.5-cp39-cp39-linux_x86_64.whl
torch-2.1.1+rocm5.5-cp310-cp310-linux_x86_64.whl
torch-2.1.1+rocm5.5-cp311-cp311-linux_x86_64.whl
torch-2.1.1+rocm5.5-cp38-cp38-linux_x86_64.whl
torch-2.1.1+rocm5.5-cp39-cp39-linux_x86_64.whl
torch-2.1.2+rocm5.5-cp310-cp310-linux_x86_64.whl
torch-2.1.2+rocm5.5-cp311-cp311-linux_x86_64.whl
torch-2.1.2+rocm5.5-cp38-cp38-linux_x86_64.whl
torch-2.1.2+rocm5.5-cp39-cp39-linux_x86_64.whl
torch-2.1.0+rocm5.6-cp310-cp310-linux_x86_64.whl
torch-2.1.0+rocm5.6-cp311-cp311-linux_x86_64.whl
torch-2.1.0+rocm5.6-cp38-cp38-linux_x86_64.whl
torch-2.1.0+rocm5.6-cp39-cp39-linux_x86_64.whl
torch-2.1.1+rocm5.6-cp310-cp310-linux_x86_64.whl
torch-2.1.1+rocm5.6-cp311-cp311-linux_x86_64.whl
torch-2.1.1+rocm5.6-cp38-cp38-linux_x86_64.whl
torch-2.1.1+rocm5.6-cp39-cp39-linux_x86_64.whl
torch-2.1.2+rocm5.6-cp310-cp310-linux_x86_64.whl
torch-2.1.2+rocm5.6-cp311-cp311-linux_x86_64.whl
torch-2.1.2+rocm5.6-cp38-cp38-linux_x86_64.whl
torch-2.1.2+rocm5.6-cp39-cp39-linux_x86_64.whl
torch-2.2.0+rocm5.6-cp310-cp310-linux_x86_64.whl
torch-2.2.0+rocm5.6-cp311-cp311-linux_x86_64.whl
torch-2.2.0+rocm5.6-cp312-cp312-linux_x86_64.whl
torch-2.2.0+rocm5.6-cp38-cp38-linux_x86_64.whl
torch-2.2.0+rocm5.6-cp39-cp39-linux_x86_64.whl
torch-2.2.1+rocm5.6-cp310-cp310-linux_x86_64.whl
torch-2.2.1+rocm5.6-cp311-cp311-linux_x86_64.whl
torch-2.2.1+rocm5.6-cp312-cp312-linux_x86_64.whl
torch-2.2.1+rocm5.6-cp38-cp38-linux_x86_64.whl
torch-2.2.1+rocm5.6-cp39-cp39-linux_x86_64.whl
torch-2.2.2+rocm5.6-cp310-cp310-linux_x86_64.whl
torch-2.2.2+rocm5.6-cp311-cp311-linux_x86_64.whl
torch-2.2.2+rocm5.6-cp312-cp312-linux_x86_64.whl
torch-2.2.2+rocm5.6-cp38-cp38-linux_x86_64.whl
torch-2.2.2+rocm5.6-cp39-cp39-linux_x86_64.whl
torch-2.2.0+rocm5.7-cp310-cp310-linux_x86_64.whl
torch-2.2.0+rocm5.7-cp311-cp311-linux_x86_64.whl
torch-2.2.0+rocm5.7-cp312-cp312-linux_x86_64.whl
torch-2.2.0+rocm5.7-cp38-cp38-linux_x86_64.whl
torch-2.2.0+rocm5.7-cp39-cp39-linux_x86_64.whl
torch-2.2.1+rocm5.7-cp310-cp310-linux_x86_64.whl
torch-2.2.1+rocm5.7-cp311-cp311-linux_x86_64.whl
torch-2.2.1+rocm5.7-cp312-cp312-linux_x86_64.whl
torch-2.2.1+rocm5.7-cp38-cp38-linux_x86_64.whl
torch-2.2.1+rocm5.7-cp39-cp39-linux_x86_64.whl
torch-2.2.2+rocm5.7-cp310-cp310-linux_x86_64.whl
torch-2.2.2+rocm5.7-cp311-cp311-linux_x86_64.whl
torch-2.2.2+rocm5.7-cp312-cp312-linux_x86_64.whl
torch-2.2.2+rocm5.7-cp38-cp38-linux_x86_64.whl
torch-2.2.2+rocm5.7-cp39-cp39-linux_x86_64.whl
torch-2.3.0+rocm5.7-cp310-cp310-linux_x86_64.whl
torch-2.3.0+rocm5.7-cp311-cp311-linux_x86_64.whl
torch-2.3.0+rocm5.7-cp312-cp312-linux_x86_64.whl
torch-2.3.0+rocm5.7-cp38-cp38-linux_x86_64.whl
torch-2.3.0+rocm5.7-cp39-cp39-linux_x86_64.whl
torch-2.3.1+rocm5.7-cp310-cp310-linux_x86_64.whl
torch-2.3.1+rocm5.7-cp311-cp311-linux_x86_64.whl
torch-2.3.1+rocm5.7-cp312-cp312-linux_x86_64.whl
torch-2.3.1+rocm5.7-cp38-cp38-linux_x86_64.whl
torch-2.3.1+rocm5.7-cp39-cp39-linux_x86_64.whl
torch-2.3.0+rocm6.0-cp310-cp310-linux_x86_64.whl
torch-2.3.0+rocm6.0-cp311-cp311-linux_x86_64.whl
torch-2.3.0+rocm6.0-cp312-cp312-linux_x86_64.whl
torch-2.3.0+rocm6.0-cp38-cp38-linux_x86_64.whl
torch-2.3.0+rocm6.0-cp39-cp39-linux_x86_64.whl
torch-2.3.1+rocm6.0-cp310-cp310-linux_x86_64.whl
torch-2.3.1+rocm6.0-cp311-cp311-linux_x86_64.whl
torch-2.3.1+rocm6.0-cp312-cp312-linux_x86_64.whl
torch-2.3.1+rocm6.0-cp38-cp38-linux_x86_64.whl
torch-2.3.1+rocm6.0-cp39-cp39-linux_x86_64.whl
torch-2.4.0+rocm6.0-cp310-cp310-linux_x86_64.whl
torch-2.4.0+rocm6.0-cp311-cp311-linux_x86_64.whl
torch-2.4.0+rocm6.0-cp312-cp312-linux_x86_64.whl
torch-2.4.0+rocm6.0-cp38-cp38-linux_x86_64.whl
torch-2.4.0+rocm6.0-cp39-cp39-linux_x86_64.whl
torch-2.4.1+rocm6.0-cp310-cp310-linux_x86_64.whl
torch-2.4.1+rocm6.0-cp311-cp311-linux_x86_64.whl
torch-2.4.1+rocm6.0-cp312-cp312-linux_x86_64.whl
torch-2.4.1+rocm6.0-cp38-cp38-linux_x86_64.whl
torch-2.4.1+rocm6.0-cp39-cp39-linux_x86_64.whl
torch-2.4.0+rocm6.1-cp310-cp310-linux_x86_64.whl
torch-2.4.0+rocm6.1-cp311-cp311-linux_x86_64.whl
torch-2.4.0+rocm6.1-cp312-cp312-linux_x86_64.whl
torch-2.4.0+rocm6.1-cp38-cp38-linux_x86_64.whl
torch-2.4.0+rocm6.1-cp39-cp39-linux_x86_64.whl
torch-2.4.1+rocm6.1-cp310-cp310-linux_x86_64.whl
torch-2.4.1+rocm6.1-cp311-cp311-linux_x86_64.whl
torch-2.4.1+rocm6.1-cp312-cp312-linux_x86_64.whl
torch-2.4.1+rocm6.1-cp38-cp38-linux_x86_64.whl
torch-2.4.1+rocm6.1-cp39-cp39-linux_x86_64.whl
torch-2.5.0+rocm6.1-cp310-cp310-linux_x86_64.whl
torch-2.5.0+rocm6.1-cp311-cp311-linux_x86_64.whl
torch-2.5.0+rocm6.1-cp312-cp312-linux_x86_64.whl
torch-2.5.0+rocm6.1-cp39-cp39-linux_x86_64.whl
torch-2.5.1+rocm6.1-cp310-cp310-linux_x86_64.whl
torch-2.5.1+rocm6.1-cp311-cp311-linux_x86_64.whl
torch-2.5.1+rocm6.1-cp312-cp312-linux_x86_64.whl
torch-2.5.1+rocm6.1-cp39-cp39-linux_x86_64.whl
torch-2.6.0+rocm6.1-cp310-cp310-manylinux_2_28_x86_64.whl
torch-2.6.0+rocm6.1-cp311-cp311-manylinux_2_28_x86_64.whl
torch-2.6.0+rocm6.1-cp312-cp312-manylinux_2_28_x86_64.whl
torch-2.6.0+rocm6.1-cp313-cp313-manylinux_2_28_x86_64.whl
torch-2.6.0+rocm6.1-cp313-cp313t-manylinux_2_28_x86_64.whl
torch-2.6.0+rocm6.1-cp39-cp39-manylinux_2_28_x86_64.whl
torch-2.6.0+rocm6.2.4-cp310-cp310-manylinux_2_28_x86_64.whl
torch-2.6.0+rocm6.2.4-cp311-cp311-manylinux_2_28_x86_64.whl
torch-2.6.0+rocm6.2.4-cp312-cp312-manylinux_2_28_x86_64.whl
torch-2.6.0+rocm6.2.4-cp313-cp313-manylinux_2_28_x86_64.whl
torch-2.6.0+rocm6.2.4-cp313-cp313t-manylinux_2_28_x86_64.whl
torch-2.6.0+rocm6.2.4-cp39-cp39-manylinux_2_28_x86_64.whl
torch-2.7.0+rocm6.2.4-cp310-cp310-manylinux_2_28_x86_64.whl
torch-2.7.0+rocm6.2.4-cp311-cp311-manylinux_2_28_x86_64.whl
torch-2.7.0+rocm6.2.4-cp312-cp312-manylinux_2_28_x86_64.whl
torch-2.7.0+rocm6.2.4-cp313-cp313-manylinux_2_28_x86_64.whl
torch-2.7.0+rocm6.2.4-cp313-cp313t-manylinux_2_28_x86_64.whl
torch-2.7.0+rocm6.2.4-cp39-cp39-manylinux_2_28_x86_64.whl
torch-2.7.1+rocm6.2.4-cp310-cp310-manylinux_2_28_x86_64.whl
torch-2.7.1+rocm6.2.4-cp311-cp311-manylinux_2_28_x86_64.whl
torch-2.7.1+rocm6.2.4-cp312-cp312-manylinux_2_28_x86_64.whl
torch-2.7.1+rocm6.2.4-cp313-cp313-manylinux_2_28_x86_64.whl
torch-2.7.1+rocm6.2.4-cp313-cp313t-manylinux_2_28_x86_64.whl
torch-2.7.1+rocm6.2.4-cp39-cp39-manylinux_2_28_x86_64.whl
torch-2.5.0+rocm6.2-cp310-cp310-linux_x86_64.whl
torch-2.5.0+rocm6.2-cp311-cp311-linux_x86_64.whl
torch-2.5.0+rocm6.2-cp312-cp312-linux_x86_64.whl
torch-2.5.0+rocm6.2-cp39-cp39-linux_x86_64.whl
torch-2.5.1+rocm6.2-cp310-cp310-linux_x86_64.whl
torch-2.5.1+rocm6.2-cp311-cp311-linux_x86_64.whl
torch-2.5.1+rocm6.2-cp312-cp312-linux_x86_64.whl
torch-2.5.1+rocm6.2-cp39-cp39-linux_x86_64.whl
torch-2.7.0+rocm6.3-cp310-cp310-manylinux_2_28_x86_64.whl
torch-2.7.0+rocm6.3-cp311-cp311-manylinux_2_28_x86_64.whl
torch-2.7.0+rocm6.3-cp312-cp312-manylinux_2_28_x86_64.whl
torch-2.7.0+rocm6.3-cp313-cp313-manylinux_2_28_x86_64.whl
torch-2.7.0+rocm6.3-cp313-cp313t-manylinux_2_28_x86_64.whl
torch-2.7.0+rocm6.3-cp39-cp39-manylinux_2_28_x86_64.whl
torch-2.7.1+rocm6.3-cp310-cp310-manylinux_2_28_x86_64.whl
torch-2.7.1+rocm6.3-cp311-cp311-manylinux_2_28_x86_64.whl
torch-2.7.1+rocm6.3-cp312-cp312-manylinux_2_28_x86_64.whl
torch-2.7.1+rocm6.3-cp313-cp313-manylinux_2_28_x86_64.whl
torch-2.7.1+rocm6.3-cp313-cp313t-manylinux_2_28_x86_64.whl
torch-2.7.1+rocm6.3-cp39-cp39-manylinux_2_28_x86_64.whl
torch-0.1-cp27-cp27m-macosx_10_6_x86_64.whl
torch-0.1-cp35-cp35m-macosx_10_6_x86_64.whl
torch-0.1.10.post1-cp27-none-macosx_10_7_x86_64.whl
torch-0.1.10.post1-cp35-cp35m-macosx_10_6_x86_64.whl
torch-0.1.10.post1-cp36-cp36m-macosx_10_7_x86_64.whl
torch-0.1.11.post4-cp27-none-macosx_10_7_x86_64.whl
torch-0.1.11.post4-cp35-cp35m-macosx_10_7_x86_64.whl
torch-0.1.11.post4-cp36-cp36m-macosx_10_7_x86_64.whl
torch-0.1.11.post5-cp27-none-macosx_10_7_x86_64.whl
torch-0.1.11.post5-cp35-cp35m-macosx_10_7_x86_64.whl
torch-0.1.11.post5-cp36-cp36m-macosx_10_7_x86_64.whl
torch-0.1.12.post1-cp27-none-macosx_10_7_x86_64.whl
torch-0.1.12.post1-cp35-cp35m-macosx_10_7_x86_64.whl
torch-0.1.12.post1-cp36-cp36m-macosx_10_7_x86_64.whl
torch-0.1.12.post2-cp27-none-macosx_10_7_x86_64.whl
torch-0.1.12.post2-cp35-cp35m-macosx_10_7_x86_64.whl
torch-0.1.12.post2-cp36-cp36m-macosx_10_7_x86_64.whl
torch-0.1.6.post17-cp27-cp27mu-linux_x86_64.whl
torch-0.1.6.post17-cp35-cp35m-linux_x86_64.whl
torch-0.1.6.post20-cp27-cp27mu-linux_x86_64.whl
torch-0.1.6.post20-cp35-cp35m-linux_x86_64.whl
torch-0.1.6.post22-cp27-none-macosx_10_7_x86_64.whl
torch-0.1.6.post22-cp35-cp35m-macosx_10_6_x86_64.whl
torch-0.1.7.post2-cp27-none-macosx_10_7_x86_64.whl
torch-0.1.7.post2-cp35-cp35m-macosx_10_6_x86_64.whl
torch-0.1.7.post2-cp36-cp36m-macosx_10_7_x86_64.whl
torch-0.1.8.post1-cp27-none-macosx_10_7_x86_64.whl
torch-0.1.8.post1-cp35-cp35m-macosx_10_6_x86_64.whl
torch-0.1.8.post1-cp36-cp36m-macosx_10_7_x86_64.whl
torch-0.1.9.post1-cp27-none-macosx_10_7_x86_64.whl
torch-0.1.9.post1-cp35-cp35m-macosx_10_6_x86_64.whl
torch-0.1.9.post1-cp36-cp36m-macosx_10_7_x86_64.whl
torch-0.1.9.post2-cp27-none-macosx_10_7_x86_64.whl
torch-0.1.9.post2-cp35-cp35m-macosx_10_6_x86_64.whl
torch-0.1.9.post2-cp36-cp36m-macosx_10_7_x86_64.whl
torch-0.2.0.post1-cp27-none-macosx_10_7_x86_64.whl
torch-0.2.0.post1-cp35-cp35m-macosx_10_7_x86_64.whl
torch-0.2.0.post1-cp36-cp36m-macosx_10_7_x86_64.whl
torch-0.2.0.post2-cp27-none-macosx_10_7_x86_64.whl
torch-0.2.0.post2-cp35-cp35m-macosx_10_7_x86_64.whl
torch-0.2.0.post2-cp36-cp36m-macosx_10_7_x86_64.whl
torch-0.2.0.post3-cp27-none-macosx_10_7_x86_64.whl
torch-0.2.0.post3-cp35-cp35m-macosx_10_7_x86_64.whl
torch-0.2.0.post3-cp36-cp36m-macosx_10_7_x86_64.whl
torch-0.3.0-cp27-none-macosx_10_6_x86_64.whl
torch-0.3.0-cp35-cp35m-macosx_10_6_x86_64.whl
torch-0.3.0-cp36-cp36m-macosx_10_7_x86_64.whl
torch-0.3.0.post2-cp27-none-macosx_10_6_x86_64.whl
torch-0.3.0.post2-cp35-cp35m-macosx_10_6_x86_64.whl
torch-0.3.0.post2-cp36-cp36m-macosx_10_7_x86_64.whl
torch-0.3.0.post3-cp27-none-macosx_10_6_x86_64.whl
torch-0.3.0.post3-cp35-cp35m-macosx_10_6_x86_64.whl
torch-0.3.0.post3-cp36-cp36m-macosx_10_7_x86_64.whl
torch-0.3.0.post4-cp27-none-macosx_10_6_x86_64.whl
torch-0.3.0.post4-cp35-cp35m-macosx_10_6_x86_64.whl
torch-0.3.0.post4-cp36-cp36m-macosx_10_7_x86_64.whl
torch-0.3.1-cp27-none-macosx_10_6_x86_64.whl
torch-0.3.1-cp35-cp35m-macosx_10_6_x86_64.whl
torch-0.3.1-cp36-cp36m-macosx_10_7_x86_64.whl
torch-0.4.0-cp27-none-macosx_10_6_x86_64.whl
torch-0.4.0-cp35-cp35m-macosx_10_6_x86_64.whl
torch-0.4.0-cp36-cp36m-macosx_10_7_x86_64.whl
torch-0.4.1-cp27-none-macosx_10_6_x86_64.whl
torch-0.4.1-cp35-cp35m-macosx_10_6_x86_64.whl
torch-0.4.1-cp36-cp36m-macosx_10_7_x86_64.whl
torch-0.4.1-cp37-cp37m-macosx_10_7_x86_64.whl
torch-1.11.0-cp310-cp310-manylinux2014_aarch64.whl
torch-1.11.0-cp37-cp37m-manylinux2014_aarch64.whl
torch-1.11.0-cp38-cp38-manylinux2014_aarch64.whl
torch-1.11.0-cp39-cp39-manylinux2014_aarch64.whl
torch-1.12.0-cp310-cp310-manylinux2014_aarch64.whl
torch-1.12.0-cp37-cp37m-manylinux2014_aarch64.whl
torch-1.12.0-cp38-cp38-manylinux2014_aarch64.whl
torch-1.12.0-cp39-cp39-manylinux2014_aarch64.whl
torch-1.12.1-cp310-cp310-manylinux2014_aarch64.whl
torch-1.12.1-cp37-cp37m-manylinux2014_aarch64.whl
torch-1.12.1-cp38-cp38-manylinux2014_aarch64.whl
torch-1.12.1-cp39-cp39-manylinux2014_aarch64.whl
torch-1.13.0-cp310-cp310-manylinux2014_aarch64.whl
torch-1.13.0-cp37-cp37m-manylinux2014_aarch64.whl
torch-1.13.0-cp38-cp38-manylinux2014_aarch64.whl
torch-1.13.0-cp39-cp39-manylinux2014_aarch64.whl
torch-1.13.1-cp310-cp310-manylinux2014_aarch64.whl
torch-1.13.1-cp37-cp37m-manylinux2014_aarch64.whl
torch-1.13.1-cp38-cp38-manylinux2014_aarch64.whl
torch-1.13.1-cp39-cp39-manylinux2014_aarch64.whl
torch-2.0.0-1-cp310-cp310-manylinux2014_aarch64.whl
torch-2.0.0-1-cp311-cp311-manylinux2014_aarch64.whl
torch-2.0.0-1-cp38-cp38-manylinux2014_aarch64.whl
torch-2.0.0-1-cp39-cp39-manylinux2014_aarch64.whl
torch-2.0.0-cp310-cp310-manylinux2014_aarch64.whl
torch-2.0.0-cp311-cp311-manylinux2014_aarch64.whl
torch-2.0.0-cp38-cp38-manylinux2014_aarch64.whl
torch-2.0.0-cp39-cp39-manylinux2014_aarch64.whl
torch-2.0.1-cp310-cp310-manylinux2014_aarch64.whl
torch-2.0.1-cp311-cp311-manylinux2014_aarch64.whl
torch-2.0.1-cp38-cp38-manylinux2014_aarch64.whl
torch-2.0.1-cp39-cp39-manylinux2014_aarch64.whl
torch-2.6.0+xpu-cp310-cp310-linux_x86_64.whl
torch-2.6.0+xpu-cp310-cp310-win_amd64.whl
torch-2.6.0+xpu-cp311-cp311-linux_x86_64.whl
torch-2.6.0+xpu-cp311-cp311-win_amd64.whl
torch-2.6.0+xpu-cp312-cp312-linux_x86_64.whl
torch-2.6.0+xpu-cp312-cp312-win_amd64.whl
torch-2.6.0+xpu-cp313-cp313-linux_x86_64.whl
torch-2.6.0+xpu-cp313-cp313-win_amd64.whl
torch-2.6.0+xpu-cp39-cp39-linux_x86_64.whl
torch-2.6.0+xpu-cp39-cp39-win_amd64.whl
torch-2.7.0+xpu-cp310-cp310-linux_x86_64.whl
torch-2.7.0+xpu-cp310-cp310-win_amd64.whl
torch-2.7.0+xpu-cp311-cp311-linux_x86_64.whl
torch-2.7.0+xpu-cp311-cp311-win_amd64.whl
torch-2.7.0+xpu-cp312-cp312-linux_x86_64.whl
torch-2.7.0+xpu-cp312-cp312-win_amd64.whl
torch-2.7.0+xpu-cp313-cp313-linux_x86_64.whl
torch-2.7.0+xpu-cp313-cp313-win_amd64.whl
torch-2.7.0+xpu-cp313-cp313t-linux_x86_64.whl
torch-2.7.0+xpu-cp313-cp313t-win_amd64.whl
torch-2.7.0+xpu-cp39-cp39-linux_x86_64.whl
torch-2.7.0+xpu-cp39-cp39-win_amd64.whl
torch-2.7.1+xpu-cp310-cp310-linux_x86_64.whl
torch-2.7.1+xpu-cp310-cp310-win_amd64.whl
torch-2.7.1+xpu-cp311-cp311-linux_x86_64.whl
torch-2.7.1+xpu-cp311-cp311-win_amd64.whl
torch-2.7.1+xpu-cp312-cp312-linux_x86_64.whl
torch-2.7.1+xpu-cp312-cp312-win_amd64.whl
torch-2.7.1+xpu-cp313-cp313-linux_x86_64.whl
torch-2.7.1+xpu-cp313-cp313-win_amd64.whl
torch-2.7.1+xpu-cp313-cp313t-linux_x86_64.whl
torch-2.7.1+xpu-cp313-cp313t-win_amd64.whl
torch-2.7.1+xpu-cp39-cp39-linux_x86_64.whl
torch-2.7.1+xpu-cp39-cp39-win_amd64.whl
================================================ FILE: tools/compatgen/internal/util.go ================================================ package internal import ( "strings" ) func split2(s string, sep string) (string, string) { parts := strings.SplitN(s, sep, 2) return parts[0], parts[1] } ================================================ FILE: tools/compatgen/main.go ================================================ package main import ( "context" "encoding/json" "os" "github.com/spf13/cobra" "github.com/replicate/cog/pkg/util/console" "github.com/replicate/cog/tools/compatgen/internal" ) func main() { var output string var rootCmd = &cobra.Command{ Use: "compatgen {cuda|torch|tensorflow}", Short: "Generate compatibility matrix for Cog base images", Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { ctx := context.Background() target := args[0] var v any var err error switch target { case "cuda": v, err = internal.FetchCUDABaseImages(ctx) if err != nil { console.Fatalf("Failed to fetch CUDA base image tags: %s", err) } case "tensorflow": v, err = internal.FetchTensorFlowCompatibilityMatrix() if err != nil { console.Fatalf("Failed to fetch TensorFlow compatibility matrix: %s", err) } case "torch": v, err = internal.FetchTorchCompatibilityMatrix() if err != nil { console.Fatalf("Failed to fetch PyTorch compatibility matrix: %s", err) } default: console.Fatalf("Unknown target: %s", target) } data, err := json.MarshalIndent(v, "", " ") if err != nil { console.Fatalf("Failed to marshal value: %s", err) } if output != "" { if err := os.WriteFile(output, data, 0o644); err != nil { console.Fatalf("Failed to write to %s: %s", output, err) } console.Infof("Wrote to %s", output) } else { console.Output(string(data)) } }, } rootCmd.Flags().StringVarP(&output, "output", "o", "", "Output flag (optional)") if err := rootCmd.Execute(); err != nil { console.Fatal(err.Error()) } } ================================================ FILE: tools/gendocs/main.go ================================================ package main import ( "fmt" "os" "path/filepath" "slices" "sort" "strings" "github.com/spf13/cobra" "github.com/spf13/cobra/doc" "github.com/replicate/cog/pkg/cli" "github.com/replicate/cog/pkg/util/console" ) func main() { var output string rootCmd := &cobra.Command{ Use: "gendocs", Short: "Generate CLI reference documentation for Cog", Run: func(cmd *cobra.Command, args []string) { if err := generateDocs(output); err != nil { console.Fatalf("Failed to generate docs: %s", err) } console.Infof("Generated CLI docs at %s", output) }, } rootCmd.Flags().StringVarP(&output, "output", "o", "docs/cli.md", "Output file path") if err := rootCmd.Execute(); err != nil { console.Fatal(err.Error()) } } func generateDocs(outputPath string) error { // Create temporary directory for cobra doc generation tmpDir, err := os.MkdirTemp("", "cog-cli-docs-*") if err != nil { return fmt.Errorf("failed to create temp dir: %w", err) } defer os.RemoveAll(tmpDir) // Get the cog command cmd, err := cli.NewRootCommand() if err != nil { return fmt.Errorf("failed to create root command: %w", err) } // Generate markdown files using cobra/doc if err := doc.GenMarkdownTree(cmd, tmpDir); err != nil { return fmt.Errorf("failed to generate markdown: %w", err) } // Read all generated files files, err := os.ReadDir(tmpDir) if err != nil { return fmt.Errorf("failed to read temp dir: %w", err) } // Sort files to ensure consistent ordering // Order: cog (root), then alphabetically by command name var fileNames []string for _, file := range files { if !file.IsDir() && strings.HasSuffix(file.Name(), ".md") { fileNames = append(fileNames, file.Name()) } } sort.Strings(fileNames) // Build the combined markdown content var content strings.Builder // Write header content.WriteString("# CLI reference\n\n") content.WriteString("\n\n") // Process each command file for _, fileName := range fileNames { filePath := filepath.Join(tmpDir, fileName) data, err := os.ReadFile(filePath) if err != nil { return fmt.Errorf("failed to read %s: %w", fileName, err) } // Process the content processed := processCommandDoc(string(data), fileName) content.WriteString(processed) content.WriteString("\n") } // Ensure output directory exists outputDir := filepath.Dir(outputPath) if err := os.MkdirAll(outputDir, 0o755); err != nil { return fmt.Errorf("failed to create output directory: %w", err) } // Write the combined file if err := os.WriteFile(outputPath, []byte(content.String()), 0o644); err != nil { return fmt.Errorf("failed to write output file: %w", err) } return nil } func processCommandDoc(content string, fileName string) string { // Remove the "SEE ALSO" section and everything after it if idx := strings.Index(content, "### SEE ALSO"); idx != -1 { content = content[:idx] } // Remove the "Options inherited from parent commands" section if idx := strings.Index(content, "### Options inherited from parent commands"); idx != -1 { content = content[:idx] } // Remove trailing whitespace content = strings.TrimRight(content, "\n") // Fix command headers to use backticks // Change "## cog init" to "## `cog init`" // Change "### Options" to "**Options**" (not a heading, won't appear in TOC) // Change "### Examples" to "**Examples**" (not a heading, won't appear in TOC) // Remove "### Synopsis" heading but keep its content // Skip the short description if there's a Synopsis section (to avoid duplication) lines := strings.Split(content, "\n") var result []string skipSynopsis := false skipShortDesc := false for _, line := range lines { switch { case strings.HasPrefix(line, "## cog"): // Extract the command name command := strings.TrimPrefix(line, "## ") result = append(result, "## `"+command+"`") // Check if next non-empty line is "### Synopsis" - if so, skip the short desc skipShortDesc = hasSynopsisSection(lines) case skipShortDesc: // Skip the short description line (first non-empty line after header) // Also skip any blank lines that follow the header if strings.TrimSpace(line) != "" && !strings.HasPrefix(line, "###") { // This is the short description line, skip it skipShortDesc = false } // If line is blank, we continue skipping until we hit the short desc case line == "### Synopsis": // Skip the "### Synopsis" heading line, but keep content after it skipSynopsis = true case skipSynopsis: // Keep synopsis content until we hit the usage block (```) or another heading switch { case line == "### Examples": skipSynopsis = false // Add blank line before if needed if len(result) > 0 && strings.TrimSpace(result[len(result)-1]) != "" { result = append(result, "") } result = append(result, "**Examples**") case strings.HasPrefix(line, "###"), strings.HasPrefix(line, "```"): skipSynopsis = false // Add blank line before if needed if len(result) > 0 && strings.TrimSpace(result[len(result)-1]) != "" { result = append(result, "") } result = append(result, line) default: // Keep all lines from synopsis (including blank lines for paragraph breaks) result = append(result, line) } case line == "### Options": // Add blank line before if needed if len(result) > 0 && strings.TrimSpace(result[len(result)-1]) != "" { result = append(result, "") } result = append(result, "**Options**") case line == "### Examples": // Add blank line before if needed if len(result) > 0 && strings.TrimSpace(result[len(result)-1]) != "" { result = append(result, "") } result = append(result, "**Examples**") default: result = append(result, line) } } // Remove consecutive blank lines result = removeConsecutiveBlankLines(result) return strings.Join(result, "\n") } // removeConsecutiveBlankLines removes consecutive blank lines, keeping only one func removeConsecutiveBlankLines(lines []string) []string { var result []string prevBlank := false for _, line := range lines { isBlank := strings.TrimSpace(line) == "" if isBlank && prevBlank { // Skip consecutive blank lines continue } result = append(result, line) prevBlank = isBlank } return result } // hasSynopsisSection checks if the content has a "### Synopsis" section func hasSynopsisSection(lines []string) bool { return slices.Contains(lines, "### Synopsis") } ================================================ FILE: tools/install.sh ================================================ #!/bin/sh # # This script should be run via curl: # sh -c "$(curl -fsSL https://raw.githubusercontent.com/replicate/cog/main/tools/install.sh)" # or via wget: # sh -c "$(wget -qO- https://raw.githubusercontent.com/replicate/cog/main/tools/install.sh)" # or via fetch: # sh -c "$(fetch -o - https://raw.githubusercontent.com/replicate/cog/main/tools/install.sh)" # # As an alternative, you can first download the install script and run it afterwards: # wget https://raw.githubusercontent.com/replicate/cog/main/tools/install.sh # sh install.sh # # You can tweak the install location by setting the INSTALL_DIR env var when running the script. # INSTALL_DIR=~/my/custom/install/location sh install.sh # # By default, cog will be installed at /usr/local/bin/cog # This install script is based on that of ohmyzsh[1], which is licensed under the MIT License # [1] https://github.com/ohmyzsh/ohmyzsh/blob/master/tools/install.sh # MIT License # Copyright (c) 2009-2022 Robby Russell and contributors (https://github.com/ohmyzsh/ohmyzsh/contributors) # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. set -e set_install_dir() { # Set install directory DEFAULT_INSTALL_DIR="/usr/local/bin" if [ -z "${INSTALL_DIR}" ]; then read -p "Install location? [$DEFAULT_INSTALL_DIR]: " INSTALL_DIR INSTALL_DIR=${INSTALL_DIR:-$DEFAULT_INSTALL_DIR} fi if [ ! -d "$INSTALL_DIR" ]; then echo "The directory $INSTALL_DIR does not exist. Please create it and re-run this script." # Ask user to manually create directory rather than making it for them, # so they don't just type in "y" again and accidentally install at ./y exit 1 fi # Expand abbreviations in INSTALL_DIR INSTALL_DIR=$(cd "$INSTALL_DIR"; pwd) } command_exists() { command -v "$@" >/dev/null 2>&1 } user_can_sudo() { # Check if sudo is installed command_exists $SUDO || return 1 # Termux can't run sudo, so we can detect it and exit the function early. case "$PREFIX" in *com.termux*) return 1 ;; esac # The following command has 3 parts: # # 1. Run `sudo` with `-v`. Does the following: # • with privilege: asks for a password immediately. # • without privilege: exits with error code 1 and prints the message: # Sorry, user may not run sudo on # # 2. Pass `-n` to `sudo` to tell it to not ask for a password. If the # password is not required, the command will finish with exit code 0. # If one is required, sudo will exit with error code 1 and print the # message: # sudo: a password is required # # 3. Check for the words "may not run sudo" in the output to really tell # whether the user has privileges or not. For that we have to make sure # to run `sudo` in the default locale (with `LANG=`) so that the message # stays consistent regardless of the user's locale. # ! LANG= $SUDO -n -v 2>&1 | grep -q "may not run $SUDO" } check_docker() { if ! command_exists docker; then echo "Docker is not installed on your system. Please install Docker before proceeding." exit 1 fi if ! docker run hello-world >/dev/null 2>&1; then echo "WARNING: Docker engine is not running, or docker cannot be run without sudo. Please setup Docker so that your user has permission to run it: https://docs.docker.com/engine/install/linux-postinstall/" fi } setup_cog() { COG_LOCATION="${INSTALL_DIR}/cog" BINARY_URI="https://github.com/replicate/cog/releases/latest/download/cog_$(uname -s)_$(uname -m)" if [ -f "$COG_LOCATION" ]; then echo "A file already exists at $COG_LOCATION" echo "Do you want to delete this file and continue with this installation anyway?" read -p "Delete file? (y/N): " choice case "$choice" in y|Y ) echo "Deleting existing file and continuing with installation..."; $SUDO rm $COG_LOCATION;; * ) echo "Exiting installation."; exit 1;; esac fi if command_exists curl; then $SUDO curl -o $COG_LOCATION -L $BINARY_URI elif command_exists wget; then $SUDO wget $BINARY_URI -O $COG_LOCATION elif command_exists fetch; then $SUDO fetch -o $COG_LOCATION $BINARY_URI else echo "One of curl, wget, or fetch must be present for this installer to work." exit 1 fi if [ "$(cat $COG_LOCATION)" = "Not Found" ]; then echo "Error: Cog binary not found at ${BINARY_URI}. Check releases to see if a binary is available for your system." rm $COG_LOCATION exit 1 fi $SUDO chmod +x $COG_LOCATION # On macOS, remove the quarantine attribute that triggers Gatekeeper's # "cannot be opened because the developer cannot be verified" warning. if [ "$(uname -s)" = "Darwin" ]; then $SUDO xattr -d com.apple.quarantine "$COG_LOCATION" 2>/dev/null || true fi SHELL_NAME=$(basename "$SHELL") if [[ ":$PATH:" != *":$INSTALL_DIR:"* ]]; then echo "Adding $INSTALL_DIR to PATH in .$SHELL_NAME"rc echo "" >> ~/.$SHELL_NAME"rc" echo "# Created by \`cog\` install script on $(date)" >> ~/.$SHELL_NAME"rc" echo "export PATH=\$PATH:$INSTALL_DIR" >> ~/.$SHELL_NAME"rc" source ~/.$SHELL_NAME"rc" echo "You may need to open a new terminal window to run cog for the first time." fi echo } print_success() { echo "Successfully installed cog. Run \`cog login\` to configure Replicate access" } main() { # Check if macOS if [ "$(uname -s)" = "Darwin" ]; then echo "On macOS, it is recommended to install cog using Homebrew instead:" echo \`brew install replicate/tap/cog\` echo "Do you want to continue with this installation anyway?" read -p "Continue? (y/N): " choice case "$choice" in y|Y ) echo "Continuing with installation...";; * ) echo "Exiting installation."; exit 1;; esac fi set_install_dir # Check if `cog` command already exists if command_exists cog; then echo "A cog command already exists on your system at the following location: $(which cog)". echo "The installations may interfere with one another." echo "Do you want to continue with this installation anyway?" read -p "Continue? (y/N): " choice case "$choice" in y|Y ) echo "Continuing with installation...";; * ) echo "Exiting installation."; exit 1;; esac fi # Check the users sudo privileges if [ -z "${SUDO+set}" ]; then SUDO="sudo" fi if [ ! user_can_sudo ] && [ "${SUDO}" != "" ]; then echo "You need sudo permissions to run this install script. Please try again as a sudoer." exit 1 fi check_docker setup_cog if command_exists cog; then print_success else echo 'Error: cog not installed.' exit 1 fi } main "$@" ================================================ FILE: tools/test-harness/.gitignore ================================================ results/*.json __pycache__/ *.pyc .venv/ ================================================ FILE: tools/test-harness/README.md ================================================ # Cog Model Test Harness Automated test harness for validating cog models against new SDK versions. Designed to test any cog model from any repo. ## Quick Start ```bash cd tools/test-harness # Create a venv and install dependencies python3 -m venv .venv source .venv/bin/activate pip install pyyaml # List all models in the manifest python -m harness list # Run all non-GPU models python -m harness run --no-gpu # Run a specific model python -m harness run --model hello-world # Run GPU models only (requires NVIDIA GPU + nvidia-docker) python -m harness run --gpu-only # Output JSON report python -m harness run --no-gpu --output json --output-file results/report.json # Build images only (no predictions) python -m harness build --no-gpu ``` ## Prerequisites - Python 3.10+ - Docker - For GPU models: NVIDIA GPU + nvidia-docker runtime ### Version Resolution By default the harness automatically resolves the **latest stable** versions of both the cog CLI (from GitHub releases) and the Python SDK (from PyPI), skipping any alpha/beta/rc tags. You can override either via the CLI or in `manifest.yaml`: ```bash # Use the latest stable CLI + SDK (default) python -m harness run --no-gpu # Pin a specific CLI version python -m harness run --cog-version v0.16.12 --no-gpu # Pin a specific SDK version python -m harness run --sdk-version 0.16.12 --no-gpu # Use a pre-release CLI python -m harness run --cog-version v0.17.0-rc.2 --no-gpu # Use a locally-built binary (overrides --cog-version) python -m harness run --cog-binary ./dist/go/darwin-arm64/cog --no-gpu ``` You can also pin versions in `manifest.yaml` under `defaults`: ```yaml defaults: sdk_version: "latest" # or pin e.g. "0.16.12" cog_version: "latest" # or pin e.g. "v0.16.12" ``` **Resolution priority** (for both CLI and SDK): CLI flag > manifest default > latest stable. ## Manifest Format Models are defined in `manifest.yaml`. Each entry specifies a GitHub repo, subdirectory, test inputs, and expected outputs: ```yaml models: - name: hello-world repo: replicate/cog-examples path: hello-world gpu: false tests: - description: "basic predict" inputs: text: "world" expect: type: exact value: "hello world" ``` ### Model Fields | Field | Required | Description | |-------|----------|-------------| | `name` | yes | Unique identifier for the model | | `repo` | yes | GitHub `owner/repo` to clone | | `path` | no | Subdirectory within the repo (default: `.`) | | `gpu` | no | Whether the model requires a GPU (default: `false`) | | `sdk_version` | no | Override the SDK version (default: from `defaults.sdk_version`) | | `timeout` | no | Per-prediction timeout in seconds (default: 300) | | `requires_env` | no | List of env vars that must be set; model is skipped if missing | | `env` | no | Extra env vars to pass; supports `${VAR}` expansion from host | | `cog_yaml_overrides` | no | Dict deep-merged into the model's cog.yaml | | `tests` | no | List of predict test cases | | `train_tests` | no | List of train test cases | ### Input References Prefix a value with `@` to reference a file in `fixtures/`: ```yaml inputs: image: "@test_image.png" # resolves to fixtures/test_image.png ``` ### Validation Types | Type | Fields | Description | |------|--------|-------------| | `exact` | `value` | Output must equal value exactly | | `contains` | `value` | Output must contain the substring | | `regex` | `pattern` | Output must match the regex | | `file_exists` | `mime` (optional) | Output is a file path that must exist | | `json_match` | `match` | Output parsed as JSON must contain the given subset | | `json_keys` | `keys` (optional) | Output parsed as JSON dict must have entries | | `not_empty` | — | Output must be non-empty | ## Adding a New Model Add an entry to `manifest.yaml`: ```yaml - name: my-model repo: myorg/my-model-repo path: "." gpu: true # sdk_version: "0.16.12" # optional per-model override env: HF_TOKEN: "${HF_TOKEN}" timeout: 600 tests: - description: "smoke test" inputs: prompt: "hello" expect: type: contains value: "result" ``` No code changes required. ## CLI Reference ``` usage: cog-test {run,build,list} [options] Commands: run Build and test models (full pipeline) build Build Docker images only (no predictions) list List models defined in the manifest Common options: --manifest PATH Path to manifest.yaml --model NAME Run only this model (repeatable) --no-gpu Skip GPU models --gpu-only Only run GPU models --sdk-version VER SDK version (default: latest stable from PyPI) --cog-version TAG CLI version to download (default: latest stable) --cog-binary PATH Path to local cog binary (overrides --cog-version) --keep-images Don't clean up Docker images after run Run-specific options: --output {console,json} Output format (default: console) --output-file PATH Write report to file ``` ## Architecture ``` tools/test-harness/ ├── manifest.yaml # Declarative test definitions ├── fixtures/ # Test input files (images, etc.) ├── harness/ │ ├── cli.py # CLI entry point │ ├── cog_resolver.py # Resolves + downloads cog CLI and SDK versions │ ├── runner.py # Clone -> patch -> build -> predict -> validate │ ├── patcher.py # Patches cog.yaml with sdk_version + overrides │ ├── validators.py # Output validation strategies │ └── report.py # Console + JSON report generation ├── results/ # Output reports (gitignored) └── pyproject.toml ``` ================================================ FILE: tools/test-harness/harness/__init__.py ================================================ ================================================ FILE: tools/test-harness/harness/__main__.py ================================================ """Allow running as ``python -m harness``.""" from .cli import main main() ================================================ FILE: tools/test-harness/harness/cli.py ================================================ """CLI entry point for the cog test harness.""" from __future__ import annotations import argparse import logging import sys from pathlib import Path import yaml from .cog_resolver import resolve_cog_binary, resolve_sdk_version from .report import console_report, write_json_report from .runner import ModelResult, Runner def main(argv: list[str] | None = None) -> None: parser = argparse.ArgumentParser( prog="cog-test", description="Test harness for validating cog models against new SDK versions", ) subparsers = parser.add_subparsers(dest="command") # ── run ───────────────────────────────────────────────────────── run_parser = subparsers.add_parser("run", help="Build and test models") _add_common_args(run_parser) run_parser.add_argument( "--output", choices=["console", "json"], default="console", help="Output format (default: console)", ) run_parser.add_argument( "--output-file", type=str, default=None, help="Write report to file instead of stdout", ) # ── build ─────────────────────────────────────────────────────── build_parser = subparsers.add_parser( "build", help="Build model images only (no predict)" ) _add_common_args(build_parser) # ── list ──────────────────────────────────────────────────────── list_parser = subparsers.add_parser("list", help="List models in manifest") list_parser.add_argument( "--manifest", type=str, default=None, help="Path to manifest.yaml", ) args = parser.parse_args(argv) if args.command is None: parser.print_help() sys.exit(1) logging.basicConfig( format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%H:%M:%S", ) if args.command == "list": _cmd_list(args) elif args.command == "build": _cmd_build(args) elif args.command == "run": _cmd_run(args) def _cmd_list(args: argparse.Namespace) -> None: manifest = _load_manifest(args.manifest) models = manifest.get("models", []) for m in models: gpu_tag = " [GPU]" if m.get("gpu") else "" req_env = m.get("requires_env", []) env_tag = f" (requires: {', '.join(req_env)})" if req_env else "" print(f" {m['name']:<25} {m['repo']}/{m.get('path', '.')}{gpu_tag}{env_tag}") print(f"\n{len(models)} models total") def _cmd_build(args: argparse.Namespace) -> None: manifest = _load_manifest(args.manifest) models = _filter_models(manifest, args) defaults = manifest.get("defaults", {}) sdk_version, _ = resolve_sdk_version( cli_sdk_version=args.sdk_version, manifest_defaults=defaults, ) cog_binary, cog_version_label = resolve_cog_binary( cog_version=args.cog_version, cog_binary=args.cog_binary, manifest_defaults=defaults, ) log = logging.getLogger(__name__) log.info("Using cog CLI: %s (%s)", cog_binary, cog_version_label) log.info("Using SDK version: %s", sdk_version) runner = Runner( cog_binary=cog_binary, sdk_version=sdk_version, keep_images=True, ) results: list[ModelResult] = [] for model in models: result = ModelResult( name=model["name"], passed=True, gpu=model.get("gpu", False) ) try: model_dir = runner.prepare_model(model) import time start = time.monotonic() runner.build_model(model_dir, model) result.build_duration_s = time.monotonic() - start logging.getLogger(__name__).info( "BUILD OK %s (%.1fs)", model["name"], result.build_duration_s ) except Exception as exc: result.passed = False result.error = str(exc) results.append(result) console_report( results, sdk_version=sdk_version or "", cog_version=cog_version_label ) failed = any(not r.passed for r in results) sys.exit(1 if failed else 0) def _cmd_run(args: argparse.Namespace) -> None: manifest = _load_manifest(args.manifest) models = _filter_models(manifest, args) defaults = manifest.get("defaults", {}) sdk_version, _ = resolve_sdk_version( cli_sdk_version=args.sdk_version, manifest_defaults=defaults, ) cog_binary, cog_version_label = resolve_cog_binary( cog_version=args.cog_version, cog_binary=args.cog_binary, manifest_defaults=defaults, ) log = logging.getLogger(__name__) log.info("Using cog CLI: %s (%s)", cog_binary, cog_version_label) log.info("Using SDK version: %s", sdk_version) runner = Runner( cog_binary=cog_binary, sdk_version=sdk_version, keep_images=args.keep_images, ) results: list[ModelResult] = [] try: for model in models: result = runner.run_model(model) results.append(result) finally: if not args.keep_images: runner.cleanup() # Output if args.output == "json": if args.output_file: with open(args.output_file, "w") as f: write_json_report( results, sdk_version=sdk_version or "", cog_version=cog_version_label, stream=f, ) else: write_json_report( results, sdk_version=sdk_version or "", cog_version=cog_version_label, ) else: console_report( results, sdk_version=sdk_version or "", cog_version=cog_version_label, ) if args.output_file: with open(args.output_file, "w") as f: write_json_report( results, sdk_version=sdk_version or "", cog_version=cog_version_label, stream=f, ) failed = any(not r.passed for r in results) sys.exit(1 if failed else 0) # ── Helpers ──────────────────────────────────────────────────────────── def _add_common_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--manifest", type=str, default=None, help="Path to manifest.yaml (default: auto-detect)", ) parser.add_argument( "--model", type=str, action="append", default=None, help="Run only specific model(s) by name (repeatable)", ) parser.add_argument( "--no-gpu", action="store_true", help="Skip models that require a GPU", ) parser.add_argument( "--gpu-only", action="store_true", help="Only run models that require a GPU", ) parser.add_argument( "--sdk-version", type=str, default=None, help=( "SDK version to inject into cog.yaml (e.g. 0.16.12). " "Default: latest stable release from PyPI." ), ) parser.add_argument( "--cog-version", type=str, default=None, help=( "Cog CLI version to download and use (e.g. v0.16.12). " "Default: latest stable release. Ignored if --cog-binary is set." ), ) parser.add_argument( "--cog-binary", type=str, default="cog", help="Path to a local cog binary (overrides --cog-version)", ) parser.add_argument( "--keep-images", action="store_true", help="Don't clean up Docker images after run", ) def _load_manifest(manifest_path: str | None) -> dict: if manifest_path: path = Path(manifest_path) else: # Search up from CWD, then fall back to the default location path = Path(__file__).parent.parent / "manifest.yaml" if not path.exists(): print(f"Error: manifest not found at {path}", file=sys.stderr) sys.exit(1) with open(path) as f: return yaml.safe_load(f) def _filter_models(manifest: dict, args: argparse.Namespace) -> list[dict]: models = manifest.get("models", []) if args.model: names = set(args.model) models = [m for m in models if m["name"] in names] found = {m["name"] for m in models} missing = names - found if missing: print(f"Warning: models not found in manifest: {missing}", file=sys.stderr) if getattr(args, "no_gpu", False): models = [m for m in models if not m.get("gpu")] if getattr(args, "gpu_only", False): models = [m for m in models if m.get("gpu")] return models if __name__ == "__main__": main() ================================================ FILE: tools/test-harness/harness/cog_resolver.py ================================================ """Resolve and download specific cog CLI and SDK versions.""" from __future__ import annotations import json import logging import os import platform import re import stat import tempfile import urllib.request from pathlib import Path logger = logging.getLogger(__name__) GITHUB_API = "https://api.github.com/repos/replicate/cog/releases" DOWNLOAD_BASE = ( "https://github.com/replicate/cog/releases/download/{tag}/cog_{os}_{arch}" ) PYPI_API = "https://pypi.org/pypi/cog/json" # Pre-release patterns to skip when resolving "latest" _PRERELEASE_RE = re.compile(r"-(alpha|beta|rc|dev)", re.IGNORECASE) def resolve_latest_stable_version() -> str: """Query GitHub releases and return the tag of the latest stable release. Skips any release marked as a prerelease or whose tag contains alpha/beta/rc/dev suffixes. """ url = f"{GITHUB_API}?per_page=50" headers = {"Accept": "application/vnd.github+json"} # Use a token if available to avoid rate limits token = os.environ.get("GITHUB_TOKEN") or os.environ.get("GH_TOKEN") if token: headers["Authorization"] = f"Bearer {token}" req = urllib.request.Request(url, headers=headers) with urllib.request.urlopen(req, timeout=30) as resp: releases = json.loads(resp.read().decode()) for release in releases: tag = release.get("tag_name", "") if release.get("prerelease") or release.get("draft"): continue if _PRERELEASE_RE.search(tag): continue return tag raise RuntimeError( "Could not find a stable cog release. " "Check https://github.com/replicate/cog/releases" ) def _platform_asset_name() -> str: """Return the cog binary asset name for the current platform.""" system = platform.system() # Darwin, Linux machine = platform.machine() # arm64, x86_64, aarch64 if system not in ("Darwin", "Linux"): raise RuntimeError(f"Unsupported OS: {system}") # Normalise architecture names arch_map = { "arm64": "arm64", "aarch64": "arm64", "x86_64": "x86_64", "amd64": "x86_64", } arch = arch_map.get(machine) if not arch: raise RuntimeError(f"Unsupported architecture: {machine}") return f"cog_{system}_{arch}" def download_cog_binary(tag: str, dest_dir: Path | None = None) -> Path: """Download the cog binary for *tag* and return the path to it. The binary is placed in *dest_dir* (default: a new temp directory) and made executable. """ asset = _platform_asset_name() url = DOWNLOAD_BASE.format(tag=tag, os=platform.system(), arch=asset.split("_")[-1]) if dest_dir is None: dest_dir = Path(tempfile.mkdtemp(prefix="cog-bin-")) dest_dir.mkdir(parents=True, exist_ok=True) dest = dest_dir / "cog" logger.info("Downloading cog %s from %s ...", tag, url) req = urllib.request.Request(url) token = os.environ.get("GITHUB_TOKEN") or os.environ.get("GH_TOKEN") if token: req.add_header("Authorization", f"Bearer {token}") with urllib.request.urlopen(req, timeout=120) as resp, open(dest, "wb") as f: while True: chunk = resp.read(1 << 16) if not chunk: break f.write(chunk) # Make executable dest.chmod(dest.stat().st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH) # Verify it works logger.info("Downloaded cog %s -> %s", tag, dest) return dest def resolve_cog_binary( cog_version: str | None, cog_binary: str | None, manifest_defaults: dict | None = None, ) -> tuple[str, str]: """Resolve which cog binary to use. Returns ``(binary_path, version_label)``. Priority: 1. ``--cog-binary`` (explicit path) — use as-is, version label = "custom" 2. ``--cog-version`` — download that specific tag 3. ``defaults.cog_version`` from manifest — download that tag 4. No version specified — resolve latest stable, download it If *cog_binary* is provided and is not the default ``"cog"``, it takes top priority (the user wants their own binary). """ defaults = manifest_defaults or {} # 1. Explicit --cog-binary (non-default) if cog_binary and cog_binary != "cog": return cog_binary, "custom" # 2. Explicit --cog-version if cog_version: tag = cog_version if cog_version.startswith("v") else f"v{cog_version}" path = download_cog_binary(tag) return str(path), tag # 3. Manifest default manifest_version = defaults.get("cog_version") if manifest_version and manifest_version != "latest": tag = ( manifest_version if manifest_version.startswith("v") else f"v{manifest_version}" ) path = download_cog_binary(tag) return str(path), tag # 4. Resolve latest stable tag = resolve_latest_stable_version() logger.info("Resolved latest stable cog version: %s", tag) path = download_cog_binary(tag) return str(path), tag # ── SDK version resolution ───────────────────────────────────────────── def resolve_latest_sdk_version() -> str: """Query PyPI and return the latest stable version of the ``cog`` package. PyPI's ``info.version`` field always returns the latest non-prerelease version, so no extra filtering is needed. """ req = urllib.request.Request(PYPI_API, headers={"Accept": "application/json"}) with urllib.request.urlopen(req, timeout=30) as resp: data = json.loads(resp.read().decode()) version = data["info"]["version"] logger.info("Resolved latest stable SDK version from PyPI: %s", version) return version def resolve_sdk_version( cli_sdk_version: str | None, manifest_defaults: dict | None = None, ) -> tuple[str, bool]: """Resolve which SDK version to use. Returns ``(version, was_resolved)``. Priority: 1. ``--sdk-version`` CLI flag — use as-is 2. ``defaults.sdk_version`` from manifest (if not ``"latest"``) 3. Resolve latest stable from PyPI *was_resolved* is ``True`` when the version was auto-resolved from PyPI. """ defaults = manifest_defaults or {} # 1. Explicit --sdk-version if cli_sdk_version: return cli_sdk_version, False # 2. Manifest default manifest_version = defaults.get("sdk_version") if manifest_version and manifest_version != "latest": return manifest_version, False # 3. Resolve latest stable from PyPI version = resolve_latest_sdk_version() return version, True ================================================ FILE: tools/test-harness/harness/patcher.py ================================================ """Patch cog.yaml files with sdk_version and arbitrary overrides.""" from __future__ import annotations import copy from pathlib import Path from typing import Any import yaml def deep_merge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: """Recursively merge *override* into *base*, returning a new dict.""" result = copy.deepcopy(base) for key, value in override.items(): if key in result and isinstance(result[key], dict) and isinstance(value, dict): result[key] = deep_merge(result[key], value) else: result[key] = copy.deepcopy(value) return result def patch_cog_yaml( cog_yaml_path: Path, sdk_version: str | None = None, overrides: dict[str, Any] | None = None, ) -> dict[str, Any]: """Read a cog.yaml, apply patches, write it back, and return the final config. Parameters ---------- cog_yaml_path: Path to the cog.yaml file to patch (modified in-place). sdk_version: If set, inject ``build.sdk_version`` into the config. overrides: Arbitrary dict that is deep-merged into the config. Useful for changing python_version, adding system_packages, etc. Returns ------- The patched config dict. """ with open(cog_yaml_path) as f: config = yaml.safe_load(f) or {} if sdk_version: config.setdefault("build", {}) config["build"]["sdk_version"] = sdk_version if overrides: config = deep_merge(config, overrides) with open(cog_yaml_path, "w") as f: yaml.dump(config, f, default_flow_style=False, sort_keys=False) return config ================================================ FILE: tools/test-harness/harness/report.py ================================================ """Generate human-readable and machine-readable test reports.""" from __future__ import annotations import json import sys from datetime import datetime, timezone from typing import Any, TextIO from .runner import ModelResult def console_report( results: list[ModelResult], *, sdk_version: str = "", cog_version: str = "", stream: TextIO = sys.stdout, ) -> None: """Print a coloured summary table to the terminal.""" parts = [] if cog_version: parts.append(f"CLI {cog_version}") if sdk_version: parts.append(f"SDK {sdk_version}") version_str = " / ".join(parts) header = ( f"Cog Compatibility Report ({version_str})" if version_str else "Cog Compatibility Report" ) stream.write(f"\n{'=' * len(header)}\n") stream.write(f"{header}\n") stream.write(f"{'=' * len(header)}\n\n") passed = 0 failed = 0 skipped = 0 for r in results: if r.skipped: _write(stream, "SKIP", r.name, r.skip_reason or "", gpu=r.gpu) skipped += 1 continue if r.error: _write(stream, "FAIL", r.name, r.error.splitlines()[0], gpu=r.gpu) failed += 1 continue all_tests = r.test_results + r.train_results if r.passed: timing = _timing_str(r.build_duration_s, all_tests) _write(stream, "PASS", r.name, timing, gpu=r.gpu) passed += 1 else: failures = [t for t in all_tests if not t.passed] msg = f"{len(failures)} test(s) failed" if failures: msg += f": {failures[0].message[:60]}" _write(stream, "FAIL", r.name, msg, gpu=r.gpu) failed += 1 # Print individual test failures indented for t in failures: stream.write(f" FAIL {t.description}: {t.message[:100]}\n") stream.write(f"\n{'-' * 40}\n") total = passed + failed + skipped stream.write(f"{passed}/{total} passed") if skipped: stream.write(f", {skipped} skipped") if failed: stream.write(f", {failed} FAILED") stream.write("\n\n") def json_report( results: list[ModelResult], *, sdk_version: str = "", cog_version: str = "", ) -> dict[str, Any]: """Return a JSON-serializable report dict.""" models = [] for r in results: entry: dict[str, Any] = { "name": r.name, "passed": r.passed, "skipped": r.skipped, "gpu": r.gpu, "build_duration_s": round(r.build_duration_s, 2), } if r.skipped: entry["skip_reason"] = r.skip_reason if r.error: entry["error"] = r.error if r.test_results: entry["tests"] = [ { "description": t.description, "passed": t.passed, "message": t.message, "duration_s": round(t.duration_s, 2), } for t in r.test_results ] if r.train_results: entry["train_tests"] = [ { "description": t.description, "passed": t.passed, "message": t.message, "duration_s": round(t.duration_s, 2), } for t in r.train_results ] models.append(entry) total = len(results) passed = sum(1 for r in results if r.passed and not r.skipped) failed = sum(1 for r in results if not r.passed) skipped_count = sum(1 for r in results if r.skipped) return { "timestamp": datetime.now(timezone.utc).isoformat(), "cog_version": cog_version, "sdk_version": sdk_version, "summary": { "total": total, "passed": passed, "failed": failed, "skipped": skipped_count, }, "models": models, } def write_json_report( results: list[ModelResult], *, sdk_version: str = "", cog_version: str = "", stream: TextIO = sys.stdout, ) -> None: """Write JSON report to a stream.""" report = json_report(results, sdk_version=sdk_version, cog_version=cog_version) json.dump(report, stream, indent=2) stream.write("\n") # ── Helpers ──────────────────────────────────────────────────────────── def _write( stream: TextIO, status: str, name: str, detail: str, *, gpu: bool = False ) -> None: icon = {"PASS": "+", "FAIL": "x", "SKIP": "-"}[status] gpu_tag = " [GPU]" if gpu else "" stream.write(f" {icon} {name:<25} {detail}{gpu_tag}\n") def _timing_str(build_s: float, tests: list[Any]) -> str: parts = [f"{build_s:.1f}s build"] if tests: total_predict = sum(t.duration_s for t in tests) parts.append(f"{total_predict:.1f}s predict") return f"({', '.join(parts)})" ================================================ FILE: tools/test-harness/harness/runner.py ================================================ """Core test runner: clone, patch, build, predict, validate.""" from __future__ import annotations import logging import os import re import shutil import subprocess import tempfile import time from dataclasses import dataclass, field from pathlib import Path from typing import Any from .patcher import patch_cog_yaml from .validators import ValidationResult, validate logger = logging.getLogger(__name__) # ── Data types ───────────────────────────────────────────────────────── @dataclass class TestCaseResult: description: str passed: bool message: str duration_s: float = 0.0 @dataclass class ModelResult: name: str passed: bool build_duration_s: float = 0.0 test_results: list[TestCaseResult] = field(default_factory=list) train_results: list[TestCaseResult] = field(default_factory=list) error: str | None = None skipped: bool = False skip_reason: str | None = None gpu: bool = False # ── Runner ───────────────────────────────────────────────────────────── class Runner: """Orchestrates the clone -> patch -> build -> predict -> validate cycle.""" def __init__( self, *, cog_binary: str = "cog", sdk_version: str | None = None, fixtures_dir: Path | None = None, work_dir: Path | None = None, keep_images: bool = False, default_timeout: int = 300, ) -> None: self.cog_binary = cog_binary self.sdk_version = sdk_version self.fixtures_dir = fixtures_dir or Path(__file__).parent.parent / "fixtures" self.work_dir = work_dir or Path(tempfile.mkdtemp(prefix="cog-harness-")) self.keep_images = keep_images self.default_timeout = default_timeout self._cloned_repos: dict[str, Path] = {} def prepare_model(self, model: dict[str, Any]) -> Path: """Public wrapper around model preparation (clone + patch).""" return self._prepare_model(model) def build_model(self, model_dir: Path, model: dict[str, Any]) -> None: """Public wrapper around ``cog build``.""" self._cog_build(model_dir, model) def run_model(self, model: dict[str, Any]) -> ModelResult: """Run all tests for a single model definition from the manifest.""" name = model["name"] gpu = model.get("gpu", False) result = ModelResult(name=name, passed=True, gpu=gpu) # Check required env vars required_env = model.get("requires_env", []) missing = [v for v in required_env if not os.environ.get(v)] if missing: result.passed = True # not a failure, just skipped result.skipped = True result.skip_reason = f"Missing env vars: {', '.join(missing)}" logger.info("SKIP %s: %s", name, result.skip_reason) return result try: model_dir = self._prepare_model(model) except Exception as exc: result.passed = False result.error = f"Preparation failed: {exc}" logger.error("FAIL %s: %s", name, result.error) return result # Build build_start = time.monotonic() try: self._cog_build(model_dir, model) result.build_duration_s = time.monotonic() - build_start logger.info("BUILD OK %s (%.1fs)", name, result.build_duration_s) except subprocess.CalledProcessError as exc: result.passed = False result.build_duration_s = time.monotonic() - build_start stderr = exc.stderr or "" result.error = f"Build failed:\n{stderr[-2000:]}" logger.error("BUILD FAIL %s:\n%s", name, stderr[-500:]) return result # Train tests for tc in model.get("train_tests", []): tc_result = self._run_train_test(model_dir, model, tc) result.train_results.append(tc_result) if not tc_result.passed: result.passed = False # Predict tests for tc in model.get("tests", []): tc_result = self._run_predict_test(model_dir, model, tc) result.test_results.append(tc_result) if not tc_result.passed: result.passed = False return result # ── Internal helpers ─────────────────────────────────────────────── def _prepare_model(self, model: dict[str, Any]) -> Path: """Clone the repo (if needed) and patch cog.yaml. Returns model dir.""" repo = model["repo"] subpath = model.get("path", ".") repo_dir = self._clone_repo(repo) model_dir = repo_dir / subpath if not (model_dir / "cog.yaml").exists(): raise FileNotFoundError(f"No cog.yaml in {model_dir}") sdk_version = model.get("sdk_version", self.sdk_version) overrides = model.get("cog_yaml_overrides") patch_cog_yaml( model_dir / "cog.yaml", sdk_version=sdk_version, overrides=overrides, ) return model_dir def _clone_repo(self, repo: str) -> Path: """Shallow-clone a GitHub repo into the work dir, caching by repo name.""" if repo in self._cloned_repos: return self._cloned_repos[repo] dest = self.work_dir / repo.replace("/", "--") if dest.exists(): shutil.rmtree(dest) url = f"https://github.com/{repo}.git" logger.info("Cloning %s ...", url) subprocess.run( ["git", "clone", "--depth=1", url, str(dest)], check=True, capture_output=True, text=True, ) self._cloned_repos[repo] = dest return dest def _cog_build(self, model_dir: Path, model: dict[str, Any]) -> None: """Run ``cog build`` in the model directory.""" image_tag = f"cog-harness-{model['name']}:test" cmd = [self.cog_binary, "build", "-t", image_tag] env = self._build_env(model) timeout = model.get("timeout", self.default_timeout) subprocess.run( cmd, cwd=model_dir, check=True, capture_output=True, text=True, env=env, timeout=timeout, ) def _run_predict_test( self, model_dir: Path, model: dict[str, Any], tc: dict[str, Any] ) -> TestCaseResult: description = tc.get("description", "predict") start = time.monotonic() cmd = [self.cog_binary, "predict"] for key, value in tc.get("inputs", {}).items(): resolved = self._resolve_input(value) cmd.extend(["-i", f"{key}={resolved}"]) env = self._build_env(model) timeout = model.get("timeout", self.default_timeout) try: proc = subprocess.run( cmd, cwd=model_dir, capture_output=True, text=True, env=env, timeout=timeout, ) duration = time.monotonic() - start if proc.returncode != 0: return TestCaseResult( description=description, passed=False, message=f"cog predict exited {proc.returncode}:\n{proc.stderr[-1000:]}", duration_s=duration, ) output = self._extract_output(proc, model_dir) vr: ValidationResult = validate(output, tc.get("expect", {})) logger.info( " %s %s: %s (%.1fs)", "PASS" if vr.passed else "FAIL", description, vr.message[:80], duration, ) return TestCaseResult( description=description, passed=vr.passed, message=vr.message, duration_s=duration, ) except subprocess.TimeoutExpired: duration = time.monotonic() - start return TestCaseResult( description=description, passed=False, message=f"Timed out after {timeout}s", duration_s=duration, ) except Exception as exc: duration = time.monotonic() - start return TestCaseResult( description=description, passed=False, message=f"Unexpected error: {exc}", duration_s=duration, ) def _run_train_test( self, model_dir: Path, model: dict[str, Any], tc: dict[str, Any] ) -> TestCaseResult: description = tc.get("description", "train") start = time.monotonic() cmd = [self.cog_binary, "train"] for key, value in tc.get("inputs", {}).items(): resolved = self._resolve_input(value) cmd.extend(["-i", f"{key}={resolved}"]) env = self._build_env(model) timeout = model.get("timeout", self.default_timeout) try: proc = subprocess.run( cmd, cwd=model_dir, capture_output=True, text=True, env=env, timeout=timeout, ) duration = time.monotonic() - start if proc.returncode != 0: return TestCaseResult( description=description, passed=False, message=f"cog train exited {proc.returncode}:\n{proc.stderr[-1000:]}", duration_s=duration, ) output = self._extract_output(proc, model_dir) vr: ValidationResult = validate(output, tc.get("expect", {})) logger.info( " %s %s: %s (%.1fs)", "PASS" if vr.passed else "FAIL", description, vr.message[:80], duration, ) return TestCaseResult( description=description, passed=vr.passed, message=vr.message, duration_s=duration, ) except subprocess.TimeoutExpired: duration = time.monotonic() - start return TestCaseResult( description=description, passed=False, message=f"Timed out after {timeout}s", duration_s=duration, ) except Exception as exc: duration = time.monotonic() - start return TestCaseResult( description=description, passed=False, message=f"Unexpected error: {exc}", duration_s=duration, ) @staticmethod def _extract_output(proc: subprocess.CompletedProcess[str], model_dir: Path) -> str: """Extract the prediction output from cog's stdout/stderr. ``cog predict`` prints text/JSON output to **stdout**. For file outputs (e.g. images) it writes the file to the CWD and prints ``Written output to: `` on **stderr**. We detect the latter pattern and return the absolute path to the file so that the ``file_exists`` validator can verify it. """ # If there's meaningful stdout, prefer that stdout = proc.stdout.strip() if stdout: return proc.stdout # Check stderr for "Written output to: " m = re.search(r"Written output to:\s*(.+)", proc.stderr) if m: rel_path = m.group(1).strip() abs_path = model_dir / rel_path return str(abs_path) # Fallback: return whatever stdout had (possibly empty) return proc.stdout def _resolve_input(self, value: Any) -> str: """Resolve input values — ``@filename`` becomes an absolute fixture path. The path is resolved to an absolute, canonical path (no symlinks or ``..`` components) so that ``cog predict -i image=@/abs/path`` works correctly when cog mounts the file into the container. """ s = str(value) if s.startswith("@"): fixture_path = (self.fixtures_dir / s[1:]).resolve() if not fixture_path.exists(): raise FileNotFoundError( f"Fixture not found: {fixture_path} (referenced as {s!r})" ) return f"@{fixture_path}" return s def _build_env(self, model: dict[str, Any]) -> dict[str, str]: """Build environment dict, expanding ${VAR} references from host env.""" env = os.environ.copy() for key, value in model.get("env", {}).items(): resolved = os.path.expandvars(value) env[key] = resolved return env def cleanup(self) -> None: """Remove work directory and optionally docker images.""" if not self.keep_images: # Clean up docker images we created try: proc = subprocess.run( [ "docker", "images", "--filter", "reference=cog-harness-*", "--format", "{{.Repository}}:{{.Tag}}", ], capture_output=True, text=True, ) images = [ line.strip() for line in proc.stdout.splitlines() if line.strip() ] if images: subprocess.run( ["docker", "rmi", "--force"] + images, capture_output=True, text=True, ) except Exception as exc: logger.warning("Failed to clean up Docker images in cleanup(): %s", exc) if self.work_dir.exists(): shutil.rmtree(self.work_dir, ignore_errors=True) ================================================ FILE: tools/test-harness/harness/validators.py ================================================ """Output validation strategies for cog model predictions.""" from __future__ import annotations import json import mimetypes import re from dataclasses import dataclass from pathlib import Path from typing import Any @dataclass class ValidationResult: passed: bool message: str def validate(output: str, expect: dict[str, Any]) -> ValidationResult: """Dispatch to the appropriate validator based on ``expect["type"]``.""" vtype = expect.get("type", "not_empty") validator = _VALIDATORS.get(vtype) if validator is None: return ValidationResult( passed=False, message=f"Unknown validation type: {vtype!r}", ) return validator(output, expect) # ── Individual validators ────────────────────────────────────────────── def _validate_exact(output: str, expect: dict[str, Any]) -> ValidationResult: expected = str(expect["value"]) clean = output.strip() if clean == expected: return ValidationResult(passed=True, message="Exact match") return ValidationResult( passed=False, message=f"Expected exact match:\n expected: {expected!r}\n got: {clean!r}", ) def _validate_contains(output: str, expect: dict[str, Any]) -> ValidationResult: substring = str(expect["value"]) if substring in output: return ValidationResult(passed=True, message=f"Contains {substring!r}") return ValidationResult( passed=False, message=f"Expected output to contain {substring!r}, got:\n {output[:200]!r}", ) def _validate_regex(output: str, expect: dict[str, Any]) -> ValidationResult: pattern = expect["pattern"] if re.search(pattern, output): return ValidationResult(passed=True, message=f"Matches pattern {pattern!r}") return ValidationResult( passed=False, message=f"Output does not match regex {pattern!r}:\n {output[:200]!r}", ) def _validate_file_exists(output: str, expect: dict[str, Any]) -> ValidationResult: """Validate that the output references an existing file. ``cog predict`` prints the output file path to stdout. It may be an absolute path or a relative path. We also handle the common case where cog wraps the path in quotes or prints extra whitespace. """ path_str = output.strip().strip("'\"") # cog predict may output a URL or a path -- for local testing it's a path if path_str.startswith("http://") or path_str.startswith("https://"): # Can't verify remote files; treat as pass return ValidationResult(passed=True, message=f"Output is a URL: {path_str}") path = Path(path_str) if not path.exists(): return ValidationResult( passed=False, message=f"Output file does not exist: {path}", ) expected_mime = expect.get("mime") if expected_mime: guessed, _ = mimetypes.guess_type(str(path)) if guessed != expected_mime: return ValidationResult( passed=False, message=f"Expected MIME {expected_mime}, got {guessed} for {path}", ) return ValidationResult(passed=True, message=f"File exists: {path}") def _validate_json_match(output: str, expect: dict[str, Any]) -> ValidationResult: """Parse output as JSON and verify that ``expect["match"]`` is a subset.""" try: parsed = json.loads(output.strip()) except json.JSONDecodeError as exc: return ValidationResult( passed=False, message=f"Output is not valid JSON: {exc}\n {output[:200]!r}", ) match = expect["match"] if not _is_subset(match, parsed): return ValidationResult( passed=False, message=f"JSON subset mismatch:\n expected subset: {match}\n got: {parsed}", ) return ValidationResult(passed=True, message="JSON subset match") def _validate_json_keys(output: str, expect: dict[str, Any]) -> ValidationResult: """Parse output as JSON dict and verify it has entries (non-empty).""" try: parsed = json.loads(output.strip()) except json.JSONDecodeError as exc: return ValidationResult( passed=False, message=f"Output is not valid JSON: {exc}\n {output[:200]!r}", ) if not isinstance(parsed, dict): return ValidationResult( passed=False, message=f"Expected JSON object, got {type(parsed).__name__}", ) required_keys = expect.get("keys", []) if required_keys: missing = [k for k in required_keys if k not in parsed] if missing: return ValidationResult( passed=False, message=f"Missing keys: {missing}. Got: {list(parsed.keys())}", ) elif not parsed: return ValidationResult( passed=False, message="Expected non-empty JSON object, got empty dict", ) return ValidationResult( passed=True, message=f"JSON dict with {len(parsed)} keys: {list(parsed.keys())[:5]}", ) def _validate_not_empty(output: str, _expect: dict[str, Any]) -> ValidationResult: if output.strip(): return ValidationResult(passed=True, message="Output is non-empty") return ValidationResult(passed=False, message="Output is empty") # ── Helpers ──────────────────────────────────────────────────────────── def _is_subset(subset: Any, superset: Any) -> bool: """Check that *subset* is recursively contained in *superset*.""" if isinstance(subset, dict) and isinstance(superset, dict): return all( k in superset and _is_subset(v, superset[k]) for k, v in subset.items() ) if isinstance(subset, list) and isinstance(superset, list): return all( any(_is_subset(s_item, p_item) for p_item in superset) for s_item in subset ) return subset == superset # ── Registry ─────────────────────────────────────────────────────────── _VALIDATORS = { "exact": _validate_exact, "contains": _validate_contains, "regex": _validate_regex, "file_exists": _validate_file_exists, "json_match": _validate_json_match, "json_keys": _validate_json_keys, "not_empty": _validate_not_empty, } ================================================ FILE: tools/test-harness/manifest.yaml ================================================ # Cog Model Test Manifest # ======================= # Each entry defines a model to test, its inputs, and expected outputs. # # Input values prefixed with "@" are resolved as fixture file paths relative # to the fixtures/ directory (e.g. "@test_image.png" -> fixtures/test_image.png). # # Validation types: # exact - output string must equal `value` exactly # contains - output string must contain `value` as a substring # regex - output string must match `pattern` # file_exists - output is a file path; optionally check `mime` type # json_match - parse output as JSON, assert `match` is a subset # json_keys - parse output as JSON dict, assert it has entries # not_empty - output is non-empty (loose smoke test) defaults: sdk_version: "latest" # "latest" = newest stable from PyPI; or pin e.g. "0.16.12" cog_version: "latest" # "latest" = newest stable release; or pin e.g. "v0.16.12" models: # ── cog-examples (CPU) ────────────────────────────────────────────── - name: hello-world repo: replicate/cog-examples path: hello-world gpu: false tests: - description: "basic predict" inputs: text: "world" expect: type: exact value: "hello world" - name: canary repo: replicate/cog-examples path: canary gpu: false tests: - description: "streaming concatenate iterator" inputs: text: "friend" expect: type: contains value: "friend" - name: blur repo: replicate/cog-examples path: blur gpu: false tests: - description: "blur an image" inputs: image: "@test_image.png" blur: 5 expect: type: file_exists mime: "image/png" - name: hello-image repo: replicate/cog-examples path: hello-image gpu: false tests: - description: "return a static image" inputs: {} expect: type: file_exists - name: hello-concurrency repo: replicate/cog-examples path: hello-concurrency gpu: false tests: - description: "async streaming output" inputs: total: 3 interval: 0 expect: type: contains value: "Apple" - name: hello-context repo: replicate/cog-examples path: hello-context gpu: false # NOTE: This model uses current_scope().context which may not be # available in coglet yet. A failure here is a real compatibility signal. tests: - description: "returns input and context" inputs: text: "testing" expect: type: json_match match: inputs: text: "testing" - name: hello-train repo: replicate/cog-examples path: hello-train gpu: false # NOTE: `cog train` in the RC may have input validation issues # (validates against predict schema instead of train schema). # The train_test below may fail — that's a real compatibility signal. train_tests: - description: "train produces weights file" inputs: prefix: "custom" expect: type: not_empty tests: - description: "predict with default weights" inputs: text: "world" expect: type: contains value: "world" # ── cog-examples (GPU required) ───────────────────────────────────── - name: resnet repo: replicate/cog-examples path: resnet gpu: true tests: - description: "classify hotdog image" inputs: image: "@hotdog.png" expect: type: json_keys - name: z-image-turbo repo: replicate/cog-examples path: z-image-turbo gpu: true timeout: 600 tests: - description: "generate image from prompt" inputs: prompt: "a cat sitting on a windowsill" expect: type: file_exists mime: "image/png" # ── cog-examples (requires external API, optional) ────────────────── - name: hello-replicate repo: replicate/cog-examples path: hello-replicate gpu: false requires_env: - REPLICATE_API_TOKEN tests: - description: "round-trip through replicate API" inputs: image: "@test_image.png" expect: type: file_exists # ── External models (add your own below) ──────────────────────────── # - name: my-custom-model # repo: myorg/my-model-repo # path: "." # gpu: true # # sdk_version: "0.16.12" # optional per-model override # env: # HF_TOKEN: "${HF_TOKEN}" # timeout: 600 # tests: # - description: "smoke test" # inputs: # prompt: "hello" # expect: # type: contains # value: "result" ================================================ FILE: tools/test-harness/pyproject.toml ================================================ [project] name = "cog-test-harness" version = "0.1.0" description = "Test harness for validating cog models against new SDK versions" requires-python = ">=3.10" dependencies = [ "pyyaml>=6.0", ] [project.scripts] cog-test = "harness.cli:main" ================================================ FILE: tools/test-harness/results/.gitkeep ================================================ ================================================ FILE: tools/test-registry-util/README.md ================================================ # `test-registry-util` A tool for creating and inspecting a local registry for testing. ## Purpose We have a lot of intricate image manipulation code that needs to be tested. Mocks are't great for this because we need to make sure the code works with actual data. This tool helps setup real data for a test registry. ## Usage Image data is stored in `pkg/registry_testhelpers/testdata` and matches the structore expected by `distribution/distribution`. During tests an ephemeral registry is spun up on a random local port, populated with the image data, and turn down when the test finishes. ### Booting a registry in a test: ```go import "github.com/replicate/cog/pkg/registry_testhelpers" func TestMyFunction(t *testing.T) { registryContainer := registry_testhelpers.StartTestRegistry(ctx) image := registryContainer.ImageRef("alpine:latest") // use image as a real image reference } ``` ### Inspect the current images in the registry: ```bash go run ./tools/test-registry-util catalog ``` will print something like: ``` alpine:latest application/vnd.oci.image.index.v1+json index -> sha256:9a0ff41dccad7a96f324a4655a715c623ed3511c7336361ffa9dadcecbdb99e5 linux/amd64 -> sha256:1c4eef651f65e2f7daee7ee785882ac164b02b78fb74503052a26dc061c90474 linux/arm64 -> sha256:757d680068d77be46fd1ea20fb21db16f150468c5e7079a08a2e4705aec096ac python:3.10 application/vnd.oci.image.manifest.v1+json single platform image -> sha256:f33bb19d5a518ba7e0353b6da48d58a04ef674de0bab0810e4751230ea1d4b19 ``` You can then use these images in your tests using references like: - `localhost:/alpine:latest` to get a multi-platform index - `localhost:/alpine:latest` with platform `linux/amd64` to get a single image from a multi-platform index - `localhost:/alpine:latest@sha256:1c4eef651f65e2f7daee7ee785882ac164b02b78fb74503052a26dc061c90474` to get a specific image - `localhost:/python:3.10` to get a single-platform image ### Initialize a new registry storage To create a new directory of images, run: ``` go run ./tools/test-registry-util init ``` This will download all the images specified in `main.go` and save them to `pkg/registry_testhelpers/testdata`. ### Run a registry This is just a convenience to inspect a registry outside of a test. ``` go run ./tools/test-registry-util run ``` ================================================ FILE: tools/test-registry-util/main.go ================================================ package main import ( "context" "fmt" "os" "os/signal" "path/filepath" "strings" "time" "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/mount" "github.com/google/go-containerregistry/pkg/authn" "github.com/google/go-containerregistry/pkg/name" v1 "github.com/google/go-containerregistry/pkg/v1" "github.com/google/go-containerregistry/pkg/v1/empty" "github.com/google/go-containerregistry/pkg/v1/mutate" "github.com/google/go-containerregistry/pkg/v1/remote" "github.com/google/go-containerregistry/pkg/v1/types" "github.com/spf13/cobra" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/modules/registry" "github.com/testcontainers/testcontainers-go/wait" "github.com/replicate/cog/pkg/util/files" ) // images to download and push to the registry. Keep the images sizes small since they're stored in git. // For reference, the `alpine:latest` image for `linux/amd64` ~3.5MB compressed. var images = []struct { Image string Platforms []string SinglePlatform string }{ { Image: "alpine:latest", Platforms: []string{ "linux/amd64", "linux/arm64", }, }, } // relative to the root of the repo var destinationDir string = "pkg/registry_testhelpers/testdata" func main() { rootCmd := &cobra.Command{ Use: "test-registry-util", } rootCmd.PersistentFlags().StringVar(&destinationDir, "storage-dir", destinationDir, "path to the directory where the registry will store its data") rootCmd.AddCommand( &cobra.Command{ Use: "init", RunE: func(cmd *cobra.Command, args []string) error { return runAndInit(cmd.Context(), destinationDir) }, }, ) rootCmd.AddCommand( &cobra.Command{ Use: "catalog", RunE: func(cmd *cobra.Command, args []string) error { return runAndCatalog(cmd.Context(), destinationDir) }, }, ) rootCmd.AddCommand( &cobra.Command{ Use: "run", RunE: func(cmd *cobra.Command, args []string) error { ctx, cancel := signal.NotifyContext(cmd.Context(), os.Interrupt) defer cancel() c, port, err := startRegistryTC(cmd.Context(), destinationDir) if err != nil { return err } defer func() { if err := c.Terminate(cmd.Context()); err != nil { fmt.Println("Failed to terminate registry:", err) } }() fmt.Println("Registry running at", fmt.Sprintf("localhost:%d", port)) <-ctx.Done() return nil }, }, ) if err := rootCmd.Execute(); err != nil { fmt.Println("Failed to run:", err) os.Exit(1) } } func runAndInit(ctx context.Context, dstDir string) error { if empty, err := files.IsEmpty(dstDir); err != nil { return fmt.Errorf("failed to check if destination directory is empty: %w", err) } else if !empty { return fmt.Errorf("destination directory %s is not empty", dstDir) } if err := os.MkdirAll(dstDir, 0o755); err != nil { return fmt.Errorf("failed to create destination directory: %w", err) } tmpDir, err := os.MkdirTemp("", "test-registry-") if err != nil { return err } defer os.RemoveAll(tmpDir) reg, hostPort, err := startRegistryTC(ctx, tmpDir) if err != nil { return err } defer func() { if err := reg.Terminate(ctx); err != nil { fmt.Println("Failed to terminate registry:", err) } }() addr := fmt.Sprintf("localhost:%d", hostPort) for _, src := range images { destRepo := fmt.Sprintf("%s/%s", addr, strings.Split(src.Image, ":")[0]) // e.g. localhost:5000/alpine tagPart := strings.Split(src.Image, ":")[1] if src.SinglePlatform != "" { osArch := strings.SplitN(src.SinglePlatform, "/", 2) plat := v1.Platform{OS: osArch[0], Architecture: osArch[1]} // Pull source image for specified platform srcRef, err := name.ParseReference(src.Image) if err != nil { return fmt.Errorf("parse reference: %w", err) } srcImg, err := remote.Image(srcRef, remote.WithPlatform(plat), remote.WithContext(ctx)) if err != nil { return err } // Push with desired tag destRef, err := name.ParseReference(fmt.Sprintf("%s:%s", destRepo, tagPart), name.Insecure) if err != nil { return fmt.Errorf("parse reference: %w", err) } if err := remote.Write(destRef, srcImg, remote.WithContext(ctx), remote.WithAuth(authn.Anonymous)); err != nil { return fmt.Errorf("write %s: %w", destRef, err) } fmt.Printf("✅ pushed single-platform image %s\n", destRef.Name()) continue } idx := mutate.IndexMediaType(empty.Index, types.OCIImageIndex) // start empty for _, platStr := range src.Platforms { osArch := strings.SplitN(platStr, "/", 2) plat := v1.Platform{OS: osArch[0], Architecture: osArch[1]} // 1. pull source manifest for this platform srcRef, err := name.ParseReference(src.Image) if err != nil { return fmt.Errorf("parse reference: %w", err) } srcImg, err := remote.Image(srcRef, remote.WithPlatform(plat), remote.WithContext(ctx)) if err != nil { return err } // 2. push it *by digest* into the new registry digest, _ := srcImg.Digest() destDigestRef, err := name.ParseReference(fmt.Sprintf("%s@%s", destRepo, digest.String()), name.Insecure) if err != nil { return fmt.Errorf("parse reference: %w", err) } if err := remote.Write(destDigestRef, srcImg, remote.WithContext(ctx), remote.WithAuth(authn.Anonymous)); err != nil { return fmt.Errorf("write %s: %w", destDigestRef, err) } // 3. add it to the (soon‑to‑be) index idx = mutate.AppendManifests(idx, mutate.IndexAddendum{Add: srcImg, Descriptor: v1.Descriptor{Platform: &plat}}) fmt.Printf("✅ pushed %s for %s/%s\n", destDigestRef.Name(), plat.OS, plat.Architecture) } // 4. push the assembled index and tag it indexTag, err := name.ParseReference(fmt.Sprintf("%s:%s", destRepo, tagPart), name.Insecure) if err != nil { return fmt.Errorf("parse reference: %w", err) } if err := remote.WriteIndex(indexTag, idx, remote.WithContext(ctx), remote.WithAuth(authn.Anonymous)); err != nil { return fmt.Errorf("write index %s: %w", indexTag, err) } fmt.Printf("🏷️ tagged multi-arch index %s\n", indexTag.Name()) } fmt.Println("Copying registry data to", dstDir) if err := os.CopyFS(dstDir, os.DirFS(tmpDir)); err != nil { return fmt.Errorf("failed to copy registry data: %w", err) } if err := catalog(ctx, addr); err != nil { return fmt.Errorf("catalog tree: %w", err) } return nil } func runAndCatalog(ctx context.Context, dir string) error { dir, err := filepath.Abs(dir) if err != nil { return fmt.Errorf("failed to get absolute path: %w", err) } reg, _, err := startRegistryTC(ctx, dir) if err != nil { return err } defer func() { if err := reg.Terminate(ctx); err != nil { fmt.Println("Failed to terminate registry:", err) } }() if err := catalog(ctx, reg.RegistryName); err != nil { return fmt.Errorf("catalog: %w", err) } return nil } func catalog(ctx context.Context, addr string) error { opts := []remote.Option{ remote.WithContext(ctx), remote.WithAuth(authn.Anonymous), // local registry } reg, err := name.NewRegistry(addr, name.Insecure) if err != nil { return fmt.Errorf("new registry: %w", err) } // first, list all repositories repos, err := remote.Catalog(ctx, reg, opts...) if err != nil { return err } for _, repoName := range repos { repo := reg.Repo(repoName) // second, list all tags tagNames, err := remote.List(repo, opts...) if err != nil { return err } for _, tagName := range tagNames { // third, get the manifest ref, err := name.ParseReference(fmt.Sprintf("%s/%s:%s", addr, repoName, tagName)) if err != nil { return fmt.Errorf("parse reference: %w", err) } desc, err := remote.Get(ref, opts...) if err != nil { return err } repoTag := fmt.Sprintf("%s:%s", ref.Context().RepositoryStr(), ref.Identifier()) switch mt := desc.MediaType; mt { case types.OCIImageIndex, types.DockerManifestList: fmt.Printf("%s %s\n index -> %s\n", repoTag, mt, desc.Digest) idx, _ := desc.ImageIndex() im, _ := idx.IndexManifest() for _, m := range im.Manifests { fmt.Printf(" %s -> %s\n", m.Platform.String(), m.Digest, ) } default: // single‑platform image fmt.Printf("%s %s\n single platform image -> %s\n", repoTag, mt, desc.Digest) } } } return nil } func startRegistryTC(ctx context.Context, dir string) (*registry.RegistryContainer, int, error) { dir, err := filepath.Abs(dir) if err != nil { return nil, 0, fmt.Errorf("failed to get absolute path: %w", err) } reg, err := registry.Run(ctx, "registry:3", testcontainers.WithHostConfigModifier(func(hostConfig *container.HostConfig) { hostConfig.Mounts = []mount.Mount{ { Type: "bind", Source: dir, Target: "/var/lib/registry", }, } }), testcontainers.WithWaitStrategy( wait.ForHTTP("/v2/").WithPort("5000/tcp"). WithStartupTimeout(10*time.Second), ), ) if err != nil { return nil, 0, fmt.Errorf("start registry: %w", err) } port, err := reg.MappedPort(ctx, "5000/tcp") if err != nil { if err := reg.Terminate(ctx); err != nil { fmt.Println("Failed to terminate registry:", err) } return nil, 0, fmt.Errorf("mapped port: %w", err) } return reg, port.Int(), nil } ================================================ FILE: tools/weights-gen/README.md ================================================ # weights-gen A tool for generating random weight files and optionally a `weights.lock` file for testing. ## Installation ```bash go install github.com/replicate/cog/tools/weights-gen@latest ``` ## Usage ```bash # If installed via go install weights-gen [flags] # Or run directly from the repository go run ./tools/weights-gen [flags] ``` ## Flags | Flag | Short | Default | Description | |------|-------|---------|-------------| | `--count` | `-n` | `3` | Number of random weight files to generate | | `--min-size` | | `25mb` | Minimum file size (e.g., `12mb`, `25MB`, `1gb`) | | `--max-size` | | `50mb` | Maximum file size (e.g., `50mb`, `100MB`, `1gb`) | | `--output-dir` | | temp dir | Directory to write generated weight files | | `--output` | `-o` | `weights.lock` | Output path for weights.lock file | | `--dest-prefix` | | `/cache/` | Prefix for destination paths in lock file | | `--no-lock` | | `false` | Skip generating the weights.lock file | ## Examples ```bash # Generate 3 random files (25-50MB each) with a weights.lock file go run ./tools/weights-gen # Generate 5 files between 12-50MB go run ./tools/weights-gen --count 5 --min-size 12mb --max-size 50mb # Generate files to a specific output directory go run ./tools/weights-gen --output-dir ./my-weights/ # Generate only weight files without a lock file go run ./tools/weights-gen --output-dir ./my-weights/ --no-lock # Generate files with custom destination prefix go run ./tools/weights-gen --output-dir ./my-weights/ --dest-prefix /models/ ``` ## Output The tool generates: - Random binary weight files named `weights-001.bin`, `weights-002.bin`, etc. - A `weights.lock` file (unless `--no-lock` is specified) containing metadata about each file including SHA256 digests for both original and gzip-compressed content. The path to the generated files is always printed to stdout. ## How the lock file works The `weights.lock` file contains a `dest` field for each weight file. By default, `dest` paths use the `/cache/` prefix, which is the standard location for weights in Cog containers. Use `--dest-prefix` to override this behavior if you need different paths in the lock file (e.g., `/models/` or local paths for testing). ================================================ FILE: tools/weights-gen/main.go ================================================ // tools/weights-gen/main.go package main import ( "crypto/sha256" "encoding/hex" "fmt" "math/rand" "os" "path/filepath" "strconv" "strings" "time" "github.com/spf13/cobra" "github.com/replicate/cog/pkg/model" ) func main() { var ( destPrefix string outputPath string outputDir string count int minSize string maxSize string noLock bool ) cmd := &cobra.Command{ Use: "weights-gen", Short: "Generate random weight files and optionally a weights.lock file", Long: `This tool generates random weight files and optionally a weights.lock file for testing. It creates random binary files of configurable size and computes their digests, simulating what a future "cog weights" command would do with real weight files. By default, both weight files and a weights.lock file are generated. Use --no-lock to generate only the weight files without the lock file. The lock file's dest paths default to /cache/ for container paths. Use --dest-prefix to override this. Examples: # Generate 3 random files (25-50MB each) with defaults (includes weights.lock) weights-gen # Generate 5 files between 12-50MB weights-gen --count 5 --min-size 12mb --max-size 50mb # Generate files to a specific output directory weights-gen --output-dir ./my-weights/ # Generate only weight files without a lock file weights-gen --output-dir ./my-weights/ --no-lock # Generate files with custom destination prefix weights-gen --output-dir ./my-weights/ --dest-prefix /models/`, RunE: func(cmd *cobra.Command, args []string) error { minBytes, err := parseSize(minSize) if err != nil { return fmt.Errorf("invalid --min-size: %w", err) } maxBytes, err := parseSize(maxSize) if err != nil { return fmt.Errorf("invalid --max-size: %w", err) } if minBytes > maxBytes { return fmt.Errorf("--min-size (%s) cannot be greater than --max-size (%s)", minSize, maxSize) } if count < 1 { return fmt.Errorf("--count must be at least 1") } return generateWeights(outputDir, destPrefix, outputPath, count, minBytes, maxBytes, !noLock) }, } cmd.Flags().StringVar(&destPrefix, "dest-prefix", "/cache/", "Prefix for destination paths in lock file (default: /cache/)") cmd.Flags().StringVarP(&outputPath, "output", "o", "weights.lock", "Output path for weights.lock file") cmd.Flags().StringVar(&outputDir, "output-dir", "", "Directory to write generated weight files (default: temp dir)") cmd.Flags().IntVarP(&count, "count", "n", 3, "Number of random weight files to generate") cmd.Flags().StringVar(&minSize, "min-size", "25mb", "Minimum file size (e.g., 12mb, 25MB, 1gb)") cmd.Flags().StringVar(&maxSize, "max-size", "50mb", "Maximum file size (e.g., 50mb, 100MB, 1gb)") cmd.Flags().BoolVar(&noLock, "no-lock", false, "Skip generating the weights.lock file") if err := cmd.Execute(); err != nil { os.Exit(1) } } // parseSize parses a size string like "25mb", "50MB", "1gb" into bytes. func parseSize(s string) (int64, error) { s = strings.TrimSpace(strings.ToLower(s)) if s == "" { return 0, fmt.Errorf("empty size string") } var multiplier int64 = 1 var numStr string switch { case strings.HasSuffix(s, "gb"): multiplier = 1024 * 1024 * 1024 numStr = strings.TrimSuffix(s, "gb") case strings.HasSuffix(s, "mb"): multiplier = 1024 * 1024 numStr = strings.TrimSuffix(s, "mb") case strings.HasSuffix(s, "kb"): multiplier = 1024 numStr = strings.TrimSuffix(s, "kb") case strings.HasSuffix(s, "b"): numStr = strings.TrimSuffix(s, "b") default: // Assume bytes if no suffix numStr = s } num, err := strconv.ParseFloat(strings.TrimSpace(numStr), 64) if err != nil { return 0, fmt.Errorf("invalid number: %s", numStr) } if num < 0 { return 0, fmt.Errorf("size cannot be negative") } return int64(num * float64(multiplier)), nil } func generateWeights(outputDir, destPrefix, outputPath string, count int, minSize, maxSize int64, generateLock bool) error { // Determine where to write files var filesDir string if outputDir != "" { // User specified an output directory if err := os.MkdirAll(outputDir, 0o755); err != nil { return fmt.Errorf("create output directory: %w", err) } filesDir = outputDir } else { // Use a temp directory (not cleaned up so user can access the files) tmpDir, err := os.MkdirTemp("", "weights-gen-") if err != nil { return fmt.Errorf("create temp directory: %w", err) } filesDir = tmpDir } // Seed random number generator // Using math/rand is fine for test data generation - we don't need crypto randomness rng := rand.New(rand.NewSource(time.Now().UnixNano())) //nolint:gosec // Generate random files fmt.Printf("Generating %d random weight files (%s - %s each)...\n", count, formatSize(minSize), formatSize(maxSize)) var files []model.WeightFile for i := 1; i <= count; i++ { // Random size between min and max var size int64 if minSize == maxSize { size = minSize } else { size = minSize + rng.Int63n(maxSize-minSize+1) } filename := fmt.Sprintf("weights-%03d.bin", i) filePath := filepath.Join(filesDir, filename) fmt.Printf(" Creating %s (%s)...\n", filename, formatSize(size)) if err := generateRandomFile(filePath, size, rng); err != nil { return fmt.Errorf("generate %s: %w", filename, err) } if generateLock { wf, err := processFile(filePath, filesDir, destPrefix) if err != nil { return fmt.Errorf("process %s: %w", filename, err) } files = append(files, *wf) fmt.Printf(" Processed: %s -> %s\n", wf.Name, wf.Dest) } else { fmt.Printf(" Created: %s\n", filename) } } if generateLock { lock := &model.WeightsLock{ Version: "1", Created: time.Now().UTC(), Files: files, } if err := lock.Save(outputPath); err != nil { return err } fmt.Printf("\nGenerated %s with %d files\n", outputPath, len(files)) } else { fmt.Printf("\nGenerated %d weight files (no lock file)\n", count) } fmt.Printf("Weight files written to: %s\n", filesDir) return nil } // generateRandomFile creates a file filled with random data of the specified size. func generateRandomFile(path string, size int64, rng *rand.Rand) error { f, err := os.Create(path) if err != nil { return fmt.Errorf("create file: %w", err) } defer f.Close() // Write in chunks to avoid allocating huge buffers const chunkSize = 1024 * 1024 // 1MB chunks chunk := make([]byte, chunkSize) remaining := size for remaining > 0 { toWrite := min(remaining, chunkSize) // Fill chunk with random data _, _ = rng.Read(chunk[:toWrite]) if _, err := f.Write(chunk[:toWrite]); err != nil { return fmt.Errorf("write: %w", err) } remaining -= toWrite } return nil } // formatSize formats bytes into a human-readable string. func formatSize(bytes int64) string { const ( kb = 1024 mb = kb * 1024 gb = mb * 1024 ) switch { case bytes >= gb: return fmt.Sprintf("%.1fGB", float64(bytes)/float64(gb)) case bytes >= mb: return fmt.Sprintf("%.1fMB", float64(bytes)/float64(mb)) case bytes >= kb: return fmt.Sprintf("%.1fKB", float64(bytes)/float64(kb)) default: return fmt.Sprintf("%dB", bytes) } } func processFile(path, baseDir, destPrefix string) (*model.WeightFile, error) { // Read file data, err := os.ReadFile(path) if err != nil { return nil, fmt.Errorf("read file: %w", err) } // Compute digest hash := sha256.Sum256(data) digest := "sha256:" + hex.EncodeToString(hash[:]) // Compute relative path for dest relPath, err := filepath.Rel(baseDir, path) if err != nil { return nil, fmt.Errorf("rel path: %w", err) } dest := filepath.Join(destPrefix, relPath) // Normalize to forward slashes for container paths dest = strings.ReplaceAll(dest, "\\", "/") // Generate a simple identifier from the filename (without extension) baseName := filepath.Base(path) name := baseName[:len(baseName)-len(filepath.Ext(baseName))] size := int64(len(data)) return &model.WeightFile{ Name: name, Dest: dest, Digest: digest, DigestOriginal: digest, Size: size, SizeUncompressed: size, MediaType: model.MediaTypeWeightLayer, }, nil }