Full Code of sgshubham98/image-service-be for AI

main f9ddc6802cdf cached
48 files
120.4 KB
30.8k tokens
186 symbols
1 requests
Download .txt
Repository: sgshubham98/image-service-be
Branch: main
Commit: f9ddc6802cdf
Files: 48
Total size: 120.4 KB

Directory structure:
gitextract_xtvq1k3n/

├── .gitignore
├── Dockerfile
├── README.md
├── app/
│   ├── __init__.py
│   ├── api/
│   │   ├── __init__.py
│   │   ├── batch.py
│   │   ├── generate.py
│   │   ├── health.py
│   │   └── jobs.py
│   ├── config.py
│   ├── database.py
│   ├── engine/
│   │   ├── __init__.py
│   │   ├── engine_config.py
│   │   ├── flux_engine.py
│   │   └── mock_engine.py
│   ├── main.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── enums.py
│   │   └── job.py
│   ├── schemas/
│   │   ├── __init__.py
│   │   ├── batch.py
│   │   ├── generate.py
│   │   └── job.py
│   ├── services/
│   │   ├── __init__.py
│   │   ├── generation_service.py
│   │   ├── job_service.py
│   │   └── moderation_service.py
│   ├── utils/
│   │   ├── __init__.py
│   │   ├── image_utils.py
│   │   ├── moderation.py
│   │   └── storage.py
│   └── worker/
│       ├── __init__.py
│       └── gpu_worker.py
├── flux_image_service.egg-info/
│   ├── PKG-INFO
│   ├── SOURCES.txt
│   ├── dependency_links.txt
│   ├── requires.txt
│   └── top_level.txt
├── pyproject.toml
└── tests/
    ├── __init__.py
    ├── conftest.py
    ├── test_api_batch.py
    ├── test_api_generate.py
    ├── test_api_jobs.py
    ├── test_moderation.py
    ├── test_moderation_service.py
    ├── test_schemas.py
    └── test_worker.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled
__pycache__/
*.py[cod]

# Virtual env
.venv/
venv/

# Output
output/

# Logs
logs/

# IDE
.idea/
.vscode/

# Environment
.env

# SQLite
*.db

# OS
.DS_Store


================================================
FILE: Dockerfile
================================================
FROM nvidia/cuda:12.4.1-runtime-ubuntu22.04

WORKDIR /app

