[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled\n__pycache__/\n*.py[cod]\n\n# Virtual env\n.venv/\nvenv/\n\n# Output\noutput/\n\n# Logs\nlogs/\n\n# IDE\n.idea/\n.vscode/\n\n# Environment\n.env\n\n# SQLite\n*.db\n\n# OS\n.DS_Store\n"
  },
  {
    "path": "Dockerfile",
    "content": "FROM nvidia/cuda:12.4.1-runtime-ubuntu22.04\n\nWORKDIR /app\n\n# System deps\nRUN apt-get update && apt-get install -y --no-install-recommends \\\n    python3.11 python3.11-venv python3-pip git && \\\n    rm -rf /var/lib/apt/lists/*\n\n# Create venv\nRUN python3.11 -m venv /opt/venv\nENV PATH=\"/opt/venv/bin:$PATH\"\n\n# Install Python deps\nCOPY pyproject.toml .\nRUN pip install --no-cache-dir --upgrade pip && \\\n    pip install --no-cache-dir \".[dev]\"\n\n# Install PyTorch with CUDA 12.4 (override index for torch)\nRUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cu124\n\n# Copy application\nCOPY app/ app/\nCOPY tests/ tests/\n\n# Create output directory\nRUN mkdir -p /app/output\n\nEXPOSE 8000\n\nCMD [\"uvicorn\", \"app.main:app\", \"--host\", \"0.0.0.0\", \"--port\", \"8000\"]\n"
  },
  {
    "path": "README.md",
    "content": "# Flux Image Service\n\nLocal image generation API powered by **Flux.1 Schnell 12B** running on an H100 GPU with FP8 quantization.\n\n## Features\n\n- **Single image generation** — async (returns job IDs) or sync (waits for result), up to 4 images per request\n- **Batch/dataset generation** — submit up to 50K prompts via API or CSV/JSON file upload\n- **Priority queue** — real-time requests always jump ahead of batch jobs\n- **SSE progress streaming** — monitor batch progress in real-time\n- **Crash recovery** — pending jobs auto-resume on server restart (persisted in SQLite)\n- **Multiple formats** — PNG (with metadata), JPEG, WebP\n- **Custom resolutions** — any multiple of 8, up to 2048px\n\n## Prerequisites\n\n- **Python** 3.10+\n- **NVIDIA GPU** with CUDA 12.x (tested on H100 20GB)\n- **NVIDIA drivers** 535+ with CUDA toolkit\n- **System RAM** 32 GB recommended (16 GB minimum)\n- **Disk** ~12 GB for model weights + storage for generated images\n\n## Setup\n\n```bash\n# 1. Clone / navigate to the project\ncd flux-image-service\n\n# 2. Create a virtual environment\npython3 -m venv .venv\nsource .venv/bin/activate\n\n# 3. Install dependencies (includes torch, diffusers, fastapi, etc.)\npip install -e \".[dev]\"\n\n# 4. Copy and configure environment variables\ncp .env.example .env\n# Edit .env if needed (model ID, storage path, etc.)\n```\n\n## Running the Server\n\n```bash\n# Activate venv (if not already)\nsource .venv/bin/activate\n\n# Start the server\nuvicorn app.main:app --host 0.0.0.0 --port 8000\n\n# Or with auto-reload for development (not recommended for GPU — reloads unload the model)\nuvicorn app.main:app --host 0.0.0.0 --port 8000 --reload\n```\n\nThe first startup will:\n1. Download the Flux.1 Schnell model from HuggingFace (~12 GB)\n2. Compile CUDA kernels via `torch.compile` (~30-60s one-time cost per restart)\n3. Run a warm-up inference\n\nOnce you see `Flux Image Service ready`, the server is accepting requests.\n\n## Running Tests\n\n```bash\nsource .venv/bin/activate\n\n# Run all tests\npytest tests/ -v\n\n# Run a specific test file\npytest tests/test_api_generate.py -v\n\n# Run with coverage\npip install pytest-cov\npytest tests/ -v --cov=app --cov-report=term-missing\n```\n\n## Docker\n\n```bash\n# Build\ndocker build -t flux-image-service .\n\n# Run (requires nvidia-container-toolkit)\ndocker run --gpus all -p 8000:8000 -v ./output:/app/output flux-image-service\n\n# Run with custom env\ndocker run --gpus all -p 8000:8000 \\\n  -e MODEL_ID=black-forest-labs/FLUX.1-schnell \\\n  -e STORAGE_DIR=/app/output \\\n  -v ./output:/app/output \\\n  flux-image-service\n```\n\n## Linting\n\n```bash\n# Check\nruff check app/ tests/\n\n# Auto-fix\nruff check app/ tests/ --fix\n\n# Format\nruff format app/ tests/\n```\n\n---\n\n## API Reference\n\n### `POST /generate` — Async Image Generation\n\nSubmit a generation request (1–4 images). Returns immediately with job IDs.\n\n**Request:**\n```bash\ncurl -X POST http://localhost:8000/generate \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"prompt\": \"a majestic mountain landscape at sunset\",\n    \"width\": 1024,\n    \"height\": 1024,\n    \"num_images\": 2,\n    \"seed\": 42\n  }'\n```\n\n**Response** `200 OK`:\n```json\n{\n  \"job_ids\": [\"f47ac10b-58cc-4372-a567-0e02b2c3d479\", \"a1b2c3d4-5678-9abc-def0-1234567890ab\"],\n  \"status\": \"pending\",\n  \"image_urls\": []\n}\n```\n\n**Validation error** `422`:\n```json\n{\n  \"detail\": [\n    {\n      \"type\": \"value_error\",\n      \"loc\": [\"body\", \"width\"],\n      \"msg\": \"Value error, width must be a multiple of 8\"\n    }\n  ]\n}\n```\n\nAll request parameters for `ImageRequest`:\n\n| Field | Type | Default | Description |\n|-------|------|---------|-------------|\n| `prompt` | string | *required* | Text prompt (1–2000 chars) |\n| `negative_prompt` | string \\| null | null | Negative prompt |\n| `width` | int | 1024 | Image width (64–2048, multiple of 8) |\n| `height` | int | 1024 | Image height (64–2048, multiple of 8) |\n| `num_steps` | int | 4 | Inference steps (1–50) |\n| `guidance_scale` | float | 0.0 | CFG scale (Schnell doesn't use CFG) |\n| `seed` | int \\| null | null | Random seed (0–4294967295) |\n| `num_images` | int | 1 | Number of images (1–4) |\n| `format` | string | \"png\" | Output format: `png`, `jpeg`, `webp` |\n\n---\n\n### `GET /generate/stream` — SSE Progress Stream for Single Generate\n\nStream status updates for one or more job IDs returned by `/generate` until all are terminal.\n\n**Request:**\n```bash\ncurl -N \"http://localhost:8000/generate/stream?job_id=f47ac10b-58cc-4372-a567-0e02b2c3d479&job_id=a1b2c3d4-5678-9abc-def0-1234567890ab\"\n```\n\n**Events:**\n- `progress` — emitted periodically with per-job status and aggregate counters\n- `done` — emitted once when all jobs are in `completed`, `failed`, or `cancelled`\n\n**Example event payload:**\n```json\n{\n  \"total\": 2,\n  \"completed\": 1,\n  \"failed\": 0,\n  \"cancelled\": 0,\n  \"pending\": 1,\n  \"jobs\": [\n    {\n      \"id\": \"f47ac10b-58cc-4372-a567-0e02b2c3d479\",\n      \"status\": \"completed\",\n      \"image_url\": \"/jobs/f47ac10b-58cc-4372-a567-0e02b2c3d479/image\"\n    },\n    {\n      \"id\": \"a1b2c3d4-5678-9abc-def0-1234567890ab\",\n      \"status\": \"processing\",\n      \"image_url\": null\n    }\n  ]\n}\n```\n\n---\n\n### `POST /generate/sync` — Synchronous Generation\n\nSame request body as `/generate`. Blocks until all images are generated (up to 60s timeout).\n\n**Request:**\n```bash\ncurl -X POST http://localhost:8000/generate/sync \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\"prompt\": \"a cute cat wearing a hat\", \"num_images\": 1}' \\\n  --max-time 60\n```\n\n**Response** `200 OK`:\n```json\n{\n  \"job_ids\": [\"f47ac10b-58cc-4372-a567-0e02b2c3d479\"],\n  \"status\": \"completed\",\n  \"image_urls\": [\"/jobs/f47ac10b-58cc-4372-a567-0e02b2c3d479/image\"]\n}\n```\n\n**Timeout** `408`:\n```json\n{\"detail\": \"Generation timed out\"}\n```\n\n---\n\n### `GET /jobs/{job_id}/image` — Serve Generated Image\n\nReturns the image file for a completed job.\n\n**Request:**\n```bash\ncurl http://localhost:8000/jobs/f47ac10b-58cc-4372-a567-0e02b2c3d479/image --output image.png\n```\n\n**Response:** Binary image file with appropriate `Content-Type` (`image/png`, `image/jpeg`, or `image/webp`).\n\n**Not found** `404`:\n```json\n{\"detail\": \"Image not available\"}\n```\n\n---\n\n### `GET /jobs/{job_id}` — Get Job Details\n\n**Request:**\n```bash\ncurl http://localhost:8000/jobs/f47ac10b-58cc-4372-a567-0e02b2c3d479\n```\n\n**Response** `200 OK`:\n```json\n{\n  \"id\": \"f47ac10b-58cc-4372-a567-0e02b2c3d479\",\n  \"prompt\": \"a majestic mountain landscape at sunset\",\n  \"negative_prompt\": null,\n  \"width\": 1024,\n  \"height\": 1024,\n  \"num_steps\": 4,\n  \"guidance_scale\": 0.0,\n  \"seed\": 42,\n  \"status\": \"completed\",\n  \"priority\": 1,\n  \"format\": \"png\",\n  \"file_path\": \"output/images/2026-03-08/f47ac10b-58cc-4372-a567-0e02b2c3d479.png\",\n  \"image_url\": \"/jobs/f47ac10b-58cc-4372-a567-0e02b2c3d479/image\",\n  \"error_message\": null,\n  \"batch_id\": null,\n  \"created_at\": \"2026-03-08T12:00:00+00:00\",\n  \"started_at\": \"2026-03-08T12:00:01+00:00\",\n  \"completed_at\": \"2026-03-08T12:00:03+00:00\"\n}\n```\n\n---\n\n### `GET /jobs` — List Jobs\n\n**Request:**\n```bash\n# List all jobs (paginated)\ncurl \"http://localhost:8000/jobs?page=1&page_size=20\"\n\n# Filter by status\ncurl \"http://localhost:8000/jobs?status=completed\"\n\n# Filter by batch\ncurl \"http://localhost:8000/jobs?batch_id=batch-uuid-here\"\n```\n\n**Response** `200 OK`:\n```json\n{\n  \"jobs\": [\n    {\n      \"id\": \"f47ac10b-...\",\n      \"prompt\": \"a mountain\",\n      \"status\": \"completed\",\n      \"width\": 1024,\n      \"height\": 1024,\n      \"image_url\": \"/jobs/f47ac10b-.../image\",\n      \"...\": \"...\"\n    }\n  ],\n  \"total\": 150,\n  \"page\": 1,\n  \"page_size\": 20\n}\n```\n\n---\n\n### `DELETE /jobs/{job_id}` — Cancel a Job\n\nCancels a pending/processing job. Already completed/failed jobs are returned as-is.\n\n**Request:**\n```bash\ncurl -X DELETE http://localhost:8000/jobs/f47ac10b-58cc-4372-a567-0e02b2c3d479\n```\n\n**Response** `200 OK`:\n```json\n{\n  \"id\": \"f47ac10b-58cc-4372-a567-0e02b2c3d479\",\n  \"status\": \"cancelled\",\n  \"...\": \"...\"\n}\n```\n\n---\n\n### `POST /batch` — Create Batch Job\n\nSubmit 1–50,000 prompts as a batch. All jobs get BATCH priority (lower than real-time).\n\n**Request:**\n```bash\ncurl -X POST http://localhost:8000/batch \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"name\": \"training-dataset-v1\",\n    \"prompts\": [\n      {\"prompt\": \"a red sports car\", \"width\": 512, \"height\": 512},\n      {\"prompt\": \"a blue ocean wave\", \"width\": 1024, \"height\": 768},\n      {\"prompt\": \"a forest path in autumn\", \"seed\": 123}\n    ]\n  }'\n```\n\n**Response** `200 OK`:\n```json\n{\n  \"batch_id\": \"b1c2d3e4-5678-9abc-def0-1234567890ab\",\n  \"name\": \"training-dataset-v1\",\n  \"total\": 3,\n  \"completed\": 0,\n  \"failed\": 0,\n  \"cancelled\": 0,\n  \"pending\": 3,\n  \"status\": \"pending\",\n  \"estimated_remaining_seconds\": null,\n  \"created_at\": \"2026-03-08T12:00:00+00:00\",\n  \"completed_at\": null\n}\n```\n\n---\n\n### `POST /batch/from-file` — Create Batch from File Upload\n\nUpload a CSV or JSON file of prompts.\n\n**CSV upload:**\n```bash\ncurl -X POST http://localhost:8000/batch/from-file \\\n  -F \"file=@prompts.csv\" \\\n  -F \"name=my-dataset\"\n```\n\nCSV format (only `prompt` column is required):\n```csv\nprompt,width,height,num_steps,seed,format\na red car,512,512,4,42,png\na blue house,1024,1024,,,\na green tree,,,,,\n```\n\n**JSON upload:**\n```bash\ncurl -X POST http://localhost:8000/batch/from-file \\\n  -F \"file=@prompts.json\" \\\n  -F \"name=my-dataset\"\n```\n\nJSON format:\n```json\n[\n  {\"prompt\": \"a red car\", \"width\": 512, \"height\": 512, \"seed\": 42},\n  {\"prompt\": \"a blue house\"},\n  {\"prompt\": \"a green tree\", \"format\": \"jpeg\"}\n]\n```\n\n**Response:** Same as `POST /batch`.\n\n---\n\n### `GET /batch/{batch_id}` — Get Batch Progress\n\n**Request:**\n```bash\ncurl http://localhost:8000/batch/b1c2d3e4-5678-9abc-def0-1234567890ab\n```\n\n**Response** `200 OK`:\n```json\n{\n  \"batch_id\": \"b1c2d3e4-5678-9abc-def0-1234567890ab\",\n  \"name\": \"training-dataset-v1\",\n  \"total\": 1000,\n  \"completed\": 342,\n  \"failed\": 2,\n  \"cancelled\": 0,\n  \"pending\": 656,\n  \"status\": \"processing\",\n  \"estimated_remaining_seconds\": 1968.0,\n  \"created_at\": \"2026-03-08T12:00:00+00:00\",\n  \"completed_at\": null\n}\n```\n\n---\n\n### `GET /batch/{batch_id}/stream` — SSE Progress Stream\n\nReal-time Server-Sent Events stream of batch progress. Updates every 2 seconds.\n\n**Request:**\n```bash\ncurl -N http://localhost:8000/batch/b1c2d3e4-5678-9abc-def0-1234567890ab/stream\n```\n\n**Response** (event stream):\n```\nevent: progress\ndata: {\"batch_id\":\"b1c2d3e4-...\",\"total\":1000,\"completed\":342,\"failed\":2,\"pending\":656,\"status\":\"processing\",\"estimated_remaining_seconds\":1968.0}\n\nevent: progress\ndata: {\"batch_id\":\"b1c2d3e4-...\",\"total\":1000,\"completed\":343,\"failed\":2,\"pending\":655,\"status\":\"processing\",\"estimated_remaining_seconds\":1965.0}\n\nevent: done\ndata: {\"batch_id\":\"b1c2d3e4-...\",\"total\":1000,\"completed\":998,\"failed\":2,\"pending\":0,\"status\":\"failed\"}\n```\n\n---\n\n### `DELETE /batch/{batch_id}` — Cancel Batch\n\nCancels all remaining pending jobs in the batch.\n\n**Request:**\n```bash\ncurl -X DELETE http://localhost:8000/batch/b1c2d3e4-5678-9abc-def0-1234567890ab\n```\n\n**Response** `200 OK`:\n```json\n{\n  \"batch_id\": \"b1c2d3e4-5678-9abc-def0-1234567890ab\",\n  \"total\": 1000,\n  \"completed\": 342,\n  \"failed\": 2,\n  \"cancelled\": 656,\n  \"pending\": 0,\n  \"status\": \"cancelled\",\n  \"...\": \"...\"\n}\n```\n\n---\n\n### `POST /batch/{batch_id}/retry-failed` — Retry Failed Jobs\n\nRe-enqueue only the failed jobs in a batch.\n\n**Request:**\n```bash\ncurl -X POST http://localhost:8000/batch/b1c2d3e4-5678-9abc-def0-1234567890ab/retry-failed\n```\n\n**Response** `200 OK`:\n```json\n{\n  \"retried\": 2,\n  \"job_ids\": [\"job-uuid-1\", \"job-uuid-2\"]\n}\n```\n\n---\n\n### `GET /health` — Health Check\n\n**Request:**\n```bash\ncurl http://localhost:8000/health\n```\n\n**Response** `200 OK`:\n```json\n{\n  \"status\": \"ok\",\n  \"model_loaded\": true,\n  \"vram\": {\n    \"available\": true,\n    \"allocated_gb\": 12.34,\n    \"reserved_gb\": 14.50,\n    \"max_allocated_gb\": 16.78,\n    \"total_gb\": 20.00,\n    \"device_name\": \"NVIDIA H100\"\n  },\n  \"worker\": {\n    \"queue_depth\": 5,\n    \"jobs_processed\": 142,\n    \"total_inference_time_s\": 398.50,\n    \"avg_inference_time_s\": 2.81\n  },\n  \"uptime_seconds\": 3600.0\n}\n```\n\n---\n\n## Architecture\n\n```\nClient → FastAPI → asyncio.PriorityQueue → GPU Worker → Flux Engine\n                       ↓                        ↓\n                    SQLite                   Local FS\n                  (job state)              (images)\n```\n\n- **Single GPU worker** — processes one image at a time (VRAM constraint)\n- **Priority queue** — REALTIME(1) always before BATCH(5)\n- **OOM recovery** — catches GPU OOM, clears cache, continues processing\n- **Crash recovery** — on startup, all PENDING/PROCESSING jobs in SQLite are automatically re-enqueued\n\n### Queue Persistence\n\nThe in-memory `asyncio.PriorityQueue` is **lost on server restart**. However, **no work is lost** because:\n\n1. All jobs are persisted to SQLite the moment they're created (before being enqueued)\n2. On startup, `resume_pending_jobs()` scans for any jobs in `PENDING` or `PROCESSING` state and re-enqueues them\n3. Completed jobs and their images are already on disk\n\nThis means a restart simply rebuilds the queue from the database. The only effect is a brief pause while the model reloads.\n\n## Configuration\n\nAll settings can be overridden via environment variables or `.env` file:\n\n| Variable | Default | Description |\n|----------|---------|-------------|\n| `MODEL_ID` | `black-forest-labs/FLUX.1-schnell` | HuggingFace model ID |\n| `DEVICE` | `cuda` | Torch device |\n| `DTYPE` | `float8_e4m3fn` | Model precision |\n| `ENABLE_TORCH_COMPILE` | `true` | Enable torch.compile optimization |\n| `DEFAULT_NUM_STEPS` | `4` | Default inference steps |\n| `DEFAULT_GUIDANCE_SCALE` | `0.0` | Default CFG scale |\n| `MAX_RESOLUTION` | `2048` | Maximum allowed width/height |\n| `MAX_IMMEDIATE_IMAGES` | `4` | Max images per /generate request |\n| `STORAGE_DIR` | `./output` | Where images and DB are stored |\n| `MAX_QUEUE_SIZE` | `10000` | Maximum queue capacity |\n| `SYNC_TIMEOUT` | `60.0` | Timeout for /generate/sync (seconds) |\n| `HOST` | `0.0.0.0` | Server bind address |\n| `PORT` | `8000` | Server port |\n\n## Performance (H100 20GB, FP8)\n\n| Resolution | Time per image | 10K dataset |\n|-----------|---------------|-------------|\n| 512×512 | ~0.8–1.2s | ~2–3 hours |\n| 1024×1024 | ~2–4s | ~6–10 hours |\n| 1536×640 | ~2–3s | ~5–8 hours |\n"
  },
  {
    "path": "app/__init__.py",
    "content": ""
  },
  {
    "path": "app/api/__init__.py",
    "content": ""
  },
  {
    "path": "app/api/batch.py",
    "content": "\"\"\"Batch / dataset generation endpoints with SSE progress streaming.\"\"\"\n\nfrom __future__ import annotations\n\nimport asyncio\nimport csv\nimport io\nimport json\nimport os\nimport tempfile\nimport zipfile\nfrom pathlib import Path\nfrom typing import AsyncGenerator\n\nfrom fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, File\nfrom fastapi.responses import StreamingResponse\nfrom sse_starlette.sse import EventSourceResponse\nfrom sqlalchemy import select\nfrom sqlalchemy.ext.asyncio import AsyncSession\n\nfrom app.database import get_session\nfrom app.models.job import Job\nfrom app.models.enums import ImageFormat\nfrom app.schemas.batch import BatchProgress, BatchRequest\nfrom app.schemas.generate import ImageRequest\nfrom app.services.generation_service import create_batch\nfrom app.services.job_service import (\n    cancel_batch,\n    get_batch_progress,\n    retry_failed_in_batch,\n)\n\nrouter = APIRouter(prefix=\"/batch\", tags=[\"batch\"])\n\n\ndef _get_worker(request: Request):\n    return request.app.state.worker\n\n\n@router.post(\"\", response_model=BatchProgress)\nasync def create_batch_job(\n    body: BatchRequest,\n    request: Request,\n    session: AsyncSession = Depends(get_session),\n):\n    \"\"\"Create a batch of image generation jobs from a list of prompts.\"\"\"\n    worker = _get_worker(request)\n    batch_id = await create_batch(session, body.name, body.prompts, worker)\n    progress = await get_batch_progress(session, batch_id)\n    return BatchProgress(**progress)\n\n\n@router.post(\"/from-file\", response_model=BatchProgress)\nasync def create_batch_from_file(\n    request: Request,\n    file: UploadFile = File(...),\n    name: str | None = None,\n    session: AsyncSession = Depends(get_session),\n):\n    \"\"\"Create a batch from a CSV or JSON file upload.\n\n    CSV format: prompt,width,height,num_steps,seed,format\n    JSON format: array of ImageRequest objects\n    \"\"\"\n    worker = _get_worker(request)\n    content = await file.read()\n\n    filename = file.filename or \"\"\n    if filename.endswith(\".json\"):\n        prompts = _parse_json(content)\n    elif filename.endswith(\".csv\"):\n        prompts = _parse_csv(content)\n    else:\n        raise HTTPException(\n            status_code=400,\n            detail=\"Unsupported file format. Use .csv or .json\",\n        )\n\n    if not prompts:\n        raise HTTPException(status_code=400, detail=\"No valid prompts found in file\")\n\n    batch_id = await create_batch(session, name or filename, prompts, worker)\n    progress = await get_batch_progress(session, batch_id)\n    return BatchProgress(**progress)\n\n\n@router.get(\"/{batch_id}\", response_model=BatchProgress)\nasync def get_batch_status(\n    batch_id: str,\n    session: AsyncSession = Depends(get_session),\n):\n    progress = await get_batch_progress(session, batch_id)\n    if not progress:\n        raise HTTPException(status_code=404, detail=\"Batch not found\")\n    return BatchProgress(**progress)\n\n\n@router.get(\"/{batch_id}/stream\")\nasync def stream_batch_progress(\n    batch_id: str,\n    request: Request,\n    session: AsyncSession = Depends(get_session),\n):\n    \"\"\"SSE endpoint streaming batch progress updates until completion.\"\"\"\n    progress = await get_batch_progress(session, batch_id)\n    if not progress:\n        raise HTTPException(status_code=404, detail=\"Batch not found\")\n\n    async def event_generator() -> AsyncGenerator[dict, None]:\n        while True:\n            if await request.is_disconnected():\n                break\n\n            async with _fresh_session() as sess:\n                p = await get_batch_progress(sess, batch_id)\n\n            if p is None:\n                break\n\n            # Calculate ETA\n            eta = _estimate_eta(p, request.app.state.worker)\n\n            yield {\n                \"event\": \"progress\",\n                \"data\": json.dumps({**p, \"estimated_remaining_seconds\": eta}),\n            }\n\n            if p[\"status\"] in (\"completed\", \"failed\", \"cancelled\"):\n                yield {\"event\": \"done\", \"data\": json.dumps(p)}\n                break\n\n            await asyncio.sleep(2)\n\n    return EventSourceResponse(event_generator())\n\n\n@router.get(\"/{batch_id}/download\")\nasync def download_batch_images(\n    batch_id: str,\n    session: AsyncSession = Depends(get_session),\n):\n    \"\"\"Download all completed images in a batch as a ZIP archive.\n\n    Streams the ZIP incrementally so the full archive is never held\n    in memory.  File I/O runs in a thread pool to avoid blocking the\n    async event loop.\n    \"\"\"\n    progress = await get_batch_progress(session, batch_id)\n    if not progress:\n        raise HTTPException(status_code=404, detail=\"Batch not found\")\n\n    result = await session.execute(\n        select(Job).where(Job.batch_id == batch_id, Job.status == \"completed\")\n    )\n    jobs = result.scalars().all()\n\n    if not jobs:\n        raise HTTPException(status_code=404, detail=\"No completed images in this batch\")\n\n    # Collect valid file paths up-front (lightweight check).\n    file_entries: list[tuple[str, Path]] = []\n    for job in jobs:\n        if not job.file_path:\n            continue\n        fp = Path(job.file_path)\n        if fp.is_file():\n            file_entries.append((fp.name, fp))\n\n    if not file_entries:\n        raise HTTPException(status_code=404, detail=\"No image files found on disk\")\n\n    batch_name = progress.get(\"name\") or batch_id[:12]\n    safe_name = \"\".join(\n        c if c.isalnum() or c in (\" \", \"-\", \"_\") else \"_\" for c in batch_name\n    ).strip()\n    filename = f\"{safe_name}.zip\"\n\n    async def _stream_zip() -> AsyncGenerator[bytes, None]:\n        \"\"\"Build a valid ZIP in a temp file (off the event loop), then\n        stream it in chunks so memory stays flat.\"\"\"\n        loop = asyncio.get_running_loop()\n        tmp_fd, tmp_path = tempfile.mkstemp(suffix=\".zip\")\n        os.close(tmp_fd)\n\n        def _build_zip() -> None:\n            with zipfile.ZipFile(tmp_path, \"w\", zipfile.ZIP_STORED) as zf:\n                for arc_name, file_path in file_entries:\n                    zf.write(str(file_path), arc_name)\n\n        try:\n            # Build the ZIP in a worker thread — no event-loop blocking.\n            await loop.run_in_executor(None, _build_zip)\n\n            # Stream the temp file in 256 KB chunks.\n            chunk_size = 256 * 1024\n            with open(tmp_path, \"rb\") as f:\n                while True:\n                    chunk = await loop.run_in_executor(None, f.read, chunk_size)\n                    if not chunk:\n                        break\n                    yield chunk\n        finally:\n            try:\n                os.unlink(tmp_path)\n            except OSError:\n                pass\n\n    return StreamingResponse(\n        _stream_zip(),\n        media_type=\"application/zip\",\n        headers={\"Content-Disposition\": f'attachment; filename=\"{filename}\"'},\n    )\n\n\n@router.delete(\"/{batch_id}\", response_model=BatchProgress)\nasync def cancel_batch_job(\n    batch_id: str,\n    session: AsyncSession = Depends(get_session),\n):\n    result = await cancel_batch(session, batch_id)\n    if not result:\n        raise HTTPException(status_code=404, detail=\"Batch not found\")\n    return BatchProgress(**result)\n\n\n@router.post(\"/{batch_id}/retry-failed\")\nasync def retry_failed_jobs(\n    batch_id: str,\n    request: Request,\n    session: AsyncSession = Depends(get_session),\n):\n    \"\"\"Re-enqueue all failed jobs in a batch.\"\"\"\n    worker = _get_worker(request)\n    job_ids = await retry_failed_in_batch(session, batch_id)\n    if not job_ids:\n        raise HTTPException(status_code=404, detail=\"No failed jobs found in batch\")\n\n    for jid in job_ids:\n        await worker.enqueue(jid, priority=5)\n\n    return {\"retried\": len(job_ids), \"job_ids\": job_ids}\n\n\n# ── File parsing helpers ─────────────────────────────────────────\n\n\ndef _parse_json(content: bytes) -> list[ImageRequest]:\n    try:\n        data = json.loads(content)\n    except json.JSONDecodeError as e:\n        raise HTTPException(status_code=400, detail=f\"Invalid JSON: {e}\")\n\n    if not isinstance(data, list):\n        raise HTTPException(status_code=400, detail=\"JSON must be an array of objects\")\n\n    prompts = []\n    for i, item in enumerate(data):\n        try:\n            prompts.append(ImageRequest(**item))\n        except Exception as e:\n            raise HTTPException(status_code=400, detail=f\"Invalid entry at index {i}: {e}\")\n    return prompts\n\n\ndef _parse_csv(content: bytes) -> list[ImageRequest]:\n    try:\n        text = content.decode(\"utf-8\")\n    except UnicodeDecodeError:\n        raise HTTPException(status_code=400, detail=\"CSV must be UTF-8 encoded\")\n\n    reader = csv.DictReader(io.StringIO(text))\n    prompts = []\n    for i, row in enumerate(reader):\n        try:\n            kwargs: dict = {\"prompt\": row[\"prompt\"]}\n            if \"width\" in row and row[\"width\"]:\n                kwargs[\"width\"] = int(row[\"width\"])\n            if \"height\" in row and row[\"height\"]:\n                kwargs[\"height\"] = int(row[\"height\"])\n            if \"num_steps\" in row and row[\"num_steps\"]:\n                kwargs[\"num_steps\"] = int(row[\"num_steps\"])\n            if \"seed\" in row and row[\"seed\"]:\n                kwargs[\"seed\"] = int(row[\"seed\"])\n            if \"format\" in row and row[\"format\"]:\n                kwargs[\"format\"] = ImageFormat(row[\"format\"].strip().lower())\n            prompts.append(ImageRequest(**kwargs))\n        except KeyError:\n            raise HTTPException(\n                status_code=400, detail=f\"Row {i + 1}: 'prompt' column is required\"\n            )\n        except Exception as e:\n            raise HTTPException(status_code=400, detail=f\"Row {i + 1}: {e}\")\n    return prompts\n\n\ndef _estimate_eta(progress: dict, worker) -> float | None:\n    \"\"\"Estimate remaining seconds based on average inference time.\"\"\"\n    stats = worker.get_stats()\n    avg_time = stats.get(\"avg_inference_time_s\", 0)\n    if avg_time <= 0:\n        return None\n    remaining = progress.get(\"pending\", 0)\n    return round(remaining * avg_time, 1)\n\n\n# Helper to get fresh session for SSE generator\nfrom app.database import async_session_factory as _fresh_session  # noqa: E402\n"
  },
  {
    "path": "app/api/generate.py",
    "content": "\"\"\"Image generation endpoints — async and sync modes.\"\"\"\n\nfrom __future__ import annotations\n\nimport asyncio\nimport json\nfrom pathlib import Path\nfrom typing import AsyncGenerator\n\nfrom fastapi import APIRouter, Depends, HTTPException, Query, Request\nfrom fastapi.responses import FileResponse\nfrom sse_starlette.sse import EventSourceResponse\nfrom sqlalchemy import select\nfrom sqlalchemy.ext.asyncio import AsyncSession\n\nfrom app.config import settings\nfrom app.database import get_session\nfrom app.models.job import Job\nfrom app.schemas.generate import ImageRequest, ImageResponse\nfrom app.services.generation_service import create_single_job\nfrom app.services.job_service import get_job\nfrom app.services.moderation_service import ModerationError, check_prompt_safety\nfrom app.utils.storage import image_url_path\n\nrouter = APIRouter(tags=[\"generate\"])\n\n_TERMINAL_JOB_STATUSES = {\"completed\", \"failed\", \"cancelled\"}\n\n\ndef _get_worker(request: Request):\n    return request.app.state.worker\n\n\n@router.post(\"/generate\", response_model=ImageResponse)\nasync def generate_async(\n    body: ImageRequest,\n    request: Request,\n    session: AsyncSession = Depends(get_session),\n):\n    \"\"\"Submit an image generation request (1–4 images). Returns immediately with job IDs.\"\"\"\n    try:\n        is_safe, reason = await check_prompt_safety(body.prompt)\n    except ModerationError:\n        raise HTTPException(status_code=503, detail=\"Content moderation service unavailable\")\n    if not is_safe:\n        raise HTTPException(\n            status_code=400,\n            detail=f\"Prompt contains unsafe or explicit content: {reason}\",\n        )\n\n    worker = _get_worker(request)\n    job_ids = await create_single_job(session, body, worker)\n    return ImageResponse(job_ids=job_ids, status=\"pending\")\n\n\n@router.get(\"/generate/stream\")\nasync def stream_generate_progress(\n    request: Request,\n    job_id: list[str] = Query(..., description=\"One or more job IDs to monitor\"),\n    session: AsyncSession = Depends(get_session),\n):\n    \"\"\"SSE stream for one or more generate job IDs until all reach terminal state.\"\"\"\n    unique_job_ids = list(dict.fromkeys(job_id))\n\n    # Fail fast if any requested job ID does not exist.\n    checks = [await get_job(session, jid) for jid in unique_job_ids]\n    missing = [jid for jid, detail in zip(unique_job_ids, checks, strict=False) if detail is None]\n    if missing:\n        raise HTTPException(\n            status_code=404,\n            detail=f\"Job not found: {', '.join(missing)}\",\n        )\n\n    async def event_generator() -> AsyncGenerator[dict, None]:\n        while True:\n            if await request.is_disconnected():\n                break\n\n            # Expire identity map so concurrent worker updates are visible.\n            session.expire_all()\n            details = [await get_job(session, jid) for jid in unique_job_ids]\n\n            if any(d is None for d in details):\n                yield {\n                    \"event\": \"error\",\n                    \"data\": json.dumps({\"detail\": \"Job no longer exists\"}),\n                }\n                break\n\n            jobs = [d.model_dump() for d in details if d is not None]\n            total = len(jobs)\n            completed = sum(1 for d in details if d and d.status == \"completed\")\n            failed = sum(1 for d in details if d and d.status == \"failed\")\n            cancelled = sum(1 for d in details if d and d.status == \"cancelled\")\n            terminal = sum(1 for d in details if d and d.status in _TERMINAL_JOB_STATUSES)\n\n            payload = {\n                \"total\": total,\n                \"completed\": completed,\n                \"failed\": failed,\n                \"cancelled\": cancelled,\n                \"pending\": total - terminal,\n                \"jobs\": jobs,\n            }\n            yield {\"event\": \"progress\", \"data\": json.dumps(payload)}\n\n            if terminal == total:\n                yield {\"event\": \"done\", \"data\": json.dumps(payload)}\n                break\n\n            await asyncio.sleep(1.0)\n\n    return EventSourceResponse(event_generator())\n\n\n@router.post(\"/generate/sync\", response_model=ImageResponse)\nasync def generate_sync(\n    body: ImageRequest,\n    request: Request,\n    session: AsyncSession = Depends(get_session),\n):\n    \"\"\"Submit and wait for image(s) to be generated (up to SYNC_TIMEOUT seconds).\"\"\"\n    try:\n        is_safe, reason = await check_prompt_safety(body.prompt)\n    except ModerationError:\n        raise HTTPException(status_code=503, detail=\"Content moderation service unavailable\")\n    if not is_safe:\n        raise HTTPException(\n            status_code=400,\n            detail=f\"Prompt contains unsafe or explicit content: {reason}\",\n        )\n\n    worker = _get_worker(request)\n    job_ids = await create_single_job(session, body, worker)\n\n    # Poll until all jobs complete or timeout\n    deadline = asyncio.get_event_loop().time() + settings.SYNC_TIMEOUT\n    while asyncio.get_event_loop().time() < deadline:\n        await asyncio.sleep(0.3)\n        details = [await get_job(session, jid) for jid in job_ids]\n        if all(d and d.status in (\"completed\", \"failed\", \"cancelled\") for d in details):\n            break\n    else:\n        raise HTTPException(status_code=408, detail=\"Generation timed out\")\n\n    failed = [d for d in details if d and d.status == \"failed\"]\n    if failed and len(failed) == len(details):\n        raise HTTPException(\n            status_code=500,\n            detail=failed[0].error_message or \"Generation failed\",\n        )\n\n    return ImageResponse(\n        job_ids=job_ids,\n        status=\"completed\",\n        image_urls=[d.image_url if d and d.status == \"completed\" else None for d in details],\n    )\n\n\n@router.get(\"/jobs/{job_id}/image\")\nasync def serve_image(\n    job_id: str,\n    session: AsyncSession = Depends(get_session),\n):\n    \"\"\"Serve the generated image file for a completed job.\"\"\"\n    result = await session.execute(select(Job).where(Job.id == job_id))\n    job = result.scalar_one_or_none()\n    if not job:\n        raise HTTPException(status_code=404, detail=\"Job not found\")\n    if job.status != \"completed\" or not job.file_path:\n        raise HTTPException(status_code=404, detail=\"Image not available\")\n\n    path = Path(job.file_path)\n    if not path.is_file():\n        raise HTTPException(status_code=404, detail=\"Image file missing from disk\")\n\n    media_types = {\n        \"png\": \"image/png\",\n        \"jpeg\": \"image/jpeg\",\n        \"jpg\": \"image/jpeg\",\n        \"webp\": \"image/webp\",\n    }\n    ext = path.suffix.lstrip(\".\")\n    media_type = media_types.get(ext, \"application/octet-stream\")\n\n    return FileResponse(path, media_type=media_type, filename=path.name)\n"
  },
  {
    "path": "app/api/health.py",
    "content": "\"\"\"Health and status endpoint.\"\"\"\n\nfrom __future__ import annotations\n\nimport time\n\nfrom fastapi import APIRouter, Request\n\nrouter = APIRouter(tags=[\"health\"])\n\n_start_time: float = 0.0\n\n\ndef set_start_time() -> None:\n    global _start_time\n    _start_time = time.time()\n\n\n@router.get(\"/health\")\nasync def health_check(request: Request):\n    engine = request.app.state.engine\n    worker = request.app.state.worker\n\n    uptime = time.time() - _start_time if _start_time else 0\n\n    return {\n        \"status\": \"ok\",\n        \"model_loaded\": engine.is_loaded,\n        \"vram\": engine.get_vram_stats(),\n        \"worker\": worker.get_stats(),\n        \"uptime_seconds\": round(uptime, 1),\n    }\n"
  },
  {
    "path": "app/api/jobs.py",
    "content": "\"\"\"Job management endpoints.\"\"\"\n\nfrom __future__ import annotations\n\nfrom fastapi import APIRouter, Depends, HTTPException, Query\nfrom sqlalchemy.ext.asyncio import AsyncSession\n\nfrom app.database import get_session\nfrom app.schemas.job import JobDetail, JobList\nfrom app.services.job_service import cancel_job, get_job, list_jobs\n\nrouter = APIRouter(prefix=\"/jobs\", tags=[\"jobs\"])\n\n\n@router.get(\"\", response_model=JobList)\nasync def list_all_jobs(\n    status: str | None = Query(None, description=\"Filter by status\"),\n    batch_id: str | None = Query(None, description=\"Filter by batch ID\"),\n    page: int = Query(1, ge=1),\n    page_size: int = Query(50, ge=1, le=200),\n    session: AsyncSession = Depends(get_session),\n):\n    jobs, total = await list_jobs(session, status=status, batch_id=batch_id, page=page, page_size=page_size)\n    return JobList(jobs=jobs, total=total, page=page, page_size=page_size)\n\n\n@router.get(\"/{job_id}\", response_model=JobDetail)\nasync def get_job_detail(\n    job_id: str,\n    session: AsyncSession = Depends(get_session),\n):\n    job = await get_job(session, job_id)\n    if not job:\n        raise HTTPException(status_code=404, detail=\"Job not found\")\n    return job\n\n\n@router.delete(\"/{job_id}\", response_model=JobDetail)\nasync def cancel_job_endpoint(\n    job_id: str,\n    session: AsyncSession = Depends(get_session),\n):\n    job = await cancel_job(session, job_id)\n    if not job:\n        raise HTTPException(status_code=404, detail=\"Job not found\")\n    return job\n"
  },
  {
    "path": "app/config.py",
    "content": "from pydantic_settings import BaseSettings\nfrom pathlib import Path\n\n\nclass Settings(BaseSettings):\n    model_config = {\"env_file\": \".env\", \"env_file_encoding\": \"utf-8\"}\n\n    # ── Dev / mock mode ──\n    MOCK_MODE: bool = False\n\n    # ── Model ──\n    MODEL_ID: str = \"black-forest-labs/FLUX.1-schnell\"\n    DEVICE: str = \"cuda\"\n    DTYPE: str = \"float8_e4m3fn\"\n    ENABLE_TORCH_COMPILE: bool = True\n\n    # ── Generation defaults ──\n    DEFAULT_NUM_STEPS: int = 4\n    DEFAULT_GUIDANCE_SCALE: float = 0.0\n    MAX_RESOLUTION: int = 2048\n    MIN_RESOLUTION: int = 64\n\n    # ── Storage ──\n    STORAGE_DIR: Path = Path(\"./output\")\n\n    # ── Logging ──\n    LOG_DIR: Path = Path(\"./logs\")\n    LOG_LEVEL: str = \"INFO\"\n    LOG_ROTATION_WHEN: str = \"midnight\"  # midnight, h, m, s, etc.\n    LOG_RETENTION_COUNT: int = 30  # number of rotated files to keep\n\n    # ── Queue ──\n    MAX_QUEUE_SIZE: int = 10000\n\n    # ── Server ──\n    HOST: str = \"0.0.0.0\"\n    PORT: int = 8000\n\n    # ── Immediate generation ──\n    MAX_IMMEDIATE_IMAGES: int = 4\n\n    # ── Sync generation timeout (seconds) ──\n    SYNC_TIMEOUT: float = 60.0\n\n    # ── Content moderation ──\n    MODERATION_ENABLED: bool = True\n    MODERATION_MODEL_ID: str = \"google/gemma-3-1b-it\"\n    HF_API_TOKEN: str | None = None\n    MODERATION_API_KEY: str = \"\"\n    MODERATION_API_URL: str = \"https://api.openai.com/v1/moderations\"\n    MODERATION_TIMEOUT: float = 10.0\n\n\nsettings = Settings()\n"
  },
  {
    "path": "app/database.py",
    "content": "from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine\n\nfrom app.config import settings\n\nDATABASE_URL = f\"sqlite+aiosqlite:///{settings.STORAGE_DIR / 'flux_service.db'}\"\n\nengine = create_async_engine(DATABASE_URL, echo=False)\nasync_session_factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)\n\n\nasync def init_db() -> None:\n    from app.models.job import Base\n\n    async with engine.begin() as conn:\n        await conn.run_sync(Base.metadata.create_all)\n\n\nasync def get_session() -> AsyncSession:  # type: ignore[misc]\n    async with async_session_factory() as session:\n        yield session  # type: ignore[misc]\n"
  },
  {
    "path": "app/engine/__init__.py",
    "content": ""
  },
  {
    "path": "app/engine/engine_config.py",
    "content": "\"\"\"Resolution presets and engine configuration.\"\"\"\n\nPRESET_RESOLUTIONS: dict[str, tuple[int, int]] = {\n    \"square_sm\": (512, 512),\n    \"square\": (1024, 1024),\n    \"landscape\": (1344, 768),\n    \"portrait\": (768, 1344),\n    \"wide\": (1536, 640),\n}\n\n# Approximate peak VRAM (GB) by max dimension for FP8 Flux Schnell.\n# Used as a safety guard — not an exact model; conservative estimates.\nVRAM_ESTIMATES: dict[int, float] = {\n    512: 13.5,\n    768: 14.5,\n    1024: 16.0,\n    1536: 18.5,\n    2048: 20.0,\n}\n\n\ndef estimated_vram_gb(width: int, height: int) -> float:\n    \"\"\"Return conservative VRAM estimate for a given resolution.\"\"\"\n    max_dim = max(width, height)\n    for threshold in sorted(VRAM_ESTIMATES.keys()):\n        if max_dim <= threshold:\n            return VRAM_ESTIMATES[threshold]\n    return VRAM_ESTIMATES[max(VRAM_ESTIMATES.keys())]\n"
  },
  {
    "path": "app/engine/flux_engine.py",
    "content": "\"\"\"Flux.1 Schnell inference engine — model loading, generation, VRAM management.\"\"\"\n\nfrom __future__ import annotations\n\nimport logging\nimport time\nfrom dataclasses import dataclass, field\n\nimport torch\nfrom PIL import Image\n\nfrom app.config import settings\n\nlogger = logging.getLogger(__name__)\n\n# Dtype mapping\n_DTYPE_MAP: dict[str, torch.dtype] = {\n    \"float16\": torch.float16,\n    \"bfloat16\": torch.bfloat16,\n    \"float8_e4m3fn\": torch.float8_e4m3fn,\n    \"float32\": torch.float32,\n}\n\n\n@dataclass\nclass GenerationResult:\n    image: Image.Image\n    seed: int\n    inference_time: float  # seconds\n\n\n@dataclass\nclass FluxEngine:\n    \"\"\"Singleton engine wrapping the Flux.1 Schnell diffusion pipeline.\"\"\"\n\n    _pipeline: object | None = field(default=None, repr=False)\n    _loaded: bool = False\n    _device: str = \"cuda\"\n    _dtype: torch.dtype = torch.bfloat16\n\n    def load_model(self) -> None:\n        \"\"\"Load the Flux pipeline onto the GPU.\"\"\"\n        if self._loaded:\n            logger.info(\"Model already loaded, skipping.\")\n            return\n\n        from diffusers import FluxPipeline\n\n        self._device = settings.DEVICE\n        self._dtype = _DTYPE_MAP.get(settings.DTYPE, torch.bfloat16)\n\n        logger.info(\n            \"Loading Flux.1 Schnell model=%s device=%s dtype=%s\",\n            settings.MODEL_ID,\n            self._device,\n            self._dtype,\n        )\n\n        load_kwargs: dict = {\n            \"torch_dtype\": self._dtype,\n        }\n\n        # For FP8 on H100, use the fp8 variant if available\n        if self._dtype == torch.float8_e4m3fn:\n            load_kwargs[\"variant\"] = \"fp8\"\n\n        self._pipeline = FluxPipeline.from_pretrained(\n            settings.MODEL_ID,\n            **load_kwargs,\n        )\n        self._pipeline.to(self._device)\n        self._pipeline.set_progress_bar_config(disable=True)\n\n        # torch.compile for H100 Tensor Core acceleration\n        if settings.ENABLE_TORCH_COMPILE and self._device == \"cuda\":\n            logger.info(\"Compiling transformer with torch.compile (mode=max-autotune-no-cudagraphs)...\")\n            self._pipeline.transformer = torch.compile(\n                self._pipeline.transformer,\n                mode=\"max-autotune-no-cudagraphs\",\n            )\n\n        self._loaded = True\n        logger.info(\"Model loaded. VRAM: %.2f GB\", self.get_vram_usage_gb())\n\n    def warmup(self) -> None:\n        \"\"\"Run a single throwaway inference to trigger torch.compile tracing.\"\"\"\n        if not self._loaded:\n            raise RuntimeError(\"Model not loaded — call load_model() first\")\n\n        logger.info(\"Running warm-up inference (first run compiles CUDA kernels)...\")\n        start = time.perf_counter()\n        self.generate(prompt=\"warmup\", width=512, height=512, num_steps=1, seed=0)\n        elapsed = time.perf_counter() - start\n        logger.info(\"Warm-up completed in %.1fs\", elapsed)\n\n    @torch.inference_mode()\n    def generate(\n        self,\n        prompt: str,\n        width: int = 1024,\n        height: int = 1024,\n        num_steps: int = 4,\n        guidance_scale: float = 0.0,\n        seed: int | None = None,\n    ) -> GenerationResult:\n        \"\"\"Generate a single image. Runs synchronously on the GPU.\"\"\"\n        if not self._loaded:\n            raise RuntimeError(\"Model not loaded — call load_model() first\")\n\n        if seed is None:\n            seed = torch.randint(0, 2**32, (1,)).item()\n\n        generator = torch.Generator(device=self._device).manual_seed(seed)\n\n        start = time.perf_counter()\n\n        output = self._pipeline(\n            prompt=prompt,\n            width=width,\n            height=height,\n            num_inference_steps=num_steps,\n            guidance_scale=guidance_scale,\n            generator=generator,\n            output_type=\"pil\",\n        )\n\n        elapsed = time.perf_counter() - start\n        image: Image.Image = output.images[0]\n\n        logger.info(\n            \"Generated %dx%d in %.2fs (steps=%d, seed=%d)\",\n            width,\n            height,\n            elapsed,\n            num_steps,\n            seed,\n        )\n\n        return GenerationResult(image=image, seed=seed, inference_time=elapsed)\n\n    def get_vram_usage_gb(self) -> float:\n        \"\"\"Return current GPU memory allocated in GB.\"\"\"\n        if not torch.cuda.is_available():\n            return 0.0\n        return torch.cuda.memory_allocated(self._device) / (1024**3)\n\n    def get_vram_stats(self) -> dict:\n        \"\"\"Return detailed VRAM statistics.\"\"\"\n        if not torch.cuda.is_available():\n            return {\"available\": False}\n        return {\n            \"available\": True,\n            \"allocated_gb\": round(torch.cuda.memory_allocated(self._device) / (1024**3), 2),\n            \"reserved_gb\": round(torch.cuda.memory_reserved(self._device) / (1024**3), 2),\n            \"max_allocated_gb\": round(\n                torch.cuda.max_memory_allocated(self._device) / (1024**3), 2\n            ),\n            \"total_gb\": round(torch.cuda.get_device_properties(0).total_mem / (1024**3), 2),\n            \"device_name\": torch.cuda.get_device_name(0),\n        }\n\n    def unload(self) -> None:\n        \"\"\"Free the model and clear CUDA cache.\"\"\"\n        if self._pipeline is not None:\n            del self._pipeline\n            self._pipeline = None\n        self._loaded = False\n        if torch.cuda.is_available():\n            torch.cuda.empty_cache()\n            torch.cuda.reset_peak_memory_stats()\n        logger.info(\"Model unloaded, CUDA cache cleared.\")\n\n    @property\n    def is_loaded(self) -> bool:\n        return self._loaded\n"
  },
  {
    "path": "app/engine/mock_engine.py",
    "content": "\"\"\"Mock FluxEngine for local development without GPU or model weights.\"\"\"\n\nfrom __future__ import annotations\n\nimport logging\nimport random\nimport time\nfrom dataclasses import dataclass\n\nfrom PIL import Image, ImageDraw\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclass\nclass GenerationResult:\n    image: Image.Image\n    seed: int\n    inference_time: float\n\n\n@dataclass\nclass MockFluxEngine:\n    \"\"\"Drop-in replacement for FluxEngine that generates gradient placeholder images.\"\"\"\n\n    _loaded: bool = False\n\n    def load_model(self) -> None:\n        logger.info(\"MockFluxEngine: loaded (no actual model).\")\n        self._loaded = True\n\n    def warmup(self) -> None:\n        logger.info(\"MockFluxEngine: warmup (no-op).\")\n\n    def generate(\n        self,\n        prompt: str,\n        width: int = 1024,\n        height: int = 1024,\n        num_steps: int = 4,\n        guidance_scale: float = 0.0,\n        seed: int | None = None,\n    ) -> GenerationResult:\n        if seed is None:\n            seed = random.randint(0, 2**32 - 1)\n\n        start = time.perf_counter()\n\n        rng = random.Random(seed)\n        r1, g1, b1 = rng.randint(40, 180), rng.randint(40, 180), rng.randint(40, 180)\n        r2, g2, b2 = rng.randint(40, 180), rng.randint(40, 180), rng.randint(40, 180)\n\n        img = Image.new(\"RGB\", (width, height))\n        pixels = img.load()\n        for y in range(height):\n            t = y / max(height - 1, 1)\n            r = int(r1 + (r2 - r1) * t)\n            g = int(g1 + (g2 - g1) * t)\n            b = int(b1 + (b2 - b1) * t)\n            for x in range(width):\n                pixels[x, y] = (r, g, b)\n\n        draw = ImageDraw.Draw(img)\n        lines = [\n            f\"[MOCK] {prompt[:60]}\",\n            f\"Seed: {seed}  |  {width}x{height}  |  steps={num_steps}\",\n        ]\n        draw.text((20, 20), \"\\n\".join(lines), fill=(255, 255, 255))\n\n        # Small artificial delay to simulate inference\n        time.sleep(2.05 + rng.random() * 0.15)\n        elapsed = time.perf_counter() - start\n\n        logger.info(\n            \"MockFluxEngine: generated %dx%d in %.2fs (seed=%d)\",\n            width, height, elapsed, seed,\n        )\n        return GenerationResult(image=img, seed=seed, inference_time=elapsed)\n\n    def get_vram_usage_gb(self) -> float:\n        return 0.0\n\n    def get_vram_stats(self) -> dict:\n        return {\"available\": False, \"mock_mode\": True}\n\n    def unload(self) -> None:\n        self._loaded = False\n        logger.info(\"MockFluxEngine: unloaded.\")\n\n    @property\n    def is_loaded(self) -> bool:\n        return self._loaded\n"
  },
  {
    "path": "app/main.py",
    "content": "\"\"\"FastAPI application — lifespan manages model loading, worker, and DB.\"\"\"\n\nfrom __future__ import annotations\n\nimport logging\nfrom contextlib import asynccontextmanager\nfrom logging.handlers import TimedRotatingFileHandler\n\nfrom fastapi import FastAPI\nfrom fastapi.middleware.cors import CORSMiddleware\n\nfrom app.api import batch, generate, health, jobs\nfrom app.config import settings\nfrom app.database import init_db\nfrom app.utils.moderation import get_moderation_engine\nfrom app.utils.storage import ensure_storage_dirs\nfrom app.worker.gpu_worker import GPUWorker\n\n\ndef _setup_logging() -> None:\n    \"\"\"Configure logging to both console and a time-rotating file.\"\"\"\n    settings.LOG_DIR.mkdir(parents=True, exist_ok=True)\n\n    log_format = \"%(asctime)s %(levelname)-8s %(name)s — %(message)s\"\n    level = getattr(logging, settings.LOG_LEVEL.upper(), logging.INFO)\n\n    root = logging.getLogger()\n    root.setLevel(level)\n\n    # Console handler\n    console = logging.StreamHandler()\n    console.setLevel(level)\n    console.setFormatter(logging.Formatter(log_format))\n    root.addHandler(console)\n\n    # Time-rotating file handler\n    file_handler = TimedRotatingFileHandler(\n        filename=settings.LOG_DIR / \"flux-service.log\",\n        when=settings.LOG_ROTATION_WHEN,\n        backupCount=settings.LOG_RETENTION_COUNT,\n        encoding=\"utf-8\",\n    )\n    file_handler.setLevel(level)\n    file_handler.setFormatter(logging.Formatter(log_format))\n    file_handler.suffix = \"%Y-%m-%d\"  # rotated files: flux-service.log.2026-03-08\n    root.addHandler(file_handler)\n\n\n_setup_logging()\nlogger = logging.getLogger(__name__)\n\n\n@asynccontextmanager\nasync def lifespan(app: FastAPI):\n    # ── Startup ──────────────────────────────────────────────────\n    logger.info(\"Starting Flux Image Service...\")\n\n    # Storage directories\n    ensure_storage_dirs()\n\n    # Database\n    await init_db()\n    logger.info(\"Database initialized.\")\n\n    # Flux engine (mock or real)\n    if settings.MOCK_MODE:\n        from app.engine.mock_engine import MockFluxEngine\n        engine = MockFluxEngine()\n    else:\n        from app.engine.flux_engine import FluxEngine\n        engine = FluxEngine()\n    engine.load_model()\n    engine.warmup()\n    app.state.engine = engine\n\n    # Content moderation engine (skip model loading in mock mode)\n    moderation_engine = get_moderation_engine()\n    if not settings.MOCK_MODE:\n        moderation_engine.load(settings.MODERATION_MODEL_ID)\n    else:\n        logger.info(\"Mock mode: skipping moderation model load.\")\n    app.state.moderation_engine = moderation_engine\n\n    # GPU worker\n    worker = GPUWorker(engine=engine)\n    worker.start()\n    resumed = await worker.resume_pending_jobs()\n    if resumed:\n        logger.info(\"Resumed %d pending jobs from previous session.\", resumed)\n    app.state.worker = worker\n\n    # Start time for health endpoint\n    health.set_start_time()\n\n    logger.info(\"Flux Image Service ready. Queue depth: %d\", worker.queue_depth)\n\n    yield\n\n    # ── Shutdown ─────────────────────────────────────────────────\n    logger.info(\"Shutting down Flux Image Service...\")\n    await worker.shutdown()\n    engine.unload()\n    logger.info(\"Shutdown complete.\")\n\n\napp = FastAPI(\n    title=\"Flux Image Service\",\n    description=\"Local image generation API powered by Flux.1 Schnell 12B\",\n    version=\"0.1.0\",\n    lifespan=lifespan,\n)\n\n# CORS — allow all for local development\napp.add_middleware(\n    CORSMiddleware,\n    allow_origins=[\"*\"],\n    allow_methods=[\"*\"],\n    allow_headers=[\"*\"],\n)\n\n# Mount routers\napp.include_router(generate.router)\napp.include_router(batch.router)\napp.include_router(jobs.router)\napp.include_router(health.router)\n"
  },
  {
    "path": "app/models/__init__.py",
    "content": ""
  },
  {
    "path": "app/models/enums.py",
    "content": "import enum\n\n\nclass JobStatus(str, enum.Enum):\n    PENDING = \"pending\"\n    PROCESSING = \"processing\"\n    COMPLETED = \"completed\"\n    FAILED = \"failed\"\n    CANCELLED = \"cancelled\"\n\n\nclass JobPriority(int, enum.Enum):\n    REALTIME = 1\n    BATCH = 5\n\n\nclass ImageFormat(str, enum.Enum):\n    PNG = \"png\"\n    JPEG = \"jpeg\"\n    WEBP = \"webp\"\n"
  },
  {
    "path": "app/models/job.py",
    "content": "import uuid\nfrom datetime import datetime, timezone\n\nfrom sqlalchemy import (\n    Column,\n    DateTime,\n    Enum,\n    Float,\n    ForeignKey,\n    Integer,\n    String,\n    Text,\n)\nfrom sqlalchemy.orm import DeclarativeBase, relationship\n\n\nclass Base(DeclarativeBase):\n    pass\n\n\ndef _utcnow() -> datetime:\n    return datetime.now(timezone.utc)\n\n\nclass Job(Base):\n    __tablename__ = \"jobs\"\n\n    id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))\n    prompt = Column(Text, nullable=False)\n    negative_prompt = Column(Text, nullable=True)\n    width = Column(Integer, nullable=False, default=1024)\n    height = Column(Integer, nullable=False, default=1024)\n    num_steps = Column(Integer, nullable=False, default=4)\n    guidance_scale = Column(Float, nullable=False, default=0.0)\n    seed = Column(Integer, nullable=True)\n    status = Column(\n        Enum(\"pending\", \"processing\", \"completed\", \"failed\", \"cancelled\", name=\"job_status\"),\n        nullable=False,\n        default=\"pending\",\n    )\n    priority = Column(Integer, nullable=False, default=1)\n    format = Column(\n        Enum(\"png\", \"jpeg\", \"webp\", name=\"image_format\"),\n        nullable=False,\n        default=\"png\",\n    )\n    file_path = Column(String(512), nullable=True)\n    error_message = Column(Text, nullable=True)\n    batch_id = Column(String(36), ForeignKey(\"batch_jobs.id\"), nullable=True)\n    created_at = Column(DateTime(timezone=True), nullable=False, default=_utcnow)\n    started_at = Column(DateTime(timezone=True), nullable=True)\n    completed_at = Column(DateTime(timezone=True), nullable=True)\n\n    batch = relationship(\"BatchJob\", back_populates=\"jobs\")\n\n\nclass BatchJob(Base):\n    __tablename__ = \"batch_jobs\"\n\n    id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))\n    name = Column(String(255), nullable=True)\n    total_count = Column(Integer, nullable=False, default=0)\n    completed_count = Column(Integer, nullable=False, default=0)\n    failed_count = Column(Integer, nullable=False, default=0)\n    status = Column(\n        Enum(\"pending\", \"processing\", \"completed\", \"failed\", \"cancelled\", name=\"batch_status\"),\n        nullable=False,\n        default=\"pending\",\n    )\n    created_at = Column(DateTime(timezone=True), nullable=False, default=_utcnow)\n    completed_at = Column(DateTime(timezone=True), nullable=True)\n\n    jobs = relationship(\"Job\", back_populates=\"batch\", lazy=\"selectin\")\n"
  },
  {
    "path": "app/schemas/__init__.py",
    "content": ""
  },
  {
    "path": "app/schemas/batch.py",
    "content": "from pydantic import BaseModel, Field\n\nfrom app.schemas.generate import ImageRequest\n\n\nclass BatchRequest(BaseModel):\n    name: str | None = None\n    prompts: list[ImageRequest] = Field(..., min_length=1, max_length=50000)\n\n\nclass BatchProgress(BaseModel):\n    batch_id: str\n    name: str | None\n    total: int\n    completed: int\n    failed: int\n    cancelled: int\n    pending: int\n    status: str\n    estimated_remaining_seconds: float | None = None\n    created_at: str\n    completed_at: str | None = None\n"
  },
  {
    "path": "app/schemas/generate.py",
    "content": "from pydantic import BaseModel, Field, field_validator, model_validator\n\nfrom app.config import settings\nfrom app.models.enums import ImageFormat\n\n\nclass ImageRequest(BaseModel):\n    prompt: str = Field(..., min_length=1, max_length=2000)\n    negative_prompt: str | None = None\n    width: int = Field(default=1024, ge=settings.MIN_RESOLUTION, le=settings.MAX_RESOLUTION)\n    height: int = Field(default=1024, ge=settings.MIN_RESOLUTION, le=settings.MAX_RESOLUTION)\n    num_steps: int = Field(default=settings.DEFAULT_NUM_STEPS, ge=1, le=50)\n    guidance_scale: float = Field(default=settings.DEFAULT_GUIDANCE_SCALE, ge=0.0, le=20.0)\n    seed: int | None = Field(default=None, ge=0, le=2**32 - 1)\n    num_images: int = Field(default=1, ge=1, le=settings.MAX_IMMEDIATE_IMAGES)\n    format: ImageFormat = ImageFormat.PNG\n\n    @field_validator(\"prompt\")\n    @classmethod\n    def check_content_safety(cls, v: str) -> str:\n        from app.utils.moderation import check_prompt\n\n        check_prompt(v)\n        return v\n\n    @model_validator(mode=\"after\")\n    def validate_resolution(self) -> \"ImageRequest\":\n        if self.width % 8 != 0:\n            raise ValueError(\"width must be a multiple of 8\")\n        if self.height % 8 != 0:\n            raise ValueError(\"height must be a multiple of 8\")\n        return self\n\n\nclass ImageResponse(BaseModel):\n    job_ids: list[str]\n    status: str\n    image_urls: list[str | None] = []\n\n\nPRESET_RESOLUTIONS: dict[str, tuple[int, int]] = {\n    \"square_sm\": (512, 512),\n    \"square\": (1024, 1024),\n    \"landscape\": (1344, 768),\n    \"portrait\": (768, 1344),\n    \"wide\": (1536, 640),\n}\n"
  },
  {
    "path": "app/schemas/job.py",
    "content": "from pydantic import BaseModel\n\n\nclass JobDetail(BaseModel):\n    id: str\n    prompt: str\n    negative_prompt: str | None\n    width: int\n    height: int\n    num_steps: int\n    guidance_scale: float\n    seed: int | None\n    status: str\n    priority: int\n    format: str\n    file_path: str | None\n    image_url: str | None\n    error_message: str | None\n    batch_id: str | None\n    created_at: str\n    started_at: str | None\n    completed_at: str | None\n\n\nclass JobList(BaseModel):\n    jobs: list[JobDetail]\n    total: int\n    page: int\n    page_size: int\n"
  },
  {
    "path": "app/services/__init__.py",
    "content": ""
  },
  {
    "path": "app/services/generation_service.py",
    "content": "\"\"\"Orchestrates image generation: validate → persist → enqueue.\"\"\"\n\nfrom __future__ import annotations\n\nimport uuid\nfrom typing import TYPE_CHECKING\n\nfrom sqlalchemy.ext.asyncio import AsyncSession\n\nfrom app.models.enums import JobPriority\nfrom app.models.job import BatchJob, Job\nfrom app.schemas.generate import ImageRequest\n\nif TYPE_CHECKING:\n    from app.worker.gpu_worker import GPUWorker\n\n\nasync def create_single_job(\n    session: AsyncSession,\n    request: ImageRequest,\n    worker: \"GPUWorker\",\n    priority: int = JobPriority.REALTIME,\n) -> list[str]:\n    \"\"\"Create generation job(s) for the request. Returns list of job_ids.\n\n    Creates one job per requested image (num_images). Each job gets a unique\n    seed derived from the base seed (if provided) to ensure varied outputs.\n    \"\"\"\n    job_ids: list[str] = []\n\n    queue_priority = int(priority)\n\n    for i in range(request.num_images):\n        # Derive per-image seed: if user provided a seed, offset it per image\n        seed = None\n        if request.seed is not None:\n            seed = (request.seed + i) % (2**32)\n\n        job = Job(\n            id=str(uuid.uuid4()),\n            prompt=request.prompt,\n            negative_prompt=request.negative_prompt,\n            width=request.width,\n            height=request.height,\n            num_steps=request.num_steps,\n            guidance_scale=request.guidance_scale,\n            seed=seed,\n            format=request.format.value,\n            priority=queue_priority,\n            status=\"pending\",\n        )\n        session.add(job)\n        job_ids.append(job.id)\n\n    await session.commit()\n\n    for jid in job_ids:\n        await worker.enqueue(jid, queue_priority)\n\n    return job_ids\n\n\nasync def create_batch(\n    session: AsyncSession,\n    name: str | None,\n    prompts: list[ImageRequest],\n    worker: \"GPUWorker\",\n) -> str:\n    \"\"\"Create a batch of generation jobs. Returns batch_id.\"\"\"\n    batch_id = str(uuid.uuid4())\n    batch = BatchJob(\n        id=batch_id,\n        name=name,\n        total_count=len(prompts),\n        status=\"pending\",\n    )\n    session.add(batch)\n\n    jobs: list[Job] = []\n    for req in prompts:\n        job = Job(\n            id=str(uuid.uuid4()),\n            prompt=req.prompt,\n            negative_prompt=req.negative_prompt,\n            width=req.width,\n            height=req.height,\n            num_steps=req.num_steps,\n            guidance_scale=req.guidance_scale,\n            seed=req.seed,\n            format=req.format.value,\n            priority=int(JobPriority.BATCH),\n            status=\"pending\",\n            batch_id=batch_id,\n        )\n        jobs.append(job)\n\n    session.add_all(jobs)\n    await session.commit()\n\n    # Enqueue all jobs\n    for job in jobs:\n        await worker.enqueue(job.id, int(JobPriority.BATCH))\n\n    return batch_id\n"
  },
  {
    "path": "app/services/job_service.py",
    "content": "\"\"\"Job and batch CRUD operations.\"\"\"\n\nfrom __future__ import annotations\n\nfrom datetime import datetime, timezone\n\nfrom sqlalchemy import func, select\nfrom sqlalchemy.ext.asyncio import AsyncSession\n\nfrom app.models.enums import JobStatus\nfrom app.models.job import BatchJob, Job\nfrom app.schemas.job import JobDetail\nfrom app.utils.storage import image_url_path\n\n\ndef _job_to_detail(job: Job) -> JobDetail:\n    return JobDetail(\n        id=job.id,\n        prompt=job.prompt,\n        negative_prompt=job.negative_prompt,\n        width=job.width,\n        height=job.height,\n        num_steps=job.num_steps,\n        guidance_scale=job.guidance_scale,\n        seed=job.seed,\n        status=job.status,\n        priority=job.priority,\n        format=job.format,\n        file_path=job.file_path,\n        image_url=image_url_path(job.id) if job.status == \"completed\" else None,\n        error_message=job.error_message,\n        batch_id=job.batch_id,\n        created_at=job.created_at.isoformat() if job.created_at else None,\n        started_at=job.started_at.isoformat() if job.started_at else None,\n        completed_at=job.completed_at.isoformat() if job.completed_at else None,\n    )\n\n\nasync def get_job(session: AsyncSession, job_id: str) -> JobDetail | None:\n    result = await session.execute(select(Job).where(Job.id == job_id))\n    job = result.scalar_one_or_none()\n    if not job:\n        return None\n    return _job_to_detail(job)\n\n\nasync def list_jobs(\n    session: AsyncSession,\n    status: str | None = None,\n    batch_id: str | None = None,\n    page: int = 1,\n    page_size: int = 50,\n) -> tuple[list[JobDetail], int]:\n    query = select(Job).order_by(Job.created_at.desc())\n    count_query = select(func.count()).select_from(Job)\n\n    if status:\n        query = query.where(Job.status == status)\n        count_query = count_query.where(Job.status == status)\n    if batch_id:\n        query = query.where(Job.batch_id == batch_id)\n        count_query = count_query.where(Job.batch_id == batch_id)\n\n    total = (await session.execute(count_query)).scalar() or 0\n    query = query.offset((page - 1) * page_size).limit(page_size)\n    result = await session.execute(query)\n    jobs = [_job_to_detail(j) for j in result.scalars().all()]\n    return jobs, total\n\n\nasync def cancel_job(session: AsyncSession, job_id: str) -> JobDetail | None:\n    result = await session.execute(select(Job).where(Job.id == job_id))\n    job = result.scalar_one_or_none()\n    if not job:\n        return None\n    if job.status in (\"completed\", \"failed\", \"cancelled\"):\n        return _job_to_detail(job)\n\n    job.status = \"cancelled\"\n    job.completed_at = datetime.now(timezone.utc)\n    await session.commit()\n    return _job_to_detail(job)\n\n\nasync def get_batch_progress(session: AsyncSession, batch_id: str) -> dict | None:\n    result = await session.execute(select(BatchJob).where(BatchJob.id == batch_id))\n    batch = result.scalar_one_or_none()\n    if not batch:\n        return None\n\n    cancelled = (\n        await session.execute(\n            select(func.count()).where(Job.batch_id == batch_id, Job.status == \"cancelled\")\n        )\n    ).scalar() or 0\n\n    pending = batch.total_count - batch.completed_count - batch.failed_count - cancelled\n\n    return {\n        \"batch_id\": batch.id,\n        \"name\": batch.name,\n        \"total\": batch.total_count,\n        \"completed\": batch.completed_count,\n        \"failed\": batch.failed_count,\n        \"cancelled\": cancelled,\n        \"pending\": pending,\n        \"status\": batch.status,\n        \"created_at\": batch.created_at.isoformat() if batch.created_at else None,\n        \"completed_at\": batch.completed_at.isoformat() if batch.completed_at else None,\n    }\n\n\nasync def cancel_batch(session: AsyncSession, batch_id: str) -> dict | None:\n    result = await session.execute(select(BatchJob).where(BatchJob.id == batch_id))\n    batch = result.scalar_one_or_none()\n    if not batch:\n        return None\n\n    # Cancel all pending jobs in this batch\n    pending_result = await session.execute(\n        select(Job).where(Job.batch_id == batch_id, Job.status == \"pending\")\n    )\n    now = datetime.now(timezone.utc)\n    for job in pending_result.scalars().all():\n        job.status = \"cancelled\"\n        job.completed_at = now\n\n    batch.status = \"cancelled\"\n    batch.completed_at = now\n    await session.commit()\n\n    return await get_batch_progress(session, batch_id)\n\n\nasync def retry_failed_in_batch(session: AsyncSession, batch_id: str) -> list[str]:\n    \"\"\"Re-set failed jobs in a batch back to pending. Returns list of job IDs.\"\"\"\n    result = await session.execute(\n        select(Job).where(Job.batch_id == batch_id, Job.status == \"failed\")\n    )\n    job_ids = []\n    for job in result.scalars().all():\n        job.status = \"pending\"\n        job.error_message = None\n        job.completed_at = None\n        job_ids.append(job.id)\n\n    if job_ids:\n        # Reset batch status\n        batch_result = await session.execute(select(BatchJob).where(BatchJob.id == batch_id))\n        batch = batch_result.scalar_one_or_none()\n        if batch:\n            batch.status = \"processing\"\n            batch.completed_at = None\n            batch.failed_count = 0\n\n        await session.commit()\n\n    return job_ids\n"
  },
  {
    "path": "app/services/moderation_service.py",
    "content": "\"\"\"Content moderation service – checks prompt safety via a third-party API.\"\"\"\n\nfrom __future__ import annotations\n\nimport logging\n\nimport httpx\n\nfrom app.config import settings\n\nlogger = logging.getLogger(__name__)\n\n\nclass ModerationError(Exception):\n    \"\"\"Raised when the moderation API call fails unexpectedly.\"\"\"\n\n\nasync def check_prompt_safety(prompt: str) -> tuple[bool, str | None]:\n    \"\"\"Check whether a prompt is safe using the configured moderation API.\n\n    Calls the API specified by ``MODERATION_API_URL`` (default: OpenAI\n    moderation endpoint) and returns ``(is_safe, reason)``.  When moderation\n    is disabled or no API key is configured the function returns ``(True, None)``\n    without making any network request.\n\n    Args:\n        prompt: The user-supplied generation prompt to evaluate.\n\n    Returns:\n        A ``(is_safe, reason)`` tuple.  ``is_safe`` is ``True`` when the prompt\n        is considered safe.  When flagged, ``reason`` contains the triggered\n        content categories (e.g. ``\"sexual, violence\"``).\n\n    Raises:\n        ModerationError: If the API request fails and the failure cannot be\n            handled gracefully.\n    \"\"\"\n    if not settings.MODERATION_ENABLED:\n        return True, None\n\n    if not settings.MODERATION_API_KEY:\n        logger.warning(\n            \"Content moderation is enabled but MODERATION_API_KEY is not set; \"\n            \"skipping moderation check.\"\n        )\n        return True, None\n\n    try:\n        async with httpx.AsyncClient(timeout=settings.MODERATION_TIMEOUT) as client:\n            response = await client.post(\n                settings.MODERATION_API_URL,\n                headers={\n                    \"Authorization\": f\"Bearer {settings.MODERATION_API_KEY}\",\n                    \"Content-Type\": \"application/json\",\n                },\n                json={\"input\": prompt},\n            )\n            response.raise_for_status()\n            data = response.json()\n    except httpx.HTTPStatusError as exc:\n        logger.error(\n            \"Moderation API returned an error status %s: %s\",\n            exc.response.status_code,\n            exc.response.text,\n        )\n        raise ModerationError(\n            f\"Moderation API returned HTTP {exc.response.status_code}\"\n        ) from exc\n    except httpx.HTTPError as exc:\n        logger.error(\"Moderation API request failed: %s\", exc)\n        raise ModerationError(f\"Moderation API request failed: {exc}\") from exc\n\n    # Parse OpenAI-compatible moderation response:\n    # {\"results\": [{\"flagged\": bool, \"categories\": {category: bool, ...}}]}\n    results = data.get(\"results\", [])\n    if not results:\n        return True, None\n\n    result = results[0]\n    if result.get(\"flagged\", False):\n        categories: dict[str, bool] = result.get(\"categories\", {})\n        flagged_categories = [cat for cat, triggered in categories.items() if triggered]\n        reason = \", \".join(flagged_categories) if flagged_categories else \"unsafe content\"\n        return False, reason\n\n    return True, None\n"
  },
  {
    "path": "app/utils/__init__.py",
    "content": ""
  },
  {
    "path": "app/utils/image_utils.py",
    "content": "\"\"\"Image saving, format conversion, and metadata embedding.\"\"\"\n\nfrom __future__ import annotations\n\nfrom pathlib import Path\n\nfrom PIL import Image, PngImagePlugin\n\nfrom app.models.enums import ImageFormat\n\n\ndef save_image(\n    image: Image.Image,\n    path: Path,\n    fmt: ImageFormat,\n    quality: int = 95,\n    prompt: str | None = None,\n    seed: int | None = None,\n) -> None:\n    \"\"\"Save a PIL Image in the requested format with embedded metadata.\"\"\"\n    if fmt == ImageFormat.PNG:\n        meta = PngImagePlugin.PngInfo()\n        if prompt:\n            meta.add_text(\"prompt\", prompt)\n        if seed is not None:\n            meta.add_text(\"seed\", str(seed))\n        image.save(path, format=\"PNG\", pnginfo=meta)\n\n    elif fmt == ImageFormat.JPEG:\n        # JPEG doesn't support alpha\n        if image.mode in (\"RGBA\", \"LA\"):\n            image = image.convert(\"RGB\")\n        image.save(path, format=\"JPEG\", quality=quality)\n\n    elif fmt == ImageFormat.WEBP:\n        image.save(path, format=\"WEBP\", quality=quality)\n\n    else:\n        image.save(path, format=\"PNG\")\n"
  },
  {
    "path": "app/utils/moderation.py",
    "content": "\"\"\"Content moderation using Gemma-3-1b-it (local) with HF Inference API fallback.\"\"\"\n\nfrom __future__ import annotations\n\nimport logging\nfrom dataclasses import dataclass\n\nlogger = logging.getLogger(__name__)\n\n_SAFETY_SYSTEM_PROMPT = (\n    \"You are a content safety classifier. \"\n    \"Determine whether the following text prompt for image generation contains \"\n    \"harmful, illegal, sexually explicit, or violent content. \"\n    \"Respond with exactly one word: 'SAFE' if the prompt is acceptable, \"\n    \"or 'UNSAFE' if it should be rejected. Do not include any other text.\"\n)\n\n\n@dataclass\nclass ModerationResult:\n    is_safe: bool\n    reason: str = \"\"\n    source: str = \"\"\n\n\nclass ModerationEngine:\n    \"\"\"Content moderation engine.\n\n    Primary check: locally-loaded ``google/gemma-3-1b-it`` text-generation pipeline.\n    Fallback check: HuggingFace Inference API (agent API) using the same model.\n    If both are unavailable the engine fails open and logs a warning.\n    \"\"\"\n\n    def __init__(self) -> None:\n        self._pipeline = None\n        self._model_id: str | None = None\n\n    # ── Loading ──────────────────────────────────────────────────────────────\n\n    def load(self, model_id: str = \"google/gemma-3-1b-it\") -> None:\n        \"\"\"Load *model_id* locally for synchronous inference.\"\"\"\n        try:\n            import torch\n            from transformers import pipeline\n\n            logger.info(\"Loading moderation model %s …\", model_id)\n            self._pipeline = pipeline(\n                \"text-generation\",\n                model=model_id,\n                device_map=\"auto\",\n                torch_dtype=torch.bfloat16,\n                max_new_tokens=10,\n            )\n            self._model_id = model_id\n            logger.info(\"Moderation model %s loaded successfully.\", model_id)\n        except Exception as exc:  # noqa: BLE001\n            logger.warning(\"Failed to load moderation model %s: %s\", model_id, exc)\n            self._pipeline = None\n\n    @property\n    def is_loaded(self) -> bool:\n        \"\"\"Return True when the local pipeline is ready.\"\"\"\n        return self._pipeline is not None\n\n    # ── Local-model check ────────────────────────────────────────────────────\n\n    def _parse_pipeline_output(self, raw: object) -> str:\n        \"\"\"Extract the assistant reply text from a text-generation pipeline result.\"\"\"\n        # The pipeline returns a list of dicts like:\n        # [{\"generated_text\": [{\"role\": \"system\", ...}, {\"role\": \"assistant\", \"content\": \"SAFE\"}]}]\n        if isinstance(raw, list) and raw:\n            inner = raw[0].get(\"generated_text\", raw[0])\n            if isinstance(inner, list) and inner:\n                last = inner[-1]\n                if isinstance(last, dict):\n                    return last.get(\"content\", \"\").strip().upper()\n            return str(inner).strip().upper()\n        return str(raw).strip().upper()\n\n    def check_with_local_model(self, prompt: str) -> ModerationResult:\n        \"\"\"Run the locally-loaded Gemma pipeline and return a :class:`ModerationResult`.\"\"\"\n        if not self.is_loaded:\n            raise RuntimeError(\"Local moderation model is not loaded\")\n        try:\n            messages = [\n                {\"role\": \"system\", \"content\": _SAFETY_SYSTEM_PROMPT},\n                {\"role\": \"user\", \"content\": f\"Prompt: {prompt}\"},\n            ]\n            raw = self._pipeline(messages)\n            response = self._parse_pipeline_output(raw)\n            is_safe = \"UNSAFE\" not in response\n            return ModerationResult(\n                is_safe=is_safe,\n                reason=\"\" if is_safe else f\"Flagged by local model: {response}\",\n                source=\"local\",\n            )\n        except Exception as exc:\n            raise RuntimeError(f\"Local model inference failed: {exc}\") from exc\n\n    # ── Agent-API fallback ───────────────────────────────────────────────────\n\n    def check_with_agent_api(\n        self,\n        prompt: str,\n        *,\n        model_id: str | None = None,\n        token: str | None = None,\n    ) -> ModerationResult:\n        \"\"\"Use the HuggingFace Inference API (agent API) as a fallback safety check.\"\"\"\n        from huggingface_hub import InferenceClient\n\n        _model = model_id or self._model_id or \"google/gemma-3-1b-it\"\n        client = InferenceClient(model=_model, token=token)\n        messages = [\n            {\"role\": \"system\", \"content\": _SAFETY_SYSTEM_PROMPT},\n            {\"role\": \"user\", \"content\": f\"Prompt: {prompt}\"},\n        ]\n        result = client.chat_completion(messages=messages, max_tokens=10)\n        response = result.choices[0].message.content.strip().upper()\n        is_safe = \"UNSAFE\" not in response\n        return ModerationResult(\n            is_safe=is_safe,\n            reason=\"\" if is_safe else f\"Flagged by agent API: {response}\",\n            source=\"agent_api\",\n        )\n\n    # ── Unified entry-point ──────────────────────────────────────────────────\n\n    def check(self, prompt: str, *, hf_token: str | None = None) -> ModerationResult:\n        \"\"\"Check *prompt* safety.\n\n        Tries the local Gemma model first; falls back to the HF Inference API;\n        if both are unavailable the call succeeds with a warning (fail-open).\n        \"\"\"\n        # 1. Local model\n        if self.is_loaded:\n            try:\n                return self.check_with_local_model(prompt)\n            except Exception as exc:  # noqa: BLE001\n                logger.warning(\n                    \"Local moderation check failed — falling back to agent API: %s\", exc\n                )\n\n        # 2. Agent API fallback\n        if hf_token or self._model_id:\n            try:\n                return self.check_with_agent_api(\n                    prompt,\n                    model_id=self._model_id,\n                    token=hf_token,\n                )\n            except Exception as exc:  # noqa: BLE001\n                logger.warning(\n                    \"Agent API moderation check failed — allowing prompt through: %s\", exc\n                )\n\n        # 3. Both unavailable — fail open\n        logger.warning(\n            \"No moderation backend available; prompt allowed through without safety check.\"\n        )\n        return ModerationResult(is_safe=True, reason=\"Moderation unavailable\", source=\"none\")\n\n\n# ── Module-level singleton ────────────────────────────────────────────────────\n\n_moderation_engine: ModerationEngine = ModerationEngine()\n\n\ndef get_moderation_engine() -> ModerationEngine:\n    \"\"\"Return the global :class:`ModerationEngine` singleton.\"\"\"\n    return _moderation_engine\n\n\ndef check_prompt(prompt: str) -> None:\n    \"\"\"Validate *prompt* against the moderation engine.\n\n    Raises :class:`ValueError` if the prompt is rejected.\n    Does nothing when ``MODERATION_ENABLED`` is ``False`` or when no moderation\n    backend is configured (fail-open to avoid breaking existing tests).\n    \"\"\"\n    from app.config import settings\n\n    if not settings.MODERATION_ENABLED:\n        return\n\n    result = _moderation_engine.check(prompt, hf_token=settings.HF_API_TOKEN)\n    if not result.is_safe:\n        raise ValueError(result.reason or \"Prompt rejected by content moderation\")\n"
  },
  {
    "path": "app/utils/storage.py",
    "content": "\"\"\"Local filesystem storage management for generated images.\"\"\"\n\nfrom __future__ import annotations\n\nfrom datetime import datetime, timezone\nfrom pathlib import Path\n\nfrom app.config import settings\n\n\ndef ensure_storage_dirs() -> None:\n    \"\"\"Create the output directory tree if it doesn't exist.\"\"\"\n    (settings.STORAGE_DIR / \"images\").mkdir(parents=True, exist_ok=True)\n    (settings.STORAGE_DIR / \"batches\").mkdir(parents=True, exist_ok=True)\n\n\ndef get_image_path(job_id: str, fmt: str, batch_id: str | None = None) -> Path:\n    \"\"\"Return the storage path for a generated image.\n\n    Layout:\n        output/batches/{batch_id}/{job_id}.{ext}   (if part of a batch)\n        output/images/{YYYY-MM-DD}/{job_id}.{ext}   (standalone)\n    \"\"\"\n    ext = fmt.lower()\n    if ext == \"jpeg\":\n        ext = \"jpg\"\n\n    if batch_id:\n        directory = settings.STORAGE_DIR / \"batches\" / batch_id\n    else:\n        date_str = datetime.now(timezone.utc).strftime(\"%Y-%m-%d\")\n        directory = settings.STORAGE_DIR / \"images\" / date_str\n\n    directory.mkdir(parents=True, exist_ok=True)\n    return directory / f\"{job_id}.{ext}\"\n\n\ndef image_url_path(job_id: str) -> str:\n    \"\"\"Return the API URL path to serve this image.\"\"\"\n    return f\"/jobs/{job_id}/image\"\n"
  },
  {
    "path": "app/worker/__init__.py",
    "content": ""
  },
  {
    "path": "app/worker/gpu_worker.py",
    "content": "\"\"\"GPU worker — async priority queue consumer driving the Flux engine.\"\"\"\n\nfrom __future__ import annotations\n\nimport asyncio\nimport logging\nimport time\nfrom dataclasses import dataclass, field\nfrom datetime import datetime, timezone\nfrom typing import TYPE_CHECKING\n\nfrom sqlalchemy import select, func\n\nfrom app.database import async_session_factory\nfrom app.models.enums import ImageFormat, JobStatus\nfrom app.models.job import BatchJob, Job\nfrom app.utils.image_utils import save_image\nfrom app.utils.storage import get_image_path, image_url_path\n\nif TYPE_CHECKING:\n    from app.engine.flux_engine import FluxEngine\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclass(order=True)\nclass _QueueItem:\n    \"\"\"Priority queue item. Lower priority value = higher urgency.\"\"\"\n\n    priority: int\n    timestamp: float = field(compare=True)\n    job_id: str = field(compare=False)\n\n\n@dataclass\nclass GPUWorker:\n    \"\"\"Consumes jobs from an async priority queue and runs inference.\"\"\"\n\n    engine: \"FluxEngine\"\n    _queue: asyncio.PriorityQueue[_QueueItem] = field(default_factory=asyncio.PriorityQueue)\n    _stop_event: asyncio.Event = field(default_factory=asyncio.Event)\n    _task: asyncio.Task | None = field(default=None, repr=False)\n    _jobs_processed: int = 0\n    _total_inference_time: float = 0.0\n\n    # ── Queue interface ──────────────────────────────────────────────\n\n    @staticmethod\n    def _normalize_priority(priority: int) -> int:\n        \"\"\"Convert enum/int-like priorities to a stable numeric queue value.\"\"\"\n        try:\n            value = int(priority)\n        except (TypeError, ValueError):\n            value = 1\n        return max(1, value)\n\n    async def enqueue(self, job_id: str, priority: int = 1) -> None:\n        normalized = self._normalize_priority(priority)\n        item = _QueueItem(priority=normalized, timestamp=time.time(), job_id=job_id)\n        await self._queue.put(item)\n        logger.debug(\"Enqueued job=%s priority=%d queue_depth=%d\", job_id, normalized, self.queue_depth)\n\n    @property\n    def queue_depth(self) -> int:\n        return self._queue.qsize()\n\n    # ── Lifecycle ────────────────────────────────────────────────────\n\n    def start(self) -> asyncio.Task:\n        \"\"\"Start the worker loop as a background task.\"\"\"\n        self._stop_event.clear()\n        self._task = asyncio.create_task(self._run(), name=\"gpu-worker\")\n        logger.info(\"GPU worker started.\")\n        return self._task\n\n    async def shutdown(self) -> None:\n        \"\"\"Signal the worker to stop and wait for it to finish the current job.\"\"\"\n        logger.info(\"Shutting down GPU worker (queue_depth=%d)...\", self.queue_depth)\n        self._stop_event.set()\n        if self._task and not self._task.done():\n            # Unblock a pending queue.get() by pushing a sentinel\n            await self._queue.put(_QueueItem(priority=999, timestamp=time.time(), job_id=\"__stop__\"))\n            await self._task\n        logger.info(\"GPU worker stopped. Processed %d jobs total.\", self._jobs_processed)\n\n    # ── Resume support ───────────────────────────────────────────────\n\n    async def resume_pending_jobs(self) -> int:\n        \"\"\"Re-enqueue jobs left in PENDING or PROCESSING state (crash recovery).\"\"\"\n        async with async_session_factory() as session:\n            result = await session.execute(\n                select(Job).where(Job.status.in_([\"pending\", \"processing\"]))\n            )\n            stale_jobs = result.scalars().all()\n\n            for job in stale_jobs:\n                job.status = \"pending\"\n                await self.enqueue(job.id, self._normalize_priority(job.priority))\n\n            await session.commit()\n\n        if stale_jobs:\n            logger.info(\"Resumed %d pending/stale jobs.\", len(stale_jobs))\n        return len(stale_jobs)\n\n    # ── Main loop ────────────────────────────────────────────────────\n\n    async def _run(self) -> None:\n        logger.info(\"GPU worker loop running.\")\n        while not self._stop_event.is_set():\n            try:\n                item = await asyncio.wait_for(self._queue.get(), timeout=1.0)\n            except asyncio.TimeoutError:\n                continue\n\n            if item.job_id == \"__stop__\":\n                break\n\n            await self._process_job(item.job_id)\n            self._queue.task_done()\n\n        logger.info(\"GPU worker loop exited.\")\n\n    async def _process_job(self, job_id: str) -> None:\n        async with async_session_factory() as session:\n            result = await session.execute(select(Job).where(Job.id == job_id))\n            job = result.scalar_one_or_none()\n            if job is None:\n                logger.warning(\"Job %s not found in DB, skipping.\", job_id)\n                return\n\n            # Skip cancelled jobs\n            if job.status == \"cancelled\":\n                logger.info(\"Job %s is cancelled, skipping.\", job_id)\n                return\n\n            # Mark as processing\n            job.status = \"processing\"\n            job.started_at = datetime.now(timezone.utc)\n            await session.commit()\n\n        try:\n            # Run inference in a thread to avoid blocking asyncio\n            gen_result = await asyncio.to_thread(\n                self.engine.generate,\n                prompt=job.prompt,\n                width=job.width,\n                height=job.height,\n                num_steps=job.num_steps,\n                guidance_scale=job.guidance_scale,\n                seed=job.seed,\n            )\n\n            # Save image to disk\n            fmt = ImageFormat(job.format)\n            file_path = get_image_path(job.id, job.format, batch_id=job.batch_id)\n            save_image(\n                gen_result.image,\n                file_path,\n                fmt,\n                prompt=job.prompt,\n                seed=gen_result.seed,\n            )\n\n            # Update job as completed\n            async with async_session_factory() as session:\n                result = await session.execute(select(Job).where(Job.id == job_id))\n                job = result.scalar_one()\n                job.status = \"completed\"\n                job.file_path = str(file_path)\n                job.seed = gen_result.seed\n                job.completed_at = datetime.now(timezone.utc)\n                await session.commit()\n\n                # Update batch counters if applicable\n                if job.batch_id:\n                    await self._update_batch_count(session, job.batch_id)\n\n            self._jobs_processed += 1\n            self._total_inference_time += gen_result.inference_time\n\n        except Exception as exc:\n            # Detect CUDA OOM without importing torch in mock/non-GPU environments.\n            if exc.__class__.__name__ == \"OutOfMemoryError\":\n                logger.error(\"OOM on job %s (%dx%d).\", job_id, job.width, job.height)\n                try:\n                    import torch\n\n                    torch.cuda.empty_cache()\n                except Exception:\n                    pass\n                await self._fail_job(job_id, \"GPU out of memory for this resolution\")\n                return\n\n            logger.exception(\"Error processing job %s\", job_id)\n            await self._fail_job(job_id, \"Internal generation error\")\n\n    async def _fail_job(self, job_id: str, error: str) -> None:\n        async with async_session_factory() as session:\n            result = await session.execute(select(Job).where(Job.id == job_id))\n            job = result.scalar_one_or_none()\n            if job:\n                job.status = \"failed\"\n                job.error_message = error\n                job.completed_at = datetime.now(timezone.utc)\n                await session.commit()\n\n                if job.batch_id:\n                    await self._update_batch_count(session, job.batch_id)\n\n    async def _update_batch_count(self, session, batch_id: str) -> None:\n        \"\"\"Recompute batch counters from individual jobs.\"\"\"\n        result = await session.execute(select(BatchJob).where(BatchJob.id == batch_id))\n        batch = result.scalar_one_or_none()\n        if not batch:\n            return\n\n        # Count completed and failed\n        completed = (\n            await session.execute(\n                select(func.count()).where(Job.batch_id == batch_id, Job.status == \"completed\")\n            )\n        ).scalar()\n        failed = (\n            await session.execute(\n                select(func.count()).where(Job.batch_id == batch_id, Job.status == \"failed\")\n            )\n        ).scalar()\n\n        batch.completed_count = completed or 0\n        batch.failed_count = failed or 0\n\n        done = (completed or 0) + (failed or 0)\n        if done >= batch.total_count:\n            batch.status = \"completed\" if (failed or 0) == 0 else \"failed\"\n            batch.completed_at = datetime.now(timezone.utc)\n        elif done > 0:\n            batch.status = \"processing\"\n\n        await session.commit()\n\n    # ── Stats ────────────────────────────────────────────────────────\n\n    def get_stats(self) -> dict:\n        avg = (\n            self._total_inference_time / self._jobs_processed\n            if self._jobs_processed > 0\n            else 0.0\n        )\n        return {\n            \"queue_depth\": self.queue_depth,\n            \"jobs_processed\": self._jobs_processed,\n            \"total_inference_time_s\": round(self._total_inference_time, 2),\n            \"avg_inference_time_s\": round(avg, 2),\n        }\n"
  },
  {
    "path": "flux_image_service.egg-info/PKG-INFO",
    "content": "Metadata-Version: 2.4\nName: flux-image-service\nVersion: 0.1.0\nSummary: Local image generation service using Flux.1 Schnell 12B\nRequires-Python: >=3.10\nRequires-Dist: fastapi>=0.115.0\nRequires-Dist: uvicorn[standard]>=0.30.0\nRequires-Dist: pydantic>=2.0\nRequires-Dist: pydantic-settings>=2.0\nRequires-Dist: sqlalchemy[asyncio]>=2.0\nRequires-Dist: aiosqlite>=0.20.0\nRequires-Dist: torch>=2.1\nRequires-Dist: diffusers>=0.30.0\nRequires-Dist: transformers>=4.40.0\nRequires-Dist: accelerate>=0.30.0\nRequires-Dist: sentencepiece>=0.2.0\nRequires-Dist: protobuf>=4.25.0\nRequires-Dist: Pillow>=10.0.0\nRequires-Dist: python-multipart>=0.0.9\nRequires-Dist: sse-starlette>=2.0.0\nRequires-Dist: httpx>=0.27.0\nProvides-Extra: dev\nRequires-Dist: pytest>=8.0; extra == \"dev\"\nRequires-Dist: pytest-asyncio>=0.23.0; extra == \"dev\"\nRequires-Dist: httpx>=0.27.0; extra == \"dev\"\nRequires-Dist: ruff>=0.4.0; extra == \"dev\"\n"
  },
  {
    "path": "flux_image_service.egg-info/SOURCES.txt",
    "content": "README.md\npyproject.toml\napp/__init__.py\napp/config.py\napp/database.py\napp/main.py\napp/api/__init__.py\napp/api/batch.py\napp/api/generate.py\napp/api/health.py\napp/api/jobs.py\napp/engine/__init__.py\napp/engine/engine_config.py\napp/engine/flux_engine.py\napp/models/__init__.py\napp/models/enums.py\napp/models/job.py\napp/schemas/__init__.py\napp/schemas/batch.py\napp/schemas/generate.py\napp/schemas/job.py\napp/services/__init__.py\napp/services/generation_service.py\napp/services/job_service.py\napp/services/moderation_service.py\napp/utils/__init__.py\napp/utils/image_utils.py\napp/utils/moderation.py\napp/utils/storage.py\napp/worker/__init__.py\napp/worker/gpu_worker.py\nflux_image_service.egg-info/PKG-INFO\nflux_image_service.egg-info/SOURCES.txt\nflux_image_service.egg-info/dependency_links.txt\nflux_image_service.egg-info/requires.txt\nflux_image_service.egg-info/top_level.txt\ntests/test_api_batch.py\ntests/test_api_generate.py\ntests/test_api_jobs.py\ntests/test_moderation.py\ntests/test_moderation_service.py\ntests/test_schemas.py\ntests/test_worker.py"
  },
  {
    "path": "flux_image_service.egg-info/dependency_links.txt",
    "content": "\n"
  },
  {
    "path": "flux_image_service.egg-info/requires.txt",
    "content": "fastapi>=0.115.0\nuvicorn[standard]>=0.30.0\npydantic>=2.0\npydantic-settings>=2.0\nsqlalchemy[asyncio]>=2.0\naiosqlite>=0.20.0\ntorch>=2.1\ndiffusers>=0.30.0\ntransformers>=4.40.0\naccelerate>=0.30.0\nsentencepiece>=0.2.0\nprotobuf>=4.25.0\nPillow>=10.0.0\npython-multipart>=0.0.9\nsse-starlette>=2.0.0\nhttpx>=0.27.0\n\n[dev]\npytest>=8.0\npytest-asyncio>=0.23.0\nhttpx>=0.27.0\nruff>=0.4.0\n"
  },
  {
    "path": "flux_image_service.egg-info/top_level.txt",
    "content": "app\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[project]\nname = \"flux-image-service\"\nversion = \"0.1.0\"\ndescription = \"Local image generation service using Flux.1 Schnell 12B\"\nrequires-python = \">=3.10\"\ndependencies = [\n    \"fastapi>=0.115.0\",\n    \"uvicorn[standard]>=0.30.0\",\n    \"pydantic>=2.0\",\n    \"pydantic-settings>=2.0\",\n    \"sqlalchemy[asyncio]>=2.0\",\n    \"aiosqlite>=0.20.0\",\n    \"torch>=2.1\",\n    \"diffusers>=0.30.0\",\n    \"transformers>=4.40.0\",\n    \"accelerate>=0.30.0\",\n    \"sentencepiece>=0.2.0\",\n    \"protobuf>=4.25.0\",\n    \"Pillow>=10.0.0\",\n    \"python-multipart>=0.0.9\",\n    \"sse-starlette>=2.0.0\",\n    \"httpx>=0.27.0\",\n]\n\n[project.optional-dependencies]\ndev = [\n    \"pytest>=8.0\",\n    \"pytest-asyncio>=0.23.0\",\n    \"httpx>=0.27.0\",\n    \"ruff>=0.4.0\",\n]\n\n[tool.pytest.ini_options]\nasyncio_mode = \"auto\"\ntestpaths = [\"tests\"]\n\n[tool.ruff]\ntarget-version = \"py310\"\nline-length = 100\n"
  },
  {
    "path": "tests/__init__.py",
    "content": ""
  },
  {
    "path": "tests/conftest.py",
    "content": "\"\"\"Shared test fixtures.\"\"\"\n\nfrom __future__ import annotations\n\nimport asyncio\nfrom unittest.mock import AsyncMock, MagicMock, patch\nfrom pathlib import Path\n\nimport pytest\nfrom httpx import ASGITransport, AsyncClient\nfrom PIL import Image\nfrom sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine\n\nfrom app.engine.flux_engine import GenerationResult\nfrom app.models.job import Base\n\n\n@pytest.fixture(scope=\"session\")\ndef event_loop():\n    loop = asyncio.new_event_loop()\n    yield loop\n    loop.close()\n\n\n@pytest.fixture\nasync def db_engine(tmp_path):\n    \"\"\"In-memory SQLite for tests.\"\"\"\n    engine = create_async_engine(\"sqlite+aiosqlite:///:memory:\", echo=False)\n    async with engine.begin() as conn:\n        await conn.run_sync(Base.metadata.create_all)\n    yield engine\n    await engine.dispose()\n\n\n@pytest.fixture\nasync def db_session(db_engine):\n    factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)\n    async with factory() as session:\n        yield session\n\n\n@pytest.fixture\ndef mock_engine():\n    \"\"\"A mock FluxEngine that returns a tiny test image.\"\"\"\n    engine = MagicMock()\n    engine.is_loaded = True\n    engine.get_vram_stats.return_value = {\"available\": True, \"allocated_gb\": 12.0}\n    engine.get_vram_usage_gb.return_value = 12.0\n\n    dummy_image = Image.new(\"RGB\", (64, 64), color=\"red\")\n    engine.generate.return_value = GenerationResult(\n        image=dummy_image, seed=42, inference_time=0.5\n    )\n    return engine\n\n\n@pytest.fixture\ndef mock_worker(mock_engine):\n    \"\"\"A mock GPUWorker that tracks enqueued job IDs.\"\"\"\n    worker = MagicMock()\n    worker.queue_depth = 0\n    worker.enqueue = AsyncMock()\n    worker.get_stats.return_value = {\n        \"queue_depth\": 0,\n        \"jobs_processed\": 0,\n        \"total_inference_time_s\": 0.0,\n        \"avg_inference_time_s\": 0.0,\n    }\n    return worker\n\n\n@pytest.fixture\nasync def client(mock_engine, mock_worker, db_engine, tmp_path):\n    \"\"\"Test client with mocked engine/worker and in-memory DB.\"\"\"\n    from app.database import get_session\n    from app.models.job import Base\n\n    factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)\n\n    async def override_get_session():\n        async with factory() as session:\n            yield session\n\n    # Patch the database session factory used by the worker\n    with patch(\"app.worker.gpu_worker.async_session_factory\", factory), \\\n         patch(\"app.database.async_session_factory\", factory):\n\n        from app.main import app\n\n        app.dependency_overrides[get_session] = override_get_session\n        app.state.engine = mock_engine\n        app.state.worker = mock_worker\n\n        transport = ASGITransport(app=app)\n        async with AsyncClient(transport=transport, base_url=\"http://test\") as ac:\n            yield ac\n\n        app.dependency_overrides.clear()\n"
  },
  {
    "path": "tests/test_api_batch.py",
    "content": "\"\"\"Tests for batch API endpoints.\"\"\"\n\nimport pytest\n\n\n@pytest.mark.asyncio\nasync def test_create_batch(client, mock_worker):\n    resp = await client.post(\n        \"/batch\",\n        json={\n            \"name\": \"test-batch\",\n            \"prompts\": [\n                {\"prompt\": \"a red car\"},\n                {\"prompt\": \"a blue house\"},\n                {\"prompt\": \"a green tree\"},\n            ],\n        },\n    )\n    assert resp.status_code == 200\n    data = resp.json()\n    assert data[\"total\"] == 3\n    assert data[\"status\"] == \"pending\"\n    assert data[\"batch_id\"]\n    # Worker should have been called 3 times\n    assert mock_worker.enqueue.call_count == 3\n\n\n@pytest.mark.asyncio\nasync def test_get_batch_404(client):\n    resp = await client.get(\"/batch/nonexistent-id\")\n    assert resp.status_code == 404\n\n\n@pytest.mark.asyncio\nasync def test_cancel_batch(client, mock_worker):\n    # Create batch first\n    resp = await client.post(\n        \"/batch\",\n        json={\n            \"name\": \"cancel-test\",\n            \"prompts\": [{\"prompt\": \"p1\"}, {\"prompt\": \"p2\"}],\n        },\n    )\n    batch_id = resp.json()[\"batch_id\"]\n\n    # Cancel it\n    resp = await client.delete(f\"/batch/{batch_id}\")\n    assert resp.status_code == 200\n    data = resp.json()\n    assert data[\"status\"] == \"cancelled\"\n\n\n@pytest.mark.asyncio\nasync def test_batch_requires_prompts(client):\n    resp = await client.post(\"/batch\", json={\"name\": \"empty\", \"prompts\": []})\n    assert resp.status_code == 422\n"
  },
  {
    "path": "tests/test_api_generate.py",
    "content": "\"\"\"Tests for generation API endpoints.\"\"\"\n\nimport pytest\nfrom unittest.mock import AsyncMock, patch\nfrom sqlalchemy import select\nfrom sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker\n\nfrom app.models.job import Job\nfrom app.services.moderation_service import ModerationError\n\n\n@pytest.mark.asyncio\nasync def test_generate_async_returns_job_ids(client):\n    resp = await client.post(\"/generate\", json={\"prompt\": \"a beautiful sunset\"})\n    assert resp.status_code == 200\n    data = resp.json()\n    assert \"job_ids\" in data\n    assert len(data[\"job_ids\"]) == 1\n    assert data[\"status\"] == \"pending\"\n\n\n@pytest.mark.asyncio\nasync def test_generate_multiple_images(client, mock_worker):\n    resp = await client.post(\"/generate\", json={\"prompt\": \"test\", \"num_images\": 3})\n    assert resp.status_code == 200\n    data = resp.json()\n    assert len(data[\"job_ids\"]) == 3\n    assert mock_worker.enqueue.call_count == 3\n\n\n@pytest.mark.asyncio\nasync def test_generate_max_4_images(client):\n    resp = await client.post(\"/generate\", json={\"prompt\": \"test\", \"num_images\": 4})\n    assert resp.status_code == 200\n    assert len(resp.json()[\"job_ids\"]) == 4\n\n\n@pytest.mark.asyncio\nasync def test_generate_rejects_more_than_4_images(client):\n    resp = await client.post(\"/generate\", json={\"prompt\": \"test\", \"num_images\": 5})\n    assert resp.status_code == 422\n\n\n@pytest.mark.asyncio\nasync def test_generate_async_enqueues_work(client, mock_worker):\n    resp = await client.post(\"/generate\", json={\"prompt\": \"test prompt\"})\n    assert resp.status_code == 200\n    mock_worker.enqueue.assert_called_once()\n\n\n@pytest.mark.asyncio\nasync def test_generate_validates_resolution_multiple_of_8(client):\n    resp = await client.post(\n        \"/generate\",\n        json={\"prompt\": \"test\", \"width\": 513, \"height\": 1024},\n    )\n    assert resp.status_code == 422  # validation error\n\n\n@pytest.mark.asyncio\nasync def test_generate_rejects_empty_prompt(client):\n    resp = await client.post(\"/generate\", json={\"prompt\": \"\"})\n    assert resp.status_code == 422\n\n\n@pytest.mark.asyncio\nasync def test_generate_rejects_oversized_resolution(client):\n    resp = await client.post(\n        \"/generate\",\n        json={\"prompt\": \"test\", \"width\": 4096, \"height\": 4096},\n    )\n    assert resp.status_code == 422\n\n\n@pytest.mark.asyncio\nasync def test_generate_default_values(client, mock_worker):\n    resp = await client.post(\"/generate\", json={\"prompt\": \"a cat\"})\n    assert resp.status_code == 200\n\n\n@pytest.mark.asyncio\nasync def test_serve_image_404_for_missing_job(client):\n    resp = await client.get(\"/jobs/nonexistent-id/image\")\n    assert resp.status_code == 404\n\n\n@pytest.mark.asyncio\nasync def test_generate_stream_emits_done_for_completed_jobs(client, db_engine):\n    create = await client.post(\"/generate\", json={\"prompt\": \"test\", \"num_images\": 2})\n    assert create.status_code == 200\n    job_ids = create.json()[\"job_ids\"]\n\n    # Mark jobs as completed so the SSE stream terminates immediately.\n    factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)\n    async with factory() as session:\n        result = await session.execute(select(Job).where(Job.id.in_(job_ids)))\n        jobs = result.scalars().all()\n        for job in jobs:\n            job.status = \"completed\"\n            job.file_path = f\"output/images/{job.id}.png\"\n        await session.commit()\n\n    resp = await client.get(\n        \"/generate/stream\",\n        params=[(\"job_id\", job_ids[0]), (\"job_id\", job_ids[1])],\n    )\n    assert resp.status_code == 200\n    text = resp.text\n    assert \"event: progress\" in text\n    assert \"event: done\" in text\n    assert '\"completed\": 2' in text\n\n\n@pytest.mark.asyncio\nasync def test_generate_stream_returns_404_for_unknown_job(client):\n    resp = await client.get(\"/generate/stream\", params={\"job_id\": \"unknown-job\"})\n    assert resp.status_code == 404\n\n\n# ---------------------------------------------------------------------------\n# Content moderation integration tests\n# ---------------------------------------------------------------------------\n\n\n@pytest.mark.asyncio\nasync def test_generate_async_blocked_when_prompt_is_unsafe(client):\n    \"\"\"POST /generate returns 400 when the moderation service flags the prompt.\"\"\"\n    with patch(\n        \"app.api.generate.check_prompt_safety\",\n        new=AsyncMock(return_value=(False, \"sexual\")),\n    ):\n        resp = await client.post(\"/generate\", json={\"prompt\": \"explicit content\"})\n\n    assert resp.status_code == 400\n    assert \"unsafe\" in resp.json()[\"detail\"].lower()\n\n\n@pytest.mark.asyncio\nasync def test_generate_async_proceeds_when_prompt_is_safe(client, mock_worker):\n    \"\"\"POST /generate proceeds normally when the moderation service approves the prompt.\"\"\"\n    with patch(\n        \"app.api.generate.check_prompt_safety\",\n        new=AsyncMock(return_value=(True, None)),\n    ):\n        resp = await client.post(\"/generate\", json={\"prompt\": \"a beautiful landscape\"})\n\n    assert resp.status_code == 200\n    assert \"job_ids\" in resp.json()\n    mock_worker.enqueue.assert_called()\n\n\n@pytest.mark.asyncio\nasync def test_generate_async_returns_503_when_moderation_service_fails(client):\n    \"\"\"POST /generate returns 503 when the moderation API is unreachable.\"\"\"\n    with patch(\n        \"app.api.generate.check_prompt_safety\",\n        new=AsyncMock(side_effect=ModerationError(\"API unreachable\")),\n    ):\n        resp = await client.post(\"/generate\", json={\"prompt\": \"a beautiful landscape\"})\n\n    assert resp.status_code == 503\n\n\n@pytest.mark.asyncio\nasync def test_generate_sync_blocked_when_prompt_is_unsafe(client):\n    \"\"\"POST /generate/sync returns 400 when the moderation service flags the prompt.\"\"\"\n    with patch(\n        \"app.api.generate.check_prompt_safety\",\n        new=AsyncMock(return_value=(False, \"violence\")),\n    ):\n        resp = await client.post(\"/generate/sync\", json={\"prompt\": \"violent content\"})\n\n    assert resp.status_code == 400\n    assert \"unsafe\" in resp.json()[\"detail\"].lower()\n\n\n@pytest.mark.asyncio\nasync def test_generate_sync_returns_503_when_moderation_service_fails(client):\n    \"\"\"POST /generate/sync returns 503 when the moderation API is unreachable.\"\"\"\n    with patch(\n        \"app.api.generate.check_prompt_safety\",\n        new=AsyncMock(side_effect=ModerationError(\"timeout\")),\n    ):\n        resp = await client.post(\"/generate/sync\", json={\"prompt\": \"a sunset\"})\n\n    assert resp.status_code == 503\n"
  },
  {
    "path": "tests/test_api_jobs.py",
    "content": "\"\"\"Tests for job management endpoints.\"\"\"\n\nimport pytest\n\n\n@pytest.mark.asyncio\nasync def test_list_jobs_empty(client):\n    resp = await client.get(\"/jobs\")\n    assert resp.status_code == 200\n    data = resp.json()\n    assert data[\"total\"] == 0\n    assert data[\"jobs\"] == []\n\n\n@pytest.mark.asyncio\nasync def test_get_job_after_create(client, mock_worker):\n    # Create a job via generate\n    gen_resp = await client.post(\"/generate\", json={\"prompt\": \"test job\"})\n    job_id = gen_resp.json()[\"job_ids\"][0]\n\n    # Fetch it\n    resp = await client.get(f\"/jobs/{job_id}\")\n    assert resp.status_code == 200\n    data = resp.json()\n    assert data[\"id\"] == job_id\n    assert data[\"prompt\"] == \"test job\"\n    assert data[\"status\"] == \"pending\"\n    assert data[\"width\"] == 1024\n    assert data[\"height\"] == 1024\n\n\n@pytest.mark.asyncio\nasync def test_cancel_job(client, mock_worker):\n    gen_resp = await client.post(\"/generate\", json={\"prompt\": \"to cancel\"})\n    job_id = gen_resp.json()[\"job_ids\"][0]\n\n    resp = await client.delete(f\"/jobs/{job_id}\")\n    assert resp.status_code == 200\n    assert resp.json()[\"status\"] == \"cancelled\"\n\n\n@pytest.mark.asyncio\nasync def test_list_jobs_with_status_filter(client, mock_worker):\n    await client.post(\"/generate\", json={\"prompt\": \"j1\"})\n    await client.post(\"/generate\", json={\"prompt\": \"j2\"})\n\n    resp = await client.get(\"/jobs?status=pending\")\n    assert resp.status_code == 200\n    data = resp.json()\n    assert data[\"total\"] == 2\n\n\n@pytest.mark.asyncio\nasync def test_get_nonexistent_job(client):\n    resp = await client.get(\"/jobs/does-not-exist\")\n    assert resp.status_code == 404\n\n\n@pytest.mark.asyncio\nasync def test_health_endpoint(client):\n    resp = await client.get(\"/health\")\n    assert resp.status_code == 200\n    data = resp.json()\n    assert data[\"status\"] == \"ok\"\n    assert data[\"model_loaded\"] is True\n    assert \"vram\" in data\n    assert \"worker\" in data\n"
  },
  {
    "path": "tests/test_moderation.py",
    "content": "\"\"\"Tests for the content moderation module.\"\"\"\n\nfrom __future__ import annotations\n\nfrom unittest.mock import MagicMock, patch\n\nimport pytest\nfrom pydantic import ValidationError\n\nfrom app.utils.moderation import ModerationEngine, ModerationResult, check_prompt\n\n\n# ── ModerationEngine unit tests ───────────────────────────────────────────────\n\n\nclass TestModerationEngine:\n    def test_is_loaded_false_before_load(self):\n        engine = ModerationEngine()\n        assert not engine.is_loaded\n\n    def test_load_failure_leaves_engine_unloaded(self):\n        engine = ModerationEngine()\n        with patch(\"app.utils.moderation.ModerationEngine.load\") as mock_load:\n            mock_load.side_effect = RuntimeError(\"no GPU\")\n            try:\n                engine.load(\"google/gemma-3-1b-it\")\n            except RuntimeError:\n                pass\n        assert not engine.is_loaded\n\n    def test_parse_pipeline_output_safe(self):\n        engine = ModerationEngine()\n        raw = [{\"generated_text\": [{\"role\": \"assistant\", \"content\": \"SAFE\"}]}]\n        assert engine._parse_pipeline_output(raw) == \"SAFE\"\n\n    def test_parse_pipeline_output_unsafe(self):\n        engine = ModerationEngine()\n        raw = [{\"generated_text\": [{\"role\": \"assistant\", \"content\": \"UNSAFE\"}]}]\n        assert engine._parse_pipeline_output(raw) == \"UNSAFE\"\n\n    def test_check_with_local_model_raises_when_not_loaded(self):\n        engine = ModerationEngine()\n        with pytest.raises(RuntimeError, match=\"not loaded\"):\n            engine.check_with_local_model(\"test prompt\")\n\n    def test_check_with_local_model_safe(self):\n        engine = ModerationEngine()\n        engine._pipeline = MagicMock(\n            return_value=[{\"generated_text\": [{\"role\": \"assistant\", \"content\": \"SAFE\"}]}]\n        )\n        result = engine.check_with_local_model(\"a beautiful sunset\")\n        assert result.is_safe is True\n        assert result.source == \"local\"\n\n    def test_check_with_local_model_unsafe(self):\n        engine = ModerationEngine()\n        engine._pipeline = MagicMock(\n            return_value=[{\"generated_text\": [{\"role\": \"assistant\", \"content\": \"UNSAFE\"}]}]\n        )\n        result = engine.check_with_local_model(\"explicit harmful content\")\n        assert result.is_safe is False\n        assert result.source == \"local\"\n        assert \"Flagged by local model\" in result.reason\n\n    def test_check_with_agent_api_safe(self):\n        engine = ModerationEngine()\n        engine._model_id = \"google/gemma-3-1b-it\"\n\n        mock_choice = MagicMock()\n        mock_choice.message.content = \"SAFE\"\n        mock_response = MagicMock()\n        mock_response.choices = [mock_choice]\n\n        mock_client = MagicMock()\n        mock_client.chat_completion.return_value = mock_response\n\n        with patch(\"huggingface_hub.InferenceClient\", return_value=mock_client):\n            result = engine.check_with_agent_api(\"a nice landscape\", token=\"hf_test\")\n\n        assert result.is_safe is True\n        assert result.source == \"agent_api\"\n\n    def test_check_with_agent_api_unsafe(self):\n        engine = ModerationEngine()\n        engine._model_id = \"google/gemma-3-1b-it\"\n\n        mock_choice = MagicMock()\n        mock_choice.message.content = \"UNSAFE\"\n        mock_response = MagicMock()\n        mock_response.choices = [mock_choice]\n\n        mock_client = MagicMock()\n        mock_client.chat_completion.return_value = mock_response\n\n        with patch(\"huggingface_hub.InferenceClient\", return_value=mock_client):\n            result = engine.check_with_agent_api(\"explicit content\", token=\"hf_test\")\n\n        assert result.is_safe is False\n        assert result.source == \"agent_api\"\n        assert \"Flagged by agent API\" in result.reason\n\n    def test_check_uses_local_model_first(self):\n        engine = ModerationEngine()\n        engine._pipeline = MagicMock(\n            return_value=[{\"generated_text\": [{\"role\": \"assistant\", \"content\": \"SAFE\"}]}]\n        )\n\n        with patch.object(engine, \"check_with_agent_api\") as mock_api:\n            result = engine.check(\"safe prompt\")\n\n        assert result.is_safe is True\n        assert result.source == \"local\"\n        mock_api.assert_not_called()\n\n    def test_check_falls_back_to_agent_api_when_local_fails(self):\n        engine = ModerationEngine()\n        engine._pipeline = MagicMock(side_effect=RuntimeError(\"inference error\"))\n        engine._model_id = \"google/gemma-3-1b-it\"\n\n        fallback_result = ModerationResult(is_safe=True, source=\"agent_api\")\n        with patch.object(engine, \"check_with_agent_api\", return_value=fallback_result) as mock_api:\n            result = engine.check(\"test prompt\", hf_token=\"hf_test\")\n\n        assert result.source == \"agent_api\"\n        mock_api.assert_called_once()\n\n    def test_check_fails_open_when_both_unavailable(self):\n        engine = ModerationEngine()\n        # Local not loaded, no model_id so agent API branch skipped\n        result = engine.check(\"test prompt\")\n        assert result.is_safe is True\n        assert result.source == \"none\"\n\n    def test_check_falls_open_when_agent_api_also_fails(self):\n        engine = ModerationEngine()\n        engine._model_id = \"google/gemma-3-1b-it\"\n\n        with patch.object(engine, \"check_with_agent_api\", side_effect=Exception(\"API error\")):\n            result = engine.check(\"test prompt\", hf_token=\"hf_test\")\n\n        assert result.is_safe is True\n        assert result.source == \"none\"\n\n\n# ── check_prompt() integration tests ─────────────────────────────────────────\n\n\nclass TestCheckPrompt:\n    def test_allows_safe_prompt(self):\n        with patch(\"app.config.settings\") as mock_settings, \\\n             patch(\"app.utils.moderation._moderation_engine\") as mock_engine:\n            mock_settings.MODERATION_ENABLED = True\n            mock_settings.HF_API_TOKEN = None\n            mock_engine.check.return_value = ModerationResult(is_safe=True, source=\"local\")\n            check_prompt(\"a beautiful sunset\")  # should not raise\n\n    def test_rejects_unsafe_prompt(self):\n        with patch(\"app.config.settings\") as mock_settings, \\\n             patch(\"app.utils.moderation._moderation_engine\") as mock_engine:\n            mock_settings.MODERATION_ENABLED = True\n            mock_settings.HF_API_TOKEN = None\n            mock_engine.check.return_value = ModerationResult(\n                is_safe=False, reason=\"Flagged by local model: UNSAFE\", source=\"local\"\n            )\n            with pytest.raises(ValueError, match=\"Flagged by local model\"):\n                check_prompt(\"explicit content\")\n\n    def test_skips_when_moderation_disabled(self):\n        with patch(\"app.config.settings\") as mock_settings, patch(\n            \"app.utils.moderation._moderation_engine\"\n        ) as mock_engine:\n            mock_settings.MODERATION_ENABLED = False\n            check_prompt(\"any content\")\n            mock_engine.check.assert_not_called()\n\n\n# ── ImageRequest schema integration ──────────────────────────────────────────\n\n\nclass TestImageRequestModeration:\n    def test_unsafe_prompt_raises_validation_error(self):\n        with patch(\"app.config.settings\") as mock_settings, \\\n             patch(\"app.utils.moderation._moderation_engine\") as mock_engine:\n            mock_settings.MODERATION_ENABLED = True\n            mock_settings.HF_API_TOKEN = None\n            mock_engine.check.return_value = ModerationResult(\n                is_safe=False, reason=\"Flagged by local model: UNSAFE\", source=\"local\"\n            )\n            with pytest.raises(ValidationError, match=\"content moderation|Flagged\"):\n                from app.schemas.generate import ImageRequest\n\n                ImageRequest(prompt=\"explicit content\")\n\n    def test_safe_prompt_passes_validation(self):\n        with patch(\"app.config.settings\") as mock_settings, \\\n             patch(\"app.utils.moderation._moderation_engine\") as mock_engine:\n            mock_settings.MODERATION_ENABLED = True\n            mock_settings.HF_API_TOKEN = None\n            mock_engine.check.return_value = ModerationResult(is_safe=True, source=\"local\")\n            from app.schemas.generate import ImageRequest\n\n            req = ImageRequest(prompt=\"a mountain landscape\")\n            assert req.prompt == \"a mountain landscape\"\n"
  },
  {
    "path": "tests/test_moderation_service.py",
    "content": "\"\"\"Tests for the content moderation service.\"\"\"\n\nfrom __future__ import annotations\n\nfrom unittest.mock import AsyncMock, MagicMock, patch\n\nimport httpx\nimport pytest\n\nfrom app.services.moderation_service import ModerationError, check_prompt_safety\n\n\n# ---------------------------------------------------------------------------\n# Helpers\n# ---------------------------------------------------------------------------\n\n\ndef _make_moderation_response(flagged: bool, categories: dict[str, bool] | None = None):\n    \"\"\"Build a mock httpx.Response that mimics the OpenAI moderation format.\"\"\"\n    if categories is None:\n        categories = {}\n    payload = {\n        \"results\": [\n            {\n                \"flagged\": flagged,\n                \"categories\": categories,\n                \"category_scores\": {k: 0.9 if v else 0.01 for k, v in categories.items()},\n            }\n        ]\n    }\n    response = MagicMock(spec=httpx.Response)\n    response.json.return_value = payload\n    response.raise_for_status = MagicMock()\n    return response\n\n\n# ---------------------------------------------------------------------------\n# Tests: moderation disabled\n# ---------------------------------------------------------------------------\n\n\n@pytest.mark.asyncio\nasync def test_check_prompt_safety_disabled_skips_api():\n    \"\"\"When moderation is disabled the function returns safe without any API call.\"\"\"\n    with patch(\"app.services.moderation_service.settings\") as mock_settings:\n        mock_settings.MODERATION_ENABLED = False\n        is_safe, reason = await check_prompt_safety(\"a test prompt\")\n\n    assert is_safe is True\n    assert reason is None\n\n\n@pytest.mark.asyncio\nasync def test_check_prompt_safety_no_api_key_skips_api():\n    \"\"\"When enabled but no API key is set, the function returns safe with a warning.\"\"\"\n    with patch(\"app.services.moderation_service.settings\") as mock_settings:\n        mock_settings.MODERATION_ENABLED = True\n        mock_settings.MODERATION_API_KEY = \"\"\n        is_safe, reason = await check_prompt_safety(\"a test prompt\")\n\n    assert is_safe is True\n    assert reason is None\n\n\n# ---------------------------------------------------------------------------\n# Tests: safe prompt\n# ---------------------------------------------------------------------------\n\n\n@pytest.mark.asyncio\nasync def test_check_prompt_safety_safe_prompt():\n    \"\"\"A prompt that the API says is safe returns (True, None).\"\"\"\n    mock_response = _make_moderation_response(flagged=False)\n\n    with patch(\"app.services.moderation_service.settings\") as mock_settings, patch(\n        \"app.services.moderation_service.httpx.AsyncClient\"\n    ) as mock_client_cls:\n        mock_settings.MODERATION_ENABLED = True\n        mock_settings.MODERATION_API_KEY = \"test-key\"\n        mock_settings.MODERATION_API_URL = \"https://api.openai.com/v1/moderations\"\n        mock_settings.MODERATION_TIMEOUT = 10.0\n\n        mock_client = AsyncMock()\n        mock_client.post = AsyncMock(return_value=mock_response)\n        mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)\n        mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)\n\n        is_safe, reason = await check_prompt_safety(\"a beautiful sunset over the mountains\")\n\n    assert is_safe is True\n    assert reason is None\n\n\n# ---------------------------------------------------------------------------\n# Tests: unsafe / flagged prompt\n# ---------------------------------------------------------------------------\n\n\n@pytest.mark.asyncio\nasync def test_check_prompt_safety_flagged_prompt():\n    \"\"\"A flagged prompt returns (False, reason) with the triggered categories.\"\"\"\n    mock_response = _make_moderation_response(\n        flagged=True,\n        categories={\"sexual\": True, \"violence\": False, \"hate\": False},\n    )\n\n    with patch(\"app.services.moderation_service.settings\") as mock_settings, patch(\n        \"app.services.moderation_service.httpx.AsyncClient\"\n    ) as mock_client_cls:\n        mock_settings.MODERATION_ENABLED = True\n        mock_settings.MODERATION_API_KEY = \"test-key\"\n        mock_settings.MODERATION_API_URL = \"https://api.openai.com/v1/moderations\"\n        mock_settings.MODERATION_TIMEOUT = 10.0\n\n        mock_client = AsyncMock()\n        mock_client.post = AsyncMock(return_value=mock_response)\n        mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)\n        mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)\n\n        is_safe, reason = await check_prompt_safety(\"explicit nsfw content here\")\n\n    assert is_safe is False\n    assert reason is not None\n    assert \"sexual\" in reason\n\n\n@pytest.mark.asyncio\nasync def test_check_prompt_safety_flagged_multiple_categories():\n    \"\"\"Multiple triggered categories are all included in the reason string.\"\"\"\n    mock_response = _make_moderation_response(\n        flagged=True,\n        categories={\"sexual\": True, \"violence\": True, \"hate\": False},\n    )\n\n    with patch(\"app.services.moderation_service.settings\") as mock_settings, patch(\n        \"app.services.moderation_service.httpx.AsyncClient\"\n    ) as mock_client_cls:\n        mock_settings.MODERATION_ENABLED = True\n        mock_settings.MODERATION_API_KEY = \"test-key\"\n        mock_settings.MODERATION_API_URL = \"https://api.openai.com/v1/moderations\"\n        mock_settings.MODERATION_TIMEOUT = 10.0\n\n        mock_client = AsyncMock()\n        mock_client.post = AsyncMock(return_value=mock_response)\n        mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)\n        mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)\n\n        is_safe, reason = await check_prompt_safety(\"violent explicit content\")\n\n    assert is_safe is False\n    assert \"sexual\" in reason\n    assert \"violence\" in reason\n\n\n@pytest.mark.asyncio\nasync def test_check_prompt_safety_flagged_no_categories():\n    \"\"\"If the API flags the prompt but provides no categories, reason is generic.\"\"\"\n    mock_response = _make_moderation_response(flagged=True, categories={})\n\n    with patch(\"app.services.moderation_service.settings\") as mock_settings, patch(\n        \"app.services.moderation_service.httpx.AsyncClient\"\n    ) as mock_client_cls:\n        mock_settings.MODERATION_ENABLED = True\n        mock_settings.MODERATION_API_KEY = \"test-key\"\n        mock_settings.MODERATION_API_URL = \"https://api.openai.com/v1/moderations\"\n        mock_settings.MODERATION_TIMEOUT = 10.0\n\n        mock_client = AsyncMock()\n        mock_client.post = AsyncMock(return_value=mock_response)\n        mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)\n        mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)\n\n        is_safe, reason = await check_prompt_safety(\"some flagged content\")\n\n    assert is_safe is False\n    assert reason == \"unsafe content\"\n\n\n# ---------------------------------------------------------------------------\n# Tests: API failures\n# ---------------------------------------------------------------------------\n\n\n@pytest.mark.asyncio\nasync def test_check_prompt_safety_http_error_raises_moderation_error():\n    \"\"\"A network-level HTTP error raises ModerationError.\"\"\"\n    with patch(\"app.services.moderation_service.settings\") as mock_settings, patch(\n        \"app.services.moderation_service.httpx.AsyncClient\"\n    ) as mock_client_cls:\n        mock_settings.MODERATION_ENABLED = True\n        mock_settings.MODERATION_API_KEY = \"test-key\"\n        mock_settings.MODERATION_API_URL = \"https://api.openai.com/v1/moderations\"\n        mock_settings.MODERATION_TIMEOUT = 10.0\n\n        mock_client = AsyncMock()\n        mock_client.post = AsyncMock(\n            side_effect=httpx.ConnectError(\"Connection refused\")\n        )\n        mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)\n        mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)\n\n        with pytest.raises(ModerationError):\n            await check_prompt_safety(\"test prompt\")\n\n\n@pytest.mark.asyncio\nasync def test_check_prompt_safety_http_status_error_raises_moderation_error():\n    \"\"\"A non-2xx API response raises ModerationError.\"\"\"\n    mock_response = MagicMock(spec=httpx.Response)\n    mock_response.status_code = 401\n    mock_response.text = \"Unauthorized\"\n    http_error = httpx.HTTPStatusError(\n        \"401 Unauthorized\",\n        request=MagicMock(),\n        response=mock_response,\n    )\n    mock_response.raise_for_status = MagicMock(side_effect=http_error)\n\n    with patch(\"app.services.moderation_service.settings\") as mock_settings, patch(\n        \"app.services.moderation_service.httpx.AsyncClient\"\n    ) as mock_client_cls:\n        mock_settings.MODERATION_ENABLED = True\n        mock_settings.MODERATION_API_KEY = \"bad-key\"\n        mock_settings.MODERATION_API_URL = \"https://api.openai.com/v1/moderations\"\n        mock_settings.MODERATION_TIMEOUT = 10.0\n\n        mock_client = AsyncMock()\n        mock_client.post = AsyncMock(return_value=mock_response)\n        mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)\n        mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)\n\n        with pytest.raises(ModerationError):\n            await check_prompt_safety(\"test prompt\")\n\n\n@pytest.mark.asyncio\nasync def test_check_prompt_safety_empty_results():\n    \"\"\"An API response with no results array is treated as safe.\"\"\"\n    mock_response = MagicMock(spec=httpx.Response)\n    mock_response.json.return_value = {\"results\": []}\n    mock_response.raise_for_status = MagicMock()\n\n    with patch(\"app.services.moderation_service.settings\") as mock_settings, patch(\n        \"app.services.moderation_service.httpx.AsyncClient\"\n    ) as mock_client_cls:\n        mock_settings.MODERATION_ENABLED = True\n        mock_settings.MODERATION_API_KEY = \"test-key\"\n        mock_settings.MODERATION_API_URL = \"https://api.openai.com/v1/moderations\"\n        mock_settings.MODERATION_TIMEOUT = 10.0\n\n        mock_client = AsyncMock()\n        mock_client.post = AsyncMock(return_value=mock_response)\n        mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)\n        mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)\n\n        is_safe, reason = await check_prompt_safety(\"a normal prompt\")\n\n    assert is_safe is True\n    assert reason is None\n"
  },
  {
    "path": "tests/test_schemas.py",
    "content": "\"\"\"Tests for schemas — resolution validation, format handling.\"\"\"\n\nimport pytest\nfrom pydantic import ValidationError\n\nfrom app.schemas.generate import ImageRequest\nfrom app.schemas.batch import BatchRequest\n\n\nclass TestImageRequest:\n    def test_defaults(self):\n        req = ImageRequest(prompt=\"test\")\n        assert req.width == 1024\n        assert req.height == 1024\n        assert req.num_steps == 4\n        assert req.guidance_scale == 0.0\n        assert req.format.value == \"png\"\n        assert req.seed is None\n        assert req.num_images == 1\n\n    def test_num_images_up_to_4(self):\n        req = ImageRequest(prompt=\"test\", num_images=4)\n        assert req.num_images == 4\n\n    def test_num_images_rejects_above_4(self):\n        with pytest.raises(ValidationError):\n            ImageRequest(prompt=\"test\", num_images=5)\n\n    def test_num_images_rejects_zero(self):\n        with pytest.raises(ValidationError):\n            ImageRequest(prompt=\"test\", num_images=0)\n\n    def test_width_must_be_multiple_of_8(self):\n        with pytest.raises(ValidationError, match=\"multiple of 8\"):\n            ImageRequest(prompt=\"test\", width=513)\n\n    def test_height_must_be_multiple_of_8(self):\n        with pytest.raises(ValidationError, match=\"multiple of 8\"):\n            ImageRequest(prompt=\"test\", height=100 + 3)\n\n    def test_valid_custom_resolution(self):\n        req = ImageRequest(prompt=\"test\", width=768, height=1344)\n        assert req.width == 768\n        assert req.height == 1344\n\n    def test_rejects_empty_prompt(self):\n        with pytest.raises(ValidationError):\n            ImageRequest(prompt=\"\")\n\n    def test_rejects_oversized_resolution(self):\n        with pytest.raises(ValidationError):\n            ImageRequest(prompt=\"test\", width=4096)\n\n    def test_seed_bounds(self):\n        req = ImageRequest(prompt=\"test\", seed=0)\n        assert req.seed == 0\n\n        req = ImageRequest(prompt=\"test\", seed=2**32 - 1)\n        assert req.seed == 2**32 - 1\n\n        with pytest.raises(ValidationError):\n            ImageRequest(prompt=\"test\", seed=-1)\n\n    def test_all_formats(self):\n        for fmt in (\"png\", \"jpeg\", \"webp\"):\n            req = ImageRequest(prompt=\"test\", format=fmt)\n            assert req.format.value == fmt\n\n\nclass TestBatchRequest:\n    def test_valid_batch(self):\n        batch = BatchRequest(\n            name=\"test\",\n            prompts=[ImageRequest(prompt=\"a\"), ImageRequest(prompt=\"b\")],\n        )\n        assert len(batch.prompts) == 2\n\n    def test_empty_prompts_rejected(self):\n        with pytest.raises(ValidationError):\n            BatchRequest(name=\"empty\", prompts=[])\n"
  },
  {
    "path": "tests/test_worker.py",
    "content": "\"\"\"Tests for the GPU worker queue and processing logic.\"\"\"\n\nimport asyncio\nfrom unittest.mock import AsyncMock, MagicMock, patch\nfrom datetime import datetime, timezone\n\nimport pytest\nfrom PIL import Image\nfrom sqlalchemy import select\nfrom sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker\n\nfrom app.engine.flux_engine import GenerationResult\nfrom app.models.enums import JobPriority\nfrom app.models.job import Base, Job\nfrom app.worker.gpu_worker import GPUWorker, _QueueItem\n\n\nclass TestQueuePriority:\n    def test_realtime_before_batch(self):\n        \"\"\"REALTIME (priority=1) items should sort before BATCH (priority=5).\"\"\"\n        rt = _QueueItem(priority=1, timestamp=100.0, job_id=\"rt\")\n        batch = _QueueItem(priority=5, timestamp=99.0, job_id=\"batch\")\n        assert rt < batch\n\n    def test_same_priority_fifo(self):\n        \"\"\"Same priority: earlier timestamp wins.\"\"\"\n        a = _QueueItem(priority=1, timestamp=100.0, job_id=\"a\")\n        b = _QueueItem(priority=1, timestamp=200.0, job_id=\"b\")\n        assert a < b\n\n\nclass TestWorkerEnqueue:\n    @pytest.mark.asyncio\n    async def test_enqueue_increases_depth(self):\n        engine = MagicMock()\n        worker = GPUWorker(engine=engine)\n        await worker.enqueue(\"job-1\", priority=1)\n        await worker.enqueue(\"job-2\", priority=5)\n        assert worker.queue_depth == 2\n\n    @pytest.mark.asyncio\n    async def test_realtime_preempts_queued_batch_jobs(self):\n        engine = MagicMock()\n        worker = GPUWorker(engine=engine)\n\n        await worker.enqueue(\"batch-1\", priority=JobPriority.BATCH)\n        await worker.enqueue(\"batch-2\", priority=JobPriority.BATCH)\n        await worker.enqueue(\"rt-1\", priority=JobPriority.REALTIME)\n\n        first = await worker._queue.get()\n        assert first.job_id == \"rt-1\"\n\n\nclass TestWorkerStats:\n    def test_initial_stats(self):\n        engine = MagicMock()\n        worker = GPUWorker(engine=engine)\n        stats = worker.get_stats()\n        assert stats[\"jobs_processed\"] == 0\n        assert stats[\"queue_depth\"] == 0\n        assert stats[\"avg_inference_time_s\"] == 0.0\n"
  }
]