# System deps
RUN apt-get update && apt-get install -y --no-install-recommends \
    python3.11 python3.11-venv python3-pip git && \
    rm -rf /var/lib/apt/lists/*

# Create venv
RUN python3.11 -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"

# Install Python deps
COPY pyproject.toml .
RUN pip install --no-cache-dir --upgrade pip && \
    pip install --no-cache-dir ".[dev]"

# Install PyTorch with CUDA 12.4 (override index for torch)
RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cu124

# Copy application
COPY app/ app/
COPY tests/ tests/

# Create output directory
RUN mkdir -p /app/output

EXPOSE 8000

CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]


================================================
FILE: README.md
================================================
# Flux Image Service

Local image generation API powered by **Flux.1 Schnell 12B** running on an H100 GPU with FP8 quantization.

## Features

- **Single image generation** — async (returns job IDs) or sync (waits for result), up to 4 images per request
- **Batch/dataset generation** — submit up to 50K prompts via API or CSV/JSON file upload
- **Priority queue** — real-time requests always jump ahead of batch jobs
- **SSE progress streaming** — monitor batch progress in real-time
- **Crash recovery** — pending jobs auto-resume on server restart (persisted in SQLite)
- **Multiple formats** — PNG (with metadata), JPEG, WebP
- **Custom resolutions** — any multiple of 8, up to 2048px

## Prerequisites

- **Python** 3.10+
- **NVIDIA GPU** with CUDA 12.x (tested on H100 20GB)
- **NVIDIA drivers** 535+ with CUDA toolkit
- **System RAM** 32 GB recommended (16 GB minimum)
- **Disk** ~12 GB for model weights + storage for generated images

## Setup

```bash
# 1. Clone / navigate to the project
cd flux-image-service

# 2. Create a virtual environment
python3 -m venv .venv
source .venv/bin/activate

# 3. Install dependencies (includes torch, diffusers, fastapi, etc.)
pip install -e ".[dev]"

# 4. Copy and configure environment variables
cp .env.example .env
# Edit .env if needed (model ID, storage path, etc.)
```

## Running the Server

```bash
# Activate venv (if not already)
source .venv/bin/activate

# Start the server
uvicorn app.main:app --host 0.0.0.0 --port 8000

# Or with auto-reload for development (not recommended for GPU — reloads unload the model)
uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
```

The first startup will:
1. Download the Flux.1 Schnell model from HuggingFace (~12 GB)
2. Compile CUDA kernels via `torch.compile` (~30-60s one-time cost per restart)
3. Run a warm-up inference

Once you see `Flux Image Service ready`, the server is accepting requests.

## Running Tests

```bash
source .venv/bin/activate

# Run all tests
pytest tests/ -v

# Run a specific test file
pytest tests/test_api_generate.py -v

# Run with coverage
pip install pytest-cov
pytest tests/ -v --cov=app --cov-report=term-missing
```

## Docker

```bash
# Build
docker build -t flux-image-service .

# Run (requires nvidia-container-toolkit)
docker run --gpus all -p 8000:8000 -v ./output:/app/output flux-image-service

# Run with custom env
docker run --gpus all -p 8000:8000 \
  -e MODEL_ID=black-forest-labs/FLUX.1-schnell \
  -e STORAGE_DIR=/app/output \
  -v ./output:/app/output \
  flux-image-service
```

## Linting

```bash
# Check
ruff check app/ tests/

# Auto-fix
ruff check app/ tests/ --fix

# Format
ruff format app/ tests/
```

---

## API Reference

### `POST /generate` — Async Image Generation

Submit a generation request (1–4 images). Returns immediately with job IDs.

**Request:**
```bash
curl -X POST http://localhost:8000/generate \
  -H "Content-Type: application/json" \
  -d '{
    "prompt": "a majestic mountain landscape at sunset",
    "width": 1024,
    "height": 1024,
    "num_images": 2,
    "seed": 42
  }'
```

**Response** `200 OK`:
```json
{
  "job_ids": ["f47ac10b-58cc-4372-a567-0e02b2c3d479", "a1b2c3d4-5678-9abc-def0-1234567890ab"],
  "status": "pending",
  "image_urls": []
}
```

**Validation error** `422`:
```json
{
  "detail": [
    {
      "type": "value_error",
      "loc": ["body", "width"],
      "msg": "Value error, width must be a multiple of 8"
    }
  ]
}
```

All request parameters for `ImageRequest`:

| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `prompt` | string | *required* | Text prompt (1–2000 chars) |
| `negative_prompt` | string \| null | null | Negative prompt |
| `width` | int | 1024 | Image width (64–2048, multiple of 8) |
| `height` | int | 1024 | Image height (64–2048, multiple of 8) |
| `num_steps` | int | 4 | Inference steps (1–50) |
| `guidance_scale` | float | 0.0 | CFG scale (Schnell doesn't use CFG) |
| `seed` | int \| null | null | Random seed (0–4294967295) |
| `num_images` | int | 1 | Number of images (1–4) |
| `format` | string | "png" | Output format: `png`, `jpeg`, `webp` |

---

### `GET /generate/stream` — SSE Progress Stream for Single Generate

Stream status updates for one or more job IDs returned by `/generate` until all are terminal.

**Request:**
```bash
curl -N "http://localhost:8000/generate/stream?job_id=f47ac10b-58cc-4372-a567-0e02b2c3d479&job_id=a1b2c3d4-5678-9abc-def0-1234567890ab"
```

**Events:**
- `progress` — emitted periodically with per-job status and aggregate counters
- `done` — emitted once when all jobs are in `completed`, `failed`, or `cancelled`

**Example event payload:**
```json
{
  "total": 2,
  "completed": 1,
  "failed": 0,
  "cancelled": 0,
  "pending": 1,
  "jobs": [
    {
      "id": "f47ac10b-58cc-4372-a567-0e02b2c3d479",
      "status": "completed",
      "image_url": "/jobs/f47ac10b-58cc-4372-a567-0e02b2c3d479/image"
    },
    {
      "id": "a1b2c3d4-5678-9abc-def0-1234567890ab",
      "status": "processing",
      "image_url": null
    }
  ]
}
```

---

### `POST /generate/sync` — Synchronous Generation

Same request body as `/generate`. Blocks until all images are generated (up to 60s timeout).

**Request:**
```bash
curl -X POST http://localhost:8000/generate/sync \
  -H "Content-Type: application/json" \
  -d '{"prompt": "a cute cat wearing a hat", "num_images": 1}' \
  --max-time 60
```

**Response** `200 OK`:
```json
{
  "job_ids": ["f47ac10b-58cc-4372-a567-0e02b2c3d479"],
  "status": "completed",
  "image_urls": ["/jobs/f47ac10b-58cc-4372-a567-0e02b2c3d479/image"]
}
```

**Timeout** `408`:
```json
{"detail": "Generation timed out"}
```

---

### `GET /jobs/{job_id}/image` — Serve Generated Image

Returns the image file for a completed job.

**Request:**
```bash
curl http://localhost:8000/jobs/f47ac10b-58cc-4372-a567-0e02b2c3d479/image --output image.png
```

**Response:** Binary image file with appropriate `Content-Type` (`image/png`, `image/jpeg`, or `image/webp`).

**Not found** `404`:
```json
{"detail": "Image not available"}
```

---

### `GET /jobs/{job_id}` — Get Job Details

**Request:**
```bash
curl http://localhost:8000/jobs/f47ac10b-58cc-4372-a567-0e02b2c3d479
```

**Response** `200 OK`:
```json
{
  "id": "f47ac10b-58cc-4372-a567-0e02b2c3d479",
  "prompt": "a majestic mountain landscape at sunset",
  "negative_prompt": null,
  "width": 1024,
  "height": 1024,
  "num_steps": 4,
  "guidance_scale": 0.0,
  "seed": 42,
  "status": "completed",
  "priority": 1,
  "format": "png",
  "file_path": "output/images/2026-03-08/f47ac10b-58cc-4372-a567-0e02b2c3d479.png",
  "image_url": "/jobs/f47ac10b-58cc-4372-a567-0e02b2c3d479/image",
  "error_message": null,
  "batch_id": null,
  "created_at": "2026-03-08T12:00:00+00:00",
  "started_at": "2026-03-08T12:00:01+00:00",
  "completed_at": "2026-03-08T12:00:03+00:00"
}
```

---

### `GET /jobs` — List Jobs

**Request:**
```bash
# List all jobs (paginated)
curl "http://localhost:8000/jobs?page=1&page_size=20"

# Filter by status
curl "http://localhost:8000/jobs?status=completed"

# Filter by batch
curl "http://localhost:8000/jobs?batch_id=batch-uuid-here"
```

**Response** `200 OK`:
```json
{
  "jobs": [
    {
      "id": "f47ac10b-...",
      "prompt": "a mountain",
      "status": "completed",
      "width": 1024,
      "height": 1024,
      "image_url": "/jobs/f47ac10b-.../image",
      "...": "..."
    }
  ],
  "total": 150,
  "page": 1,
  "page_size": 20
}
```

---

### `DELETE /jobs/{job_id}` — Cancel a Job

Cancels a pending/processing job. Already completed/failed jobs are returned as-is.

**Request:**
```bash
curl -X DELETE http://localhost:8000/jobs/f47ac10b-58cc-4372-a567-0e02b2c3d479
```

**Response** `200 OK`:
```json
{
  "id": "f47ac10b-58cc-4372-a567-0e02b2c3d479",
  "status": "cancelled",
  "...": "..."
}
```

---

### `POST /batch` — Create Batch Job

Submit 1–50,000 prompts as a batch. All jobs get BATCH priority (lower than real-time).

**Request:**
```bash
curl -X POST http://localhost:8000/batch \
  -H "Content-Type: application/json" \
  -d '{
    "name": "training-dataset-v1",
    "prompts": [
      {"prompt": "a red sports car", "width": 512, "height": 512},
      {"prompt": "a blue ocean wave", "width": 1024, "height": 768},
      {"prompt": "a forest path in autumn", "seed": 123}
    ]
  }'
```

**Response** `200 OK`:
```json
{
  "batch_id": "b1c2d3e4-5678-9abc-def0-1234567890ab",
  "name": "training-dataset-v1",
  "total": 3,
  "completed": 0,
  "failed": 0,
  "cancelled": 0,
  "pending": 3,
  "status": "pending",
  "estimated_remaining_seconds": null,
  "created_at": "2026-03-08T12:00:00+00:00",
  "completed_at": null
}
```

---

### `POST /batch/from-file` — Create Batch from File Upload

Upload a CSV or JSON file of prompts.

**CSV upload:**
```bash
curl -X POST http://localhost:8000/batch/from-file \
  -F "file=@prompts.csv" \
  -F "name=my-dataset"
```

CSV format (only `prompt` column is required):
```csv
prompt,width,height,num_steps,seed,format
a red car,512,512,4,42,png
a blue house,1024,1024,,,
a green tree,,,,,
```

**JSON upload:**
```bash
curl -X POST http://localhost:8000/batch/from-file \
  -F "file=@prompts.json" \
  -F "name=my-dataset"
```

JSON format:
```json
[
  {"prompt": "a red car", "width": 512, "height": 512, "seed": 42},
  {"prompt": "a blue house"},
  {"prompt": "a green tree", "format": "jpeg"}
]
```

**Response:** Same as `POST /batch`.

---

### `GET /batch/{batch_id}` — Get Batch Progress

**Request:**
```bash
curl http://localhost:8000/batch/b1c2d3e4-5678-9abc-def0-1234567890ab
```

**Response** `200 OK`:
```json
{
  "batch_id": "b1c2d3e4-5678-9abc-def0-1234567890ab",
  "name": "training-dataset-v1",
  "total": 1000,
  "completed": 342,
  "failed": 2,
  "cancelled": 0,
  "pending": 656,
  "status": "processing",
  "estimated_remaining_seconds": 1968.0,
  "created_at": "2026-03-08T12:00:00+00:00",
  "completed_at": null
}
```

---

### `GET /batch/{batch_id}/stream` — SSE Progress Stream

Real-time Server-Sent Events stream of batch progress. Updates every 2 seconds.

**Request:**
```bash
curl -N http://localhost:8000/batch/b1c2d3e4-5678-9abc-def0-1234567890ab/stream
```

**Response** (event stream):
```
event: progress
data: {"batch_id":"b1c2d3e4-...","total":1000,"completed":342,"failed":2,"pending":656,"status":"processing","estimated_remaining_seconds":1968.0}

event: progress
data: {"batch_id":"b1c2d3e4-...","total":1000,"completed":343,"failed":2,"pending":655,"status":"processing","estimated_remaining_seconds":1965.0}

event: done
data: {"batch_id":"b1c2d3e4-...","total":1000,"completed":998,"failed":2,"pending":0,"status":"failed"}
```

---

### `DELETE /batch/{batch_id}` — Cancel Batch

Cancels all remaining pending jobs in the batch.

**Request:**
```bash
curl -X DELETE http://localhost:8000/batch/b1c2d3e4-5678-9abc-def0-1234567890ab
```

**Response** `200 OK`:
```json
{
  "batch_id": "b1c2d3e4-5678-9abc-def0-1234567890ab",
  "total": 1000,
  "completed": 342,
  "failed": 2,
  "cancelled": 656,
  "pending": 0,
  "status": "cancelled",
  "...": "..."
}
```

---

### `POST /batch/{batch_id}/retry-failed` — Retry Failed Jobs

Re-enqueue only the failed jobs in a batch.

**Request:**
```bash
curl -X POST http://localhost:8000/batch/b1c2d3e4-5678-9abc-def0-1234567890ab/retry-failed
```

**Response** `200 OK`:
```json
{
  "retried": 2,
  "job_ids": ["job-uuid-1", "job-uuid-2"]
}
```

---

### `GET /health` — Health Check

**Request:**
```bash
curl http://localhost:8000/health
```

**Response** `200 OK`:
```json
{
  "status": "ok",
  "model_loaded": true,
  "vram": {
    "available": true,
    "allocated_gb": 12.34,
    "reserved_gb": 14.50,
    "max_allocated_gb": 16.78,
    "total_gb": 20.00,
    "device_name": "NVIDIA H100"
  },
  "worker": {
    "queue_depth": 5,
    "jobs_processed": 142,
    "total_inference_time_s": 398.50,
    "avg_inference_time_s": 2.81
  },
  "uptime_seconds": 3600.0
}
```

---

## Architecture

```
Client → FastAPI → asyncio.PriorityQueue → GPU Worker → Flux Engine
                       ↓                        ↓
                    SQLite                   Local FS
                  (job state)              (images)
```

- **Single GPU worker** — processes one image at a time (VRAM constraint)
- **Priority queue** — REALTIME(1) always before BATCH(5)
- **OOM recovery** — catches GPU OOM, clears cache, continues processing
- **Crash recovery** — on startup, all PENDING/PROCESSING jobs in SQLite are automatically re-enqueued

### Queue Persistence

The in-memory `asyncio.PriorityQueue` is **lost on server restart**. However, **no work is lost** because:

1. All jobs are persisted to SQLite the moment they're created (before being enqueued)
2. On startup, `resume_pending_jobs()` scans for any jobs in `PENDING` or `PROCESSING` state and re-enqueues them
3. Completed jobs and their images are already on disk

This means a restart simply rebuilds the queue from the database. The only effect is a brief pause while the model reloads.

## Configuration

All settings can be overridden via environment variables or `.env` file:

| Variable | Default | Description |
|----------|---------|-------------|
| `MODEL_ID` | `black-forest-labs/FLUX.1-schnell` | HuggingFace model ID |
| `DEVICE` | `cuda` | Torch device |
| `DTYPE` | `float8_e4m3fn` | Model precision |
| `ENABLE_TORCH_COMPILE` | `true` | Enable torch.compile optimization |
| `DEFAULT_NUM_STEPS` | `4` | Default inference steps |
| `DEFAULT_GUIDANCE_SCALE` | `0.0` | Default CFG scale |
| `MAX_RESOLUTION` | `2048` | Maximum allowed width/height |
| `MAX_IMMEDIATE_IMAGES` | `4` | Max images per /generate request |
| `STORAGE_DIR` | `./output` | Where images and DB are stored |
| `MAX_QUEUE_SIZE` | `10000` | Maximum queue capacity |
| `SYNC_TIMEOUT` | `60.0` | Timeout for /generate/sync (seconds) |
| `HOST` | `0.0.0.0` | Server bind address |
| `PORT` | `8000` | Server port |

## Performance (H100 20GB, FP8)

| Resolution | Time per image | 10K dataset |
|-----------|---------------|-------------|
| 512×512 | ~0.8–1.2s | ~2–3 hours |
| 1024×1024 | ~2–4s | ~6–10 hours |
| 1536×640 | ~2–3s | ~5–8 hours |


================================================
FILE: app/__init__.py
================================================


================================================
FILE: app/api/__init__.py
================================================


================================================
FILE: app/api/batch.py
================================================
"""Batch / dataset generation endpoints with SSE progress streaming."""

from __future__ import annotations

import asyncio
import csv
import io
import json
import os
import tempfile
import zipfile
from pathlib import Path
from typing import AsyncGenerator

from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, File
from fastapi.responses import StreamingResponse
from sse_starlette.sse import EventSourceResponse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from app.database import get_session
from app.models.job import Job
from app.models.enums import ImageFormat
from app.schemas.batch import BatchProgress, BatchRequest
from app.schemas.generate import ImageRequest
from app.services.generation_service import create_batch
from app.services.job_service import (
    cancel_batch,
    get_batch_progress,
    retry_failed_in_batch,
)

router = APIRouter(prefix="/batch", tags=["batch"])


def _get_worker(request: Request):
    return request.app.state.worker


@router.post("", response_model=BatchProgress)
async def create_batch_job(
    body: BatchRequest,
    request: Request,
    session: AsyncSession = Depends(get_session),
):
    """Create a batch of image generation jobs from a list of prompts."""
    worker = _get_worker(request)
    batch_id = await create_batch(session, body.name, body.prompts, worker)
    progress = await get_batch_progress(session, batch_id)
    return BatchProgress(**progress)


@router.post("/from-file", response_model=BatchProgress)
async def create_batch_from_file(
    request: Request,
    file: UploadFile = File(...),
    name: str | None = None,
    session: AsyncSession = Depends(get_session),
):
    """Create a batch from a CSV or JSON file upload.

    CSV format: prompt,width,height,num_steps,seed,format
    JSON format: array of ImageRequest objects
    """
    worker = _get_worker(request)
    content = await file.read()

    filename = file.filename or ""
    if filename.endswith(".json"):
        prompts = _parse_json(content)
    elif filename.endswith(".csv"):
        prompts = _parse_csv(content)
    else:
        raise HTTPException(
            status_code=400,
            detail="Unsupported file format. Use .csv or .json",
        )

    if not prompts:
        raise HTTPException(status_code=400, detail="No valid prompts found in file")

    batch_id = await create_batch(session, name or filename, prompts, worker)
    progress = await get_batch_progress(session, batch_id)
    return BatchProgress(**progress)


@router.get("/{batch_id}", response_model=BatchProgress)
async def get_batch_status(
    batch_id: str,
    session: AsyncSession = Depends(get_session),
):
    progress = await get_batch_progress(session, batch_id)
    if not progress:
        raise HTTPException(status_code=404, detail="Batch not found")
    return BatchProgress(**progress)


@router.get("/{batch_id}/stream")
async def stream_batch_progress(
    batch_id: str,
    request: Request,
    session: AsyncSession = Depends(get_session),
):
    """SSE endpoint streaming batch progress updates until completion."""
    progress = await get_batch_progress(session, batch_id)
    if not progress:
        raise HTTPException(status_code=404, detail="Batch not found")

    async def event_generator() -> AsyncGenerator[dict, None]:
        while True:
            if await request.is_disconnected():
                break

            async with _fresh_session() as sess:
                p = await get_batch_progress(sess, batch_id)

            if p is None:
                break

            # Calculate ETA
            eta = _estimate_eta(p, request.app.state.worker)

            yield {
                "event": "progress",
                "data": json.dumps({**p, "estimated_remaining_seconds": eta}),
            }

            if p["status"] in ("completed", "failed", "cancelled"):
                yield {"event": "done", "data": json.dumps(p)}
                break

            await asyncio.sleep(2)

    return EventSourceResponse(event_generator())


@router.get("/{batch_id}/download")
async def download_batch_images(
    batch_id: str,
    session: AsyncSession = Depends(get_session),
):
    """Download all completed images in a batch as a ZIP archive.

    Streams the ZIP incrementally so the full archive is never held
    in memory.  File I/O runs in a thread pool to avoid blocking the
    async event loop.
    """
    progress = await get_batch_progress(session, batch_id)
    if not progress:
        raise HTTPException(status_code=404, detail="Batch not found")

    result = await session.execute(
        select(Job).where(Job.batch_id == batch_id, Job.status == "completed")
    )
    jobs = result.scalars().all()

    if not jobs:
        raise HTTPException(status_code=404, detail="No completed images in this batch")

    # Collect valid file paths up-front (lightweight check).
    file_entries: list[tuple[str, Path]] = []
    for job in jobs:
        if not job.file_path:
            continue
        fp = Path(job.file_path)
        if fp.is_file():
            file_entries.append((fp.name, fp))

    if not file_entries:
        raise HTTPException(status_code=404, detail="No image files found on disk")

    batch_name = progress.get("name") or batch_id[:12]
    safe_name = "".join(
        c if c.isalnum() or c in (" ", "-", "_") else "_" for c in batch_name
    ).strip()
    filename = f"{safe_name}.zip"

    async def _stream_zip() -> AsyncGenerator[bytes, None]:
        """Build a valid ZIP in a temp file (off the event loop), then
        stream it in chunks so memory stays flat."""
        loop = asyncio.get_running_loop()
        tmp_fd, tmp_path = tempfile.mkstemp(suffix=".zip")
        os.close(tmp_fd)

        def _build_zip() -> None:
            with zipfile.ZipFile(tmp_path, "w", zipfile.ZIP_STORED) as zf:
                for arc_name, file_path in file_entries:
                    zf.write(str(file_path), arc_name)

        try:
            # Build the ZIP in a worker thread — no event-loop blocking.
            await loop.run_in_executor(None, _build_zip)

            # Stream the temp file in 256 KB chunks.
            chunk_size = 256 * 1024
            with open(tmp_path, "rb") as f:
                while True:
                    chunk = await loop.run_in_executor(None, f.read, chunk_size)
                    if not chunk:
                        break
                    yield chunk
        finally:
            try:
                os.unlink(tmp_path)
            except OSError:
                pass

    return StreamingResponse(
        _stream_zip(),
        media_type="application/zip",
        headers={"Content-Disposition": f'attachment; filename="{filename}"'},
    )


@router.delete("/{batch_id}", response_model=BatchProgress)
async def cancel_batch_job(
    batch_id: str,
    session: AsyncSession = Depends(get_session),
):
    result = await cancel_batch(session, batch_id)
    if not result:
        raise HTTPException(status_code=404, detail="Batch not found")
    return BatchProgress(**result)


@router.post("/{batch_id}/retry-failed")
async def retry_failed_jobs(
    batch_id: str,
    request: Request,
    session: AsyncSession = Depends(get_session),
):
    """Re-enqueue all failed jobs in a batch."""
    worker = _get_worker(request)
    job_ids = await retry_failed_in_batch(session, batch_id)
    if not job_ids:
        raise HTTPException(status_code=404, detail="No failed jobs found in batch")

    for jid in job_ids:
        await worker.enqueue(jid, priority=5)

    return {"retried": len(job_ids), "job_ids": job_ids}


# ── File parsing helpers ─────────────────────────────────────────


def _parse_json(content: bytes) -> list[ImageRequest]:
    try:
        data = json.loads(content)
    except json.JSONDecodeError as e:
        raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}")

    if not isinstance(data, list):
        raise HTTPException(status_code=400, detail="JSON must be an array of objects")

    prompts = []
    for i, item in enumerate(data):
        try:
            prompts.append(ImageRequest(**item))
        except Exception as e:
            raise HTTPException(status_code=400, detail=f"Invalid entry at index {i}: {e}")
    return prompts


def _parse_csv(content: bytes) -> list[ImageRequest]:
    try:
        text = content.decode("utf-8")
    except UnicodeDecodeError:
        raise HTTPException(status_code=400, detail="CSV must be UTF-8 encoded")

    reader = csv.DictReader(io.StringIO(text))
    prompts = []
    for i, row in enumerate(reader):
        try:
            kwargs: dict = {"prompt": row["prompt"]}
            if "width" in row and row["width"]:
                kwargs["width"] = int(row["width"])
            if "height" in row and row["height"]:
                kwargs["height"] = int(row["height"])
            if "num_steps" in row and row["num_steps"]:
                kwargs["num_steps"] = int(row["num_steps"])
            if "seed" in row and row["seed"]:
                kwargs["seed"] = int(row["seed"])
            if "format" in row and row["format"]:
                kwargs["format"] = ImageFormat(row["format"].strip().lower())
            prompts.append(ImageRequest(**kwargs))
        except KeyError:
            raise HTTPException(
                status_code=400, detail=f"Row {i + 1}: 'prompt' column is required"
            )
        except Exception as e:
            raise HTTPException(status_code=400, detail=f"Row {i + 1}: {e}")
    return prompts


def _estimate_eta(progress: dict, worker) -> float | None:
    """Estimate remaining seconds based on average inference time."""
    stats = worker.get_stats()
    avg_time = stats.get("avg_inference_time_s", 0)
    if avg_time <= 0:
        return None
    remaining = progress.get("pending", 0)
    return round(remaining * avg_time, 1)


# Helper to get fresh session for SSE generator
from app.database import async_session_factory as _fresh_session  # noqa: E402


================================================
FILE: app/api/generate.py
================================================
"""Image generation endpoints — async and sync modes."""

from __future__ import annotations

import asyncio
import json
from pathlib import Path
from typing import AsyncGenerator

from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import FileResponse
from sse_starlette.sse import EventSourceResponse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from app.config import settings
from app.database import get_session
from app.models.job import Job
from app.schemas.generate import ImageRequest, ImageResponse
from app.services.generation_service import create_single_job
from app.services.job_service import get_job
from app.services.moderation_service import ModerationError, check_prompt_safety
from app.utils.storage import image_url_path

router = APIRouter(tags=["generate"])

_TERMINAL_JOB_STATUSES = {"completed", "failed", "cancelled"}


def _get_worker(request: Request):
    return request.app.state.worker


@router.post("/generate", response_model=ImageResponse)
async def generate_async(
    body: ImageRequest,
    request: Request,
    session: AsyncSession = Depends(get_session),
):
    """Submit an image generation request (1–4 images). Returns immediately with job IDs."""
    try:
        is_safe, reason = await check_prompt_safety(body.prompt)
    except ModerationError:
        raise HTTPException(status_code=503, detail="Content moderation service unavailable")
    if not is_safe:
        raise HTTPException(
            status_code=400,
            detail=f"Prompt contains unsafe or explicit content: {reason}",
        )

    worker = _get_worker(request)
    job_ids = await create_single_job(session, body, worker)
    return ImageResponse(job_ids=job_ids, status="pending")


@router.get("/generate/stream")
async def stream_generate_progress(
    request: Request,
    job_id: list[str] = Query(..., description="One or more job IDs to monitor"),
    session: AsyncSession = Depends(get_session),
):
    """SSE stream for one or more generate job IDs until all reach terminal state."""
    unique_job_ids = list(dict.fromkeys(job_id))

    # Fail fast if any requested job ID does not exist.
    checks = [await get_job(session, jid) for jid in unique_job_ids]
    missing = [jid for jid, detail in zip(unique_job_ids, checks, strict=False) if detail is None]
    if missing:
        raise HTTPException(
            status_code=404,
            detail=f"Job not found: {', '.join(missing)}",
        )

    async def event_generator() -> AsyncGenerator[dict, None]:
        while True:
            if await request.is_disconnected():
                break

            # Expire identity map so concurrent worker updates are visible.
            session.expire_all()
            details = [await get_job(session, jid) for jid in unique_job_ids]

            if any(d is None for d in details):
                yield {
                    "event": "error",
                    "data": json.dumps({"detail": "Job no longer exists"}),
                }
                break

            jobs = [d.model_dump() for d in details if d is not None]
            total = len(jobs)
            completed = sum(1 for d in details if d and d.status == "completed")
            failed = sum(1 for d in details if d and d.status == "failed")
            cancelled = sum(1 for d in details if d and d.status == "cancelled")
            terminal = sum(1 for d in details if d and d.status in _TERMINAL_JOB_STATUSES)

            payload = {
                "total": total,
                "completed": completed,
                "failed": failed,
                "cancelled": cancelled,
                "pending": total - terminal,
                "jobs": jobs,
            }
            yield {"event": "progress", "data": json.dumps(payload)}

            if terminal == total:
                yield {"event": "done", "data": json.dumps(payload)}
                break

            await asyncio.sleep(1.0)

    return EventSourceResponse(event_generator())


@router.post("/generate/sync", response_model=ImageResponse)
async def generate_sync(
    body: ImageRequest,
    request: Request,
    session: AsyncSession = Depends(get_session),
):
    """Submit and wait for image(s) to be generated (up to SYNC_TIMEOUT seconds)."""
    try:
        is_safe, reason = await check_prompt_safety(body.prompt)
    except ModerationError:
        raise HTTPException(status_code=503, detail="Content moderation service unavailable")
    if not is_safe:
        raise HTTPException(
            status_code=400,
            detail=f"Prompt contains unsafe or explicit content: {reason}",
        )

    worker = _get_worker(request)
    job_ids = await create_single_job(session, body, worker)

    # Poll until all jobs complete or timeout
    deadline = asyncio.get_event_loop().time() + settings.SYNC_TIMEOUT
    while asyncio.get_event_loop().time() < deadline:
        await asyncio.sleep(0.3)
        details = [await get_job(session, jid) for jid in job_ids]
        if all(d and d.status in ("completed", "failed", "cancelled") for d in details):
            break
    else:
        raise HTTPException(status_code=408, detail="Generation timed out")

    failed = [d for d in details if d and d.status == "failed"]
    if failed and len(failed) == len(details):
        raise HTTPException(
            status_code=500,
            detail=failed[0].error_message or "Generation failed",
        )

    return ImageResponse(
        job_ids=job_ids,
        status="completed",
        image_urls=[d.image_url if d and d.status == "completed" else None for d in details],
    )


@router.get("/jobs/{job_id}/image")
async def serve_image(
    job_id: str,
    session: AsyncSession = Depends(get_session),
):
    """Serve the generated image file for a completed job."""
    result = await session.execute(select(Job).where(Job.id == job_id))
    job = result.scalar_one_or_none()
    if not job:
        raise HTTPException(status_code=404, detail="Job not found")
    if job.status != "completed" or not job.file_path:
        raise HTTPException(status_code=404, detail="Image not available")

    path = Path(job.file_path)
    if not path.is_file():
        raise HTTPException(status_code=404, detail="Image file missing from disk")

    media_types = {
        "png": "image/png",
        "jpeg": "image/jpeg",
        "jpg": "image/jpeg",
        "webp": "image/webp",
    }
    ext = path.suffix.lstrip(".")
    media_type = media_types.get(ext, "application/octet-stream")

    return FileResponse(path, media_type=media_type, filename=path.name)


================================================
FILE: app/api/health.py
================================================
"""Health and status endpoint."""

from __future__ import annotations

import time

from fastapi import APIRouter, Request

router = APIRouter(tags=["health"])

_start_time: float = 0.0


def set_start_time() -> None:
    global _start_time
    _start_time = time.time()


@router.get("/health")
async def health_check(request: Request):
    engine = request.app.state.engine
    worker = request.app.state.worker

    uptime = time.time() - _start_time if _start_time else 0

    return {
        "status": "ok",
        "model_loaded": engine.is_loaded,
        "vram": engine.get_vram_stats(),
        "worker": worker.get_stats(),
        "uptime_seconds": round(uptime, 1),
    }


================================================
FILE: app/api/jobs.py
================================================
"""Job management endpoints."""

from __future__ import annotations

from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession

from app.database import get_session
from app.schemas.job import JobDetail, JobList
from app.services.job_service import cancel_job, get_job, list_jobs

router = APIRouter(prefix="/jobs", tags=["jobs"])


@router.get("", response_model=JobList)
async def list_all_jobs(
    status: str | None = Query(None, description="Filter by status"),
    batch_id: str | None = Query(None, description="Filter by batch ID"),
    page: int = Query(1, ge=1),
    page_size: int = Query(50, ge=1, le=200),
    session: AsyncSession = Depends(get_session),
):
    jobs, total = await list_jobs(session, status=status, batch_id=batch_id, page=page, page_size=page_size)
    return JobList(jobs=jobs, total=total, page=page, page_size=page_size)


@router.get("/{job_id}", response_model=JobDetail)
async def get_job_detail(
    job_id: str,
    session: AsyncSession = Depends(get_session),
):
    job = await get_job(session, job_id)
    if not job:
        raise HTTPException(status_code=404, detail="Job not found")
    return job


@router.delete("/{job_id}", response_model=JobDetail)
async def cancel_job_endpoint(
    job_id: str,
    session: AsyncSession = Depends(get_session),
):
    job = await cancel_job(session, job_id)
    if not job:
        raise HTTPException(status_code=404, detail="Job not found")
    return job


================================================
FILE: app/config.py
================================================
from pydantic_settings import BaseSettings
from pathlib import Path


class Settings(BaseSettings):
    model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}

    # ── Dev / mock mode ──
    MOCK_MODE: bool = False

    # ── Model ──
    MODEL_ID: str = "black-forest-labs/FLUX.1-schnell"
    DEVICE: str = "cuda"
    DTYPE: str = "float8_e4m3fn"
    ENABLE_TORCH_COMPILE: bool = True

    # ── Generation defaults ──
    DEFAULT_NUM_STEPS: int = 4
    DEFAULT_GUIDANCE_SCALE: float = 0.0
    MAX_RESOLUTION: int = 2048
    MIN_RESOLUTION: int = 64

    # ── Storage ──
    STORAGE_DIR: Path = Path("./output")

    # ── Logging ──
    LOG_DIR: Path = Path("./logs")
    LOG_LEVEL: str = "INFO"
    LOG_ROTATION_WHEN: str = "midnight"  # midnight, h, m, s, etc.
    LOG_RETENTION_COUNT: int = 30  # number of rotated files to keep

    # ── Queue ──
    MAX_QUEUE_SIZE: int = 10000

    # ── Server ──
    HOST: str = "0.0.0.0"
    PORT: int = 8000

    # ── Immediate generation ──
    MAX_IMMEDIATE_IMAGES: int = 4

    # ── Sync generation timeout (seconds) ──
    SYNC_TIMEOUT: float = 60.0

    # ── Content moderation ──
    MODERATION_ENABLED: bool = True
    MODERATION_MODEL_ID: str = "google/gemma-3-1b-it"
    HF_API_TOKEN: str | None = None
    MODERATION_API_KEY: str = ""
    MODERATION_API_URL: str = "https://api.openai.com/v1/moderations"
    MODERATION_TIMEOUT: float = 10.0


settings = Settings()


================================================
FILE: app/database.py
================================================
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine

from app.config import settings

DATABASE_URL = f"sqlite+aiosqlite:///{settings.STORAGE_DIR / 'flux_service.db'}"

engine = create_async_engine(DATABASE_URL, echo=False)
async_session_factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)


async def init_db() -> None:
    from app.models.job import Base

    async with engine.begin() as conn:
        await conn.run_sync(Base.metadata.create_all)


async def get_session() -> AsyncSession:  # type: ignore[misc]
    async with async_session_factory() as session:
        yield session  # type: ignore[misc]


================================================
FILE: app/engine/__init__.py
================================================


================================================
FILE: app/engine/engine_config.py
================================================
"""Resolution presets and engine configuration."""

PRESET_RESOLUTIONS: dict[str, tuple[int, int]] = {
    "square_sm": (512, 512),
    "square": (1024, 1024),
    "landscape": (1344, 768),
    "portrait": (768, 1344),
    "wide": (1536, 640),
}

# Approximate peak VRAM (GB) by max dimension for FP8 Flux Schnell.
# Used as a safety guard — not an exact model; conservative estimates.
VRAM_ESTIMATES: dict[int, float] = {
    512: 13.5,
    768: 14.5,
    1024: 16.0,
    1536: 18.5,
    2048: 20.0,
}


def estimated_vram_gb(width: int, height: int) -> float:
    """Return conservative VRAM estimate for a given resolution."""
    max_dim = max(width, height)
    for threshold in sorted(VRAM_ESTIMATES.keys()):
        if max_dim <= threshold:
            return VRAM_ESTIMATES[threshold]
    return VRAM_ESTIMATES[max(VRAM_ESTIMATES.keys())]


================================================
FILE: app/engine/flux_engine.py
================================================
"""Flux.1 Schnell inference engine — model loading, generation, VRAM management."""

from __future__ import annotations

import logging
import time
from dataclasses import dataclass, field

import torch
from PIL import Image

from app.config import settings

logger = logging.getLogger(__name__)

# Dtype mapping
_DTYPE_MAP: dict[str, torch.dtype] = {
    "float16": torch.float16,
    "bfloat16": torch.bfloat16,
    "float8_e4m3fn": torch.float8_e4m3fn,
    "float32": torch.float32,
}


@dataclass
class GenerationResult:
    image: Image.Image
    seed: int
    inference_time: float  # seconds


@dataclass
class FluxEngine:
    """Singleton engine wrapping the Flux.1 Schnell diffusion pipeline."""

    _pipeline: object | None = field(default=None, repr=False)
    _loaded: bool = False
    _device: str = "cuda"
    _dtype: torch.dtype = torch.bfloat16

    def load_model(self) -> None:
        """Load the Flux pipeline onto the GPU."""
        if self._loaded:
            logger.info("Model already loaded, skipping.")
            return

        from diffusers import FluxPipeline

        self._device = settings.DEVICE
        self._dtype = _DTYPE_MAP.get(settings.DTYPE, torch.bfloat16)

        logger.info(
            "Loading Flux.1 Schnell model=%s device=%s dtype=%s",
            settings.MODEL_ID,
            self._device,
            self._dtype,
        )

        load_kwargs: dict = {
            "torch_dtype": self._dtype,
        }

        # For FP8 on H100, use the fp8 variant if available
        if self._dtype == torch.float8_e4m3fn:
            load_kwargs["variant"] = "fp8"

        self._pipeline = FluxPipeline.from_pretrained(
            settings.MODEL_ID,
            **load_kwargs,
        )
        self._pipeline.to(self._device)
        self._pipeline.set_progress_bar_config(disable=True)

        # torch.compile for H100 Tensor Core acceleration
        if settings.ENABLE_TORCH_COMPILE and self._device == "cuda":
            logger.info("Compiling transformer with torch.compile (mode=max-autotune-no-cudagraphs)...")
            self._pipeline.transformer = torch.compile(
                self._pipeline.transformer,
                mode="max-autotune-no-cudagraphs",
            )

        self._loaded = True
        logger.info("Model loaded. VRAM: %.2f GB", self.get_vram_usage_gb())

    def warmup(self) -> None:
        """Run a single throwaway inference to trigger torch.compile tracing."""
        if not self._loaded:
            raise RuntimeError("Model not loaded — call load_model() first")

        logger.info("Running warm-up inference (first run compiles CUDA kernels)...")
        start = time.perf_counter()
        self.generate(prompt="warmup", width=512, height=512, num_steps=1, seed=0)
        elapsed = time.perf_counter() - start
        logger.info("Warm-up completed in %.1fs", elapsed)

    @torch.inference_mode()
    def generate(
        self,
        prompt: str,
        width: int = 1024,
        height: int = 1024,
        num_steps: int = 4,
        guidance_scale: float = 0.0,
        seed: int | None = None,
    ) -> GenerationResult:
        """Generate a single image. Runs synchronously on the GPU."""
        if not self._loaded:
            raise RuntimeError("Model not loaded — call load_model() first")

        if seed is None:
            seed = torch.randint(0, 2**32, (1,)).item()

        generator = torch.Generator(device=self._device).manual_seed(seed)

        start = time.perf_counter()

        output = self._pipeline(
            prompt=prompt,
            width=width,
            height=height,
            num_inference_steps=num_steps,
            guidance_scale=guidance_scale,
            generator=generator,
            output_type="pil",
        )

        elapsed = time.perf_counter() - start
        image: Image.Image = output.images[0]

        logger.info(
            "Generated %dx%d in %.2fs (steps=%d, seed=%d)",
            width,
            height,
            elapsed,
            num_steps,
            seed,
        )

        return GenerationResult(image=image, seed=seed, inference_time=elapsed)

    def get_vram_usage_gb(self) -> float:
        """Return current GPU memory allocated in GB."""
        if not torch.cuda.is_available():
            return 0.0
        return torch.cuda.memory_allocated(self._device) / (1024**3)

    def get_vram_stats(self) -> dict:
        """Return detailed VRAM statistics."""
        if not torch.cuda.is_available():
            return {"available": False}
        return {
            "available": True,
            "allocated_gb": round(torch.cuda.memory_allocated(self._device) / (1024**3), 2),
            "reserved_gb": round(torch.cuda.memory_reserved(self._device) / (1024**3), 2),
            "max_allocated_gb": round(
                torch.cuda.max_memory_allocated(self._device) / (1024**3), 2
            ),
            "total_gb": round(torch.cuda.get_device_properties(0).total_mem / (1024**3), 2),
            "device_name": torch.cuda.get_device_name(0),
        }

    def unload(self) -> None:
        """Free the model and clear CUDA cache."""
        if self._pipeline is not None:
            del self._pipeline
            self._pipeline = None
        self._loaded = False
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()
        logger.info("Model unloaded, CUDA cache cleared.")

    @property
    def is_loaded(self) -> bool:
        return self._loaded


================================================
FILE: app/engine/mock_engine.py
================================================
"""Mock FluxEngine for local development without GPU or model weights."""

from __future__ import annotations

import logging
import random
import time
from dataclasses import dataclass

from PIL import Image, ImageDraw

logger = logging.getLogger(__name__)


@dataclass
class GenerationResult:
    image: Image.Image
    seed: int
    inference_time: float


@dataclass
class MockFluxEngine:
    """Drop-in replacement for FluxEngine that generates gradient placeholder images."""

    _loaded: bool = False

    def load_model(self) -> None:
        logger.info("MockFluxEngine: loaded (no actual model).")
        self._loaded = True

    def warmup(self) -> None:
        logger.info("MockFluxEngine: warmup (no-op).")

    def generate(
        self,
        prompt: str,
        width: int = 1024,
        height: int = 1024,
        num_steps: int = 4,
        guidance_scale: float = 0.0,
        seed: int | None = None,
    ) -> GenerationResult:
        if seed is None:
            seed = random.randint(0, 2**32 - 1)

        start = time.perf_counter()

        rng = random.Random(seed)
        r1, g1, b1 = rng.randint(40, 180), rng.randint(40, 180), rng.randint(40, 180)
        r2, g2, b2 = rng.randint(40, 180), rng.randint(40, 180), rng.randint(40, 180)

        img = Image.new("RGB", (width, height))
        pixels = img.load()
        for y in range(height):
            t = y / max(height - 1, 1)
            r = int(r1 + (r2 - r1) * t)
            g = int(g1 + (g2 - g1) * t)
            b = int(b1 + (b2 - b1) * t)
            for x in range(width):
                pixels[x, y] = (r, g, b)

        draw = ImageDraw.Draw(img)
        lines = [
            f"[MOCK] {prompt[:60]}",
            f"Seed: {seed}  |  {width}x{height}  |  steps={num_steps}",
        ]
        draw.text((20, 20), "\n".join(lines), fill=(255, 255, 255))

        # Small artificial delay to simulate inference
        time.sleep(2.05 + rng.random() * 0.15)
        elapsed = time.perf_counter() - start

        logger.info(
            "MockFluxEngine: generated %dx%d in %.2fs (seed=%d)",
            width, height, elapsed, seed,
        )
        return GenerationResult(image=img, seed=seed, inference_time=elapsed)

    def get_vram_usage_gb(self) -> float:
        return 0.0

    def get_vram_stats(self) -> dict:
        return {"available": False, "mock_mode": True}

    def unload(self) -> None:
        self._loaded = False
        logger.info("MockFluxEngine: unloaded.")

    @property
    def is_loaded(self) -> bool:
        return self._loaded


================================================
FILE: app/main.py
================================================
"""FastAPI application — lifespan manages model loading, worker, and DB."""

from __future__ import annotations

import logging
from contextlib import asynccontextmanager
from logging.handlers import TimedRotatingFileHandler

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

from app.api import batch, generate, health, jobs
from app.config import settings
from app.database import init_db
from app.utils.moderation import get_moderation_engine
from app.utils.storage import ensure_storage_dirs
from app.worker.gpu_worker import GPUWorker


def _setup_logging() -> None:
    """Configure logging to both console and a time-rotating file."""
    settings.LOG_DIR.mkdir(parents=True, exist_ok=True)

    log_format = "%(asctime)s %(levelname)-8s %(name)s — %(message)s"
    level = getattr(logging, settings.LOG_LEVEL.upper(), logging.INFO)

    root = logging.getLogger()
    root.setLevel(level)

    # Console handler
    console = logging.StreamHandler()
    console.setLevel(level)
    console.setFormatter(logging.Formatter(log_format))
    root.addHandler(console)

    # Time-rotating file handler
    file_handler = TimedRotatingFileHandler(
        filename=settings.LOG_DIR / "flux-service.log",
        when=settings.LOG_ROTATION_WHEN,
        backupCount=settings.LOG_RETENTION_COUNT,
        encoding="utf-8",
    )
    file_handler.setLevel(level)
    file_handler.setFormatter(logging.Formatter(log_format))
    file_handler.suffix = "%Y-%m-%d"  # rotated files: flux-service.log.2026-03-08
    root.addHandler(file_handler)


_setup_logging()
logger = logging.getLogger(__name__)


@asynccontextmanager
async def lifespan(app: FastAPI):
    # ── Startup ──────────────────────────────────────────────────
    logger.info("Starting Flux Image Service...")

    # Storage directories
    ensure_storage_dirs()

    # Database
    await init_db()
    logger.info("Database initialized.")

    # Flux engine (mock or real)
    if settings.MOCK_MODE:
        from app.engine.mock_engine import MockFluxEngine
        engine = MockFluxEngine()
    else:
        from app.engine.flux_engine import FluxEngine
        engine = FluxEngine()
    engine.load_model()
    engine.warmup()
    app.state.engine = engine

    # Content moderation engine (skip model loading in mock mode)
    moderation_engine = get_moderation_engine()
    if not settings.MOCK_MODE:
        moderation_engine.load(settings.MODERATION_MODEL_ID)
    else:
        logger.info("Mock mode: skipping moderation model load.")
    app.state.moderation_engine = moderation_engine

    # GPU worker
    worker = GPUWorker(engine=engine)
    worker.start()
    resumed = await worker.resume_pending_jobs()
    if resumed:
        logger.info("Resumed %d pending jobs from previous session.", resumed)
    app.state.worker = worker

    # Start time for health endpoint
    health.set_start_time()

    logger.info("Flux Image Service ready. Queue depth: %d", worker.queue_depth)

    yield

    # ── Shutdown ─────────────────────────────────────────────────
    logger.info("Shutting down Flux Image Service...")
    await worker.shutdown()
    engine.unload()
    logger.info("Shutdown complete.")


app = FastAPI(
    title="Flux Image Service",
    description="Local image generation API powered by Flux.1 Schnell 12B",
    version="0.1.0",
    lifespan=lifespan,
)

# CORS — allow all for local development
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# Mount routers
app.include_router(generate.router)
app.include_router(batch.router)
app.include_router(jobs.router)
app.include_router(health.router)


================================================
FILE: app/models/__init__.py
================================================


================================================
FILE: app/models/enums.py
================================================
import enum


class JobStatus(str, enum.Enum):
    PENDING = "pending"
    PROCESSING = "processing"
    COMPLETED = "completed"
    FAILED = "failed"
    CANCELLED = "cancelled"


class JobPriority(int, enum.Enum):
    REALTIME = 1
    BATCH = 5


class ImageFormat(str, enum.Enum):
    PNG = "png"
    JPEG = "jpeg"
    WEBP = "webp"


================================================
FILE: app/models/job.py
================================================
import uuid
from datetime import datetime, timezone

from sqlalchemy import (
    Column,
    DateTime,
    Enum,
    Float,
    ForeignKey,
    Integer,
    String,
    Text,
)
from sqlalchemy.orm import DeclarativeBase, relationship


class Base(DeclarativeBase):
    pass


def _utcnow() -> datetime:
    return datetime.now(timezone.utc)


class Job(Base):
    __tablename__ = "jobs"

    id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
    prompt = Column(Text, nullable=False)
    negative_prompt = Column(Text, nullable=True)
    width = Column(Integer, nullable=False, default=1024)
    height = Column(Integer, nullable=False, default=1024)
    num_steps = Column(Integer, nullable=False, default=4)
    guidance_scale = Column(Float, nullable=False, default=0.0)
    seed = Column(Integer, nullable=True)
    status = Column(
        Enum("pending", "processing", "completed", "failed", "cancelled", name="job_status"),
        nullable=False,
        default="pending",
    )
    priority = Column(Integer, nullable=False, default=1)
    format = Column(
        Enum("png", "jpeg", "webp", name="image_format"),
        nullable=False,
        default="png",
    )
    file_path = Column(String(512), nullable=True)
    error_message = Column(Text, nullable=True)
    batch_id = Column(String(36), ForeignKey("batch_jobs.id"), nullable=True)
    created_at = Column(DateTime(timezone=True), nullable=False, default=_utcnow)
    started_at = Column(DateTime(timezone=True), nullable=True)
    completed_at = Column(DateTime(timezone=True), nullable=True)

    batch = relationship("BatchJob", back_populates="jobs")


class BatchJob(Base):
    __tablename__ = "batch_jobs"

    id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
    name = Column(String(255), nullable=True)
    total_count = Column(Integer, nullable=False, default=0)
    completed_count = Column(Integer, nullable=False, default=0)
    failed_count = Column(Integer, nullable=False, default=0)
    status = Column(
        Enum("pending", "processing", "completed", "failed", "cancelled", name="batch_status"),
        nullable=False,
        default="pending",
    )
    created_at = Column(DateTime(timezone=True), nullable=False, default=_utcnow)
    completed_at = Column(DateTime(timezone=True), nullable=True)

    jobs = relationship("Job", back_populates="batch", lazy="selectin")


================================================
FILE: app/schemas/__init__.py
================================================


================================================
FILE: app/schemas/batch.py
================================================
from pydantic import BaseModel, Field

from app.schemas.generate import ImageRequest


class BatchRequest(BaseModel):
    name: str | None = None
    prompts: list[ImageRequest] = Field(..., min_length=1, max_length=50000)


class BatchProgress(BaseModel):
    batch_id: str
    name: str | None
    total: int
    completed: int
    failed: int
    cancelled: int
    pending: int
    status: str
    estimated_remaining_seconds: float | None = None
    created_at: str
    completed_at: str | None = None


================================================
FILE: app/schemas/generate.py
================================================
from pydantic import BaseModel, Field, field_validator, model_validator

from app.config import settings
from app.models.enums import ImageFormat


class ImageRequest(BaseModel):
    prompt: str = Field(..., min_length=1, max_length=2000)
    negative_prompt: str | None = None
    width: int = Field(default=1024, ge=settings.MIN_RESOLUTION, le=settings.MAX_RESOLUTION)
    height: int = Field(default=1024, ge=settings.MIN_RESOLUTION, le=settings.MAX_RESOLUTION)
    num_steps: int = Field(default=settings.DEFAULT_NUM_STEPS, ge=1, le=50)
    guidance_scale: float = Field(default=settings.DEFAULT_GUIDANCE_SCALE, ge=0.0, le=20.0)
    seed: int | None = Field(default=None, ge=0, le=2**32 - 1)
    num_images: int = Field(default=1, ge=1, le=settings.MAX_IMMEDIATE_IMAGES)
    format: ImageFormat = ImageFormat.PNG

    @field_validator("prompt")
    @classmethod
    def check_content_safety(cls, v: str) -> str:
        from app.utils.moderation import check_prompt

        check_prompt(v)
        return v

    @model_validator(mode="after")
    def validate_resolution(self) -> "ImageRequest":
        if self.width % 8 != 0:
            raise ValueError("width must be a multiple of 8")
        if self.height % 8 != 0:
            raise ValueError("height must be a multiple of 8")
        return self


class ImageResponse(BaseModel):
    job_ids: list[str]
    status: str
    image_urls: list[str | None] = []


PRESET_RESOLUTIONS: dict[str, tuple[int, int]] = {
    "square_sm": (512, 512),
    "square": (1024, 1024),
    "landscape": (1344, 768),
    "portrait": (768, 1344),
    "wide": (1536, 640),
}


================================================
FILE: app/schemas/job.py
================================================
from pydantic import BaseModel


class JobDetail(BaseModel):
    id: str
    prompt: str
    negative_prompt: str | None
    width: int
    height: int
    num_steps: int
    guidance_scale: float
    seed: int | None
    status: str
    priority: int
    format: str
    file_path: str | None
    image_url: str | None
    error_message: str | None
    batch_id: str | None
    created_at: str
    started_at: str | None
    completed_at: str | None


class JobList(BaseModel):
    jobs: list[JobDetail]
    total: int
    page: int
    page_size: int


================================================
FILE: app/services/__init__.py
================================================


================================================
FILE: app/services/generation_service.py
================================================
"""Orchestrates image generation: validate → persist → enqueue."""

from __future__ import annotations

import uuid
from typing import TYPE_CHECKING

from sqlalchemy.ext.asyncio import AsyncSession

from app.models.enums import JobPriority
from app.models.job import BatchJob, Job
from app.schemas.generate import ImageRequest

if TYPE_CHECKING:
    from app.worker.gpu_worker import GPUWorker


async def create_single_job(
    session: AsyncSession,
    request: ImageRequest,
    worker: "GPUWorker",
    priority: int = JobPriority.REALTIME,
) -> list[str]:
    """Create generation job(s) for the request. Returns list of job_ids.

    Creates one job per requested image (num_images). Each job gets a unique
    seed derived from the base seed (if provided) to ensure varied outputs.
    """
    job_ids: list[str] = []

    queue_priority = int(priority)

    for i in range(request.num_images):
        # Derive per-image seed: if user provided a seed, offset it per image
        seed = None
        if request.seed is not None:
            seed = (request.seed + i) % (2**32)

        job = Job(
            id=str(uuid.uuid4()),
            prompt=request.prompt,
            negative_prompt=request.negative_prompt,
            width=request.width,
            height=request.height,
            num_steps=request.num_steps,
            guidance_scale=request.guidance_scale,
            seed=seed,
            format=request.format.value,
            priority=queue_priority,
            status="pending",
        )
        session.add(job)
        job_ids.append(job.id)

    await session.commit()

    for jid in job_ids:
        await worker.enqueue(jid, queue_priority)

    return job_ids


async def create_batch(
    session: AsyncSession,
    name: str | None,
    prompts: list[ImageRequest],
    worker: "GPUWorker",
) -> str:
    """Create a batch of generation jobs. Returns batch_id."""
    batch_id = str(uuid.uuid4())
    batch = BatchJob(
        id=batch_id,
        name=name,
        total_count=len(prompts),
        status="pending",
    )
    session.add(batch)

    jobs: list[Job] = []
    for req in prompts:
        job = Job(
            id=str(uuid.uuid4()),
            prompt=req.prompt,
            negative_prompt=req.negative_prompt,
            width=req.width,
            height=req.height,
            num_steps=req.num_steps,
            guidance_scale=req.guidance_scale,
            seed=req.seed,
            format=req.format.value,
            priority=int(JobPriority.BATCH),
            status="pending",
            batch_id=batch_id,
        )
        jobs.append(job)

    session.add_all(jobs)
    await session.commit()

    # Enqueue all jobs
    for job in jobs:
        await worker.enqueue(job.id, int(JobPriority.BATCH))

    return batch_id


================================================
FILE: app/services/job_service.py
================================================
"""Job and batch CRUD operations."""

from __future__ import annotations

from datetime import datetime, timezone

from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession

from app.models.enums import JobStatus
from app.models.job import BatchJob, Job
from app.schemas.job import JobDetail
from app.utils.storage import image_url_path


def _job_to_detail(job: Job) -> JobDetail:
    return JobDetail(
        id=job.id,
        prompt=job.prompt,
        negative_prompt=job.negative_prompt,
        width=job.width,
        height=job.height,
        num_steps=job.num_steps,
        guidance_scale=job.guidance_scale,
        seed=job.seed,
        status=job.status,
        priority=job.priority,
        format=job.format,
        file_path=job.file_path,
        image_url=image_url_path(job.id) if job.status == "completed" else None,
        error_message=job.error_message,
        batch_id=job.batch_id,
        created_at=job.created_at.isoformat() if job.created_at else None,
        started_at=job.started_at.isoformat() if job.started_at else None,
        completed_at=job.completed_at.isoformat() if job.completed_at else None,
    )


async def get_job(session: AsyncSession, job_id: str) -> JobDetail | None:
    result = await session.execute(select(Job).where(Job.id == job_id))
    job = result.scalar_one_or_none()
    if not job:
        return None
    return _job_to_detail(job)


async def list_jobs(
    session: AsyncSession,
    status: str | None = None,
    batch_id: str | None = None,
    page: int = 1,
    page_size: int = 50,
) -> tuple[list[JobDetail], int]:
    query = select(Job).order_by(Job.created_at.desc())
    count_query = select(func.count()).select_from(Job)

    if status:
        query = query.where(Job.status == status)
        count_query = count_query.where(Job.status == status)
    if batch_id:
        query = query.where(Job.batch_id == batch_id)
        count_query = count_query.where(Job.batch_id == batch_id)

    total = (await session.execute(count_query)).scalar() or 0
    query = query.offset((page - 1) * page_size).limit(page_size)
    result = await session.execute(query)
    jobs = [_job_to_detail(j) for j in result.scalars().all()]
    return jobs, total


async def cancel_job(session: AsyncSession, job_id: str) -> JobDetail | None:
    result = await session.execute(select(Job).where(Job.id == job_id))
    job = result.scalar_one_or_none()
    if not job:
        return None
    if job.status in ("completed", "failed", "cancelled"):
        return _job_to_detail(job)

    job.status = "cancelled"
    job.completed_at = datetime.now(timezone.utc)
    await session.commit()
    return _job_to_detail(job)


async def get_batch_progress(session: AsyncSession, batch_id: str) -> dict | None:
    result = await session.execute(select(BatchJob).where(BatchJob.id == batch_id))
    batch = result.scalar_one_or_none()
    if not batch:
        return None

    cancelled = (
        await session.execute(
            select(func.count()).where(Job.batch_id == batch_id, Job.status == "cancelled")
        )
    ).scalar() or 0

    pending = batch.total_count - batch.completed_count - batch.failed_count - cancelled

    return {
        "batch_id": batch.id,
        "name": batch.name,
        "total": batch.total_count,
        "completed": batch.completed_count,
        "failed": batch.failed_count,
        "cancelled": cancelled,
        "pending": pending,
        "status": batch.status,
        "created_at": batch.created_at.isoformat() if batch.created_at else None,
        "completed_at": batch.completed_at.isoformat() if batch.completed_at else None,
    }


async def cancel_batch(session: AsyncSession, batch_id: str) -> dict | None:
    result = await session.execute(select(BatchJob).where(BatchJob.id == batch_id))
    batch = result.scalar_one_or_none()
    if not batch:
        return None

    # Cancel all pending jobs in this batch
    pending_result = await session.execute(
        select(Job).where(Job.batch_id == batch_id, Job.status == "pending")
    )
    now = datetime.now(timezone.utc)
    for job in pending_result.scalars().all():
        job.status = "cancelled"
        job.completed_at = now

    batch.status = "cancelled"
    batch.completed_at = now
    await session.commit()

    return await get_batch_progress(session, batch_id)


async def retry_failed_in_batch(session: AsyncSession, batch_id: str) -> list[str]:
    """Re-set failed jobs in a batch back to pending. Returns list of job IDs."""
    result = await session.execute(
        select(Job).where(Job.batch_id == batch_id, Job.status == "failed")
    )
    job_ids = []
    for job in result.scalars().all():
        job.status = "pending"
        job.error_message = None
        job.completed_at = None
        job_ids.append(job.id)

    if job_ids:
        # Reset batch status
        batch_result = await session.execute(select(BatchJob).where(BatchJob.id == batch_id))
        batch = batch_result.scalar_one_or_none()
        if batch:
            batch.status = "processing"
            batch.completed_at = None
            batch.failed_count = 0

        await session.commit()

    return job_ids


================================================
FILE: app/services/moderation_service.py
================================================
"""Content moderation service – checks prompt safety via a third-party API."""

from __future__ import annotations

import logging

import httpx

from app.config import settings

logger = logging.getLogger(__name__)


class ModerationError(Exception):
    """Raised when the moderation API call fails unexpectedly."""


async def check_prompt_safety(prompt: str) -> tuple[bool, str | None]:
    """Check whether a prompt is safe using the configured moderation API.

    Calls the API specified by ``MODERATION_API_URL`` (default: OpenAI
    moderation endpoint) and returns ``(is_safe, reason)``.  When moderation
    is disabled or no API key is configured the function returns ``(True, None)``
    without making any network request.

    Args:
        prompt: The user-supplied generation prompt to evaluate.

    Returns:
        A ``(is_safe, reason)`` tuple.  ``is_safe`` is ``True`` when the prompt
        is considered safe.  When flagged, ``reason`` contains the triggered
        content categories (e.g. ``"sexual, violence"``).

    Raises:
        ModerationError: If the API request fails and the failure cannot be
            handled gracefully.
    """
    if not settings.MODERATION_ENABLED:
        return True, None

    if not settings.MODERATION_API_KEY:
        logger.warning(
            "Content moderation is enabled but MODERATION_API_KEY is not set; "
            "skipping moderation check."
        )
        return True, None

    try:
        async with httpx.AsyncClient(timeout=settings.MODERATION_TIMEOUT) as client:
            response = await client.post(
                settings.MODERATION_API_URL,
                headers={
                    "Authorization": f"Bearer {settings.MODERATION_API_KEY}",
                    "Content-Type": "application/json",
                },
                json={"input": prompt},
            )
            response.raise_for_status()
            data = response.json()
    except httpx.HTTPStatusError as exc:
        logger.error(
            "Moderation API returned an error status %s: %s",
            exc.response.status_code,
            exc.response.text,
        )
        raise ModerationError(
            f"Moderation API returned HTTP {exc.response.status_code}"
        ) from exc
    except httpx.HTTPError as exc:
        logger.error("Moderation API request failed: %s", exc)
        raise ModerationError(f"Moderation API request failed: {exc}") from exc

    # Parse OpenAI-compatible moderation response:
    # {"results": [{"flagged": bool, "categories": {category: bool, ...}}]}
    results = data.get("results", [])
    if not results:
        return True, None

    result = results[0]
    if result.get("flagged", False):
        categories: dict[str, bool] = result.get("categories", {})
        flagged_categories = [cat for cat, triggered in categories.items() if triggered]
        reason = ", ".join(flagged_categories) if flagged_categories else "unsafe content"
        return False, reason

    return True, None


================================================
FILE: app/utils/__init__.py
================================================


================================================
FILE: app/utils/image_utils.py
================================================
"""Image saving, format conversion, and metadata embedding."""

from __future__ import annotations

from pathlib import Path

from PIL import Image, PngImagePlugin

from app.models.enums import ImageFormat


def save_image(
    image: Image.Image,
    path: Path,
    fmt: ImageFormat,
    quality: int = 95,
    prompt: str | None = None,
    seed: int | None = None,
) -> None:
    """Save a PIL Image in the requested format with embedded metadata."""
    if fmt == ImageFormat.PNG:
        meta = PngImagePlugin.PngInfo()
        if prompt:
            meta.add_text("prompt", prompt)
        if seed is not None:
            meta.add_text("seed", str(seed))
        image.save(path, format="PNG", pnginfo=meta)

    elif fmt == ImageFormat.JPEG:
        # JPEG doesn't support alpha
        if image.mode in ("RGBA", "LA"):
            image = image.convert("RGB")
        image.save(path, format="JPEG", quality=quality)

    elif fmt == ImageFormat.WEBP:
        image.save(path, format="WEBP", quality=quality)

    else:
        image.save(path, format="PNG")


================================================
FILE: app/utils/moderation.py
================================================
"""Content moderation using Gemma-3-1b-it (local) with HF Inference API fallback."""

from __future__ import annotations

import logging
from dataclasses import dataclass

logger = logging.getLogger(__name__)

_SAFETY_SYSTEM_PROMPT = (
    "You are a content safety classifier. "
    "Determine whether the following text prompt for image generation contains "
    "harmful, illegal, sexually explicit, or violent content. "
    "Respond with exactly one word: 'SAFE' if the prompt is acceptable, "
    "or 'UNSAFE' if it should be rejected. Do not include any other text."
)


@dataclass
class ModerationResult:
    is_safe: bool
    reason: str = ""
    source: str = ""


class ModerationEngine:
    """Content moderation engine.

    Primary check: locally-loaded ``google/gemma-3-1b-it`` text-generation pipeline.
    Fallback check: HuggingFace Inference API (agent API) using the same model.
    If both are unavailable the engine fails open and logs a warning.
    """

    def __init__(self) -> None:
        self._pipeline = None
        self._model_id: str | None = None

    # ── Loading ──────────────────────────────────────────────────────────────

    def load(self, model_id: str = "google/gemma-3-1b-it") -> None:
        """Load *model_id* locally for synchronous inference."""
        try:
            import torch
            from transformers import pipeline

            logger.info("Loading moderation model %s …", model_id)
            self._pipeline = pipeline(
                "text-generation",
                model=model_id,
                device_map="auto",
                torch_dtype=torch.bfloat16,
                max_new_tokens=10,
            )
            self._model_id = model_id
            logger.info("Moderation model %s loaded successfully.", model_id)
        except Exception as exc:  # noqa: BLE001
            logger.warning("Failed to load moderation model %s: %s", model_id, exc)
            self._pipeline = None

    @property
    def is_loaded(self) -> bool:
        """Return True when the local pipeline is ready."""
        return self._pipeline is not None

    # ── Local-model check ────────────────────────────────────────────────────

    def _parse_pipeline_output(self, raw: object) -> str:
        """Extract the assistant reply text from a text-generation pipeline result."""
        # The pipeline returns a list of dicts like:
        # [{"generated_text": [{"role": "system", ...}, {"role": "assistant", "content": "SAFE"}]}]
        if isinstance(raw, list) and raw:
            inner = raw[0].get("generated_text", raw[0])
            if isinstance(inner, list) and inner:
                last = inner[-1]
                if isinstance(last, dict):
                    return last.get("content", "").strip().upper()
            return str(inner).strip().upper()
        return str(raw).strip().upper()

    def check_with_local_model(self, prompt: str) -> ModerationResult:
        """Run the locally-loaded Gemma pipeline and return a :class:`ModerationResult`."""
        if not self.is_loaded:
            raise RuntimeError("Local moderation model is not loaded")
        try:
            messages = [
                {"role": "system", "content": _SAFETY_SYSTEM_PROMPT},
                {"role": "user", "content": f"Prompt: {prompt}"},
            ]
            raw = self._pipeline(messages)
            response = self._parse_pipeline_output(raw)
            is_safe = "UNSAFE" not in response
            return ModerationResult(
                is_safe=is_safe,
                reason="" if is_safe else f"Flagged by local model: {response}",
                source="local",
            )
        except Exception as exc:
            raise RuntimeError(f"Local model inference failed: {exc}") from exc

    # ── Agent-API fallback ───────────────────────────────────────────────────

    def check_with_agent_api(
        self,
        prompt: str,
        *,
        model_id: str | None = None,
        token: str | None = None,
    ) -> ModerationResult:
        """Use the HuggingFace Inference API (agent API) as a fallback safety check."""
        from huggingface_hub import InferenceClient

        _model = model_id or self._model_id or "google/gemma-3-1b-it"
        client = InferenceClient(model=_model, token=token)
        messages = [
            {"role": "system", "content": _SAFETY_SYSTEM_PROMPT},
            {"role": "user", "content": f"Prompt: {prompt}"},
        ]
        result = client.chat_completion(messages=messages, max_tokens=10)
        response = result.choices[0].message.content.strip().upper()
        is_safe = "UNSAFE" not in response
        return ModerationResult(
            is_safe=is_safe,
            reason="" if is_safe else f"Flagged by agent API: {response}",
            source="agent_api",
        )

    # ── Unified entry-point ──────────────────────────────────────────────────

    def check(self, prompt: str, *, hf_token: str | None = None) -> ModerationResult:
        """Check *prompt* safety.

        Tries the local Gemma model first; falls back to the HF Inference API;
        if both are unavailable the call succeeds with a warning (fail-open).
        """
        # 1. Local model
        if self.is_loaded:
            try:
                return self.check_with_local_model(prompt)
            except Exception as exc:  # noqa: BLE001
                logger.warning(
                    "Local moderation check failed — falling back to agent API: %s", exc
                )

        # 2. Agent API fallback
        if hf_token or self._model_id:
            try:
                return self.check_with_agent_api(
                    prompt,
                    model_id=self._model_id,
                    token=hf_token,
                )
            except Exception as exc:  # noqa: BLE001
                logger.warning(
                    "Agent API moderation check failed — allowing prompt through: %s", exc
                )

        # 3. Both unavailable — fail open
        logger.warning(
            "No moderation backend available; prompt allowed through without safety check."
        )
        return ModerationResult(is_safe=True, reason="Moderation unavailable", source="none")


# ── Module-level singleton ────────────────────────────────────────────────────

_moderation_engine: ModerationEngine = ModerationEngine()


def get_moderation_engine() -> ModerationEngine:
    """Return the global :class:`ModerationEngine` singleton."""
    return _moderation_engine


def check_prompt(prompt: str) -> None:
    """Validate *prompt* against the moderation engine.

    Raises :class:`ValueError` if the prompt is rejected.
    Does nothing when ``MODERATION_ENABLED`` is ``False`` or when no moderation
    backend is configured (fail-open to avoid breaking existing tests).
    """
    from app.config import settings

    if not settings.MODERATION_ENABLED:
        return

    result = _moderation_engine.check(prompt, hf_token=settings.HF_API_TOKEN)
    if not result.is_safe:
        raise ValueError(result.reason or "Prompt rejected by content moderation")


================================================
FILE: app/utils/storage.py
================================================
"""Local filesystem storage management for generated images."""

from __future__ import annotations

from datetime import datetime, timezone
from pathlib import Path

from app.config import settings


def ensure_storage_dirs() -> None:
    """Create the output directory tree if it doesn't exist."""
    (settings.STORAGE_DIR / "images").mkdir(parents=True, exist_ok=True)
    (settings.STORAGE_DIR / "batches").mkdir(parents=True, exist_ok=True)


def get_image_path(job_id: str, fmt: str, batch_id: str | None = None) -> Path:
    """Return the storage path for a generated image.

    Layout:
        output/batches/{batch_id}/{job_id}.{ext}   (if part of a batch)
        output/images/{YYYY-MM-DD}/{job_id}.{ext}   (standalone)
    """
    ext = fmt.lower()
    if ext == "jpeg":
        ext = "jpg"

    if batch_id:
        directory = settings.STORAGE_DIR / "batches" / batch_id
    else:
        date_str = datetime.now(timezone.utc).strftime("%Y-%m-%d")
        directory = settings.STORAGE_DIR / "images" / date_str

    directory.mkdir(parents=True, exist_ok=True)
    return directory / f"{job_id}.{ext}"


def image_url_path(job_id: str) -> str:
    """Return the API URL path to serve this image."""
    return f"/jobs/{job_id}/image"


================================================
FILE: app/worker/__init__.py
================================================


================================================
FILE: app/worker/gpu_worker.py
================================================
"""GPU worker — async priority queue consumer driving the Flux engine."""

from __future__ import annotations

import asyncio
import logging
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import TYPE_CHECKING

from sqlalchemy import select, func

from app.database import async_session_factory
from app.models.enums import ImageFormat, JobStatus
from app.models.job import BatchJob, Job
from app.utils.image_utils import save_image
from app.utils.storage import get_image_path, image_url_path

if TYPE_CHECKING:
    from app.engine.flux_engine import FluxEngine

logger = logging.getLogger(__name__)


@dataclass(order=True)
class _QueueItem:
    """Priority queue item. Lower priority value = higher urgency."""

    priority: int
    timestamp: float = field(compare=True)
    job_id: str = field(compare=False)


@dataclass
class GPUWorker:
    """Consumes jobs from an async priority queue and runs inference."""

    engine: "FluxEngine"
    _queue: asyncio.PriorityQueue[_QueueItem] = field(default_factory=asyncio.PriorityQueue)
    _stop_event: asyncio.Event = field(default_factory=asyncio.Event)
    _task: asyncio.Task | None = field(default=None, repr=False)
    _jobs_processed: int = 0
    _total_inference_time: float = 0.0

    # ── Queue interface ──────────────────────────────────────────────

    @staticmethod
    def _normalize_priority(priority: int) -> int:
        """Convert enum/int-like priorities to a stable numeric queue value."""
        try:
            value = int(priority)
        except (TypeError, ValueError):
            value = 1
        return max(1, value)

    async def enqueue(self, job_id: str, priority: int = 1) -> None:
        normalized = self._normalize_priority(priority)
        item = _QueueItem(priority=normalized, timestamp=time.time(), job_id=job_id)
        await self._queue.put(item)
        logger.debug("Enqueued job=%s priority=%d queue_depth=%d", job_id, normalized, self.queue_depth)

    @property
    def queue_depth(self) -> int:
        return self._queue.qsize()

    # ── Lifecycle ────────────────────────────────────────────────────

    def start(self) -> asyncio.Task:
        """Start the worker loop as a background task."""
        self._stop_event.clear()
        self._task = asyncio.create_task(self._run(), name="gpu-worker")
        logger.info("GPU worker started.")
        return self._task

    async def shutdown(self) -> None:
        """Signal the worker to stop and wait for it to finish the current job."""
        logger.info("Shutting down GPU worker (queue_depth=%d)...", self.queue_depth)
        self._stop_event.set()
        if self._task and not self._task.done():
            # Unblock a pending queue.get() by pushing a sentinel
            await self._queue.put(_QueueItem(priority=999, timestamp=time.time(), job_id="__stop__"))
            await self._task
        logger.info("GPU worker stopped. Processed %d jobs total.", self._jobs_processed)

    # ── Resume support ───────────────────────────────────────────────

    async def resume_pending_jobs(self) -> int:
        """Re-enqueue jobs left in PENDING or PROCESSING state (crash recovery)."""
        async with async_session_factory() as session:
            result = await session.execute(
                select(Job).where(Job.status.in_(["pending", "processing"]))
            )
            stale_jobs = result.scalars().all()

            for job in stale_jobs:
                job.status = "pending"
                await self.enqueue(job.id, self._normalize_priority(job.priority))

            await session.commit()

        if stale_jobs:
            logger.info("Resumed %d pending/stale jobs.", len(stale_jobs))
        return len(stale_jobs)

    # ── Main loop ────────────────────────────────────────────────────

    async def _run(self) -> None:
        logger.info("GPU worker loop running.")
        while not self._stop_event.is_set():
            try:
                item = await asyncio.wait_for(self._queue.get(), timeout=1.0)
            except asyncio.TimeoutError:
                continue

            if item.job_id == "__stop__":
                break

            await self._process_job(item.job_id)
            self._queue.task_done()

        logger.info("GPU worker loop exited.")

    async def _process_job(self, job_id: str) -> None:
        async with async_session_factory() as session:
            result = await session.execute(select(Job).where(Job.id == job_id))
            job = result.scalar_one_or_none()
            if job is None:
                logger.warning("Job %s not found in DB, skipping.", job_id)
                return

            # Skip cancelled jobs
            if job.status == "cancelled":
                logger.info("Job %s is cancelled, skipping.", job_id)
                return

            # Mark as processing
            job.status = "processing"
            job.started_at = datetime.now(timezone.utc)
            await session.commit()

        try:
            # Run inference in a thread to avoid blocking asyncio
            gen_result = await asyncio.to_thread(
                self.engine.generate,
                prompt=job.prompt,
                width=job.width,
                height=job.height,
                num_steps=job.num_steps,
                guidance_scale=job.guidance_scale,
                seed=job.seed,
            )

            # Save image to disk
            fmt = ImageFormat(job.format)
            file_path = get_image_path(job.id, job.format, batch_id=job.batch_id)
            save_image(
                gen_result.image,
                file_path,
                fmt,
                prompt=job.prompt,
                seed=gen_result.seed,
            )

            # Update job as completed
            async with async_session_factory() as session:
                result = await session.execute(select(Job).where(Job.id == job_id))
                job = result.scalar_one()
                job.status = "completed"
                job.file_path = str(file_path)
                job.seed = gen_result.seed
                job.completed_at = datetime.now(timezone.utc)
                await session.commit()

                # Update batch counters if applicable
                if job.batch_id:
                    await self._update_batch_count(session, job.batch_id)

            self._jobs_processed += 1
            self._total_inference_time += gen_result.inference_time

        except Exception as exc:
            # Detect CUDA OOM without importing torch in mock/non-GPU environments.
            if exc.__class__.__name__ == "OutOfMemoryError":
                logger.error("OOM on job %s (%dx%d).", job_id, job.width, job.height)
                try:
                    import torch

                    torch.cuda.empty_cache()
                except Exception:
                    pass
                await self._fail_job(job_id, "GPU out of memory for this resolution")
                return

            logger.exception("Error processing job %s", job_id)
            await self._fail_job(job_id, "Internal generation error")

    async def _fail_job(self, job_id: str, error: str) -> None:
        async with async_session_factory() as session:
            result = await session.execute(select(Job).where(Job.id == job_id))
            job = result.scalar_one_or_none()
            if job:
                job.status = "failed"
                job.error_message = error
                job.completed_at = datetime.now(timezone.utc)
                await session.commit()

                if job.batch_id:
                    await self._update_batch_count(session, job.batch_id)

    async def _update_batch_count(self, session, batch_id: str) -> None:
        """Recompute batch counters from individual jobs."""
        result = await session.execute(select(BatchJob).where(BatchJob.id == batch_id))
        batch = result.scalar_one_or_none()
        if not batch:
            return

        # Count completed and failed
        completed = (
            await session.execute(
                select(func.count()).where(Job.batch_id == batch_id, Job.status == "completed")
            )
        ).scalar()
        failed = (
            await session.execute(
                select(func.count()).where(Job.batch_id == batch_id, Job.status == "failed")
            )
        ).scalar()

        batch.completed_count = completed or 0
        batch.failed_count = failed or 0

        done = (completed or 0) + (failed or 0)
        if done >= batch.total_count:
            batch.status = "completed" if (failed or 0) == 0 else "failed"
            batch.completed_at = datetime.now(timezone.utc)
        elif done > 0:
            batch.status = "processing"

        await session.commit()

    # ── Stats ────────────────────────────────────────────────────────

    def get_stats(self) -> dict:
        avg = (
            self._total_inference_time / self._jobs_processed
            if self._jobs_processed > 0
            else 0.0
        )
        return {
            "queue_depth": self.queue_depth,
            "jobs_processed": self._jobs_processed,
            "total_inference_time_s": round(self._total_inference_time, 2),
            "avg_inference_time_s": round(avg, 2),
        }


================================================
FILE: flux_image_service.egg-info/PKG-INFO
================================================
Metadata-Version: 2.4
Name: flux-image-service
Version: 0.1.0
Summary: Local image generation service using Flux.1 Schnell 12B
Requires-Python: >=3.10
Requires-Dist: fastapi>=0.115.0
Requires-Dist: uvicorn[standard]>=0.30.0
Requires-Dist: pydantic>=2.0
Requires-Dist: pydantic-settings>=2.0
Requires-Dist: sqlalchemy[asyncio]>=2.0
Requires-Dist: aiosqlite>=0.20.0
Requires-Dist: torch>=2.1
Requires-Dist: diffusers>=0.30.0
Requires-Dist: transformers>=4.40.0
Requires-Dist: accelerate>=0.30.0
Requires-Dist: sentencepiece>=0.2.0
Requires-Dist: protobuf>=4.25.0
Requires-Dist: Pillow>=10.0.0
Requires-Dist: python-multipart>=0.0.9
Requires-Dist: sse-starlette>=2.0.0
Requires-Dist: httpx>=0.27.0
Provides-Extra: dev
Requires-Dist: pytest>=8.0; extra == "dev"
Requires-Dist: pytest-asyncio>=0.23.0; extra == "dev"
Requires-Dist: httpx>=0.27.0; extra == "dev"
Requires-Dist: ruff>=0.4.0; extra == "dev"


================================================
FILE: flux_image_service.egg-info/SOURCES.txt
================================================
README.md
pyproject.toml
app/__init__.py
app/config.py
app/database.py
app/main.py
app/api/__init__.py
app/api/batch.py
app/api/generate.py
app/api/health.py
app/api/jobs.py
app/engine/__init__.py
app/engine/engine_config.py
app/engine/flux_engine.py
app/models/__init__.py
app/models/enums.py
app/models/job.py
app/schemas/__init__.py
app/schemas/batch.py
app/schemas/generate.py
app/schemas/job.py
app/services/__init__.py
app/services/generation_service.py
app/services/job_service.py
app/services/moderation_service.py
app/utils/__init__.py
app/utils/image_utils.py
app/utils/moderation.py
app/utils/storage.py
app/worker/__init__.py
app/worker/gpu_worker.py
flux_image_service.egg-info/PKG-INFO
flux_image_service.egg-info/SOURCES.txt
flux_image_service.egg-info/dependency_links.txt
flux_image_service.egg-info/requires.txt
flux_image_service.egg-info/top_level.txt
tests/test_api_batch.py
tests/test_api_generate.py
tests/test_api_jobs.py
tests/test_moderation.py
tests/test_moderation_service.py
tests/test_schemas.py
tests/test_worker.py

================================================
FILE: flux_image_service.egg-info/dependency_links.txt
================================================



================================================
FILE: flux_image_service.egg-info/requires.txt
================================================
fastapi>=0.115.0
uvicorn[standard]>=0.30.0
pydantic>=2.0
pydantic-settings>=2.0
sqlalchemy[asyncio]>=2.0
aiosqlite>=0.20.0
torch>=2.1
diffusers>=0.30.0
transformers>=4.40.0
accelerate>=0.30.0
sentencepiece>=0.2.0
protobuf>=4.25.0
Pillow>=10.0.0
python-multipart>=0.0.9
sse-starlette>=2.0.0
httpx>=0.27.0

[dev]
pytest>=8.0
pytest-asyncio>=0.23.0
httpx>=0.27.0
ruff>=0.4.0


================================================
FILE: flux_image_service.egg-info/top_level.txt
================================================
app


================================================
FILE: pyproject.toml
================================================
[project]
name = "flux-image-service"
version = "0.1.0"
description = "Local image generation service using Flux.1 Schnell 12B"
requires-python = ">=3.10"
dependencies = [
    "fastapi>=0.115.0",
    "uvicorn[standard]>=0.30.0",
    "pydantic>=2.0",
    "pydantic-settings>=2.0",
    "sqlalchemy[asyncio]>=2.0",
    "aiosqlite>=0.20.0",
    "torch>=2.1",
    "diffusers>=0.30.0",
    "transformers>=4.40.0",
    "accelerate>=0.30.0",
    "sentencepiece>=0.2.0",
    "protobuf>=4.25.0",
    "Pillow>=10.0.0",
    "python-multipart>=0.0.9",
    "sse-starlette>=2.0.0",
    "httpx>=0.27.0",
]

[project.optional-dependencies]
dev = [
    "pytest>=8.0",
    "pytest-asyncio>=0.23.0",
    "httpx>=0.27.0",
    "ruff>=0.4.0",
]

[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]

[tool.ruff]
target-version = "py310"
line-length = 100


================================================
FILE: tests/__init__.py
================================================


================================================
FILE: tests/conftest.py
================================================
"""Shared test fixtures."""

from __future__ import annotations

import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from pathlib import Path

import pytest
from httpx import ASGITransport, AsyncClient
from PIL import Image
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine

from app.engine.flux_engine import GenerationResult
from app.models.job import Base


@pytest.fixture(scope="session")
def event_loop():
    loop = asyncio.new_event_loop()
    yield loop
    loop.close()


@pytest.fixture
async def db_engine(tmp_path):
    """In-memory SQLite for tests."""
    engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
    async with engine.begin() as conn:
        await conn.run_sync(Base.metadata.create_all)
    yield engine
    await engine.dispose()


@pytest.fixture
async def db_session(db_engine):
    factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
    async with factory() as session:
        yield session


@pytest.fixture
def mock_engine():
    """A mock FluxEngine that returns a tiny test image."""
    engine = MagicMock()
    engine.is_loaded = True
    engine.get_vram_stats.return_value = {"available": True, "allocated_gb": 12.0}
    engine.get_vram_usage_gb.return_value = 12.0

    dummy_image = Image.new("RGB", (64, 64), color="red")
    engine.generate.return_value = GenerationResult(
        image=dummy_image, seed=42, inference_time=0.5
    )
    return engine


@pytest.fixture
def mock_worker(mock_engine):
    """A mock GPUWorker that tracks enqueued job IDs."""
    worker = MagicMock()
    worker.queue_depth = 0
    worker.enqueue = AsyncMock()
    worker.get_stats.return_value = {
        "queue_depth": 0,
        "jobs_processed": 0,
        "total_inference_time_s": 0.0,
        "avg_inference_time_s": 0.0,
    }
    return worker


@pytest.fixture
async def client(mock_engine, mock_worker, db_engine, tmp_path):
    """Test client with mocked engine/worker and in-memory DB."""
    from app.database import get_session
    from app.models.job import Base

    factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)

    async def override_get_session():
        async with factory() as session:
            yield session

    # Patch the database session factory used by the worker
    with patch("app.worker.gpu_worker.async_session_factory", factory), \
         patch("app.database.async_session_factory", factory):

        from app.main import app

        app.dependency_overrides[get_session] = override_get_session
        app.state.engine = mock_engine
        app.state.worker = mock_worker

        transport = ASGITransport(app=app)
        async with AsyncClient(transport=transport, base_url="http://test") as ac:
            yield ac

        app.dependency_overrides.clear()


================================================
FILE: tests/test_api_batch.py
================================================
"""Tests for batch API endpoints."""

import pytest


@pytest.mark.asyncio
async def test_create_batch(client, mock_worker):
    resp = await client.post(
        "/batch",
        json={
            "name": "test-batch",
            "prompts": [
                {"prompt": "a red car"},
                {"prompt": "a blue house"},
                {"prompt": "a green tree"},
            ],
        },
    )
    assert resp.status_code == 200
    data = resp.json()
    assert data["total"] == 3
    assert data["status"] == "pending"
    assert data["batch_id"]
    # Worker should have been called 3 times
    assert mock_worker.enqueue.call_count == 3


@pytest.mark.asyncio
async def test_get_batch_404(client):
    resp = await client.get("/batch/nonexistent-id")
    assert resp.status_code == 404


@pytest.mark.asyncio
async def test_cancel_batch(client, mock_worker):
    # Create batch first
    resp = await client.post(
        "/batch",
        json={
            "name": "cancel-test",
            "prompts": [{"prompt": "p1"}, {"prompt": "p2"}],
        },
    )
    batch_id = resp.json()["batch_id"]

    # Cancel it
    resp = await client.delete(f"/batch/{batch_id}")
    assert resp.status_code == 200
    data = resp.json()
    assert data["status"] == "cancelled"


@pytest.mark.asyncio
async def test_batch_requires_prompts(client):
    resp = await client.post("/batch", json={"name": "empty", "prompts": []})
    assert resp.status_code == 422


================================================
FILE: tests/test_api_generate.py
================================================
"""Tests for generation API endpoints."""

import pytest
from unittest.mock import AsyncMock, patch
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker

from app.models.job import Job
from app.services.moderation_service import ModerationError


@pytest.mark.asyncio
async def test_generate_async_returns_job_ids(client):
    resp = await client.post("/generate", json={"prompt": "a beautiful sunset"})
    assert resp.status_code == 200
    data = resp.json()
    assert "job_ids" in data
    assert len(data["job_ids"]) == 1
    assert data["status"] == "pending"


@pytest.mark.asyncio
async def test_generate_multiple_images(client, mock_worker):
    resp = await client.post("/generate", json={"prompt": "test", "num_images": 3})
    assert resp.status_code == 200
    data = resp.json()
    assert len(data["job_ids"]) == 3
    assert mock_worker.enqueue.call_count == 3


@pytest.mark.asyncio
async def test_generate_max_4_images(client):
    resp = await client.post("/generate", json={"prompt": "test", "num_images": 4})
    assert resp.status_code == 200
    assert len(resp.json()["job_ids"]) == 4


@pytest.mark.asyncio
async def test_generate_rejects_more_than_4_images(client):
    resp = await client.post("/generate", json={"prompt": "test", "num_images": 5})
    assert resp.status_code == 422


@pytest.mark.asyncio
async def test_generate_async_enqueues_work(client, mock_worker):
    resp = await client.post("/generate", json={"prompt": "test prompt"})
    assert resp.status_code == 200
    mock_worker.enqueue.assert_called_once()


@pytest.mark.asyncio
async def test_generate_validates_resolution_multiple_of_8(client):
    resp = await client.post(
        "/generate",
        json={"prompt": "test", "width": 513, "height": 1024},
    )
    assert resp.status_code == 422  # validation error


@pytest.mark.asyncio
async def test_generate_rejects_empty_prompt(client):
    resp = await client.post("/generate", json={"prompt": ""})
    assert resp.status_code == 422


@pytest.mark.asyncio
async def test_generate_rejects_oversized_resolution(client):
    resp = await client.post(
        "/generate",
        json={"prompt": "test", "width": 4096, "height": 4096},
    )
    assert resp.status_code == 422


@pytest.mark.asyncio
async def test_generate_default_values(client, mock_worker):
    resp = await client.post("/generate", json={"prompt": "a cat"})
    assert resp.status_code == 200


@pytest.mark.asyncio
async def test_serve_image_404_for_missing_job(client):
    resp = await client.get("/jobs/nonexistent-id/image")
    assert resp.status_code == 404


@pytest.mark.asyncio
async def test_generate_stream_emits_done_for_completed_jobs(client, db_engine):
    create = await client.post("/generate", json={"prompt": "test", "num_images": 2})
    assert create.status_code == 200
    job_ids = create.json()["job_ids"]

    # Mark jobs as completed so the SSE stream terminates immediately.
    factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
    async with factory() as session:
        result = await session.execute(select(Job).where(Job.id.in_(job_ids)))
        jobs = result.scalars().all()
        for job in jobs:
            job.status = "completed"
            job.file_path = f"output/images/{job.id}.png"
        await session.commit()

    resp = await client.get(
        "/generate/stream",
        params=[("job_id", job_ids[0]), ("job_id", job_ids[1])],
    )
    assert resp.status_code == 200
    text = resp.text
    assert "event: progress" in text
    assert "event: done" in text
    assert '"completed": 2' in text


@pytest.mark.asyncio
async def test_generate_stream_returns_404_for_unknown_job(client):
    resp = await client.get("/generate/stream", params={"job_id": "unknown-job"})
    assert resp.status_code == 404


# ---------------------------------------------------------------------------
# Content moderation integration tests
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_generate_async_blocked_when_prompt_is_unsafe(client):
    """POST /generate returns 400 when the moderation service flags the prompt."""
    with patch(
        "app.api.generate.check_prompt_safety",
        new=AsyncMock(return_value=(False, "sexual")),
    ):
        resp = await client.post("/generate", json={"prompt": "explicit content"})

    assert resp.status_code == 400
    assert "unsafe" in resp.json()["detail"].lower()


@pytest.mark.asyncio
async def test_generate_async_proceeds_when_prompt_is_safe(client, mock_worker):
    """POST /generate proceeds normally when the moderation service approves the prompt."""
    with patch(
        "app.api.generate.check_prompt_safety",
        new=AsyncMock(return_value=(True, None)),
    ):
        resp = await client.post("/generate", json={"prompt": "a beautiful landscape"})

    assert resp.status_code == 200
    assert "job_ids" in resp.json()
    mock_worker.enqueue.assert_called()


@pytest.mark.asyncio
async def test_generate_async_returns_503_when_moderation_service_fails(client):
    """POST /generate returns 503 when the moderation API is unreachable."""
    with patch(
        "app.api.generate.check_prompt_safety",
        new=AsyncMock(side_effect=ModerationError("API unreachable")),
    ):
        resp = await client.post("/generate", json={"prompt": "a beautiful landscape"})

    assert resp.status_code == 503


@pytest.mark.asyncio
async def test_generate_sync_blocked_when_prompt_is_unsafe(client):
    """POST /generate/sync returns 400 when the moderation service flags the prompt."""
    with patch(
        "app.api.generate.check_prompt_safety",
        new=AsyncMock(return_value=(False, "violence")),
    ):
        resp = await client.post("/generate/sync", json={"prompt": "violent content"})

    assert resp.status_code == 400
    assert "unsafe" in resp.json()["detail"].lower()


@pytest.mark.asyncio
async def test_generate_sync_returns_503_when_moderation_service_fails(client):
    """POST /generate/sync returns 503 when the moderation API is unreachable."""
    with patch(
        "app.api.generate.check_prompt_safety",
        new=AsyncMock(side_effect=ModerationError("timeout")),
    ):
        resp = await client.post("/generate/sync", json={"prompt": "a sunset"})

    assert resp.status_code == 503


================================================
FILE: tests/test_api_jobs.py
================================================
"""Tests for job management endpoints."""

import pytest


@pytest.mark.asyncio
async def test_list_jobs_empty(client):
    resp = await client.get("/jobs")
    assert resp.status_code == 200
    data = resp.json()
    assert data["total"] == 0
    assert data["jobs"] == []


@pytest.mark.asyncio
async def test_get_job_after_create(client, mock_worker):
    # Create a job via generate
    gen_resp = await client.post("/generate", json={"prompt": "test job"})
    job_id = gen_resp.json()["job_ids"][0]

    # Fetch it
    resp = await client.get(f"/jobs/{job_id}")
    assert resp.status_code == 200
    data = resp.json()
    assert data["id"] == job_id
    assert data["prompt"] == "test job"
    assert data["status"] == "pending"
    assert data["width"] == 1024
    assert data["height"] == 1024


@pytest.mark.asyncio
async def test_cancel_job(client, mock_worker):
    gen_resp = await client.post("/generate", json={"prompt": "to cancel"})
    job_id = gen_resp.json()["job_ids"][0]

    resp = await client.delete(f"/jobs/{job_id}")
    assert resp.status_code == 200
    assert resp.json()["status"] == "cancelled"


@pytest.mark.asyncio
async def test_list_jobs_with_status_filter(client, mock_worker):
    await client.post("/generate", json={"prompt": "j1"})
    await client.post("/generate", json={"prompt": "j2"})

    resp = await client.get("/jobs?status=pending")
    assert resp.status_code == 200
    data = resp.json()
    assert data["total"] == 2


@pytest.mark.asyncio
async def test_get_nonexistent_job(client):
    resp = await client.get("/jobs/does-not-exist")
    assert resp.status_code == 404


@pytest.mark.asyncio
async def test_health_endpoint(client):
    resp = await client.get("/health")
    assert resp.status_code == 200
    data = resp.json()
    assert data["status"] == "ok"
    assert data["model_loaded"] is True
    assert "vram" in data
    assert "worker" in data


================================================
FILE: tests/test_moderation.py
================================================
"""Tests for the content moderation module."""

from __future__ import annotations

from unittest.mock import MagicMock, patch

import pytest
from pydantic import ValidationError

from app.utils.moderation import ModerationEngine, ModerationResult, check_prompt


# ── ModerationEngine unit tests ───────────────────────────────────────────────


class TestModerationEngine:
    def test_is_loaded_false_before_load(self):
        engine = ModerationEngine()
        assert not engine.is_loaded

    def test_load_failure_leaves_engine_unloaded(self):
        engine = ModerationEngine()
        with patch("app.utils.moderation.ModerationEngine.load") as mock_load:
            mock_load.side_effect = RuntimeError("no GPU")
            try:
                engine.load("google/gemma-3-1b-it")
            except RuntimeError:
                pass
        assert not engine.is_loaded

    def test_parse_pipeline_output_safe(self):
        engine = ModerationEngine()
        raw = [{"generated_text": [{"role": "assistant", "content": "SAFE"}]}]
        assert engine._parse_pipeline_output(raw) == "SAFE"

    def test_parse_pipeline_output_unsafe(self):
        engine = ModerationEngine()
        raw = [{"generated_text": [{"role": "assistant", "content": "UNSAFE"}]}]
        assert engine._parse_pipeline_output(raw) == "UNSAFE"

    def test_check_with_local_model_raises_when_not_loaded(self):
        engine = ModerationEngine()
        with pytest.raises(RuntimeError, match="not loaded"):
            engine.check_with_local_model("test prompt")

    def test_check_with_local_model_safe(self):
        engine = ModerationEngine()
        engine._pipeline = MagicMock(
            return_value=[{"generated_text": [{"role": "assistant", "content": "SAFE"}]}]
        )
        result = engine.check_with_local_model("a beautiful sunset")
        assert result.is_safe is True
        assert result.source == "local"

    def test_check_with_local_model_unsafe(self):
        engine = ModerationEngine()
        engine._pipeline = MagicMock(
            return_value=[{"generated_text": [{"role": "assistant", "content": "UNSAFE"}]}]
        )
        result = engine.check_with_local_model("explicit harmful content")
        assert result.is_safe is False
        assert result.source == "local"
        assert "Flagged by local model" in result.reason

    def test_check_with_agent_api_safe(self):
        engine = ModerationEngine()
        engine._model_id = "google/gemma-3-1b-it"

        mock_choice = MagicMock()
        mock_choice.message.content = "SAFE"
        mock_response = MagicMock()
        mock_response.choices = [mock_choice]

        mock_client = MagicMock()
        mock_client.chat_completion.return_value = mock_response

        with patch("huggingface_hub.InferenceClient", return_value=mock_client):
            result = engine.check_with_agent_api("a nice landscape", token="hf_test")

        assert result.is_safe is True
        assert result.source == "agent_api"

    def test_check_with_agent_api_unsafe(self):
        engine = ModerationEngine()
        engine._model_id = "google/gemma-3-1b-it"

        mock_choice = MagicMock()
        mock_choice.message.content = "UNSAFE"
        mock_response = MagicMock()
        mock_response.choices = [mock_choice]

        mock_client = MagicMock()
        mock_client.chat_completion.return_value = mock_response

        with patch("huggingface_hub.InferenceClient", return_value=mock_client):
            result = engine.check_with_agent_api("explicit content", token="hf_test")

        assert result.is_safe is False
        assert result.source == "agent_api"
        assert "Flagged by agent API" in result.reason

    def test_check_uses_local_model_first(self):
        engine = ModerationEngine()
        engine._pipeline = MagicMock(
            return_value=[{"generated_text": [{"role": "assistant", "content": "SAFE"}]}]
        )

        with patch.object(engine, "check_with_agent_api") as mock_api:
            result = engine.check("safe prompt")

        assert result.is_safe is True
        assert result.source == "local"
        mock_api.assert_not_called()

    def test_check_falls_back_to_agent_api_when_local_fails(self):
        engine = ModerationEngine()
        engine._pipeline = MagicMock(side_effect=RuntimeError("inference error"))
        engine._model_id = "google/gemma-3-1b-it"

        fallback_result = ModerationResult(is_safe=True, source="agent_api")
        with patch.object(engine, "check_with_agent_api", return_value=fallback_result) as mock_api:
            result = engine.check("test prompt", hf_token="hf_test")

        assert result.source == "agent_api"
        mock_api.assert_called_once()

    def test_check_fails_open_when_both_unavailable(self):
        engine = ModerationEngine()
        # Local not loaded, no model_id so agent API branch skipped
        result = engine.check("test prompt")
        assert result.is_safe is True
        assert result.source == "none"

    def test_check_falls_open_when_agent_api_also_fails(self):
        engine = ModerationEngine()
        engine._model_id = "google/gemma-3-1b-it"

        with patch.object(engine, "check_with_agent_api", side_effect=Exception("API error")):
            result = engine.check("test prompt", hf_token="hf_test")

        assert result.is_safe is True
        assert result.source == "none"


# ── check_prompt() integration tests ─────────────────────────────────────────


class TestCheckPrompt:
    def test_allows_safe_prompt(self):
        with patch("app.config.settings") as mock_settings, \
             patch("app.utils.moderation._moderation_engine") as mock_engine:
            mock_settings.MODERATION_ENABLED = True
            mock_settings.HF_API_TOKEN = None
            mock_engine.check.return_value = ModerationResult(is_safe=True, source="local")
            check_prompt("a beautiful sunset")  # should not raise

    def test_rejects_unsafe_prompt(self):
        with patch("app.config.settings") as mock_settings, \
             patch("app.utils.moderation._moderation_engine") as mock_engine:
            mock_settings.MODERATION_ENABLED = True
            mock_settings.HF_API_TOKEN = None
            mock_engine.check.return_value = ModerationResult(
                is_safe=False, reason="Flagged by local model: UNSAFE", source="local"
            )
            with pytest.raises(ValueError, match="Flagged by local model"):
                check_prompt("explicit content")

    def test_skips_when_moderation_disabled(self):
        with patch("app.config.settings") as mock_settings, patch(
            "app.utils.moderation._moderation_engine"
        ) as mock_engine:
            mock_settings.MODERATION_ENABLED = False
            check_prompt("any content")
            mock_engine.check.assert_not_called()


# ── ImageRequest schema integration ──────────────────────────────────────────


class TestImageRequestModeration:
    def test_unsafe_prompt_raises_validation_error(self):
        with patch("app.config.settings") as mock_settings, \
             patch("app.utils.moderation._moderation_engine") as mock_engine:
            mock_settings.MODERATION_ENABLED = True
            mock_settings.HF_API_TOKEN = None
            mock_engine.check.return_value = ModerationResult(
                is_safe=False, reason="Flagged by local model: UNSAFE", source="local"
            )
            with pytest.raises(ValidationError, match="content moderation|Flagged"):
                from app.schemas.generate import ImageRequest

                ImageRequest(prompt="explicit content")

    def test_safe_prompt_passes_validation(self):
        with patch("app.config.settings") as mock_settings, \
             patch("app.utils.moderation._moderation_engine") as mock_engine:
            mock_settings.MODERATION_ENABLED = True
            mock_settings.HF_API_TOKEN = None
            mock_engine.check.return_value = ModerationResult(is_safe=True, source="local")
            from app.schemas.generate import ImageRequest

            req = ImageRequest(prompt="a mountain landscape")
            assert req.prompt == "a mountain landscape"


================================================
FILE: tests/test_moderation_service.py
================================================
"""Tests for the content moderation service."""

from __future__ import annotations

from unittest.mock import AsyncMock, MagicMock, patch

import httpx
import pytest

from app.services.moderation_service import ModerationError, check_prompt_safety


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _make_moderation_response(flagged: bool, categories: dict[str, bool] | None = None):
    """Build a mock httpx.Response that mimics the OpenAI moderation format."""
    if categories is None:
        categories = {}
    payload = {
        "results": [
            {
                "flagged": flagged,
                "categories": categories,
                "category_scores": {k: 0.9 if v else 0.01 for k, v in categories.items()},
            }
        ]
    }
    response = MagicMock(spec=httpx.Response)
    response.json.return_value = payload
    response.raise_for_status = MagicMock()
    return response


# ---------------------------------------------------------------------------
# Tests: moderation disabled
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_check_prompt_safety_disabled_skips_api():
    """When moderation is disabled the function returns safe without any API call."""
    with patch("app.services.moderation_service.settings") as mock_settings:
        mock_settings.MODERATION_ENABLED = False
        is_safe, reason = await check_prompt_safety("a test prompt")

    assert is_safe is True
    assert reason is None


@pytest.mark.asyncio
async def test_check_prompt_safety_no_api_key_skips_api():
    """When enabled but no API key is set, the function returns safe with a warning."""
    with patch("app.services.moderation_service.settings") as mock_settings:
        mock_settings.MODERATION_ENABLED = True
        mock_settings.MODERATION_API_KEY = ""
        is_safe, reason = await check_prompt_safety("a test prompt")

    assert is_safe is True
    assert reason is None


# ---------------------------------------------------------------------------
# Tests: safe prompt
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_check_prompt_safety_safe_prompt():
    """A prompt that the API says is safe returns (True, None)."""
    mock_response = _make_moderation_response(flagged=False)

    with patch("app.services.moderation_service.settings") as mock_settings, patch(
        "app.services.moderation_service.httpx.AsyncClient"
    ) as mock_client_cls:
        mock_settings.MODERATION_ENABLED = True
        mock_settings.MODERATION_API_KEY = "test-key"
        mock_settings.MODERATION_API_URL = "https://api.openai.com/v1/moderations"
        mock_settings.MODERATION_TIMEOUT = 10.0

        mock_client = AsyncMock()
        mock_client.post = AsyncMock(return_value=mock_response)
        mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
        mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)

        is_safe, reason = await check_prompt_safety("a beautiful sunset over the mountains")

    assert is_safe is True
    assert reason is None


# ---------------------------------------------------------------------------
# Tests: unsafe / flagged prompt
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_check_prompt_safety_flagged_prompt():
    """A flagged prompt returns (False, reason) with the triggered categories."""
    mock_response = _make_moderation_response(
        flagged=True,
        categories={"sexual": True, "violence": False, "hate": False},
    )

    with patch("app.services.moderation_service.settings") as mock_settings, patch(
        "app.services.moderation_service.httpx.AsyncClient"
    ) as mock_client_cls:
        mock_settings.MODERATION_ENABLED = True
        mock_settings.MODERATION_API_KEY = "test-key"
        mock_settings.MODERATION_API_URL = "https://api.openai.com/v1/moderations"
        mock_settings.MODERATION_TIMEOUT = 10.0

        mock_client = AsyncMock()
        mock_client.post = AsyncMock(return_value=mock_response)
        mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
        mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)

        is_safe, reason = await check_prompt_safety("explicit nsfw content here")

    assert is_safe is False
    assert reason is not None
    assert "sexual" in reason


@pytest.mark.asyncio
async def test_check_prompt_safety_flagged_multiple_categories():
    """Multiple triggered categories are all included in the reason string."""
    mock_response = _make_moderation_response(
        flagged=True,
        categories={"sexual": True, "violence": True, "hate": False},
    )

    with patch("app.services.moderation_service.settings") as mock_settings, patch(
        "app.services.moderation_service.httpx.AsyncClient"
    ) as mock_client_cls:
        mock_settings.MODERATION_ENABLED = True
        mock_settings.MODERATION_API_KEY = "test-key"
        mock_settings.MODERATION_API_URL = "https://api.openai.com/v1/moderations"
        mock_settings.MODERATION_TIMEOUT = 10.0

        mock_client = AsyncMock()
        mock_client.post = AsyncMock(return_value=mock_response)
        mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
        mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)

        is_safe, reason = await check_prompt_safety("violent explicit content")

    assert is_safe is False
    assert "sexual" in reason
    assert "violence" in reason


@pytest.mark.asyncio
async def test_check_prompt_safety_flagged_no_categories():
    """If the API flags the prompt but provides no categories, reason is generic."""
    mock_response = _make_moderation_response(flagged=True, categories={})

    with patch("app.services.moderation_service.settings") as mock_settings, patch(
        "app.services.moderation_service.httpx.AsyncClient"
    ) as mock_client_cls:
        mock_settings.MODERATION_ENABLED = True
        mock_settings.MODERATION_API_KEY = "test-key"
        mock_settings.MODERATION_API_URL = "https://api.openai.com/v1/moderations"
        mock_settings.MODERATION_TIMEOUT = 10.0

        mock_client = AsyncMock()
        mock_client.post = AsyncMock(return_value=mock_response)
        mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
        mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)

        is_safe, reason = await check_prompt_safety("some flagged content")

    assert is_safe is False
    assert reason == "unsafe content"


# ---------------------------------------------------------------------------
# Tests: API failures
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_check_prompt_safety_http_error_raises_moderation_error():
    """A network-level HTTP error raises ModerationError."""
    with patch("app.services.moderation_service.settings") as mock_settings, patch(
        "app.services.moderation_service.httpx.AsyncClient"
    ) as mock_client_cls:
        mock_settings.MODERATION_ENABLED = True
        mock_settings.MODERATION_API_KEY = "test-key"
        mock_settings.MODERATION_API_URL = "https://api.openai.com/v1/moderations"
        mock_settings.MODERATION_TIMEOUT = 10.0

        mock_client = AsyncMock()
        mock_client.post = AsyncMock(
            side_effect=httpx.ConnectError("Connection refused")
        )
        mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
        mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)

        with pytest.raises(ModerationError):
            await check_prompt_safety("test prompt")


@pytest.mark.asyncio
async def test_check_prompt_safety_http_status_error_raises_moderation_error():
    """A non-2xx API response raises ModerationError."""
    mock_response = MagicMock(spec=httpx.Response)
    mock_response.status_code = 401
    mock_response.text = "Unauthorized"
    http_error = httpx.HTTPStatusError(
        "401 Unauthorized",
        request=MagicMock(),
        response=mock_response,
    )
    mock_response.raise_for_status = MagicMock(side_effect=http_error)

    with patch("app.services.moderation_service.settings") as mock_settings, patch(
        "app.services.moderation_service.httpx.AsyncClient"
    ) as mock_client_cls:
        mock_settings.MODERATION_ENABLED = True
        mock_settings.MODERATION_API_KEY = "bad-key"
        mock_settings.MODERATION_API_URL = "https://api.openai.com/v1/moderations"
        mock_settings.MODERATION_TIMEOUT = 10.0

        mock_client = AsyncMock()
        mock_client.post = AsyncMock(return_value=mock_response)
        mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
        mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)

        with pytest.raises(ModerationError):
            await check_prompt_safety("test prompt")


@pytest.mark.asyncio
async def test_check_prompt_safety_empty_results():
    """An API response with no results array is treated as safe."""
    mock_response = MagicMock(spec=httpx.Response)
    mock_response.json.return_value = {"results": []}
    mock_response.raise_for_status = MagicMock()

    with patch("app.services.moderation_service.settings") as mock_settings, patch(
        "app.services.moderation_service.httpx.AsyncClient"
    ) as mock_client_cls:
        mock_settings.MODERATION_ENABLED = True
        mock_settings.MODERATION_API_KEY = "test-key"
        mock_settings.MODERATION_API_URL = "https://api.openai.com/v1/moderations"
        mock_settings.MODERATION_TIMEOUT = 10.0

        mock_client = AsyncMock()
        mock_client.post = AsyncMock(return_value=mock_response)
        mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
        mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)

        is_safe, reason = await check_prompt_safety("a normal prompt")

    assert is_safe is True
    assert reason is None


================================================
FILE: tests/test_schemas.py
================================================
"""Tests for schemas — resolution validation, format handling."""

import pytest
from pydantic import ValidationError

from app.schemas.generate import ImageRequest
from app.schemas.batch import BatchRequest


class TestImageRequest:
    def test_defaults(self):
        req = ImageRequest(prompt="test")
        assert req.width == 1024
        assert req.height == 1024
        assert req.num_steps == 4
        assert req.guidance_scale == 0.0
        assert req.format.value == "png"
        assert req.seed is None
        assert req.num_images == 1

    def test_num_images_up_to_4(self):
        req = ImageRequest(prompt="test", num_images=4)
        assert req.num_images == 4

    def test_num_images_rejects_above_4(self):
        with pytest.raises(ValidationError):
            ImageRequest(prompt="test", num_images=5)

    def test_num_images_rejects_zero(self):
        with pytest.raises(ValidationError):
            ImageRequest(prompt="test", num_images=0)

    def test_width_must_be_multiple_of_8(self):
        with pytest.raises(ValidationError, match="multiple of 8"):
            ImageRequest(prompt="test", width=513)

    def test_height_must_be_multiple_of_8(self):
        with pytest.raises(ValidationError, match="multiple of 8"):
            ImageRequest(prompt="test", height=100 + 3)

    def test_valid_custom_resolution(self):
        req = ImageRequest(prompt="test", width=768, height=1344)
        assert req.width == 768
        assert req.height == 1344

    def test_rejects_empty_prompt(self):
        with pytest.raises(ValidationError):
            ImageRequest(prompt="")

    def test_rejects_oversized_resolution(self):
        with pytest.raises(ValidationError):
            ImageRequest(prompt="test", width=4096)

    def test_seed_bounds(self):
        req = ImageRequest(prompt="test", seed=0)
        assert req.seed == 0

        req = ImageRequest(prompt="test", seed=2**32 - 1)
        assert req.seed == 2**32 - 1

        with pytest.raises(ValidationError):
            ImageRequest(prompt="test", seed=-1)

    def test_all_formats(self):
        for fmt in ("png", "jpeg", "webp"):
            req = ImageRequest(prompt="test", format=fmt)
            assert req.format.value == fmt


class TestBatchRequest:
    def test_valid_batch(self):
        batch = BatchRequest(
            name="test",
            prompts=[ImageRequest(prompt="a"), ImageRequest(prompt="b")],
        )
        assert len(batch.prompts) == 2

    def test_empty_prompts_rejected(self):
        with pytest.raises(ValidationError):
            BatchRequest(name="empty", prompts=[])


================================================
FILE: tests/test_worker.py
================================================
"""Tests for the GPU worker queue and processing logic."""

import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from datetime import datetime, timezone

import pytest
from PIL import Image
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker

from app.engine.flux_engine import GenerationResult
from app.models.enums import JobPriority
from app.models.job import Base, Job
from app.worker.gpu_worker import GPUWorker, _QueueItem


class TestQueuePriority:
    def test_realtime_before_batch(self):
        """REALTIME (priority=1) items should sort before BATCH (priority=5)."""
        rt = _QueueItem(priority=1, timestamp=100.0, job_id="rt")
        batch = _QueueItem(priority=5, timestamp=99.0, job_id="batch")
        assert rt < batch

    def test_same_priority_fifo(self):
        """Same priority: earlier timestamp wins."""
        a = _QueueItem(priority=1, timestamp=100.0, job_id="a")
        b = _QueueItem(priority=1, timestamp=200.0, job_id="b")
        assert a < b


class TestWorkerEnqueue:
    @pytest.mark.asyncio
    async def test_enqueue_increases_depth(self):
        engine = MagicMock()
        worker = GPUWorker(engine=engine)
        await worker.enqueue("job-1", priority=1)
        await worker.enqueue("job-2", priority=5)
        assert worker.queue_depth == 2

    @pytest.mark.asyncio
    async def test_realtime_preempts_queued_batch_jobs(self):
        engine = MagicMock()
        worker = GPUWorker(engine=engine)

        await worker.enqueue("batch-1", priority=JobPriority.BATCH)
        await worker.enqueue("batch-2", priority=JobPriority.BATCH)
        await worker.enqueue("rt-1", priority=JobPriority.REALTIME)

        first = await worker._queue.get()
        assert first.job_id == "rt-1"


class TestWorkerStats:
    def test_initial_stats(self):
        engine = MagicMock()
        worker = GPUWorker(engine=engine)
        stats = worker.get_stats()
        assert stats["jobs_processed"] == 0
        assert stats["queue_depth"] == 0
        assert stats["avg_inference_time_s"] == 0.0
Download .txt
gitextract_xtvq1k3n/

├── .gitignore
├── Dockerfile
├── README.md
├── app/
│   ├── __init__.py
│   ├── api/
│   │   ├── __init__.py
│   │   ├── batch.py
│   │   ├── generate.py
│   │   ├── health.py
│   │   └── jobs.py
│   ├── config.py
│   ├── database.py
│   ├── engine/
│   │   ├── __init__.py
│   │   ├── engine_config.py
│   │   ├── flux_engine.py
│   │   └── mock_engine.py
│   ├── main.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── enums.py
│   │   └── job.py
│   ├── schemas/
│   │   ├── __init__.py
│   │   ├── batch.py
│   │   ├── generate.py
│   │   └── job.py
│   ├── services/
│   │   ├── __init__.py
│   │   ├── generation_service.py
│   │   ├── job_service.py
│   │   └── moderation_service.py
│   ├── utils/
│   │   ├── __init__.py
│   │   ├── image_utils.py
│   │   ├── moderation.py
│   │   └── storage.py
│   └── worker/
│       ├── __init__.py
│       └── gpu_worker.py
├── flux_image_service.egg-info/
│   ├── PKG-INFO
│   ├── SOURCES.txt
│   ├── dependency_links.txt
│   ├── requires.txt
│   └── top_level.txt
├── pyproject.toml
└── tests/
    ├── __init__.py
    ├── conftest.py
    ├── test_api_batch.py
    ├── test_api_generate.py
    ├── test_api_jobs.py
    ├── test_moderation.py
    ├── test_moderation_service.py
    ├── test_schemas.py
    └── test_worker.py
Download .txt
SYMBOL INDEX (186 symbols across 30 files)

FILE: app/api/batch.py
  function _get_worker (line 36) | def _get_worker(request: Request):
  function create_batch_job (line 41) | async def create_batch_job(
  function create_batch_from_file (line 54) | async def create_batch_from_file(
  function get_batch_status (line 88) | async def get_batch_status(
  function stream_batch_progress (line 99) | async def stream_batch_progress(
  function download_batch_images (line 138) | async def download_batch_images(
  function cancel_batch_job (line 216) | async def cancel_batch_job(
  function retry_failed_jobs (line 227) | async def retry_failed_jobs(
  function _parse_json (line 247) | def _parse_json(content: bytes) -> list[ImageRequest]:
  function _parse_csv (line 265) | def _parse_csv(content: bytes) -> list[ImageRequest]:
  function _estimate_eta (line 296) | def _estimate_eta(progress: dict, worker) -> float | None:

FILE: app/api/generate.py
  function _get_worker (line 30) | def _get_worker(request: Request):
  function generate_async (line 35) | async def generate_async(
  function stream_generate_progress (line 57) | async def stream_generate_progress(
  function generate_sync (line 117) | async def generate_sync(
  function serve_image (line 161) | async def serve_image(

FILE: app/api/health.py
  function set_start_time (line 14) | def set_start_time() -> None:
  function health_check (line 20) | async def health_check(request: Request):

FILE: app/api/jobs.py
  function list_all_jobs (line 16) | async def list_all_jobs(
  function get_job_detail (line 28) | async def get_job_detail(
  function cancel_job_endpoint (line 39) | async def cancel_job_endpoint(

FILE: app/config.py
  class Settings (line 5) | class Settings(BaseSettings):

FILE: app/database.py
  function init_db (line 11) | async def init_db() -> None:
  function get_session (line 18) | async def get_session() -> AsyncSession:  # type: ignore[misc]

FILE: app/engine/engine_config.py
  function estimated_vram_gb (line 22) | def estimated_vram_gb(width: int, height: int) -> float:

FILE: app/engine/flux_engine.py
  class GenerationResult (line 26) | class GenerationResult:
  class FluxEngine (line 33) | class FluxEngine:
    method load_model (line 41) | def load_model(self) -> None:
    method warmup (line 85) | def warmup(self) -> None:
    method generate (line 97) | def generate(
    method get_vram_usage_gb (line 141) | def get_vram_usage_gb(self) -> float:
    method get_vram_stats (line 147) | def get_vram_stats(self) -> dict:
    method unload (line 162) | def unload(self) -> None:
    method is_loaded (line 174) | def is_loaded(self) -> bool:

FILE: app/engine/mock_engine.py
  class GenerationResult (line 16) | class GenerationResult:
  class MockFluxEngine (line 23) | class MockFluxEngine:
    method load_model (line 28) | def load_model(self) -> None:
    method warmup (line 32) | def warmup(self) -> None:
    method generate (line 35) | def generate(
    method get_vram_usage_gb (line 80) | def get_vram_usage_gb(self) -> float:
    method get_vram_stats (line 83) | def get_vram_stats(self) -> dict:
    method unload (line 86) | def unload(self) -> None:
    method is_loaded (line 91) | def is_loaded(self) -> bool:

FILE: app/main.py
  function _setup_logging (line 20) | def _setup_logging() -> None:
  function lifespan (line 54) | async def lifespan(app: FastAPI):

FILE: app/models/enums.py
  class JobStatus (line 4) | class JobStatus(str, enum.Enum):
  class JobPriority (line 12) | class JobPriority(int, enum.Enum):
  class ImageFormat (line 17) | class ImageFormat(str, enum.Enum):

FILE: app/models/job.py
  class Base (line 17) | class Base(DeclarativeBase):
  function _utcnow (line 21) | def _utcnow() -> datetime:
  class Job (line 25) | class Job(Base):
  class BatchJob (line 57) | class BatchJob(Base):

FILE: app/schemas/batch.py
  class BatchRequest (line 6) | class BatchRequest(BaseModel):
  class BatchProgress (line 11) | class BatchProgress(BaseModel):

FILE: app/schemas/generate.py
  class ImageRequest (line 7) | class ImageRequest(BaseModel):
    method check_content_safety (line 20) | def check_content_safety(cls, v: str) -> str:
    method validate_resolution (line 27) | def validate_resolution(self) -> "ImageRequest":
  class ImageResponse (line 35) | class ImageResponse(BaseModel):

FILE: app/schemas/job.py
  class JobDetail (line 4) | class JobDetail(BaseModel):
  class JobList (line 25) | class JobList(BaseModel):

FILE: app/services/generation_service.py
  function create_single_job (line 18) | async def create_single_job(
  function create_batch (line 63) | async def create_batch(

FILE: app/services/job_service.py
  function _job_to_detail (line 16) | def _job_to_detail(job: Job) -> JobDetail:
  function get_job (line 39) | async def get_job(session: AsyncSession, job_id: str) -> JobDetail | None:
  function list_jobs (line 47) | async def list_jobs(
  function cancel_job (line 71) | async def cancel_job(session: AsyncSession, job_id: str) -> JobDetail | ...
  function get_batch_progress (line 85) | async def get_batch_progress(session: AsyncSession, batch_id: str) -> di...
  function cancel_batch (line 113) | async def cancel_batch(session: AsyncSession, batch_id: str) -> dict | N...
  function retry_failed_in_batch (line 135) | async def retry_failed_in_batch(session: AsyncSession, batch_id: str) ->...

FILE: app/services/moderation_service.py
  class ModerationError (line 14) | class ModerationError(Exception):
  function check_prompt_safety (line 18) | async def check_prompt_safety(prompt: str) -> tuple[bool, str | None]:

FILE: app/utils/image_utils.py
  function save_image (line 12) | def save_image(

FILE: app/utils/moderation.py
  class ModerationResult (line 20) | class ModerationResult:
  class ModerationEngine (line 26) | class ModerationEngine:
    method __init__ (line 34) | def __init__(self) -> None:
    method load (line 40) | def load(self, model_id: str = "google/gemma-3-1b-it") -> None:
    method is_loaded (line 61) | def is_loaded(self) -> bool:
    method _parse_pipeline_output (line 67) | def _parse_pipeline_output(self, raw: object) -> str:
    method check_with_local_model (line 80) | def check_with_local_model(self, prompt: str) -> ModerationResult:
    method check_with_agent_api (line 102) | def check_with_agent_api(
    method check (line 129) | def check(self, prompt: str, *, hf_token: str | None = None) -> Modera...
  function get_moderation_engine (line 169) | def get_moderation_engine() -> ModerationEngine:
  function check_prompt (line 174) | def check_prompt(prompt: str) -> None:

FILE: app/utils/storage.py
  function ensure_storage_dirs (line 11) | def ensure_storage_dirs() -> None:
  function get_image_path (line 17) | def get_image_path(job_id: str, fmt: str, batch_id: str | None = None) -...
  function image_url_path (line 38) | def image_url_path(job_id: str) -> str:

FILE: app/worker/gpu_worker.py
  class _QueueItem (line 27) | class _QueueItem:
  class GPUWorker (line 36) | class GPUWorker:
    method _normalize_priority (line 49) | def _normalize_priority(priority: int) -> int:
    method enqueue (line 57) | async def enqueue(self, job_id: str, priority: int = 1) -> None:
    method queue_depth (line 64) | def queue_depth(self) -> int:
    method start (line 69) | def start(self) -> asyncio.Task:
    method shutdown (line 76) | async def shutdown(self) -> None:
    method resume_pending_jobs (line 88) | async def resume_pending_jobs(self) -> int:
    method _run (line 108) | async def _run(self) -> None:
    method _process_job (line 124) | async def _process_job(self, job_id: str) -> None:
    method _fail_job (line 198) | async def _fail_job(self, job_id: str, error: str) -> None:
    method _update_batch_count (line 211) | async def _update_batch_count(self, session, batch_id: str) -> None:
    method get_stats (line 244) | def get_stats(self) -> dict:

FILE: tests/conftest.py
  function event_loop (line 19) | def event_loop():
  function db_engine (line 26) | async def db_engine(tmp_path):
  function db_session (line 36) | async def db_session(db_engine):
  function mock_engine (line 43) | def mock_engine():
  function mock_worker (line 58) | def mock_worker(mock_engine):
  function client (line 73) | async def client(mock_engine, mock_worker, db_engine, tmp_path):

FILE: tests/test_api_batch.py
  function test_create_batch (line 7) | async def test_create_batch(client, mock_worker):
  function test_get_batch_404 (line 29) | async def test_get_batch_404(client):
  function test_cancel_batch (line 35) | async def test_cancel_batch(client, mock_worker):
  function test_batch_requires_prompts (line 54) | async def test_batch_requires_prompts(client):

FILE: tests/test_api_generate.py
  function test_generate_async_returns_job_ids (line 13) | async def test_generate_async_returns_job_ids(client):
  function test_generate_multiple_images (line 23) | async def test_generate_multiple_images(client, mock_worker):
  function test_generate_max_4_images (line 32) | async def test_generate_max_4_images(client):
  function test_generate_rejects_more_than_4_images (line 39) | async def test_generate_rejects_more_than_4_images(client):
  function test_generate_async_enqueues_work (line 45) | async def test_generate_async_enqueues_work(client, mock_worker):
  function test_generate_validates_resolution_multiple_of_8 (line 52) | async def test_generate_validates_resolution_multiple_of_8(client):
  function test_generate_rejects_empty_prompt (line 61) | async def test_generate_rejects_empty_prompt(client):
  function test_generate_rejects_oversized_resolution (line 67) | async def test_generate_rejects_oversized_resolution(client):
  function test_generate_default_values (line 76) | async def test_generate_default_values(client, mock_worker):
  function test_serve_image_404_for_missing_job (line 82) | async def test_serve_image_404_for_missing_job(client):
  function test_generate_stream_emits_done_for_completed_jobs (line 88) | async def test_generate_stream_emits_done_for_completed_jobs(client, db_...
  function test_generate_stream_returns_404_for_unknown_job (line 115) | async def test_generate_stream_returns_404_for_unknown_job(client):
  function test_generate_async_blocked_when_prompt_is_unsafe (line 126) | async def test_generate_async_blocked_when_prompt_is_unsafe(client):
  function test_generate_async_proceeds_when_prompt_is_safe (line 139) | async def test_generate_async_proceeds_when_prompt_is_safe(client, mock_...
  function test_generate_async_returns_503_when_moderation_service_fails (line 153) | async def test_generate_async_returns_503_when_moderation_service_fails(...
  function test_generate_sync_blocked_when_prompt_is_unsafe (line 165) | async def test_generate_sync_blocked_when_prompt_is_unsafe(client):
  function test_generate_sync_returns_503_when_moderation_service_fails (line 178) | async def test_generate_sync_returns_503_when_moderation_service_fails(c...

FILE: tests/test_api_jobs.py
  function test_list_jobs_empty (line 7) | async def test_list_jobs_empty(client):
  function test_get_job_after_create (line 16) | async def test_get_job_after_create(client, mock_worker):
  function test_cancel_job (line 33) | async def test_cancel_job(client, mock_worker):
  function test_list_jobs_with_status_filter (line 43) | async def test_list_jobs_with_status_filter(client, mock_worker):
  function test_get_nonexistent_job (line 54) | async def test_get_nonexistent_job(client):
  function test_health_endpoint (line 60) | async def test_health_endpoint(client):

FILE: tests/test_moderation.py
  class TestModerationEngine (line 16) | class TestModerationEngine:
    method test_is_loaded_false_before_load (line 17) | def test_is_loaded_false_before_load(self):
    method test_load_failure_leaves_engine_unloaded (line 21) | def test_load_failure_leaves_engine_unloaded(self):
    method test_parse_pipeline_output_safe (line 31) | def test_parse_pipeline_output_safe(self):
    method test_parse_pipeline_output_unsafe (line 36) | def test_parse_pipeline_output_unsafe(self):
    method test_check_with_local_model_raises_when_not_loaded (line 41) | def test_check_with_local_model_raises_when_not_loaded(self):
    method test_check_with_local_model_safe (line 46) | def test_check_with_local_model_safe(self):
    method test_check_with_local_model_unsafe (line 55) | def test_check_with_local_model_unsafe(self):
    method test_check_with_agent_api_safe (line 65) | def test_check_with_agent_api_safe(self):
    method test_check_with_agent_api_unsafe (line 83) | def test_check_with_agent_api_unsafe(self):
    method test_check_uses_local_model_first (line 102) | def test_check_uses_local_model_first(self):
    method test_check_falls_back_to_agent_api_when_local_fails (line 115) | def test_check_falls_back_to_agent_api_when_local_fails(self):
    method test_check_fails_open_when_both_unavailable (line 127) | def test_check_fails_open_when_both_unavailable(self):
    method test_check_falls_open_when_agent_api_also_fails (line 134) | def test_check_falls_open_when_agent_api_also_fails(self):
  class TestCheckPrompt (line 148) | class TestCheckPrompt:
    method test_allows_safe_prompt (line 149) | def test_allows_safe_prompt(self):
    method test_rejects_unsafe_prompt (line 157) | def test_rejects_unsafe_prompt(self):
    method test_skips_when_moderation_disabled (line 168) | def test_skips_when_moderation_disabled(self):
  class TestImageRequestModeration (line 180) | class TestImageRequestModeration:
    method test_unsafe_prompt_raises_validation_error (line 181) | def test_unsafe_prompt_raises_validation_error(self):
    method test_safe_prompt_passes_validation (line 194) | def test_safe_prompt_passes_validation(self):

FILE: tests/test_moderation_service.py
  function _make_moderation_response (line 18) | def _make_moderation_response(flagged: bool, categories: dict[str, bool]...
  function test_check_prompt_safety_disabled_skips_api (line 43) | async def test_check_prompt_safety_disabled_skips_api():
  function test_check_prompt_safety_no_api_key_skips_api (line 54) | async def test_check_prompt_safety_no_api_key_skips_api():
  function test_check_prompt_safety_safe_prompt (line 71) | async def test_check_prompt_safety_safe_prompt():
  function test_check_prompt_safety_flagged_prompt (line 100) | async def test_check_prompt_safety_flagged_prompt():
  function test_check_prompt_safety_flagged_multiple_categories (line 128) | async def test_check_prompt_safety_flagged_multiple_categories():
  function test_check_prompt_safety_flagged_no_categories (line 156) | async def test_check_prompt_safety_flagged_no_categories():
  function test_check_prompt_safety_http_error_raises_moderation_error (line 185) | async def test_check_prompt_safety_http_error_raises_moderation_error():
  function test_check_prompt_safety_http_status_error_raises_moderation_error (line 207) | async def test_check_prompt_safety_http_status_error_raises_moderation_e...
  function test_check_prompt_safety_empty_results (line 237) | async def test_check_prompt_safety_empty_results():

FILE: tests/test_schemas.py
  class TestImageRequest (line 10) | class TestImageRequest:
    method test_defaults (line 11) | def test_defaults(self):
    method test_num_images_up_to_4 (line 21) | def test_num_images_up_to_4(self):
    method test_num_images_rejects_above_4 (line 25) | def test_num_images_rejects_above_4(self):
    method test_num_images_rejects_zero (line 29) | def test_num_images_rejects_zero(self):
    method test_width_must_be_multiple_of_8 (line 33) | def test_width_must_be_multiple_of_8(self):
    method test_height_must_be_multiple_of_8 (line 37) | def test_height_must_be_multiple_of_8(self):
    method test_valid_custom_resolution (line 41) | def test_valid_custom_resolution(self):
    method test_rejects_empty_prompt (line 46) | def test_rejects_empty_prompt(self):
    method test_rejects_oversized_resolution (line 50) | def test_rejects_oversized_resolution(self):
    method test_seed_bounds (line 54) | def test_seed_bounds(self):
    method test_all_formats (line 64) | def test_all_formats(self):
  class TestBatchRequest (line 70) | class TestBatchRequest:
    method test_valid_batch (line 71) | def test_valid_batch(self):
    method test_empty_prompts_rejected (line 78) | def test_empty_prompts_rejected(self):

FILE: tests/test_worker.py
  class TestQueuePriority (line 18) | class TestQueuePriority:
    method test_realtime_before_batch (line 19) | def test_realtime_before_batch(self):
    method test_same_priority_fifo (line 25) | def test_same_priority_fifo(self):
  class TestWorkerEnqueue (line 32) | class TestWorkerEnqueue:
    method test_enqueue_increases_depth (line 34) | async def test_enqueue_increases_depth(self):
    method test_realtime_preempts_queued_batch_jobs (line 42) | async def test_realtime_preempts_queued_batch_jobs(self):
  class TestWorkerStats (line 54) | class TestWorkerStats:
    method test_initial_stats (line 55) | def test_initial_stats(self):
Condensed preview — 48 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (132K chars).
[
  {
    "path": ".gitignore",
    "chars": 173,
    "preview": "# Byte-compiled\n__pycache__/\n*.py[cod]\n\n# Virtual env\n.venv/\nvenv/\n\n# Output\noutput/\n\n# Logs\nlogs/\n\n# IDE\n.idea/\n.vscode"
  },
  {
    "path": "Dockerfile",
    "chars": 778,
    "preview": "FROM nvidia/cuda:12.4.1-runtime-ubuntu22.04\n\nWORKDIR /app\n\n# System deps\nRUN apt-get update && apt-get install -y --no-i"
  },
  {
    "path": "README.md",
    "chars": 14176,
    "preview": "# Flux Image Service\n\nLocal image generation API powered by **Flux.1 Schnell 12B** running on an H100 GPU with FP8 quant"
  },
  {
    "path": "app/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "app/api/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "app/api/batch.py",
    "chars": 10067,
    "preview": "\"\"\"Batch / dataset generation endpoints with SSE progress streaming.\"\"\"\n\nfrom __future__ import annotations\n\nimport asyn"
  },
  {
    "path": "app/api/generate.py",
    "chars": 6655,
    "preview": "\"\"\"Image generation endpoints — async and sync modes.\"\"\"\n\nfrom __future__ import annotations\n\nimport asyncio\nimport json"
  },
  {
    "path": "app/api/health.py",
    "chars": 685,
    "preview": "\"\"\"Health and status endpoint.\"\"\"\n\nfrom __future__ import annotations\n\nimport time\n\nfrom fastapi import APIRouter, Reque"
  },
  {
    "path": "app/api/jobs.py",
    "chars": 1499,
    "preview": "\"\"\"Job management endpoints.\"\"\"\n\nfrom __future__ import annotations\n\nfrom fastapi import APIRouter, Depends, HTTPExcepti"
  },
  {
    "path": "app/config.py",
    "chars": 1429,
    "preview": "from pydantic_settings import BaseSettings\nfrom pathlib import Path\n\n\nclass Settings(BaseSettings):\n    model_config = {"
  },
  {
    "path": "app/database.py",
    "chars": 677,
    "preview": "from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine\n\nfrom app.config import setting"
  },
  {
    "path": "app/engine/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "app/engine/engine_config.py",
    "chars": 847,
    "preview": "\"\"\"Resolution presets and engine configuration.\"\"\"\n\nPRESET_RESOLUTIONS: dict[str, tuple[int, int]] = {\n    \"square_sm\": "
  },
  {
    "path": "app/engine/flux_engine.py",
    "chars": 5547,
    "preview": "\"\"\"Flux.1 Schnell inference engine — model loading, generation, VRAM management.\"\"\"\n\nfrom __future__ import annotations\n"
  },
  {
    "path": "app/engine/mock_engine.py",
    "chars": 2567,
    "preview": "\"\"\"Mock FluxEngine for local development without GPU or model weights.\"\"\"\n\nfrom __future__ import annotations\n\nimport lo"
  },
  {
    "path": "app/main.py",
    "chars": 3687,
    "preview": "\"\"\"FastAPI application — lifespan manages model loading, worker, and DB.\"\"\"\n\nfrom __future__ import annotations\n\nimport "
  },
  {
    "path": "app/models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "app/models/enums.py",
    "chars": 336,
    "preview": "import enum\n\n\nclass JobStatus(str, enum.Enum):\n    PENDING = \"pending\"\n    PROCESSING = \"processing\"\n    COMPLETED = \"co"
  },
  {
    "path": "app/models/job.py",
    "chars": 2424,
    "preview": "import uuid\nfrom datetime import datetime, timezone\n\nfrom sqlalchemy import (\n    Column,\n    DateTime,\n    Enum,\n    Fl"
  },
  {
    "path": "app/schemas/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "app/schemas/batch.py",
    "chars": 507,
    "preview": "from pydantic import BaseModel, Field\n\nfrom app.schemas.generate import ImageRequest\n\n\nclass BatchRequest(BaseModel):\n  "
  },
  {
    "path": "app/schemas/generate.py",
    "chars": 1618,
    "preview": "from pydantic import BaseModel, Field, field_validator, model_validator\n\nfrom app.config import settings\nfrom app.models"
  },
  {
    "path": "app/schemas/job.py",
    "chars": 553,
    "preview": "from pydantic import BaseModel\n\n\nclass JobDetail(BaseModel):\n    id: str\n    prompt: str\n    negative_prompt: str | None"
  },
  {
    "path": "app/services/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "app/services/generation_service.py",
    "chars": 2811,
    "preview": "\"\"\"Orchestrates image generation: validate → persist → enqueue.\"\"\"\n\nfrom __future__ import annotations\n\nimport uuid\nfrom"
  },
  {
    "path": "app/services/job_service.py",
    "chars": 5233,
    "preview": "\"\"\"Job and batch CRUD operations.\"\"\"\n\nfrom __future__ import annotations\n\nfrom datetime import datetime, timezone\n\nfrom "
  },
  {
    "path": "app/services/moderation_service.py",
    "chars": 3025,
    "preview": "\"\"\"Content moderation service – checks prompt safety via a third-party API.\"\"\"\n\nfrom __future__ import annotations\n\nimpo"
  },
  {
    "path": "app/utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "app/utils/image_utils.py",
    "chars": 1069,
    "preview": "\"\"\"Image saving, format conversion, and metadata embedding.\"\"\"\n\nfrom __future__ import annotations\n\nfrom pathlib import "
  },
  {
    "path": "app/utils/moderation.py",
    "chars": 7135,
    "preview": "\"\"\"Content moderation using Gemma-3-1b-it (local) with HF Inference API fallback.\"\"\"\n\nfrom __future__ import annotations"
  },
  {
    "path": "app/utils/storage.py",
    "chars": 1250,
    "preview": "\"\"\"Local filesystem storage management for generated images.\"\"\"\n\nfrom __future__ import annotations\n\nfrom datetime impor"
  },
  {
    "path": "app/worker/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "app/worker/gpu_worker.py",
    "chars": 9371,
    "preview": "\"\"\"GPU worker — async priority queue consumer driving the Flux engine.\"\"\"\n\nfrom __future__ import annotations\n\nimport as"
  },
  {
    "path": "flux_image_service.egg-info/PKG-INFO",
    "chars": 900,
    "preview": "Metadata-Version: 2.4\nName: flux-image-service\nVersion: 0.1.0\nSummary: Local image generation service using Flux.1 Schne"
  },
  {
    "path": "flux_image_service.egg-info/SOURCES.txt",
    "chars": 1046,
    "preview": "README.md\npyproject.toml\napp/__init__.py\napp/config.py\napp/database.py\napp/main.py\napp/api/__init__.py\napp/api/batch.py\n"
  },
  {
    "path": "flux_image_service.egg-info/dependency_links.txt",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "flux_image_service.egg-info/requires.txt",
    "chars": 372,
    "preview": "fastapi>=0.115.0\nuvicorn[standard]>=0.30.0\npydantic>=2.0\npydantic-settings>=2.0\nsqlalchemy[asyncio]>=2.0\naiosqlite>=0.20"
  },
  {
    "path": "flux_image_service.egg-info/top_level.txt",
    "chars": 4,
    "preview": "app\n"
  },
  {
    "path": "pyproject.toml",
    "chars": 849,
    "preview": "[project]\nname = \"flux-image-service\"\nversion = \"0.1.0\"\ndescription = \"Local image generation service using Flux.1 Schne"
  },
  {
    "path": "tests/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "tests/conftest.py",
    "chars": 2884,
    "preview": "\"\"\"Shared test fixtures.\"\"\"\n\nfrom __future__ import annotations\n\nimport asyncio\nfrom unittest.mock import AsyncMock, Mag"
  },
  {
    "path": "tests/test_api_batch.py",
    "chars": 1469,
    "preview": "\"\"\"Tests for batch API endpoints.\"\"\"\n\nimport pytest\n\n\n@pytest.mark.asyncio\nasync def test_create_batch(client, mock_work"
  },
  {
    "path": "tests/test_api_generate.py",
    "chars": 6424,
    "preview": "\"\"\"Tests for generation API endpoints.\"\"\"\n\nimport pytest\nfrom unittest.mock import AsyncMock, patch\nfrom sqlalchemy impo"
  },
  {
    "path": "tests/test_api_jobs.py",
    "chars": 1917,
    "preview": "\"\"\"Tests for job management endpoints.\"\"\"\n\nimport pytest\n\n\n@pytest.mark.asyncio\nasync def test_list_jobs_empty(client):\n"
  },
  {
    "path": "tests/test_moderation.py",
    "chars": 8221,
    "preview": "\"\"\"Tests for the content moderation module.\"\"\"\n\nfrom __future__ import annotations\n\nfrom unittest.mock import MagicMock,"
  },
  {
    "path": "tests/test_moderation_service.py",
    "chars": 10353,
    "preview": "\"\"\"Tests for the content moderation service.\"\"\"\n\nfrom __future__ import annotations\n\nfrom unittest.mock import AsyncMock"
  },
  {
    "path": "tests/test_schemas.py",
    "chars": 2623,
    "preview": "\"\"\"Tests for schemas — resolution validation, format handling.\"\"\"\n\nimport pytest\nfrom pydantic import ValidationError\n\nf"
  },
  {
    "path": "tests/test_worker.py",
    "chars": 2105,
    "preview": "\"\"\"Tests for the GPU worker queue and processing logic.\"\"\"\n\nimport asyncio\nfrom unittest.mock import AsyncMock, MagicMoc"
  }
]

About this extraction

This page contains the full source code of the sgshubham98/image-service-be GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 48 files (120.4 KB), approximately 30.8k tokens, and a symbol index with 186 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!