Repository: sgshubham98/image-service-be
Branch: main
Commit: f9ddc6802cdf
Files: 48
Total size: 120.4 KB
Directory structure:
gitextract_xtvq1k3n/
├── .gitignore
├── Dockerfile
├── README.md
├── app/
│ ├── __init__.py
│ ├── api/
│ │ ├── __init__.py
│ │ ├── batch.py
│ │ ├── generate.py
│ │ ├── health.py
│ │ └── jobs.py
│ ├── config.py
│ ├── database.py
│ ├── engine/
│ │ ├── __init__.py
│ │ ├── engine_config.py
│ │ ├── flux_engine.py
│ │ └── mock_engine.py
│ ├── main.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── enums.py
│ │ └── job.py
│ ├── schemas/
│ │ ├── __init__.py
│ │ ├── batch.py
│ │ ├── generate.py
│ │ └── job.py
│ ├── services/
│ │ ├── __init__.py
│ │ ├── generation_service.py
│ │ ├── job_service.py
│ │ └── moderation_service.py
│ ├── utils/
│ │ ├── __init__.py
│ │ ├── image_utils.py
│ │ ├── moderation.py
│ │ └── storage.py
│ └── worker/
│ ├── __init__.py
│ └── gpu_worker.py
├── flux_image_service.egg-info/
│ ├── PKG-INFO
│ ├── SOURCES.txt
│ ├── dependency_links.txt
│ ├── requires.txt
│ └── top_level.txt
├── pyproject.toml
└── tests/
├── __init__.py
├── conftest.py
├── test_api_batch.py
├── test_api_generate.py
├── test_api_jobs.py
├── test_moderation.py
├── test_moderation_service.py
├── test_schemas.py
└── test_worker.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled
__pycache__/
*.py[cod]
# Virtual env
.venv/
venv/
# Output
output/
# Logs
logs/
# IDE
.idea/
.vscode/
# Environment
.env
# SQLite
*.db
# OS
.DS_Store
================================================
FILE: Dockerfile
================================================
FROM nvidia/cuda:12.4.1-runtime-ubuntu22.04
WORKDIR /app
# System deps
RUN apt-get update && apt-get install -y --no-install-recommends \
python3.11 python3.11-venv python3-pip git && \
rm -rf /var/lib/apt/lists/*
# Create venv
RUN python3.11 -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# Install Python deps
COPY pyproject.toml .
RUN pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir ".[dev]"
# Install PyTorch with CUDA 12.4 (override index for torch)
RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cu124
# Copy application
COPY app/ app/
COPY tests/ tests/
# Create output directory
RUN mkdir -p /app/output
EXPOSE 8000
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
================================================
FILE: README.md
================================================
# Flux Image Service
Local image generation API powered by **Flux.1 Schnell 12B** running on an H100 GPU with FP8 quantization.
## Features
- **Single image generation** — async (returns job IDs) or sync (waits for result), up to 4 images per request
- **Batch/dataset generation** — submit up to 50K prompts via API or CSV/JSON file upload
- **Priority queue** — real-time requests always jump ahead of batch jobs
- **SSE progress streaming** — monitor batch progress in real-time
- **Crash recovery** — pending jobs auto-resume on server restart (persisted in SQLite)
- **Multiple formats** — PNG (with metadata), JPEG, WebP
- **Custom resolutions** — any multiple of 8, up to 2048px
## Prerequisites
- **Python** 3.10+
- **NVIDIA GPU** with CUDA 12.x (tested on H100 20GB)
- **NVIDIA drivers** 535+ with CUDA toolkit
- **System RAM** 32 GB recommended (16 GB minimum)
- **Disk** ~12 GB for model weights + storage for generated images
## Setup
```bash
# 1. Clone / navigate to the project
cd flux-image-service
# 2. Create a virtual environment
python3 -m venv .venv
source .venv/bin/activate
# 3. Install dependencies (includes torch, diffusers, fastapi, etc.)
pip install -e ".[dev]"
# 4. Copy and configure environment variables
cp .env.example .env
# Edit .env if needed (model ID, storage path, etc.)
```
## Running the Server
```bash
# Activate venv (if not already)
source .venv/bin/activate
# Start the server
uvicorn app.main:app --host 0.0.0.0 --port 8000
# Or with auto-reload for development (not recommended for GPU — reloads unload the model)
uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
```
The first startup will:
1. Download the Flux.1 Schnell model from HuggingFace (~12 GB)
2. Compile CUDA kernels via `torch.compile` (~30-60s one-time cost per restart)
3. Run a warm-up inference
Once you see `Flux Image Service ready`, the server is accepting requests.
## Running Tests
```bash
source .venv/bin/activate
# Run all tests
pytest tests/ -v
# Run a specific test file
pytest tests/test_api_generate.py -v
# Run with coverage
pip install pytest-cov
pytest tests/ -v --cov=app --cov-report=term-missing
```
## Docker
```bash
# Build
docker build -t flux-image-service .
# Run (requires nvidia-container-toolkit)
docker run --gpus all -p 8000:8000 -v ./output:/app/output flux-image-service
# Run with custom env
docker run --gpus all -p 8000:8000 \
-e MODEL_ID=black-forest-labs/FLUX.1-schnell \
-e STORAGE_DIR=/app/output \
-v ./output:/app/output \
flux-image-service
```
## Linting
```bash
# Check
ruff check app/ tests/
# Auto-fix
ruff check app/ tests/ --fix
# Format
ruff format app/ tests/
```
---
## API Reference
### `POST /generate` — Async Image Generation
Submit a generation request (1–4 images). Returns immediately with job IDs.
**Request:**
```bash
curl -X POST http://localhost:8000/generate \
-H "Content-Type: application/json" \
-d '{
"prompt": "a majestic mountain landscape at sunset",
"width": 1024,
"height": 1024,
"num_images": 2,
"seed": 42
}'
```
**Response** `200 OK`:
```json
{
"job_ids": ["f47ac10b-58cc-4372-a567-0e02b2c3d479", "a1b2c3d4-5678-9abc-def0-1234567890ab"],
"status": "pending",
"image_urls": []
}
```
**Validation error** `422`:
```json
{
"detail": [
{
"type": "value_error",
"loc": ["body", "width"],
"msg": "Value error, width must be a multiple of 8"
}
]
}
```
All request parameters for `ImageRequest`:
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `prompt` | string | *required* | Text prompt (1–2000 chars) |
| `negative_prompt` | string \| null | null | Negative prompt |
| `width` | int | 1024 | Image width (64–2048, multiple of 8) |
| `height` | int | 1024 | Image height (64–2048, multiple of 8) |
| `num_steps` | int | 4 | Inference steps (1–50) |
| `guidance_scale` | float | 0.0 | CFG scale (Schnell doesn't use CFG) |
| `seed` | int \| null | null | Random seed (0–4294967295) |
| `num_images` | int | 1 | Number of images (1–4) |
| `format` | string | "png" | Output format: `png`, `jpeg`, `webp` |
---
### `GET /generate/stream` — SSE Progress Stream for Single Generate
Stream status updates for one or more job IDs returned by `/generate` until all are terminal.
**Request:**
```bash
curl -N "http://localhost:8000/generate/stream?job_id=f47ac10b-58cc-4372-a567-0e02b2c3d479&job_id=a1b2c3d4-5678-9abc-def0-1234567890ab"
```
**Events:**
- `progress` — emitted periodically with per-job status and aggregate counters
- `done` — emitted once when all jobs are in `completed`, `failed`, or `cancelled`
**Example event payload:**
```json
{
"total": 2,
"completed": 1,
"failed": 0,
"cancelled": 0,
"pending": 1,
"jobs": [
{
"id": "f47ac10b-58cc-4372-a567-0e02b2c3d479",
"status": "completed",
"image_url": "/jobs/f47ac10b-58cc-4372-a567-0e02b2c3d479/image"
},
{
"id": "a1b2c3d4-5678-9abc-def0-1234567890ab",
"status": "processing",
"image_url": null
}
]
}
```
---
### `POST /generate/sync` — Synchronous Generation
Same request body as `/generate`. Blocks until all images are generated (up to 60s timeout).
**Request:**
```bash
curl -X POST http://localhost:8000/generate/sync \
-H "Content-Type: application/json" \
-d '{"prompt": "a cute cat wearing a hat", "num_images": 1}' \
--max-time 60
```
**Response** `200 OK`:
```json
{
"job_ids": ["f47ac10b-58cc-4372-a567-0e02b2c3d479"],
"status": "completed",
"image_urls": ["/jobs/f47ac10b-58cc-4372-a567-0e02b2c3d479/image"]
}
```
**Timeout** `408`:
```json
{"detail": "Generation timed out"}
```
---
### `GET /jobs/{job_id}/image` — Serve Generated Image
Returns the image file for a completed job.
**Request:**
```bash
curl http://localhost:8000/jobs/f47ac10b-58cc-4372-a567-0e02b2c3d479/image --output image.png
```
**Response:** Binary image file with appropriate `Content-Type` (`image/png`, `image/jpeg`, or `image/webp`).
**Not found** `404`:
```json
{"detail": "Image not available"}
```
---
### `GET /jobs/{job_id}` — Get Job Details
**Request:**
```bash
curl http://localhost:8000/jobs/f47ac10b-58cc-4372-a567-0e02b2c3d479
```
**Response** `200 OK`:
```json
{
"id": "f47ac10b-58cc-4372-a567-0e02b2c3d479",
"prompt": "a majestic mountain landscape at sunset",
"negative_prompt": null,
"width": 1024,
"height": 1024,
"num_steps": 4,
"guidance_scale": 0.0,
"seed": 42,
"status": "completed",
"priority": 1,
"format": "png",
"file_path": "output/images/2026-03-08/f47ac10b-58cc-4372-a567-0e02b2c3d479.png",
"image_url": "/jobs/f47ac10b-58cc-4372-a567-0e02b2c3d479/image",
"error_message": null,
"batch_id": null,
"created_at": "2026-03-08T12:00:00+00:00",
"started_at": "2026-03-08T12:00:01+00:00",
"completed_at": "2026-03-08T12:00:03+00:00"
}
```
---
### `GET /jobs` — List Jobs
**Request:**
```bash
# List all jobs (paginated)
curl "http://localhost:8000/jobs?page=1&page_size=20"
# Filter by status
curl "http://localhost:8000/jobs?status=completed"
# Filter by batch
curl "http://localhost:8000/jobs?batch_id=batch-uuid-here"
```
**Response** `200 OK`:
```json
{
"jobs": [
{
"id": "f47ac10b-...",
"prompt": "a mountain",
"status": "completed",
"width": 1024,
"height": 1024,
"image_url": "/jobs/f47ac10b-.../image",
"...": "..."
}
],
"total": 150,
"page": 1,
"page_size": 20
}
```
---
### `DELETE /jobs/{job_id}` — Cancel a Job
Cancels a pending/processing job. Already completed/failed jobs are returned as-is.
**Request:**
```bash
curl -X DELETE http://localhost:8000/jobs/f47ac10b-58cc-4372-a567-0e02b2c3d479
```
**Response** `200 OK`:
```json
{
"id": "f47ac10b-58cc-4372-a567-0e02b2c3d479",
"status": "cancelled",
"...": "..."
}
```
---
### `POST /batch` — Create Batch Job
Submit 1–50,000 prompts as a batch. All jobs get BATCH priority (lower than real-time).
**Request:**
```bash
curl -X POST http://localhost:8000/batch \
-H "Content-Type: application/json" \
-d '{
"name": "training-dataset-v1",
"prompts": [
{"prompt": "a red sports car", "width": 512, "height": 512},
{"prompt": "a blue ocean wave", "width": 1024, "height": 768},
{"prompt": "a forest path in autumn", "seed": 123}
]
}'
```
**Response** `200 OK`:
```json
{
"batch_id": "b1c2d3e4-5678-9abc-def0-1234567890ab",
"name": "training-dataset-v1",
"total": 3,
"completed": 0,
"failed": 0,
"cancelled": 0,
"pending": 3,
"status": "pending",
"estimated_remaining_seconds": null,
"created_at": "2026-03-08T12:00:00+00:00",
"completed_at": null
}
```
---
### `POST /batch/from-file` — Create Batch from File Upload
Upload a CSV or JSON file of prompts.
**CSV upload:**
```bash
curl -X POST http://localhost:8000/batch/from-file \
-F "file=@prompts.csv" \
-F "name=my-dataset"
```
CSV format (only `prompt` column is required):
```csv
prompt,width,height,num_steps,seed,format
a red car,512,512,4,42,png
a blue house,1024,1024,,,
a green tree,,,,,
```
**JSON upload:**
```bash
curl -X POST http://localhost:8000/batch/from-file \
-F "file=@prompts.json" \
-F "name=my-dataset"
```
JSON format:
```json
[
{"prompt": "a red car", "width": 512, "height": 512, "seed": 42},
{"prompt": "a blue house"},
{"prompt": "a green tree", "format": "jpeg"}
]
```
**Response:** Same as `POST /batch`.
---
### `GET /batch/{batch_id}` — Get Batch Progress
**Request:**
```bash
curl http://localhost:8000/batch/b1c2d3e4-5678-9abc-def0-1234567890ab
```
**Response** `200 OK`:
```json
{
"batch_id": "b1c2d3e4-5678-9abc-def0-1234567890ab",
"name": "training-dataset-v1",
"total": 1000,
"completed": 342,
"failed": 2,
"cancelled": 0,
"pending": 656,
"status": "processing",
"estimated_remaining_seconds": 1968.0,
"created_at": "2026-03-08T12:00:00+00:00",
"completed_at": null
}
```
---
### `GET /batch/{batch_id}/stream` — SSE Progress Stream
Real-time Server-Sent Events stream of batch progress. Updates every 2 seconds.
**Request:**
```bash
curl -N http://localhost:8000/batch/b1c2d3e4-5678-9abc-def0-1234567890ab/stream
```
**Response** (event stream):
```
event: progress
data: {"batch_id":"b1c2d3e4-...","total":1000,"completed":342,"failed":2,"pending":656,"status":"processing","estimated_remaining_seconds":1968.0}
event: progress
data: {"batch_id":"b1c2d3e4-...","total":1000,"completed":343,"failed":2,"pending":655,"status":"processing","estimated_remaining_seconds":1965.0}
event: done
data: {"batch_id":"b1c2d3e4-...","total":1000,"completed":998,"failed":2,"pending":0,"status":"failed"}
```
---
### `DELETE /batch/{batch_id}` — Cancel Batch
Cancels all remaining pending jobs in the batch.
**Request:**
```bash
curl -X DELETE http://localhost:8000/batch/b1c2d3e4-5678-9abc-def0-1234567890ab
```
**Response** `200 OK`:
```json
{
"batch_id": "b1c2d3e4-5678-9abc-def0-1234567890ab",
"total": 1000,
"completed": 342,
"failed": 2,
"cancelled": 656,
"pending": 0,
"status": "cancelled",
"...": "..."
}
```
---
### `POST /batch/{batch_id}/retry-failed` — Retry Failed Jobs
Re-enqueue only the failed jobs in a batch.
**Request:**
```bash
curl -X POST http://localhost:8000/batch/b1c2d3e4-5678-9abc-def0-1234567890ab/retry-failed
```
**Response** `200 OK`:
```json
{
"retried": 2,
"job_ids": ["job-uuid-1", "job-uuid-2"]
}
```
---
### `GET /health` — Health Check
**Request:**
```bash
curl http://localhost:8000/health
```
**Response** `200 OK`:
```json
{
"status": "ok",
"model_loaded": true,
"vram": {
"available": true,
"allocated_gb": 12.34,
"reserved_gb": 14.50,
"max_allocated_gb": 16.78,
"total_gb": 20.00,
"device_name": "NVIDIA H100"
},
"worker": {
"queue_depth": 5,
"jobs_processed": 142,
"total_inference_time_s": 398.50,
"avg_inference_time_s": 2.81
},
"uptime_seconds": 3600.0
}
```
---
## Architecture
```
Client → FastAPI → asyncio.PriorityQueue → GPU Worker → Flux Engine
↓ ↓
SQLite Local FS
(job state) (images)
```
- **Single GPU worker** — processes one image at a time (VRAM constraint)
- **Priority queue** — REALTIME(1) always before BATCH(5)
- **OOM recovery** — catches GPU OOM, clears cache, continues processing
- **Crash recovery** — on startup, all PENDING/PROCESSING jobs in SQLite are automatically re-enqueued
### Queue Persistence
The in-memory `asyncio.PriorityQueue` is **lost on server restart**. However, **no work is lost** because:
1. All jobs are persisted to SQLite the moment they're created (before being enqueued)
2. On startup, `resume_pending_jobs()` scans for any jobs in `PENDING` or `PROCESSING` state and re-enqueues them
3. Completed jobs and their images are already on disk
This means a restart simply rebuilds the queue from the database. The only effect is a brief pause while the model reloads.
## Configuration
All settings can be overridden via environment variables or `.env` file:
| Variable | Default | Description |
|----------|---------|-------------|
| `MODEL_ID` | `black-forest-labs/FLUX.1-schnell` | HuggingFace model ID |
| `DEVICE` | `cuda` | Torch device |
| `DTYPE` | `float8_e4m3fn` | Model precision |
| `ENABLE_TORCH_COMPILE` | `true` | Enable torch.compile optimization |
| `DEFAULT_NUM_STEPS` | `4` | Default inference steps |
| `DEFAULT_GUIDANCE_SCALE` | `0.0` | Default CFG scale |
| `MAX_RESOLUTION` | `2048` | Maximum allowed width/height |
| `MAX_IMMEDIATE_IMAGES` | `4` | Max images per /generate request |
| `STORAGE_DIR` | `./output` | Where images and DB are stored |
| `MAX_QUEUE_SIZE` | `10000` | Maximum queue capacity |
| `SYNC_TIMEOUT` | `60.0` | Timeout for /generate/sync (seconds) |
| `HOST` | `0.0.0.0` | Server bind address |
| `PORT` | `8000` | Server port |
## Performance (H100 20GB, FP8)
| Resolution | Time per image | 10K dataset |
|-----------|---------------|-------------|
| 512×512 | ~0.8–1.2s | ~2–3 hours |
| 1024×1024 | ~2–4s | ~6–10 hours |
| 1536×640 | ~2–3s | ~5–8 hours |
================================================
FILE: app/__init__.py
================================================
================================================
FILE: app/api/__init__.py
================================================
================================================
FILE: app/api/batch.py
================================================
"""Batch / dataset generation endpoints with SSE progress streaming."""
from __future__ import annotations
import asyncio
import csv
import io
import json
import os
import tempfile
import zipfile
from pathlib import Path
from typing import AsyncGenerator
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, File
from fastapi.responses import StreamingResponse
from sse_starlette.sse import EventSourceResponse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_session
from app.models.job import Job
from app.models.enums import ImageFormat
from app.schemas.batch import BatchProgress, BatchRequest
from app.schemas.generate import ImageRequest
from app.services.generation_service import create_batch
from app.services.job_service import (
cancel_batch,
get_batch_progress,
retry_failed_in_batch,
)
router = APIRouter(prefix="/batch", tags=["batch"])
def _get_worker(request: Request):
return request.app.state.worker
@router.post("", response_model=BatchProgress)
async def create_batch_job(
body: BatchRequest,
request: Request,
session: AsyncSession = Depends(get_session),
):
"""Create a batch of image generation jobs from a list of prompts."""
worker = _get_worker(request)
batch_id = await create_batch(session, body.name, body.prompts, worker)
progress = await get_batch_progress(session, batch_id)
return BatchProgress(**progress)
@router.post("/from-file", response_model=BatchProgress)
async def create_batch_from_file(
request: Request,
file: UploadFile = File(...),
name: str | None = None,
session: AsyncSession = Depends(get_session),
):
"""Create a batch from a CSV or JSON file upload.
CSV format: prompt,width,height,num_steps,seed,format
JSON format: array of ImageRequest objects
"""
worker = _get_worker(request)
content = await file.read()
filename = file.filename or ""
if filename.endswith(".json"):
prompts = _parse_json(content)
elif filename.endswith(".csv"):
prompts = _parse_csv(content)
else:
raise HTTPException(
status_code=400,
detail="Unsupported file format. Use .csv or .json",
)
if not prompts:
raise HTTPException(status_code=400, detail="No valid prompts found in file")
batch_id = await create_batch(session, name or filename, prompts, worker)
progress = await get_batch_progress(session, batch_id)
return BatchProgress(**progress)
@router.get("/{batch_id}", response_model=BatchProgress)
async def get_batch_status(
batch_id: str,
session: AsyncSession = Depends(get_session),
):
progress = await get_batch_progress(session, batch_id)
if not progress:
raise HTTPException(status_code=404, detail="Batch not found")
return BatchProgress(**progress)
@router.get("/{batch_id}/stream")
async def stream_batch_progress(
batch_id: str,
request: Request,
session: AsyncSession = Depends(get_session),
):
"""SSE endpoint streaming batch progress updates until completion."""
progress = await get_batch_progress(session, batch_id)
if not progress:
raise HTTPException(status_code=404, detail="Batch not found")
async def event_generator() -> AsyncGenerator[dict, None]:
while True:
if await request.is_disconnected():
break
async with _fresh_session() as sess:
p = await get_batch_progress(sess, batch_id)
if p is None:
break
# Calculate ETA
eta = _estimate_eta(p, request.app.state.worker)
yield {
"event": "progress",
"data": json.dumps({**p, "estimated_remaining_seconds": eta}),
}
if p["status"] in ("completed", "failed", "cancelled"):
yield {"event": "done", "data": json.dumps(p)}
break
await asyncio.sleep(2)
return EventSourceResponse(event_generator())
@router.get("/{batch_id}/download")
async def download_batch_images(
batch_id: str,
session: AsyncSession = Depends(get_session),
):
"""Download all completed images in a batch as a ZIP archive.
Streams the ZIP incrementally so the full archive is never held
in memory. File I/O runs in a thread pool to avoid blocking the
async event loop.
"""
progress = await get_batch_progress(session, batch_id)
if not progress:
raise HTTPException(status_code=404, detail="Batch not found")
result = await session.execute(
select(Job).where(Job.batch_id == batch_id, Job.status == "completed")
)
jobs = result.scalars().all()
if not jobs:
raise HTTPException(status_code=404, detail="No completed images in this batch")
# Collect valid file paths up-front (lightweight check).
file_entries: list[tuple[str, Path]] = []
for job in jobs:
if not job.file_path:
continue
fp = Path(job.file_path)
if fp.is_file():
file_entries.append((fp.name, fp))
if not file_entries:
raise HTTPException(status_code=404, detail="No image files found on disk")
batch_name = progress.get("name") or batch_id[:12]
safe_name = "".join(
c if c.isalnum() or c in (" ", "-", "_") else "_" for c in batch_name
).strip()
filename = f"{safe_name}.zip"
async def _stream_zip() -> AsyncGenerator[bytes, None]:
"""Build a valid ZIP in a temp file (off the event loop), then
stream it in chunks so memory stays flat."""
loop = asyncio.get_running_loop()
tmp_fd, tmp_path = tempfile.mkstemp(suffix=".zip")
os.close(tmp_fd)
def _build_zip() -> None:
with zipfile.ZipFile(tmp_path, "w", zipfile.ZIP_STORED) as zf:
for arc_name, file_path in file_entries:
zf.write(str(file_path), arc_name)
try:
# Build the ZIP in a worker thread — no event-loop blocking.
await loop.run_in_executor(None, _build_zip)
# Stream the temp file in 256 KB chunks.
chunk_size = 256 * 1024
with open(tmp_path, "rb") as f:
while True:
chunk = await loop.run_in_executor(None, f.read, chunk_size)
if not chunk:
break
yield chunk
finally:
try:
os.unlink(tmp_path)
except OSError:
pass
return StreamingResponse(
_stream_zip(),
media_type="application/zip",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
@router.delete("/{batch_id}", response_model=BatchProgress)
async def cancel_batch_job(
batch_id: str,
session: AsyncSession = Depends(get_session),
):
result = await cancel_batch(session, batch_id)
if not result:
raise HTTPException(status_code=404, detail="Batch not found")
return BatchProgress(**result)
@router.post("/{batch_id}/retry-failed")
async def retry_failed_jobs(
batch_id: str,
request: Request,
session: AsyncSession = Depends(get_session),
):
"""Re-enqueue all failed jobs in a batch."""
worker = _get_worker(request)
job_ids = await retry_failed_in_batch(session, batch_id)
if not job_ids:
raise HTTPException(status_code=404, detail="No failed jobs found in batch")
for jid in job_ids:
await worker.enqueue(jid, priority=5)
return {"retried": len(job_ids), "job_ids": job_ids}
# ── File parsing helpers ─────────────────────────────────────────
def _parse_json(content: bytes) -> list[ImageRequest]:
try:
data = json.loads(content)
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}")
if not isinstance(data, list):
raise HTTPException(status_code=400, detail="JSON must be an array of objects")
prompts = []
for i, item in enumerate(data):
try:
prompts.append(ImageRequest(**item))
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid entry at index {i}: {e}")
return prompts
def _parse_csv(content: bytes) -> list[ImageRequest]:
try:
text = content.decode("utf-8")
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="CSV must be UTF-8 encoded")
reader = csv.DictReader(io.StringIO(text))
prompts = []
for i, row in enumerate(reader):
try:
kwargs: dict = {"prompt": row["prompt"]}
if "width" in row and row["width"]:
kwargs["width"] = int(row["width"])
if "height" in row and row["height"]:
kwargs["height"] = int(row["height"])
if "num_steps" in row and row["num_steps"]:
kwargs["num_steps"] = int(row["num_steps"])
if "seed" in row and row["seed"]:
kwargs["seed"] = int(row["seed"])
if "format" in row and row["format"]:
kwargs["format"] = ImageFormat(row["format"].strip().lower())
prompts.append(ImageRequest(**kwargs))
except KeyError:
raise HTTPException(
status_code=400, detail=f"Row {i + 1}: 'prompt' column is required"
)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Row {i + 1}: {e}")
return prompts
def _estimate_eta(progress: dict, worker) -> float | None:
"""Estimate remaining seconds based on average inference time."""
stats = worker.get_stats()
avg_time = stats.get("avg_inference_time_s", 0)
if avg_time <= 0:
return None
remaining = progress.get("pending", 0)
return round(remaining * avg_time, 1)
# Helper to get fresh session for SSE generator
from app.database import async_session_factory as _fresh_session # noqa: E402
================================================
FILE: app/api/generate.py
================================================
"""Image generation endpoints — async and sync modes."""
from __future__ import annotations
import asyncio
import json
from pathlib import Path
from typing import AsyncGenerator
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import FileResponse
from sse_starlette.sse import EventSourceResponse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.database import get_session
from app.models.job import Job
from app.schemas.generate import ImageRequest, ImageResponse
from app.services.generation_service import create_single_job
from app.services.job_service import get_job
from app.services.moderation_service import ModerationError, check_prompt_safety
from app.utils.storage import image_url_path
router = APIRouter(tags=["generate"])
_TERMINAL_JOB_STATUSES = {"completed", "failed", "cancelled"}
def _get_worker(request: Request):
return request.app.state.worker
@router.post("/generate", response_model=ImageResponse)
async def generate_async(
body: ImageRequest,
request: Request,
session: AsyncSession = Depends(get_session),
):
"""Submit an image generation request (1–4 images). Returns immediately with job IDs."""
try:
is_safe, reason = await check_prompt_safety(body.prompt)
except ModerationError:
raise HTTPException(status_code=503, detail="Content moderation service unavailable")
if not is_safe:
raise HTTPException(
status_code=400,
detail=f"Prompt contains unsafe or explicit content: {reason}",
)
worker = _get_worker(request)
job_ids = await create_single_job(session, body, worker)
return ImageResponse(job_ids=job_ids, status="pending")
@router.get("/generate/stream")
async def stream_generate_progress(
request: Request,
job_id: list[str] = Query(..., description="One or more job IDs to monitor"),
session: AsyncSession = Depends(get_session),
):
"""SSE stream for one or more generate job IDs until all reach terminal state."""
unique_job_ids = list(dict.fromkeys(job_id))
# Fail fast if any requested job ID does not exist.
checks = [await get_job(session, jid) for jid in unique_job_ids]
missing = [jid for jid, detail in zip(unique_job_ids, checks, strict=False) if detail is None]
if missing:
raise HTTPException(
status_code=404,
detail=f"Job not found: {', '.join(missing)}",
)
async def event_generator() -> AsyncGenerator[dict, None]:
while True:
if await request.is_disconnected():
break
# Expire identity map so concurrent worker updates are visible.
session.expire_all()
details = [await get_job(session, jid) for jid in unique_job_ids]
if any(d is None for d in details):
yield {
"event": "error",
"data": json.dumps({"detail": "Job no longer exists"}),
}
break
jobs = [d.model_dump() for d in details if d is not None]
total = len(jobs)
completed = sum(1 for d in details if d and d.status == "completed")
failed = sum(1 for d in details if d and d.status == "failed")
cancelled = sum(1 for d in details if d and d.status == "cancelled")
terminal = sum(1 for d in details if d and d.status in _TERMINAL_JOB_STATUSES)
payload = {
"total": total,
"completed": completed,
"failed": failed,
"cancelled": cancelled,
"pending": total - terminal,
"jobs": jobs,
}
yield {"event": "progress", "data": json.dumps(payload)}
if terminal == total:
yield {"event": "done", "data": json.dumps(payload)}
break
await asyncio.sleep(1.0)
return EventSourceResponse(event_generator())
@router.post("/generate/sync", response_model=ImageResponse)
async def generate_sync(
body: ImageRequest,
request: Request,
session: AsyncSession = Depends(get_session),
):
"""Submit and wait for image(s) to be generated (up to SYNC_TIMEOUT seconds)."""
try:
is_safe, reason = await check_prompt_safety(body.prompt)
except ModerationError:
raise HTTPException(status_code=503, detail="Content moderation service unavailable")
if not is_safe:
raise HTTPException(
status_code=400,
detail=f"Prompt contains unsafe or explicit content: {reason}",
)
worker = _get_worker(request)
job_ids = await create_single_job(session, body, worker)
# Poll until all jobs complete or timeout
deadline = asyncio.get_event_loop().time() + settings.SYNC_TIMEOUT
while asyncio.get_event_loop().time() < deadline:
await asyncio.sleep(0.3)
details = [await get_job(session, jid) for jid in job_ids]
if all(d and d.status in ("completed", "failed", "cancelled") for d in details):
break
else:
raise HTTPException(status_code=408, detail="Generation timed out")
failed = [d for d in details if d and d.status == "failed"]
if failed and len(failed) == len(details):
raise HTTPException(
status_code=500,
detail=failed[0].error_message or "Generation failed",
)
return ImageResponse(
job_ids=job_ids,
status="completed",
image_urls=[d.image_url if d and d.status == "completed" else None for d in details],
)
@router.get("/jobs/{job_id}/image")
async def serve_image(
job_id: str,
session: AsyncSession = Depends(get_session),
):
"""Serve the generated image file for a completed job."""
result = await session.execute(select(Job).where(Job.id == job_id))
job = result.scalar_one_or_none()
if not job:
raise HTTPException(status_code=404, detail="Job not found")
if job.status != "completed" or not job.file_path:
raise HTTPException(status_code=404, detail="Image not available")
path = Path(job.file_path)
if not path.is_file():
raise HTTPException(status_code=404, detail="Image file missing from disk")
media_types = {
"png": "image/png",
"jpeg": "image/jpeg",
"jpg": "image/jpeg",
"webp": "image/webp",
}
ext = path.suffix.lstrip(".")
media_type = media_types.get(ext, "application/octet-stream")
return FileResponse(path, media_type=media_type, filename=path.name)
================================================
FILE: app/api/health.py
================================================
"""Health and status endpoint."""
from __future__ import annotations
import time
from fastapi import APIRouter, Request
router = APIRouter(tags=["health"])
_start_time: float = 0.0
def set_start_time() -> None:
global _start_time
_start_time = time.time()
@router.get("/health")
async def health_check(request: Request):
engine = request.app.state.engine
worker = request.app.state.worker
uptime = time.time() - _start_time if _start_time else 0
return {
"status": "ok",
"model_loaded": engine.is_loaded,
"vram": engine.get_vram_stats(),
"worker": worker.get_stats(),
"uptime_seconds": round(uptime, 1),
}
================================================
FILE: app/api/jobs.py
================================================
"""Job management endpoints."""
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_session
from app.schemas.job import JobDetail, JobList
from app.services.job_service import cancel_job, get_job, list_jobs
router = APIRouter(prefix="/jobs", tags=["jobs"])
@router.get("", response_model=JobList)
async def list_all_jobs(
status: str | None = Query(None, description="Filter by status"),
batch_id: str | None = Query(None, description="Filter by batch ID"),
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=200),
session: AsyncSession = Depends(get_session),
):
jobs, total = await list_jobs(session, status=status, batch_id=batch_id, page=page, page_size=page_size)
return JobList(jobs=jobs, total=total, page=page, page_size=page_size)
@router.get("/{job_id}", response_model=JobDetail)
async def get_job_detail(
job_id: str,
session: AsyncSession = Depends(get_session),
):
job = await get_job(session, job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
return job
@router.delete("/{job_id}", response_model=JobDetail)
async def cancel_job_endpoint(
job_id: str,
session: AsyncSession = Depends(get_session),
):
job = await cancel_job(session, job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
return job
================================================
FILE: app/config.py
================================================
from pydantic_settings import BaseSettings
from pathlib import Path
class Settings(BaseSettings):
model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
# ── Dev / mock mode ──
MOCK_MODE: bool = False
# ── Model ──
MODEL_ID: str = "black-forest-labs/FLUX.1-schnell"
DEVICE: str = "cuda"
DTYPE: str = "float8_e4m3fn"
ENABLE_TORCH_COMPILE: bool = True
# ── Generation defaults ──
DEFAULT_NUM_STEPS: int = 4
DEFAULT_GUIDANCE_SCALE: float = 0.0
MAX_RESOLUTION: int = 2048
MIN_RESOLUTION: int = 64
# ── Storage ──
STORAGE_DIR: Path = Path("./output")
# ── Logging ──
LOG_DIR: Path = Path("./logs")
LOG_LEVEL: str = "INFO"
LOG_ROTATION_WHEN: str = "midnight" # midnight, h, m, s, etc.
LOG_RETENTION_COUNT: int = 30 # number of rotated files to keep
# ── Queue ──
MAX_QUEUE_SIZE: int = 10000
# ── Server ──
HOST: str = "0.0.0.0"
PORT: int = 8000
# ── Immediate generation ──
MAX_IMMEDIATE_IMAGES: int = 4
# ── Sync generation timeout (seconds) ──
SYNC_TIMEOUT: float = 60.0
# ── Content moderation ──
MODERATION_ENABLED: bool = True
MODERATION_MODEL_ID: str = "google/gemma-3-1b-it"
HF_API_TOKEN: str | None = None
MODERATION_API_KEY: str = ""
MODERATION_API_URL: str = "https://api.openai.com/v1/moderations"
MODERATION_TIMEOUT: float = 10.0
settings = Settings()
================================================
FILE: app/database.py
================================================
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.config import settings
DATABASE_URL = f"sqlite+aiosqlite:///{settings.STORAGE_DIR / 'flux_service.db'}"
engine = create_async_engine(DATABASE_URL, echo=False)
async_session_factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async def init_db() -> None:
from app.models.job import Base
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def get_session() -> AsyncSession: # type: ignore[misc]
async with async_session_factory() as session:
yield session # type: ignore[misc]
================================================
FILE: app/engine/__init__.py
================================================
================================================
FILE: app/engine/engine_config.py
================================================
"""Resolution presets and engine configuration."""
PRESET_RESOLUTIONS: dict[str, tuple[int, int]] = {
"square_sm": (512, 512),
"square": (1024, 1024),
"landscape": (1344, 768),
"portrait": (768, 1344),
"wide": (1536, 640),
}
# Approximate peak VRAM (GB) by max dimension for FP8 Flux Schnell.
# Used as a safety guard — not an exact model; conservative estimates.
VRAM_ESTIMATES: dict[int, float] = {
512: 13.5,
768: 14.5,
1024: 16.0,
1536: 18.5,
2048: 20.0,
}
def estimated_vram_gb(width: int, height: int) -> float:
"""Return conservative VRAM estimate for a given resolution."""
max_dim = max(width, height)
for threshold in sorted(VRAM_ESTIMATES.keys()):
if max_dim <= threshold:
return VRAM_ESTIMATES[threshold]
return VRAM_ESTIMATES[max(VRAM_ESTIMATES.keys())]
================================================
FILE: app/engine/flux_engine.py
================================================
"""Flux.1 Schnell inference engine — model loading, generation, VRAM management."""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass, field
import torch
from PIL import Image
from app.config import settings
logger = logging.getLogger(__name__)
# Dtype mapping
_DTYPE_MAP: dict[str, torch.dtype] = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float8_e4m3fn": torch.float8_e4m3fn,
"float32": torch.float32,
}
@dataclass
class GenerationResult:
image: Image.Image
seed: int
inference_time: float # seconds
@dataclass
class FluxEngine:
"""Singleton engine wrapping the Flux.1 Schnell diffusion pipeline."""
_pipeline: object | None = field(default=None, repr=False)
_loaded: bool = False
_device: str = "cuda"
_dtype: torch.dtype = torch.bfloat16
def load_model(self) -> None:
"""Load the Flux pipeline onto the GPU."""
if self._loaded:
logger.info("Model already loaded, skipping.")
return
from diffusers import FluxPipeline
self._device = settings.DEVICE
self._dtype = _DTYPE_MAP.get(settings.DTYPE, torch.bfloat16)
logger.info(
"Loading Flux.1 Schnell model=%s device=%s dtype=%s",
settings.MODEL_ID,
self._device,
self._dtype,
)
load_kwargs: dict = {
"torch_dtype": self._dtype,
}
# For FP8 on H100, use the fp8 variant if available
if self._dtype == torch.float8_e4m3fn:
load_kwargs["variant"] = "fp8"
self._pipeline = FluxPipeline.from_pretrained(
settings.MODEL_ID,
**load_kwargs,
)
self._pipeline.to(self._device)
self._pipeline.set_progress_bar_config(disable=True)
# torch.compile for H100 Tensor Core acceleration
if settings.ENABLE_TORCH_COMPILE and self._device == "cuda":
logger.info("Compiling transformer with torch.compile (mode=max-autotune-no-cudagraphs)...")
self._pipeline.transformer = torch.compile(
self._pipeline.transformer,
mode="max-autotune-no-cudagraphs",
)
self._loaded = True
logger.info("Model loaded. VRAM: %.2f GB", self.get_vram_usage_gb())
def warmup(self) -> None:
"""Run a single throwaway inference to trigger torch.compile tracing."""
if not self._loaded:
raise RuntimeError("Model not loaded — call load_model() first")
logger.info("Running warm-up inference (first run compiles CUDA kernels)...")
start = time.perf_counter()
self.generate(prompt="warmup", width=512, height=512, num_steps=1, seed=0)
elapsed = time.perf_counter() - start
logger.info("Warm-up completed in %.1fs", elapsed)
@torch.inference_mode()
def generate(
self,
prompt: str,
width: int = 1024,
height: int = 1024,
num_steps: int = 4,
guidance_scale: float = 0.0,
seed: int | None = None,
) -> GenerationResult:
"""Generate a single image. Runs synchronously on the GPU."""
if not self._loaded:
raise RuntimeError("Model not loaded — call load_model() first")
if seed is None:
seed = torch.randint(0, 2**32, (1,)).item()
generator = torch.Generator(device=self._device).manual_seed(seed)
start = time.perf_counter()
output = self._pipeline(
prompt=prompt,
width=width,
height=height,
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
generator=generator,
output_type="pil",
)
elapsed = time.perf_counter() - start
image: Image.Image = output.images[0]
logger.info(
"Generated %dx%d in %.2fs (steps=%d, seed=%d)",
width,
height,
elapsed,
num_steps,
seed,
)
return GenerationResult(image=image, seed=seed, inference_time=elapsed)
def get_vram_usage_gb(self) -> float:
"""Return current GPU memory allocated in GB."""
if not torch.cuda.is_available():
return 0.0
return torch.cuda.memory_allocated(self._device) / (1024**3)
def get_vram_stats(self) -> dict:
"""Return detailed VRAM statistics."""
if not torch.cuda.is_available():
return {"available": False}
return {
"available": True,
"allocated_gb": round(torch.cuda.memory_allocated(self._device) / (1024**3), 2),
"reserved_gb": round(torch.cuda.memory_reserved(self._device) / (1024**3), 2),
"max_allocated_gb": round(
torch.cuda.max_memory_allocated(self._device) / (1024**3), 2
),
"total_gb": round(torch.cuda.get_device_properties(0).total_mem / (1024**3), 2),
"device_name": torch.cuda.get_device_name(0),
}
def unload(self) -> None:
"""Free the model and clear CUDA cache."""
if self._pipeline is not None:
del self._pipeline
self._pipeline = None
self._loaded = False
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
logger.info("Model unloaded, CUDA cache cleared.")
@property
def is_loaded(self) -> bool:
return self._loaded
================================================
FILE: app/engine/mock_engine.py
================================================
"""Mock FluxEngine for local development without GPU or model weights."""
from __future__ import annotations
import logging
import random
import time
from dataclasses import dataclass
from PIL import Image, ImageDraw
logger = logging.getLogger(__name__)
@dataclass
class GenerationResult:
image: Image.Image
seed: int
inference_time: float
@dataclass
class MockFluxEngine:
"""Drop-in replacement for FluxEngine that generates gradient placeholder images."""
_loaded: bool = False
def load_model(self) -> None:
logger.info("MockFluxEngine: loaded (no actual model).")
self._loaded = True
def warmup(self) -> None:
logger.info("MockFluxEngine: warmup (no-op).")
def generate(
self,
prompt: str,
width: int = 1024,
height: int = 1024,
num_steps: int = 4,
guidance_scale: float = 0.0,
seed: int | None = None,
) -> GenerationResult:
if seed is None:
seed = random.randint(0, 2**32 - 1)
start = time.perf_counter()
rng = random.Random(seed)
r1, g1, b1 = rng.randint(40, 180), rng.randint(40, 180), rng.randint(40, 180)
r2, g2, b2 = rng.randint(40, 180), rng.randint(40, 180), rng.randint(40, 180)
img = Image.new("RGB", (width, height))
pixels = img.load()
for y in range(height):
t = y / max(height - 1, 1)
r = int(r1 + (r2 - r1) * t)
g = int(g1 + (g2 - g1) * t)
b = int(b1 + (b2 - b1) * t)
for x in range(width):
pixels[x, y] = (r, g, b)
draw = ImageDraw.Draw(img)
lines = [
f"[MOCK] {prompt[:60]}",
f"Seed: {seed} | {width}x{height} | steps={num_steps}",
]
draw.text((20, 20), "\n".join(lines), fill=(255, 255, 255))
# Small artificial delay to simulate inference
time.sleep(2.05 + rng.random() * 0.15)
elapsed = time.perf_counter() - start
logger.info(
"MockFluxEngine: generated %dx%d in %.2fs (seed=%d)",
width, height, elapsed, seed,
)
return GenerationResult(image=img, seed=seed, inference_time=elapsed)
def get_vram_usage_gb(self) -> float:
return 0.0
def get_vram_stats(self) -> dict:
return {"available": False, "mock_mode": True}
def unload(self) -> None:
self._loaded = False
logger.info("MockFluxEngine: unloaded.")
@property
def is_loaded(self) -> bool:
return self._loaded
================================================
FILE: app/main.py
================================================
"""FastAPI application — lifespan manages model loading, worker, and DB."""
from __future__ import annotations
import logging
from contextlib import asynccontextmanager
from logging.handlers import TimedRotatingFileHandler
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api import batch, generate, health, jobs
from app.config import settings
from app.database import init_db
from app.utils.moderation import get_moderation_engine
from app.utils.storage import ensure_storage_dirs
from app.worker.gpu_worker import GPUWorker
def _setup_logging() -> None:
"""Configure logging to both console and a time-rotating file."""
settings.LOG_DIR.mkdir(parents=True, exist_ok=True)
log_format = "%(asctime)s %(levelname)-8s %(name)s — %(message)s"
level = getattr(logging, settings.LOG_LEVEL.upper(), logging.INFO)
root = logging.getLogger()
root.setLevel(level)
# Console handler
console = logging.StreamHandler()
console.setLevel(level)
console.setFormatter(logging.Formatter(log_format))
root.addHandler(console)
# Time-rotating file handler
file_handler = TimedRotatingFileHandler(
filename=settings.LOG_DIR / "flux-service.log",
when=settings.LOG_ROTATION_WHEN,
backupCount=settings.LOG_RETENTION_COUNT,
encoding="utf-8",
)
file_handler.setLevel(level)
file_handler.setFormatter(logging.Formatter(log_format))
file_handler.suffix = "%Y-%m-%d" # rotated files: flux-service.log.2026-03-08
root.addHandler(file_handler)
_setup_logging()
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
# ── Startup ──────────────────────────────────────────────────
logger.info("Starting Flux Image Service...")
# Storage directories
ensure_storage_dirs()
# Database
await init_db()
logger.info("Database initialized.")
# Flux engine (mock or real)
if settings.MOCK_MODE:
from app.engine.mock_engine import MockFluxEngine
engine = MockFluxEngine()
else:
from app.engine.flux_engine import FluxEngine
engine = FluxEngine()
engine.load_model()
engine.warmup()
app.state.engine = engine
# Content moderation engine (skip model loading in mock mode)
moderation_engine = get_moderation_engine()
if not settings.MOCK_MODE:
moderation_engine.load(settings.MODERATION_MODEL_ID)
else:
logger.info("Mock mode: skipping moderation model load.")
app.state.moderation_engine = moderation_engine
# GPU worker
worker = GPUWorker(engine=engine)
worker.start()
resumed = await worker.resume_pending_jobs()
if resumed:
logger.info("Resumed %d pending jobs from previous session.", resumed)
app.state.worker = worker
# Start time for health endpoint
health.set_start_time()
logger.info("Flux Image Service ready. Queue depth: %d", worker.queue_depth)
yield
# ── Shutdown ─────────────────────────────────────────────────
logger.info("Shutting down Flux Image Service...")
await worker.shutdown()
engine.unload()
logger.info("Shutdown complete.")
app = FastAPI(
title="Flux Image Service",
description="Local image generation API powered by Flux.1 Schnell 12B",
version="0.1.0",
lifespan=lifespan,
)
# CORS — allow all for local development
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Mount routers
app.include_router(generate.router)
app.include_router(batch.router)
app.include_router(jobs.router)
app.include_router(health.router)
================================================
FILE: app/models/__init__.py
================================================
================================================
FILE: app/models/enums.py
================================================
import enum
class JobStatus(str, enum.Enum):
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class JobPriority(int, enum.Enum):
REALTIME = 1
BATCH = 5
class ImageFormat(str, enum.Enum):
PNG = "png"
JPEG = "jpeg"
WEBP = "webp"
================================================
FILE: app/models/job.py
================================================
import uuid
from datetime import datetime, timezone
from sqlalchemy import (
Column,
DateTime,
Enum,
Float,
ForeignKey,
Integer,
String,
Text,
)
from sqlalchemy.orm import DeclarativeBase, relationship
class Base(DeclarativeBase):
pass
def _utcnow() -> datetime:
return datetime.now(timezone.utc)
class Job(Base):
__tablename__ = "jobs"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
prompt = Column(Text, nullable=False)
negative_prompt = Column(Text, nullable=True)
width = Column(Integer, nullable=False, default=1024)
height = Column(Integer, nullable=False, default=1024)
num_steps = Column(Integer, nullable=False, default=4)
guidance_scale = Column(Float, nullable=False, default=0.0)
seed = Column(Integer, nullable=True)
status = Column(
Enum("pending", "processing", "completed", "failed", "cancelled", name="job_status"),
nullable=False,
default="pending",
)
priority = Column(Integer, nullable=False, default=1)
format = Column(
Enum("png", "jpeg", "webp", name="image_format"),
nullable=False,
default="png",
)
file_path = Column(String(512), nullable=True)
error_message = Column(Text, nullable=True)
batch_id = Column(String(36), ForeignKey("batch_jobs.id"), nullable=True)
created_at = Column(DateTime(timezone=True), nullable=False, default=_utcnow)
started_at = Column(DateTime(timezone=True), nullable=True)
completed_at = Column(DateTime(timezone=True), nullable=True)
batch = relationship("BatchJob", back_populates="jobs")
class BatchJob(Base):
__tablename__ = "batch_jobs"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
name = Column(String(255), nullable=True)
total_count = Column(Integer, nullable=False, default=0)
completed_count = Column(Integer, nullable=False, default=0)
failed_count = Column(Integer, nullable=False, default=0)
status = Column(
Enum("pending", "processing", "completed", "failed", "cancelled", name="batch_status"),
nullable=False,
default="pending",
)
created_at = Column(DateTime(timezone=True), nullable=False, default=_utcnow)
completed_at = Column(DateTime(timezone=True), nullable=True)
jobs = relationship("Job", back_populates="batch", lazy="selectin")
================================================
FILE: app/schemas/__init__.py
================================================
================================================
FILE: app/schemas/batch.py
================================================
from pydantic import BaseModel, Field
from app.schemas.generate import ImageRequest
class BatchRequest(BaseModel):
name: str | None = None
prompts: list[ImageRequest] = Field(..., min_length=1, max_length=50000)
class BatchProgress(BaseModel):
batch_id: str
name: str | None
total: int
completed: int
failed: int
cancelled: int
pending: int
status: str
estimated_remaining_seconds: float | None = None
created_at: str
completed_at: str | None = None
================================================
FILE: app/schemas/generate.py
================================================
from pydantic import BaseModel, Field, field_validator, model_validator
from app.config import settings
from app.models.enums import ImageFormat
class ImageRequest(BaseModel):
prompt: str = Field(..., min_length=1, max_length=2000)
negative_prompt: str | None = None
width: int = Field(default=1024, ge=settings.MIN_RESOLUTION, le=settings.MAX_RESOLUTION)
height: int = Field(default=1024, ge=settings.MIN_RESOLUTION, le=settings.MAX_RESOLUTION)
num_steps: int = Field(default=settings.DEFAULT_NUM_STEPS, ge=1, le=50)
guidance_scale: float = Field(default=settings.DEFAULT_GUIDANCE_SCALE, ge=0.0, le=20.0)
seed: int | None = Field(default=None, ge=0, le=2**32 - 1)
num_images: int = Field(default=1, ge=1, le=settings.MAX_IMMEDIATE_IMAGES)
format: ImageFormat = ImageFormat.PNG
@field_validator("prompt")
@classmethod
def check_content_safety(cls, v: str) -> str:
from app.utils.moderation import check_prompt
check_prompt(v)
return v
@model_validator(mode="after")
def validate_resolution(self) -> "ImageRequest":
if self.width % 8 != 0:
raise ValueError("width must be a multiple of 8")
if self.height % 8 != 0:
raise ValueError("height must be a multiple of 8")
return self
class ImageResponse(BaseModel):
job_ids: list[str]
status: str
image_urls: list[str | None] = []
PRESET_RESOLUTIONS: dict[str, tuple[int, int]] = {
"square_sm": (512, 512),
"square": (1024, 1024),
"landscape": (1344, 768),
"portrait": (768, 1344),
"wide": (1536, 640),
}
================================================
FILE: app/schemas/job.py
================================================
from pydantic import BaseModel
class JobDetail(BaseModel):
id: str
prompt: str
negative_prompt: str | None
width: int
height: int
num_steps: int
guidance_scale: float
seed: int | None
status: str
priority: int
format: str
file_path: str | None
image_url: str | None
error_message: str | None
batch_id: str | None
created_at: str
started_at: str | None
completed_at: str | None
class JobList(BaseModel):
jobs: list[JobDetail]
total: int
page: int
page_size: int
================================================
FILE: app/services/__init__.py
================================================
================================================
FILE: app/services/generation_service.py
================================================
"""Orchestrates image generation: validate → persist → enqueue."""
from __future__ import annotations
import uuid
from typing import TYPE_CHECKING
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.enums import JobPriority
from app.models.job import BatchJob, Job
from app.schemas.generate import ImageRequest
if TYPE_CHECKING:
from app.worker.gpu_worker import GPUWorker
async def create_single_job(
session: AsyncSession,
request: ImageRequest,
worker: "GPUWorker",
priority: int = JobPriority.REALTIME,
) -> list[str]:
"""Create generation job(s) for the request. Returns list of job_ids.
Creates one job per requested image (num_images). Each job gets a unique
seed derived from the base seed (if provided) to ensure varied outputs.
"""
job_ids: list[str] = []
queue_priority = int(priority)
for i in range(request.num_images):
# Derive per-image seed: if user provided a seed, offset it per image
seed = None
if request.seed is not None:
seed = (request.seed + i) % (2**32)
job = Job(
id=str(uuid.uuid4()),
prompt=request.prompt,
negative_prompt=request.negative_prompt,
width=request.width,
height=request.height,
num_steps=request.num_steps,
guidance_scale=request.guidance_scale,
seed=seed,
format=request.format.value,
priority=queue_priority,
status="pending",
)
session.add(job)
job_ids.append(job.id)
await session.commit()
for jid in job_ids:
await worker.enqueue(jid, queue_priority)
return job_ids
async def create_batch(
session: AsyncSession,
name: str | None,
prompts: list[ImageRequest],
worker: "GPUWorker",
) -> str:
"""Create a batch of generation jobs. Returns batch_id."""
batch_id = str(uuid.uuid4())
batch = BatchJob(
id=batch_id,
name=name,
total_count=len(prompts),
status="pending",
)
session.add(batch)
jobs: list[Job] = []
for req in prompts:
job = Job(
id=str(uuid.uuid4()),
prompt=req.prompt,
negative_prompt=req.negative_prompt,
width=req.width,
height=req.height,
num_steps=req.num_steps,
guidance_scale=req.guidance_scale,
seed=req.seed,
format=req.format.value,
priority=int(JobPriority.BATCH),
status="pending",
batch_id=batch_id,
)
jobs.append(job)
session.add_all(jobs)
await session.commit()
# Enqueue all jobs
for job in jobs:
await worker.enqueue(job.id, int(JobPriority.BATCH))
return batch_id
================================================
FILE: app/services/job_service.py
================================================
"""Job and batch CRUD operations."""
from __future__ import annotations
from datetime import datetime, timezone
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.enums import JobStatus
from app.models.job import BatchJob, Job
from app.schemas.job import JobDetail
from app.utils.storage import image_url_path
def _job_to_detail(job: Job) -> JobDetail:
return JobDetail(
id=job.id,
prompt=job.prompt,
negative_prompt=job.negative_prompt,
width=job.width,
height=job.height,
num_steps=job.num_steps,
guidance_scale=job.guidance_scale,
seed=job.seed,
status=job.status,
priority=job.priority,
format=job.format,
file_path=job.file_path,
image_url=image_url_path(job.id) if job.status == "completed" else None,
error_message=job.error_message,
batch_id=job.batch_id,
created_at=job.created_at.isoformat() if job.created_at else None,
started_at=job.started_at.isoformat() if job.started_at else None,
completed_at=job.completed_at.isoformat() if job.completed_at else None,
)
async def get_job(session: AsyncSession, job_id: str) -> JobDetail | None:
result = await session.execute(select(Job).where(Job.id == job_id))
job = result.scalar_one_or_none()
if not job:
return None
return _job_to_detail(job)
async def list_jobs(
session: AsyncSession,
status: str | None = None,
batch_id: str | None = None,
page: int = 1,
page_size: int = 50,
) -> tuple[list[JobDetail], int]:
query = select(Job).order_by(Job.created_at.desc())
count_query = select(func.count()).select_from(Job)
if status:
query = query.where(Job.status == status)
count_query = count_query.where(Job.status == status)
if batch_id:
query = query.where(Job.batch_id == batch_id)
count_query = count_query.where(Job.batch_id == batch_id)
total = (await session.execute(count_query)).scalar() or 0
query = query.offset((page - 1) * page_size).limit(page_size)
result = await session.execute(query)
jobs = [_job_to_detail(j) for j in result.scalars().all()]
return jobs, total
async def cancel_job(session: AsyncSession, job_id: str) -> JobDetail | None:
result = await session.execute(select(Job).where(Job.id == job_id))
job = result.scalar_one_or_none()
if not job:
return None
if job.status in ("completed", "failed", "cancelled"):
return _job_to_detail(job)
job.status = "cancelled"
job.completed_at = datetime.now(timezone.utc)
await session.commit()
return _job_to_detail(job)
async def get_batch_progress(session: AsyncSession, batch_id: str) -> dict | None:
result = await session.execute(select(BatchJob).where(BatchJob.id == batch_id))
batch = result.scalar_one_or_none()
if not batch:
return None
cancelled = (
await session.execute(
select(func.count()).where(Job.batch_id == batch_id, Job.status == "cancelled")
)
).scalar() or 0
pending = batch.total_count - batch.completed_count - batch.failed_count - cancelled
return {
"batch_id": batch.id,
"name": batch.name,
"total": batch.total_count,
"completed": batch.completed_count,
"failed": batch.failed_count,
"cancelled": cancelled,
"pending": pending,
"status": batch.status,
"created_at": batch.created_at.isoformat() if batch.created_at else None,
"completed_at": batch.completed_at.isoformat() if batch.completed_at else None,
}
async def cancel_batch(session: AsyncSession, batch_id: str) -> dict | None:
result = await session.execute(select(BatchJob).where(BatchJob.id == batch_id))
batch = result.scalar_one_or_none()
if not batch:
return None
# Cancel all pending jobs in this batch
pending_result = await session.execute(
select(Job).where(Job.batch_id == batch_id, Job.status == "pending")
)
now = datetime.now(timezone.utc)
for job in pending_result.scalars().all():
job.status = "cancelled"
job.completed_at = now
batch.status = "cancelled"
batch.completed_at = now
await session.commit()
return await get_batch_progress(session, batch_id)
async def retry_failed_in_batch(session: AsyncSession, batch_id: str) -> list[str]:
"""Re-set failed jobs in a batch back to pending. Returns list of job IDs."""
result = await session.execute(
select(Job).where(Job.batch_id == batch_id, Job.status == "failed")
)
job_ids = []
for job in result.scalars().all():
job.status = "pending"
job.error_message = None
job.completed_at = None
job_ids.append(job.id)
if job_ids:
# Reset batch status
batch_result = await session.execute(select(BatchJob).where(BatchJob.id == batch_id))
batch = batch_result.scalar_one_or_none()
if batch:
batch.status = "processing"
batch.completed_at = None
batch.failed_count = 0
await session.commit()
return job_ids
================================================
FILE: app/services/moderation_service.py
================================================
"""Content moderation service – checks prompt safety via a third-party API."""
from __future__ import annotations
import logging
import httpx
from app.config import settings
logger = logging.getLogger(__name__)
class ModerationError(Exception):
"""Raised when the moderation API call fails unexpectedly."""
async def check_prompt_safety(prompt: str) -> tuple[bool, str | None]:
"""Check whether a prompt is safe using the configured moderation API.
Calls the API specified by ``MODERATION_API_URL`` (default: OpenAI
moderation endpoint) and returns ``(is_safe, reason)``. When moderation
is disabled or no API key is configured the function returns ``(True, None)``
without making any network request.
Args:
prompt: The user-supplied generation prompt to evaluate.
Returns:
A ``(is_safe, reason)`` tuple. ``is_safe`` is ``True`` when the prompt
is considered safe. When flagged, ``reason`` contains the triggered
content categories (e.g. ``"sexual, violence"``).
Raises:
ModerationError: If the API request fails and the failure cannot be
handled gracefully.
"""
if not settings.MODERATION_ENABLED:
return True, None
if not settings.MODERATION_API_KEY:
logger.warning(
"Content moderation is enabled but MODERATION_API_KEY is not set; "
"skipping moderation check."
)
return True, None
try:
async with httpx.AsyncClient(timeout=settings.MODERATION_TIMEOUT) as client:
response = await client.post(
settings.MODERATION_API_URL,
headers={
"Authorization": f"Bearer {settings.MODERATION_API_KEY}",
"Content-Type": "application/json",
},
json={"input": prompt},
)
response.raise_for_status()
data = response.json()
except httpx.HTTPStatusError as exc:
logger.error(
"Moderation API returned an error status %s: %s",
exc.response.status_code,
exc.response.text,
)
raise ModerationError(
f"Moderation API returned HTTP {exc.response.status_code}"
) from exc
except httpx.HTTPError as exc:
logger.error("Moderation API request failed: %s", exc)
raise ModerationError(f"Moderation API request failed: {exc}") from exc
# Parse OpenAI-compatible moderation response:
# {"results": [{"flagged": bool, "categories": {category: bool, ...}}]}
results = data.get("results", [])
if not results:
return True, None
result = results[0]
if result.get("flagged", False):
categories: dict[str, bool] = result.get("categories", {})
flagged_categories = [cat for cat, triggered in categories.items() if triggered]
reason = ", ".join(flagged_categories) if flagged_categories else "unsafe content"
return False, reason
return True, None
================================================
FILE: app/utils/__init__.py
================================================
================================================
FILE: app/utils/image_utils.py
================================================
"""Image saving, format conversion, and metadata embedding."""
from __future__ import annotations
from pathlib import Path
from PIL import Image, PngImagePlugin
from app.models.enums import ImageFormat
def save_image(
image: Image.Image,
path: Path,
fmt: ImageFormat,
quality: int = 95,
prompt: str | None = None,
seed: int | None = None,
) -> None:
"""Save a PIL Image in the requested format with embedded metadata."""
if fmt == ImageFormat.PNG:
meta = PngImagePlugin.PngInfo()
if prompt:
meta.add_text("prompt", prompt)
if seed is not None:
meta.add_text("seed", str(seed))
image.save(path, format="PNG", pnginfo=meta)
elif fmt == ImageFormat.JPEG:
# JPEG doesn't support alpha
if image.mode in ("RGBA", "LA"):
image = image.convert("RGB")
image.save(path, format="JPEG", quality=quality)
elif fmt == ImageFormat.WEBP:
image.save(path, format="WEBP", quality=quality)
else:
image.save(path, format="PNG")
================================================
FILE: app/utils/moderation.py
================================================
"""Content moderation using Gemma-3-1b-it (local) with HF Inference API fallback."""
from __future__ import annotations
import logging
from dataclasses import dataclass
logger = logging.getLogger(__name__)
_SAFETY_SYSTEM_PROMPT = (
"You are a content safety classifier. "
"Determine whether the following text prompt for image generation contains "
"harmful, illegal, sexually explicit, or violent content. "
"Respond with exactly one word: 'SAFE' if the prompt is acceptable, "
"or 'UNSAFE' if it should be rejected. Do not include any other text."
)
@dataclass
class ModerationResult:
is_safe: bool
reason: str = ""
source: str = ""
class ModerationEngine:
"""Content moderation engine.
Primary check: locally-loaded ``google/gemma-3-1b-it`` text-generation pipeline.
Fallback check: HuggingFace Inference API (agent API) using the same model.
If both are unavailable the engine fails open and logs a warning.
"""
def __init__(self) -> None:
self._pipeline = None
self._model_id: str | None = None
# ── Loading ──────────────────────────────────────────────────────────────
def load(self, model_id: str = "google/gemma-3-1b-it") -> None:
"""Load *model_id* locally for synchronous inference."""
try:
import torch
from transformers import pipeline
logger.info("Loading moderation model %s …", model_id)
self._pipeline = pipeline(
"text-generation",
model=model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
max_new_tokens=10,
)
self._model_id = model_id
logger.info("Moderation model %s loaded successfully.", model_id)
except Exception as exc: # noqa: BLE001
logger.warning("Failed to load moderation model %s: %s", model_id, exc)
self._pipeline = None
@property
def is_loaded(self) -> bool:
"""Return True when the local pipeline is ready."""
return self._pipeline is not None
# ── Local-model check ────────────────────────────────────────────────────
def _parse_pipeline_output(self, raw: object) -> str:
"""Extract the assistant reply text from a text-generation pipeline result."""
# The pipeline returns a list of dicts like:
# [{"generated_text": [{"role": "system", ...}, {"role": "assistant", "content": "SAFE"}]}]
if isinstance(raw, list) and raw:
inner = raw[0].get("generated_text", raw[0])
if isinstance(inner, list) and inner:
last = inner[-1]
if isinstance(last, dict):
return last.get("content", "").strip().upper()
return str(inner).strip().upper()
return str(raw).strip().upper()
def check_with_local_model(self, prompt: str) -> ModerationResult:
"""Run the locally-loaded Gemma pipeline and return a :class:`ModerationResult`."""
if not self.is_loaded:
raise RuntimeError("Local moderation model is not loaded")
try:
messages = [
{"role": "system", "content": _SAFETY_SYSTEM_PROMPT},
{"role": "user", "content": f"Prompt: {prompt}"},
]
raw = self._pipeline(messages)
response = self._parse_pipeline_output(raw)
is_safe = "UNSAFE" not in response
return ModerationResult(
is_safe=is_safe,
reason="" if is_safe else f"Flagged by local model: {response}",
source="local",
)
except Exception as exc:
raise RuntimeError(f"Local model inference failed: {exc}") from exc
# ── Agent-API fallback ───────────────────────────────────────────────────
def check_with_agent_api(
self,
prompt: str,
*,
model_id: str | None = None,
token: str | None = None,
) -> ModerationResult:
"""Use the HuggingFace Inference API (agent API) as a fallback safety check."""
from huggingface_hub import InferenceClient
_model = model_id or self._model_id or "google/gemma-3-1b-it"
client = InferenceClient(model=_model, token=token)
messages = [
{"role": "system", "content": _SAFETY_SYSTEM_PROMPT},
{"role": "user", "content": f"Prompt: {prompt}"},
]
result = client.chat_completion(messages=messages, max_tokens=10)
response = result.choices[0].message.content.strip().upper()
is_safe = "UNSAFE" not in response
return ModerationResult(
is_safe=is_safe,
reason="" if is_safe else f"Flagged by agent API: {response}",
source="agent_api",
)
# ── Unified entry-point ──────────────────────────────────────────────────
def check(self, prompt: str, *, hf_token: str | None = None) -> ModerationResult:
"""Check *prompt* safety.
Tries the local Gemma model first; falls back to the HF Inference API;
if both are unavailable the call succeeds with a warning (fail-open).
"""
# 1. Local model
if self.is_loaded:
try:
return self.check_with_local_model(prompt)
except Exception as exc: # noqa: BLE001
logger.warning(
"Local moderation check failed — falling back to agent API: %s", exc
)
# 2. Agent API fallback
if hf_token or self._model_id:
try:
return self.check_with_agent_api(
prompt,
model_id=self._model_id,
token=hf_token,
)
except Exception as exc: # noqa: BLE001
logger.warning(
"Agent API moderation check failed — allowing prompt through: %s", exc
)
# 3. Both unavailable — fail open
logger.warning(
"No moderation backend available; prompt allowed through without safety check."
)
return ModerationResult(is_safe=True, reason="Moderation unavailable", source="none")
# ── Module-level singleton ────────────────────────────────────────────────────
_moderation_engine: ModerationEngine = ModerationEngine()
def get_moderation_engine() -> ModerationEngine:
"""Return the global :class:`ModerationEngine` singleton."""
return _moderation_engine
def check_prompt(prompt: str) -> None:
"""Validate *prompt* against the moderation engine.
Raises :class:`ValueError` if the prompt is rejected.
Does nothing when ``MODERATION_ENABLED`` is ``False`` or when no moderation
backend is configured (fail-open to avoid breaking existing tests).
"""
from app.config import settings
if not settings.MODERATION_ENABLED:
return
result = _moderation_engine.check(prompt, hf_token=settings.HF_API_TOKEN)
if not result.is_safe:
raise ValueError(result.reason or "Prompt rejected by content moderation")
================================================
FILE: app/utils/storage.py
================================================
"""Local filesystem storage management for generated images."""
from __future__ import annotations
from datetime import datetime, timezone
from pathlib import Path
from app.config import settings
def ensure_storage_dirs() -> None:
"""Create the output directory tree if it doesn't exist."""
(settings.STORAGE_DIR / "images").mkdir(parents=True, exist_ok=True)
(settings.STORAGE_DIR / "batches").mkdir(parents=True, exist_ok=True)
def get_image_path(job_id: str, fmt: str, batch_id: str | None = None) -> Path:
"""Return the storage path for a generated image.
Layout:
output/batches/{batch_id}/{job_id}.{ext} (if part of a batch)
output/images/{YYYY-MM-DD}/{job_id}.{ext} (standalone)
"""
ext = fmt.lower()
if ext == "jpeg":
ext = "jpg"
if batch_id:
directory = settings.STORAGE_DIR / "batches" / batch_id
else:
date_str = datetime.now(timezone.utc).strftime("%Y-%m-%d")
directory = settings.STORAGE_DIR / "images" / date_str
directory.mkdir(parents=True, exist_ok=True)
return directory / f"{job_id}.{ext}"
def image_url_path(job_id: str) -> str:
"""Return the API URL path to serve this image."""
return f"/jobs/{job_id}/image"
================================================
FILE: app/worker/__init__.py
================================================
================================================
FILE: app/worker/gpu_worker.py
================================================
"""GPU worker — async priority queue consumer driving the Flux engine."""
from __future__ import annotations
import asyncio
import logging
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import TYPE_CHECKING
from sqlalchemy import select, func
from app.database import async_session_factory
from app.models.enums import ImageFormat, JobStatus
from app.models.job import BatchJob, Job
from app.utils.image_utils import save_image
from app.utils.storage import get_image_path, image_url_path
if TYPE_CHECKING:
from app.engine.flux_engine import FluxEngine
logger = logging.getLogger(__name__)
@dataclass(order=True)
class _QueueItem:
"""Priority queue item. Lower priority value = higher urgency."""
priority: int
timestamp: float = field(compare=True)
job_id: str = field(compare=False)
@dataclass
class GPUWorker:
"""Consumes jobs from an async priority queue and runs inference."""
engine: "FluxEngine"
_queue: asyncio.PriorityQueue[_QueueItem] = field(default_factory=asyncio.PriorityQueue)
_stop_event: asyncio.Event = field(default_factory=asyncio.Event)
_task: asyncio.Task | None = field(default=None, repr=False)
_jobs_processed: int = 0
_total_inference_time: float = 0.0
# ── Queue interface ──────────────────────────────────────────────
@staticmethod
def _normalize_priority(priority: int) -> int:
"""Convert enum/int-like priorities to a stable numeric queue value."""
try:
value = int(priority)
except (TypeError, ValueError):
value = 1
return max(1, value)
async def enqueue(self, job_id: str, priority: int = 1) -> None:
normalized = self._normalize_priority(priority)
item = _QueueItem(priority=normalized, timestamp=time.time(), job_id=job_id)
await self._queue.put(item)
logger.debug("Enqueued job=%s priority=%d queue_depth=%d", job_id, normalized, self.queue_depth)
@property
def queue_depth(self) -> int:
return self._queue.qsize()
# ── Lifecycle ────────────────────────────────────────────────────
def start(self) -> asyncio.Task:
"""Start the worker loop as a background task."""
self._stop_event.clear()
self._task = asyncio.create_task(self._run(), name="gpu-worker")
logger.info("GPU worker started.")
return self._task
async def shutdown(self) -> None:
"""Signal the worker to stop and wait for it to finish the current job."""
logger.info("Shutting down GPU worker (queue_depth=%d)...", self.queue_depth)
self._stop_event.set()
if self._task and not self._task.done():
# Unblock a pending queue.get() by pushing a sentinel
await self._queue.put(_QueueItem(priority=999, timestamp=time.time(), job_id="__stop__"))
await self._task
logger.info("GPU worker stopped. Processed %d jobs total.", self._jobs_processed)
# ── Resume support ───────────────────────────────────────────────
async def resume_pending_jobs(self) -> int:
"""Re-enqueue jobs left in PENDING or PROCESSING state (crash recovery)."""
async with async_session_factory() as session:
result = await session.execute(
select(Job).where(Job.status.in_(["pending", "processing"]))
)
stale_jobs = result.scalars().all()
for job in stale_jobs:
job.status = "pending"
await self.enqueue(job.id, self._normalize_priority(job.priority))
await session.commit()
if stale_jobs:
logger.info("Resumed %d pending/stale jobs.", len(stale_jobs))
return len(stale_jobs)
# ── Main loop ────────────────────────────────────────────────────
async def _run(self) -> None:
logger.info("GPU worker loop running.")
while not self._stop_event.is_set():
try:
item = await asyncio.wait_for(self._queue.get(), timeout=1.0)
except asyncio.TimeoutError:
continue
if item.job_id == "__stop__":
break
await self._process_job(item.job_id)
self._queue.task_done()
logger.info("GPU worker loop exited.")
async def _process_job(self, job_id: str) -> None:
async with async_session_factory() as session:
result = await session.execute(select(Job).where(Job.id == job_id))
job = result.scalar_one_or_none()
if job is None:
logger.warning("Job %s not found in DB, skipping.", job_id)
return
# Skip cancelled jobs
if job.status == "cancelled":
logger.info("Job %s is cancelled, skipping.", job_id)
return
# Mark as processing
job.status = "processing"
job.started_at = datetime.now(timezone.utc)
await session.commit()
try:
# Run inference in a thread to avoid blocking asyncio
gen_result = await asyncio.to_thread(
self.engine.generate,
prompt=job.prompt,
width=job.width,
height=job.height,
num_steps=job.num_steps,
guidance_scale=job.guidance_scale,
seed=job.seed,
)
# Save image to disk
fmt = ImageFormat(job.format)
file_path = get_image_path(job.id, job.format, batch_id=job.batch_id)
save_image(
gen_result.image,
file_path,
fmt,
prompt=job.prompt,
seed=gen_result.seed,
)
# Update job as completed
async with async_session_factory() as session:
result = await session.execute(select(Job).where(Job.id == job_id))
job = result.scalar_one()
job.status = "completed"
job.file_path = str(file_path)
job.seed = gen_result.seed
job.completed_at = datetime.now(timezone.utc)
await session.commit()
# Update batch counters if applicable
if job.batch_id:
await self._update_batch_count(session, job.batch_id)
self._jobs_processed += 1
self._total_inference_time += gen_result.inference_time
except Exception as exc:
# Detect CUDA OOM without importing torch in mock/non-GPU environments.
if exc.__class__.__name__ == "OutOfMemoryError":
logger.error("OOM on job %s (%dx%d).", job_id, job.width, job.height)
try:
import torch
torch.cuda.empty_cache()
except Exception:
pass
await self._fail_job(job_id, "GPU out of memory for this resolution")
return
logger.exception("Error processing job %s", job_id)
await self._fail_job(job_id, "Internal generation error")
async def _fail_job(self, job_id: str, error: str) -> None:
async with async_session_factory() as session:
result = await session.execute(select(Job).where(Job.id == job_id))
job = result.scalar_one_or_none()
if job:
job.status = "failed"
job.error_message = error
job.completed_at = datetime.now(timezone.utc)
await session.commit()
if job.batch_id:
await self._update_batch_count(session, job.batch_id)
async def _update_batch_count(self, session, batch_id: str) -> None:
"""Recompute batch counters from individual jobs."""
result = await session.execute(select(BatchJob).where(BatchJob.id == batch_id))
batch = result.scalar_one_or_none()
if not batch:
return
# Count completed and failed
completed = (
await session.execute(
select(func.count()).where(Job.batch_id == batch_id, Job.status == "completed")
)
).scalar()
failed = (
await session.execute(
select(func.count()).where(Job.batch_id == batch_id, Job.status == "failed")
)
).scalar()
batch.completed_count = completed or 0
batch.failed_count = failed or 0
done = (completed or 0) + (failed or 0)
if done >= batch.total_count:
batch.status = "completed" if (failed or 0) == 0 else "failed"
batch.completed_at = datetime.now(timezone.utc)
elif done > 0:
batch.status = "processing"
await session.commit()
# ── Stats ────────────────────────────────────────────────────────
def get_stats(self) -> dict:
avg = (
self._total_inference_time / self._jobs_processed
if self._jobs_processed > 0
else 0.0
)
return {
"queue_depth": self.queue_depth,
"jobs_processed": self._jobs_processed,
"total_inference_time_s": round(self._total_inference_time, 2),
"avg_inference_time_s": round(avg, 2),
}
================================================
FILE: flux_image_service.egg-info/PKG-INFO
================================================
Metadata-Version: 2.4
Name: flux-image-service
Version: 0.1.0
Summary: Local image generation service using Flux.1 Schnell 12B
Requires-Python: >=3.10
Requires-Dist: fastapi>=0.115.0
Requires-Dist: uvicorn[standard]>=0.30.0
Requires-Dist: pydantic>=2.0
Requires-Dist: pydantic-settings>=2.0
Requires-Dist: sqlalchemy[asyncio]>=2.0
Requires-Dist: aiosqlite>=0.20.0
Requires-Dist: torch>=2.1
Requires-Dist: diffusers>=0.30.0
Requires-Dist: transformers>=4.40.0
Requires-Dist: accelerate>=0.30.0
Requires-Dist: sentencepiece>=0.2.0
Requires-Dist: protobuf>=4.25.0
Requires-Dist: Pillow>=10.0.0
Requires-Dist: python-multipart>=0.0.9
Requires-Dist: sse-starlette>=2.0.0
Requires-Dist: httpx>=0.27.0
Provides-Extra: dev
Requires-Dist: pytest>=8.0; extra == "dev"
Requires-Dist: pytest-asyncio>=0.23.0; extra == "dev"
Requires-Dist: httpx>=0.27.0; extra == "dev"
Requires-Dist: ruff>=0.4.0; extra == "dev"
================================================
FILE: flux_image_service.egg-info/SOURCES.txt
================================================
README.md
pyproject.toml
app/__init__.py
app/config.py
app/database.py
app/main.py
app/api/__init__.py
app/api/batch.py
app/api/generate.py
app/api/health.py
app/api/jobs.py
app/engine/__init__.py
app/engine/engine_config.py
app/engine/flux_engine.py
app/models/__init__.py
app/models/enums.py
app/models/job.py
app/schemas/__init__.py
app/schemas/batch.py
app/schemas/generate.py
app/schemas/job.py
app/services/__init__.py
app/services/generation_service.py
app/services/job_service.py
app/services/moderation_service.py
app/utils/__init__.py
app/utils/image_utils.py
app/utils/moderation.py
app/utils/storage.py
app/worker/__init__.py
app/worker/gpu_worker.py
flux_image_service.egg-info/PKG-INFO
flux_image_service.egg-info/SOURCES.txt
flux_image_service.egg-info/dependency_links.txt
flux_image_service.egg-info/requires.txt
flux_image_service.egg-info/top_level.txt
tests/test_api_batch.py
tests/test_api_generate.py
tests/test_api_jobs.py
tests/test_moderation.py
tests/test_moderation_service.py
tests/test_schemas.py
tests/test_worker.py
================================================
FILE: flux_image_service.egg-info/dependency_links.txt
================================================
================================================
FILE: flux_image_service.egg-info/requires.txt
================================================
fastapi>=0.115.0
uvicorn[standard]>=0.30.0
pydantic>=2.0
pydantic-settings>=2.0
sqlalchemy[asyncio]>=2.0
aiosqlite>=0.20.0
torch>=2.1
diffusers>=0.30.0
transformers>=4.40.0
accelerate>=0.30.0
sentencepiece>=0.2.0
protobuf>=4.25.0
Pillow>=10.0.0
python-multipart>=0.0.9
sse-starlette>=2.0.0
httpx>=0.27.0
[dev]
pytest>=8.0
pytest-asyncio>=0.23.0
httpx>=0.27.0
ruff>=0.4.0
================================================
FILE: flux_image_service.egg-info/top_level.txt
================================================
app
================================================
FILE: pyproject.toml
================================================
[project]
name = "flux-image-service"
version = "0.1.0"
description = "Local image generation service using Flux.1 Schnell 12B"
requires-python = ">=3.10"
dependencies = [
"fastapi>=0.115.0",
"uvicorn[standard]>=0.30.0",
"pydantic>=2.0",
"pydantic-settings>=2.0",
"sqlalchemy[asyncio]>=2.0",
"aiosqlite>=0.20.0",
"torch>=2.1",
"diffusers>=0.30.0",
"transformers>=4.40.0",
"accelerate>=0.30.0",
"sentencepiece>=0.2.0",
"protobuf>=4.25.0",
"Pillow>=10.0.0",
"python-multipart>=0.0.9",
"sse-starlette>=2.0.0",
"httpx>=0.27.0",
]
[project.optional-dependencies]
dev = [
"pytest>=8.0",
"pytest-asyncio>=0.23.0",
"httpx>=0.27.0",
"ruff>=0.4.0",
]
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]
[tool.ruff]
target-version = "py310"
line-length = 100
================================================
FILE: tests/__init__.py
================================================
================================================
FILE: tests/conftest.py
================================================
"""Shared test fixtures."""
from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from pathlib import Path
import pytest
from httpx import ASGITransport, AsyncClient
from PIL import Image
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.engine.flux_engine import GenerationResult
from app.models.job import Base
@pytest.fixture(scope="session")
def event_loop():
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest.fixture
async def db_engine(tmp_path):
"""In-memory SQLite for tests."""
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def db_session(db_engine):
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
async with factory() as session:
yield session
@pytest.fixture
def mock_engine():
"""A mock FluxEngine that returns a tiny test image."""
engine = MagicMock()
engine.is_loaded = True
engine.get_vram_stats.return_value = {"available": True, "allocated_gb": 12.0}
engine.get_vram_usage_gb.return_value = 12.0
dummy_image = Image.new("RGB", (64, 64), color="red")
engine.generate.return_value = GenerationResult(
image=dummy_image, seed=42, inference_time=0.5
)
return engine
@pytest.fixture
def mock_worker(mock_engine):
"""A mock GPUWorker that tracks enqueued job IDs."""
worker = MagicMock()
worker.queue_depth = 0
worker.enqueue = AsyncMock()
worker.get_stats.return_value = {
"queue_depth": 0,
"jobs_processed": 0,
"total_inference_time_s": 0.0,
"avg_inference_time_s": 0.0,
}
return worker
@pytest.fixture
async def client(mock_engine, mock_worker, db_engine, tmp_path):
"""Test client with mocked engine/worker and in-memory DB."""
from app.database import get_session
from app.models.job import Base
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
async def override_get_session():
async with factory() as session:
yield session
# Patch the database session factory used by the worker
with patch("app.worker.gpu_worker.async_session_factory", factory), \
patch("app.database.async_session_factory", factory):
from app.main import app
app.dependency_overrides[get_session] = override_get_session
app.state.engine = mock_engine
app.state.worker = mock_worker
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
app.dependency_overrides.clear()
================================================
FILE: tests/test_api_batch.py
================================================
"""Tests for batch API endpoints."""
import pytest
@pytest.mark.asyncio
async def test_create_batch(client, mock_worker):
resp = await client.post(
"/batch",
json={
"name": "test-batch",
"prompts": [
{"prompt": "a red car"},
{"prompt": "a blue house"},
{"prompt": "a green tree"},
],
},
)
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 3
assert data["status"] == "pending"
assert data["batch_id"]
# Worker should have been called 3 times
assert mock_worker.enqueue.call_count == 3
@pytest.mark.asyncio
async def test_get_batch_404(client):
resp = await client.get("/batch/nonexistent-id")
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_cancel_batch(client, mock_worker):
# Create batch first
resp = await client.post(
"/batch",
json={
"name": "cancel-test",
"prompts": [{"prompt": "p1"}, {"prompt": "p2"}],
},
)
batch_id = resp.json()["batch_id"]
# Cancel it
resp = await client.delete(f"/batch/{batch_id}")
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "cancelled"
@pytest.mark.asyncio
async def test_batch_requires_prompts(client):
resp = await client.post("/batch", json={"name": "empty", "prompts": []})
assert resp.status_code == 422
================================================
FILE: tests/test_api_generate.py
================================================
"""Tests for generation API endpoints."""
import pytest
from unittest.mock import AsyncMock, patch
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.models.job import Job
from app.services.moderation_service import ModerationError
@pytest.mark.asyncio
async def test_generate_async_returns_job_ids(client):
resp = await client.post("/generate", json={"prompt": "a beautiful sunset"})
assert resp.status_code == 200
data = resp.json()
assert "job_ids" in data
assert len(data["job_ids"]) == 1
assert data["status"] == "pending"
@pytest.mark.asyncio
async def test_generate_multiple_images(client, mock_worker):
resp = await client.post("/generate", json={"prompt": "test", "num_images": 3})
assert resp.status_code == 200
data = resp.json()
assert len(data["job_ids"]) == 3
assert mock_worker.enqueue.call_count == 3
@pytest.mark.asyncio
async def test_generate_max_4_images(client):
resp = await client.post("/generate", json={"prompt": "test", "num_images": 4})
assert resp.status_code == 200
assert len(resp.json()["job_ids"]) == 4
@pytest.mark.asyncio
async def test_generate_rejects_more_than_4_images(client):
resp = await client.post("/generate", json={"prompt": "test", "num_images": 5})
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_generate_async_enqueues_work(client, mock_worker):
resp = await client.post("/generate", json={"prompt": "test prompt"})
assert resp.status_code == 200
mock_worker.enqueue.assert_called_once()
@pytest.mark.asyncio
async def test_generate_validates_resolution_multiple_of_8(client):
resp = await client.post(
"/generate",
json={"prompt": "test", "width": 513, "height": 1024},
)
assert resp.status_code == 422 # validation error
@pytest.mark.asyncio
async def test_generate_rejects_empty_prompt(client):
resp = await client.post("/generate", json={"prompt": ""})
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_generate_rejects_oversized_resolution(client):
resp = await client.post(
"/generate",
json={"prompt": "test", "width": 4096, "height": 4096},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_generate_default_values(client, mock_worker):
resp = await client.post("/generate", json={"prompt": "a cat"})
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_serve_image_404_for_missing_job(client):
resp = await client.get("/jobs/nonexistent-id/image")
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_generate_stream_emits_done_for_completed_jobs(client, db_engine):
create = await client.post("/generate", json={"prompt": "test", "num_images": 2})
assert create.status_code == 200
job_ids = create.json()["job_ids"]
# Mark jobs as completed so the SSE stream terminates immediately.
factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False)
async with factory() as session:
result = await session.execute(select(Job).where(Job.id.in_(job_ids)))
jobs = result.scalars().all()
for job in jobs:
job.status = "completed"
job.file_path = f"output/images/{job.id}.png"
await session.commit()
resp = await client.get(
"/generate/stream",
params=[("job_id", job_ids[0]), ("job_id", job_ids[1])],
)
assert resp.status_code == 200
text = resp.text
assert "event: progress" in text
assert "event: done" in text
assert '"completed": 2' in text
@pytest.mark.asyncio
async def test_generate_stream_returns_404_for_unknown_job(client):
resp = await client.get("/generate/stream", params={"job_id": "unknown-job"})
assert resp.status_code == 404
# ---------------------------------------------------------------------------
# Content moderation integration tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_generate_async_blocked_when_prompt_is_unsafe(client):
"""POST /generate returns 400 when the moderation service flags the prompt."""
with patch(
"app.api.generate.check_prompt_safety",
new=AsyncMock(return_value=(False, "sexual")),
):
resp = await client.post("/generate", json={"prompt": "explicit content"})
assert resp.status_code == 400
assert "unsafe" in resp.json()["detail"].lower()
@pytest.mark.asyncio
async def test_generate_async_proceeds_when_prompt_is_safe(client, mock_worker):
"""POST /generate proceeds normally when the moderation service approves the prompt."""
with patch(
"app.api.generate.check_prompt_safety",
new=AsyncMock(return_value=(True, None)),
):
resp = await client.post("/generate", json={"prompt": "a beautiful landscape"})
assert resp.status_code == 200
assert "job_ids" in resp.json()
mock_worker.enqueue.assert_called()
@pytest.mark.asyncio
async def test_generate_async_returns_503_when_moderation_service_fails(client):
"""POST /generate returns 503 when the moderation API is unreachable."""
with patch(
"app.api.generate.check_prompt_safety",
new=AsyncMock(side_effect=ModerationError("API unreachable")),
):
resp = await client.post("/generate", json={"prompt": "a beautiful landscape"})
assert resp.status_code == 503
@pytest.mark.asyncio
async def test_generate_sync_blocked_when_prompt_is_unsafe(client):
"""POST /generate/sync returns 400 when the moderation service flags the prompt."""
with patch(
"app.api.generate.check_prompt_safety",
new=AsyncMock(return_value=(False, "violence")),
):
resp = await client.post("/generate/sync", json={"prompt": "violent content"})
assert resp.status_code == 400
assert "unsafe" in resp.json()["detail"].lower()
@pytest.mark.asyncio
async def test_generate_sync_returns_503_when_moderation_service_fails(client):
"""POST /generate/sync returns 503 when the moderation API is unreachable."""
with patch(
"app.api.generate.check_prompt_safety",
new=AsyncMock(side_effect=ModerationError("timeout")),
):
resp = await client.post("/generate/sync", json={"prompt": "a sunset"})
assert resp.status_code == 503
================================================
FILE: tests/test_api_jobs.py
================================================
"""Tests for job management endpoints."""
import pytest
@pytest.mark.asyncio
async def test_list_jobs_empty(client):
resp = await client.get("/jobs")
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 0
assert data["jobs"] == []
@pytest.mark.asyncio
async def test_get_job_after_create(client, mock_worker):
# Create a job via generate
gen_resp = await client.post("/generate", json={"prompt": "test job"})
job_id = gen_resp.json()["job_ids"][0]
# Fetch it
resp = await client.get(f"/jobs/{job_id}")
assert resp.status_code == 200
data = resp.json()
assert data["id"] == job_id
assert data["prompt"] == "test job"
assert data["status"] == "pending"
assert data["width"] == 1024
assert data["height"] == 1024
@pytest.mark.asyncio
async def test_cancel_job(client, mock_worker):
gen_resp = await client.post("/generate", json={"prompt": "to cancel"})
job_id = gen_resp.json()["job_ids"][0]
resp = await client.delete(f"/jobs/{job_id}")
assert resp.status_code == 200
assert resp.json()["status"] == "cancelled"
@pytest.mark.asyncio
async def test_list_jobs_with_status_filter(client, mock_worker):
await client.post("/generate", json={"prompt": "j1"})
await client.post("/generate", json={"prompt": "j2"})
resp = await client.get("/jobs?status=pending")
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 2
@pytest.mark.asyncio
async def test_get_nonexistent_job(client):
resp = await client.get("/jobs/does-not-exist")
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_health_endpoint(client):
resp = await client.get("/health")
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "ok"
assert data["model_loaded"] is True
assert "vram" in data
assert "worker" in data
================================================
FILE: tests/test_moderation.py
================================================
"""Tests for the content moderation module."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from app.utils.moderation import ModerationEngine, ModerationResult, check_prompt
# ── ModerationEngine unit tests ───────────────────────────────────────────────
class TestModerationEngine:
def test_is_loaded_false_before_load(self):
engine = ModerationEngine()
assert not engine.is_loaded
def test_load_failure_leaves_engine_unloaded(self):
engine = ModerationEngine()
with patch("app.utils.moderation.ModerationEngine.load") as mock_load:
mock_load.side_effect = RuntimeError("no GPU")
try:
engine.load("google/gemma-3-1b-it")
except RuntimeError:
pass
assert not engine.is_loaded
def test_parse_pipeline_output_safe(self):
engine = ModerationEngine()
raw = [{"generated_text": [{"role": "assistant", "content": "SAFE"}]}]
assert engine._parse_pipeline_output(raw) == "SAFE"
def test_parse_pipeline_output_unsafe(self):
engine = ModerationEngine()
raw = [{"generated_text": [{"role": "assistant", "content": "UNSAFE"}]}]
assert engine._parse_pipeline_output(raw) == "UNSAFE"
def test_check_with_local_model_raises_when_not_loaded(self):
engine = ModerationEngine()
with pytest.raises(RuntimeError, match="not loaded"):
engine.check_with_local_model("test prompt")
def test_check_with_local_model_safe(self):
engine = ModerationEngine()
engine._pipeline = MagicMock(
return_value=[{"generated_text": [{"role": "assistant", "content": "SAFE"}]}]
)
result = engine.check_with_local_model("a beautiful sunset")
assert result.is_safe is True
assert result.source == "local"
def test_check_with_local_model_unsafe(self):
engine = ModerationEngine()
engine._pipeline = MagicMock(
return_value=[{"generated_text": [{"role": "assistant", "content": "UNSAFE"}]}]
)
result = engine.check_with_local_model("explicit harmful content")
assert result.is_safe is False
assert result.source == "local"
assert "Flagged by local model" in result.reason
def test_check_with_agent_api_safe(self):
engine = ModerationEngine()
engine._model_id = "google/gemma-3-1b-it"
mock_choice = MagicMock()
mock_choice.message.content = "SAFE"
mock_response = MagicMock()
mock_response.choices = [mock_choice]
mock_client = MagicMock()
mock_client.chat_completion.return_value = mock_response
with patch("huggingface_hub.InferenceClient", return_value=mock_client):
result = engine.check_with_agent_api("a nice landscape", token="hf_test")
assert result.is_safe is True
assert result.source == "agent_api"
def test_check_with_agent_api_unsafe(self):
engine = ModerationEngine()
engine._model_id = "google/gemma-3-1b-it"
mock_choice = MagicMock()
mock_choice.message.content = "UNSAFE"
mock_response = MagicMock()
mock_response.choices = [mock_choice]
mock_client = MagicMock()
mock_client.chat_completion.return_value = mock_response
with patch("huggingface_hub.InferenceClient", return_value=mock_client):
result = engine.check_with_agent_api("explicit content", token="hf_test")
assert result.is_safe is False
assert result.source == "agent_api"
assert "Flagged by agent API" in result.reason
def test_check_uses_local_model_first(self):
engine = ModerationEngine()
engine._pipeline = MagicMock(
return_value=[{"generated_text": [{"role": "assistant", "content": "SAFE"}]}]
)
with patch.object(engine, "check_with_agent_api") as mock_api:
result = engine.check("safe prompt")
assert result.is_safe is True
assert result.source == "local"
mock_api.assert_not_called()
def test_check_falls_back_to_agent_api_when_local_fails(self):
engine = ModerationEngine()
engine._pipeline = MagicMock(side_effect=RuntimeError("inference error"))
engine._model_id = "google/gemma-3-1b-it"
fallback_result = ModerationResult(is_safe=True, source="agent_api")
with patch.object(engine, "check_with_agent_api", return_value=fallback_result) as mock_api:
result = engine.check("test prompt", hf_token="hf_test")
assert result.source == "agent_api"
mock_api.assert_called_once()
def test_check_fails_open_when_both_unavailable(self):
engine = ModerationEngine()
# Local not loaded, no model_id so agent API branch skipped
result = engine.check("test prompt")
assert result.is_safe is True
assert result.source == "none"
def test_check_falls_open_when_agent_api_also_fails(self):
engine = ModerationEngine()
engine._model_id = "google/gemma-3-1b-it"
with patch.object(engine, "check_with_agent_api", side_effect=Exception("API error")):
result = engine.check("test prompt", hf_token="hf_test")
assert result.is_safe is True
assert result.source == "none"
# ── check_prompt() integration tests ─────────────────────────────────────────
class TestCheckPrompt:
def test_allows_safe_prompt(self):
with patch("app.config.settings") as mock_settings, \
patch("app.utils.moderation._moderation_engine") as mock_engine:
mock_settings.MODERATION_ENABLED = True
mock_settings.HF_API_TOKEN = None
mock_engine.check.return_value = ModerationResult(is_safe=True, source="local")
check_prompt("a beautiful sunset") # should not raise
def test_rejects_unsafe_prompt(self):
with patch("app.config.settings") as mock_settings, \
patch("app.utils.moderation._moderation_engine") as mock_engine:
mock_settings.MODERATION_ENABLED = True
mock_settings.HF_API_TOKEN = None
mock_engine.check.return_value = ModerationResult(
is_safe=False, reason="Flagged by local model: UNSAFE", source="local"
)
with pytest.raises(ValueError, match="Flagged by local model"):
check_prompt("explicit content")
def test_skips_when_moderation_disabled(self):
with patch("app.config.settings") as mock_settings, patch(
"app.utils.moderation._moderation_engine"
) as mock_engine:
mock_settings.MODERATION_ENABLED = False
check_prompt("any content")
mock_engine.check.assert_not_called()
# ── ImageRequest schema integration ──────────────────────────────────────────
class TestImageRequestModeration:
def test_unsafe_prompt_raises_validation_error(self):
with patch("app.config.settings") as mock_settings, \
patch("app.utils.moderation._moderation_engine") as mock_engine:
mock_settings.MODERATION_ENABLED = True
mock_settings.HF_API_TOKEN = None
mock_engine.check.return_value = ModerationResult(
is_safe=False, reason="Flagged by local model: UNSAFE", source="local"
)
with pytest.raises(ValidationError, match="content moderation|Flagged"):
from app.schemas.generate import ImageRequest
ImageRequest(prompt="explicit content")
def test_safe_prompt_passes_validation(self):
with patch("app.config.settings") as mock_settings, \
patch("app.utils.moderation._moderation_engine") as mock_engine:
mock_settings.MODERATION_ENABLED = True
mock_settings.HF_API_TOKEN = None
mock_engine.check.return_value = ModerationResult(is_safe=True, source="local")
from app.schemas.generate import ImageRequest
req = ImageRequest(prompt="a mountain landscape")
assert req.prompt == "a mountain landscape"
================================================
FILE: tests/test_moderation_service.py
================================================
"""Tests for the content moderation service."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from app.services.moderation_service import ModerationError, check_prompt_safety
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_moderation_response(flagged: bool, categories: dict[str, bool] | None = None):
"""Build a mock httpx.Response that mimics the OpenAI moderation format."""
if categories is None:
categories = {}
payload = {
"results": [
{
"flagged": flagged,
"categories": categories,
"category_scores": {k: 0.9 if v else 0.01 for k, v in categories.items()},
}
]
}
response = MagicMock(spec=httpx.Response)
response.json.return_value = payload
response.raise_for_status = MagicMock()
return response
# ---------------------------------------------------------------------------
# Tests: moderation disabled
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_check_prompt_safety_disabled_skips_api():
"""When moderation is disabled the function returns safe without any API call."""
with patch("app.services.moderation_service.settings") as mock_settings:
mock_settings.MODERATION_ENABLED = False
is_safe, reason = await check_prompt_safety("a test prompt")
assert is_safe is True
assert reason is None
@pytest.mark.asyncio
async def test_check_prompt_safety_no_api_key_skips_api():
"""When enabled but no API key is set, the function returns safe with a warning."""
with patch("app.services.moderation_service.settings") as mock_settings:
mock_settings.MODERATION_ENABLED = True
mock_settings.MODERATION_API_KEY = ""
is_safe, reason = await check_prompt_safety("a test prompt")
assert is_safe is True
assert reason is None
# ---------------------------------------------------------------------------
# Tests: safe prompt
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_check_prompt_safety_safe_prompt():
"""A prompt that the API says is safe returns (True, None)."""
mock_response = _make_moderation_response(flagged=False)
with patch("app.services.moderation_service.settings") as mock_settings, patch(
"app.services.moderation_service.httpx.AsyncClient"
) as mock_client_cls:
mock_settings.MODERATION_ENABLED = True
mock_settings.MODERATION_API_KEY = "test-key"
mock_settings.MODERATION_API_URL = "https://api.openai.com/v1/moderations"
mock_settings.MODERATION_TIMEOUT = 10.0
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
is_safe, reason = await check_prompt_safety("a beautiful sunset over the mountains")
assert is_safe is True
assert reason is None
# ---------------------------------------------------------------------------
# Tests: unsafe / flagged prompt
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_check_prompt_safety_flagged_prompt():
"""A flagged prompt returns (False, reason) with the triggered categories."""
mock_response = _make_moderation_response(
flagged=True,
categories={"sexual": True, "violence": False, "hate": False},
)
with patch("app.services.moderation_service.settings") as mock_settings, patch(
"app.services.moderation_service.httpx.AsyncClient"
) as mock_client_cls:
mock_settings.MODERATION_ENABLED = True
mock_settings.MODERATION_API_KEY = "test-key"
mock_settings.MODERATION_API_URL = "https://api.openai.com/v1/moderations"
mock_settings.MODERATION_TIMEOUT = 10.0
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
is_safe, reason = await check_prompt_safety("explicit nsfw content here")
assert is_safe is False
assert reason is not None
assert "sexual" in reason
@pytest.mark.asyncio
async def test_check_prompt_safety_flagged_multiple_categories():
"""Multiple triggered categories are all included in the reason string."""
mock_response = _make_moderation_response(
flagged=True,
categories={"sexual": True, "violence": True, "hate": False},
)
with patch("app.services.moderation_service.settings") as mock_settings, patch(
"app.services.moderation_service.httpx.AsyncClient"
) as mock_client_cls:
mock_settings.MODERATION_ENABLED = True
mock_settings.MODERATION_API_KEY = "test-key"
mock_settings.MODERATION_API_URL = "https://api.openai.com/v1/moderations"
mock_settings.MODERATION_TIMEOUT = 10.0
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
is_safe, reason = await check_prompt_safety("violent explicit content")
assert is_safe is False
assert "sexual" in reason
assert "violence" in reason
@pytest.mark.asyncio
async def test_check_prompt_safety_flagged_no_categories():
"""If the API flags the prompt but provides no categories, reason is generic."""
mock_response = _make_moderation_response(flagged=True, categories={})
with patch("app.services.moderation_service.settings") as mock_settings, patch(
"app.services.moderation_service.httpx.AsyncClient"
) as mock_client_cls:
mock_settings.MODERATION_ENABLED = True
mock_settings.MODERATION_API_KEY = "test-key"
mock_settings.MODERATION_API_URL = "https://api.openai.com/v1/moderations"
mock_settings.MODERATION_TIMEOUT = 10.0
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
is_safe, reason = await check_prompt_safety("some flagged content")
assert is_safe is False
assert reason == "unsafe content"
# ---------------------------------------------------------------------------
# Tests: API failures
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_check_prompt_safety_http_error_raises_moderation_error():
"""A network-level HTTP error raises ModerationError."""
with patch("app.services.moderation_service.settings") as mock_settings, patch(
"app.services.moderation_service.httpx.AsyncClient"
) as mock_client_cls:
mock_settings.MODERATION_ENABLED = True
mock_settings.MODERATION_API_KEY = "test-key"
mock_settings.MODERATION_API_URL = "https://api.openai.com/v1/moderations"
mock_settings.MODERATION_TIMEOUT = 10.0
mock_client = AsyncMock()
mock_client.post = AsyncMock(
side_effect=httpx.ConnectError("Connection refused")
)
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
with pytest.raises(ModerationError):
await check_prompt_safety("test prompt")
@pytest.mark.asyncio
async def test_check_prompt_safety_http_status_error_raises_moderation_error():
"""A non-2xx API response raises ModerationError."""
mock_response = MagicMock(spec=httpx.Response)
mock_response.status_code = 401
mock_response.text = "Unauthorized"
http_error = httpx.HTTPStatusError(
"401 Unauthorized",
request=MagicMock(),
response=mock_response,
)
mock_response.raise_for_status = MagicMock(side_effect=http_error)
with patch("app.services.moderation_service.settings") as mock_settings, patch(
"app.services.moderation_service.httpx.AsyncClient"
) as mock_client_cls:
mock_settings.MODERATION_ENABLED = True
mock_settings.MODERATION_API_KEY = "bad-key"
mock_settings.MODERATION_API_URL = "https://api.openai.com/v1/moderations"
mock_settings.MODERATION_TIMEOUT = 10.0
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
with pytest.raises(ModerationError):
await check_prompt_safety("test prompt")
@pytest.mark.asyncio
async def test_check_prompt_safety_empty_results():
"""An API response with no results array is treated as safe."""
mock_response = MagicMock(spec=httpx.Response)
mock_response.json.return_value = {"results": []}
mock_response.raise_for_status = MagicMock()
with patch("app.services.moderation_service.settings") as mock_settings, patch(
"app.services.moderation_service.httpx.AsyncClient"
) as mock_client_cls:
mock_settings.MODERATION_ENABLED = True
mock_settings.MODERATION_API_KEY = "test-key"
mock_settings.MODERATION_API_URL = "https://api.openai.com/v1/moderations"
mock_settings.MODERATION_TIMEOUT = 10.0
mock_client = AsyncMock()
mock_client.post = AsyncMock(return_value=mock_response)
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
is_safe, reason = await check_prompt_safety("a normal prompt")
assert is_safe is True
assert reason is None
================================================
FILE: tests/test_schemas.py
================================================
"""Tests for schemas — resolution validation, format handling."""
import pytest
from pydantic import ValidationError
from app.schemas.generate import ImageRequest
from app.schemas.batch import BatchRequest
class TestImageRequest:
def test_defaults(self):
req = ImageRequest(prompt="test")
assert req.width == 1024
assert req.height == 1024
assert req.num_steps == 4
assert req.guidance_scale == 0.0
assert req.format.value == "png"
assert req.seed is None
assert req.num_images == 1
def test_num_images_up_to_4(self):
req = ImageRequest(prompt="test", num_images=4)
assert req.num_images == 4
def test_num_images_rejects_above_4(self):
with pytest.raises(ValidationError):
ImageRequest(prompt="test", num_images=5)
def test_num_images_rejects_zero(self):
with pytest.raises(ValidationError):
ImageRequest(prompt="test", num_images=0)
def test_width_must_be_multiple_of_8(self):
with pytest.raises(ValidationError, match="multiple of 8"):
ImageRequest(prompt="test", width=513)
def test_height_must_be_multiple_of_8(self):
with pytest.raises(ValidationError, match="multiple of 8"):
ImageRequest(prompt="test", height=100 + 3)
def test_valid_custom_resolution(self):
req = ImageRequest(prompt="test", width=768, height=1344)
assert req.width == 768
assert req.height == 1344
def test_rejects_empty_prompt(self):
with pytest.raises(ValidationError):
ImageRequest(prompt="")
def test_rejects_oversized_resolution(self):
with pytest.raises(ValidationError):
ImageRequest(prompt="test", width=4096)
def test_seed_bounds(self):
req = ImageRequest(prompt="test", seed=0)
assert req.seed == 0
req = ImageRequest(prompt="test", seed=2**32 - 1)
assert req.seed == 2**32 - 1
with pytest.raises(ValidationError):
ImageRequest(prompt="test", seed=-1)
def test_all_formats(self):
for fmt in ("png", "jpeg", "webp"):
req = ImageRequest(prompt="test", format=fmt)
assert req.format.value == fmt
class TestBatchRequest:
def test_valid_batch(self):
batch = BatchRequest(
name="test",
prompts=[ImageRequest(prompt="a"), ImageRequest(prompt="b")],
)
assert len(batch.prompts) == 2
def test_empty_prompts_rejected(self):
with pytest.raises(ValidationError):
BatchRequest(name="empty", prompts=[])
================================================
FILE: tests/test_worker.py
================================================
"""Tests for the GPU worker queue and processing logic."""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from datetime import datetime, timezone
import pytest
from PIL import Image
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.engine.flux_engine import GenerationResult
from app.models.enums import JobPriority
from app.models.job import Base, Job
from app.worker.gpu_worker import GPUWorker, _QueueItem
class TestQueuePriority:
def test_realtime_before_batch(self):
"""REALTIME (priority=1) items should sort before BATCH (priority=5)."""
rt = _QueueItem(priority=1, timestamp=100.0, job_id="rt")
batch = _QueueItem(priority=5, timestamp=99.0, job_id="batch")
assert rt < batch
def test_same_priority_fifo(self):
"""Same priority: earlier timestamp wins."""
a = _QueueItem(priority=1, timestamp=100.0, job_id="a")
b = _QueueItem(priority=1, timestamp=200.0, job_id="b")
assert a < b
class TestWorkerEnqueue:
@pytest.mark.asyncio
async def test_enqueue_increases_depth(self):
engine = MagicMock()
worker = GPUWorker(engine=engine)
await worker.enqueue("job-1", priority=1)
await worker.enqueue("job-2", priority=5)
assert worker.queue_depth == 2
@pytest.mark.asyncio
async def test_realtime_preempts_queued_batch_jobs(self):
engine = MagicMock()
worker = GPUWorker(engine=engine)
await worker.enqueue("batch-1", priority=JobPriority.BATCH)
await worker.enqueue("batch-2", priority=JobPriority.BATCH)
await worker.enqueue("rt-1", priority=JobPriority.REALTIME)
first = await worker._queue.get()
assert first.job_id == "rt-1"
class TestWorkerStats:
def test_initial_stats(self):
engine = MagicMock()
worker = GPUWorker(engine=engine)
stats = worker.get_stats()
assert stats["jobs_processed"] == 0
assert stats["queue_depth"] == 0
assert stats["avg_inference_time_s"] == 0.0
gitextract_xtvq1k3n/
├── .gitignore
├── Dockerfile
├── README.md
├── app/
│ ├── __init__.py
│ ├── api/
│ │ ├── __init__.py
│ │ ├── batch.py
│ │ ├── generate.py
│ │ ├── health.py
│ │ └── jobs.py
│ ├── config.py
│ ├── database.py
│ ├── engine/
│ │ ├── __init__.py
│ │ ├── engine_config.py
│ │ ├── flux_engine.py
│ │ └── mock_engine.py
│ ├── main.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── enums.py
│ │ └── job.py
│ ├── schemas/
│ │ ├── __init__.py
│ │ ├── batch.py
│ │ ├── generate.py
│ │ └── job.py
│ ├── services/
│ │ ├── __init__.py
│ │ ├── generation_service.py
│ │ ├── job_service.py
│ │ └── moderation_service.py
│ ├── utils/
│ │ ├── __init__.py
│ │ ├── image_utils.py
│ │ ├── moderation.py
│ │ └── storage.py
│ └── worker/
│ ├── __init__.py
│ └── gpu_worker.py
├── flux_image_service.egg-info/
│ ├── PKG-INFO
│ ├── SOURCES.txt
│ ├── dependency_links.txt
│ ├── requires.txt
│ └── top_level.txt
├── pyproject.toml
└── tests/
├── __init__.py
├── conftest.py
├── test_api_batch.py
├── test_api_generate.py
├── test_api_jobs.py
├── test_moderation.py
├── test_moderation_service.py
├── test_schemas.py
└── test_worker.py
SYMBOL INDEX (186 symbols across 30 files)
FILE: app/api/batch.py
function _get_worker (line 36) | def _get_worker(request: Request):
function create_batch_job (line 41) | async def create_batch_job(
function create_batch_from_file (line 54) | async def create_batch_from_file(
function get_batch_status (line 88) | async def get_batch_status(
function stream_batch_progress (line 99) | async def stream_batch_progress(
function download_batch_images (line 138) | async def download_batch_images(
function cancel_batch_job (line 216) | async def cancel_batch_job(
function retry_failed_jobs (line 227) | async def retry_failed_jobs(
function _parse_json (line 247) | def _parse_json(content: bytes) -> list[ImageRequest]:
function _parse_csv (line 265) | def _parse_csv(content: bytes) -> list[ImageRequest]:
function _estimate_eta (line 296) | def _estimate_eta(progress: dict, worker) -> float | None:
FILE: app/api/generate.py
function _get_worker (line 30) | def _get_worker(request: Request):
function generate_async (line 35) | async def generate_async(
function stream_generate_progress (line 57) | async def stream_generate_progress(
function generate_sync (line 117) | async def generate_sync(
function serve_image (line 161) | async def serve_image(
FILE: app/api/health.py
function set_start_time (line 14) | def set_start_time() -> None:
function health_check (line 20) | async def health_check(request: Request):
FILE: app/api/jobs.py
function list_all_jobs (line 16) | async def list_all_jobs(
function get_job_detail (line 28) | async def get_job_detail(
function cancel_job_endpoint (line 39) | async def cancel_job_endpoint(
FILE: app/config.py
class Settings (line 5) | class Settings(BaseSettings):
FILE: app/database.py
function init_db (line 11) | async def init_db() -> None:
function get_session (line 18) | async def get_session() -> AsyncSession: # type: ignore[misc]
FILE: app/engine/engine_config.py
function estimated_vram_gb (line 22) | def estimated_vram_gb(width: int, height: int) -> float:
FILE: app/engine/flux_engine.py
class GenerationResult (line 26) | class GenerationResult:
class FluxEngine (line 33) | class FluxEngine:
method load_model (line 41) | def load_model(self) -> None:
method warmup (line 85) | def warmup(self) -> None:
method generate (line 97) | def generate(
method get_vram_usage_gb (line 141) | def get_vram_usage_gb(self) -> float:
method get_vram_stats (line 147) | def get_vram_stats(self) -> dict:
method unload (line 162) | def unload(self) -> None:
method is_loaded (line 174) | def is_loaded(self) -> bool:
FILE: app/engine/mock_engine.py
class GenerationResult (line 16) | class GenerationResult:
class MockFluxEngine (line 23) | class MockFluxEngine:
method load_model (line 28) | def load_model(self) -> None:
method warmup (line 32) | def warmup(self) -> None:
method generate (line 35) | def generate(
method get_vram_usage_gb (line 80) | def get_vram_usage_gb(self) -> float:
method get_vram_stats (line 83) | def get_vram_stats(self) -> dict:
method unload (line 86) | def unload(self) -> None:
method is_loaded (line 91) | def is_loaded(self) -> bool:
FILE: app/main.py
function _setup_logging (line 20) | def _setup_logging() -> None:
function lifespan (line 54) | async def lifespan(app: FastAPI):
FILE: app/models/enums.py
class JobStatus (line 4) | class JobStatus(str, enum.Enum):
class JobPriority (line 12) | class JobPriority(int, enum.Enum):
class ImageFormat (line 17) | class ImageFormat(str, enum.Enum):
FILE: app/models/job.py
class Base (line 17) | class Base(DeclarativeBase):
function _utcnow (line 21) | def _utcnow() -> datetime:
class Job (line 25) | class Job(Base):
class BatchJob (line 57) | class BatchJob(Base):
FILE: app/schemas/batch.py
class BatchRequest (line 6) | class BatchRequest(BaseModel):
class BatchProgress (line 11) | class BatchProgress(BaseModel):
FILE: app/schemas/generate.py
class ImageRequest (line 7) | class ImageRequest(BaseModel):
method check_content_safety (line 20) | def check_content_safety(cls, v: str) -> str:
method validate_resolution (line 27) | def validate_resolution(self) -> "ImageRequest":
class ImageResponse (line 35) | class ImageResponse(BaseModel):
FILE: app/schemas/job.py
class JobDetail (line 4) | class JobDetail(BaseModel):
class JobList (line 25) | class JobList(BaseModel):
FILE: app/services/generation_service.py
function create_single_job (line 18) | async def create_single_job(
function create_batch (line 63) | async def create_batch(
FILE: app/services/job_service.py
function _job_to_detail (line 16) | def _job_to_detail(job: Job) -> JobDetail:
function get_job (line 39) | async def get_job(session: AsyncSession, job_id: str) -> JobDetail | None:
function list_jobs (line 47) | async def list_jobs(
function cancel_job (line 71) | async def cancel_job(session: AsyncSession, job_id: str) -> JobDetail | ...
function get_batch_progress (line 85) | async def get_batch_progress(session: AsyncSession, batch_id: str) -> di...
function cancel_batch (line 113) | async def cancel_batch(session: AsyncSession, batch_id: str) -> dict | N...
function retry_failed_in_batch (line 135) | async def retry_failed_in_batch(session: AsyncSession, batch_id: str) ->...
FILE: app/services/moderation_service.py
class ModerationError (line 14) | class ModerationError(Exception):
function check_prompt_safety (line 18) | async def check_prompt_safety(prompt: str) -> tuple[bool, str | None]:
FILE: app/utils/image_utils.py
function save_image (line 12) | def save_image(
FILE: app/utils/moderation.py
class ModerationResult (line 20) | class ModerationResult:
class ModerationEngine (line 26) | class ModerationEngine:
method __init__ (line 34) | def __init__(self) -> None:
method load (line 40) | def load(self, model_id: str = "google/gemma-3-1b-it") -> None:
method is_loaded (line 61) | def is_loaded(self) -> bool:
method _parse_pipeline_output (line 67) | def _parse_pipeline_output(self, raw: object) -> str:
method check_with_local_model (line 80) | def check_with_local_model(self, prompt: str) -> ModerationResult:
method check_with_agent_api (line 102) | def check_with_agent_api(
method check (line 129) | def check(self, prompt: str, *, hf_token: str | None = None) -> Modera...
function get_moderation_engine (line 169) | def get_moderation_engine() -> ModerationEngine:
function check_prompt (line 174) | def check_prompt(prompt: str) -> None:
FILE: app/utils/storage.py
function ensure_storage_dirs (line 11) | def ensure_storage_dirs() -> None:
function get_image_path (line 17) | def get_image_path(job_id: str, fmt: str, batch_id: str | None = None) -...
function image_url_path (line 38) | def image_url_path(job_id: str) -> str:
FILE: app/worker/gpu_worker.py
class _QueueItem (line 27) | class _QueueItem:
class GPUWorker (line 36) | class GPUWorker:
method _normalize_priority (line 49) | def _normalize_priority(priority: int) -> int:
method enqueue (line 57) | async def enqueue(self, job_id: str, priority: int = 1) -> None:
method queue_depth (line 64) | def queue_depth(self) -> int:
method start (line 69) | def start(self) -> asyncio.Task:
method shutdown (line 76) | async def shutdown(self) -> None:
method resume_pending_jobs (line 88) | async def resume_pending_jobs(self) -> int:
method _run (line 108) | async def _run(self) -> None:
method _process_job (line 124) | async def _process_job(self, job_id: str) -> None:
method _fail_job (line 198) | async def _fail_job(self, job_id: str, error: str) -> None:
method _update_batch_count (line 211) | async def _update_batch_count(self, session, batch_id: str) -> None:
method get_stats (line 244) | def get_stats(self) -> dict:
FILE: tests/conftest.py
function event_loop (line 19) | def event_loop():
function db_engine (line 26) | async def db_engine(tmp_path):
function db_session (line 36) | async def db_session(db_engine):
function mock_engine (line 43) | def mock_engine():
function mock_worker (line 58) | def mock_worker(mock_engine):
function client (line 73) | async def client(mock_engine, mock_worker, db_engine, tmp_path):
FILE: tests/test_api_batch.py
function test_create_batch (line 7) | async def test_create_batch(client, mock_worker):
function test_get_batch_404 (line 29) | async def test_get_batch_404(client):
function test_cancel_batch (line 35) | async def test_cancel_batch(client, mock_worker):
function test_batch_requires_prompts (line 54) | async def test_batch_requires_prompts(client):
FILE: tests/test_api_generate.py
function test_generate_async_returns_job_ids (line 13) | async def test_generate_async_returns_job_ids(client):
function test_generate_multiple_images (line 23) | async def test_generate_multiple_images(client, mock_worker):
function test_generate_max_4_images (line 32) | async def test_generate_max_4_images(client):
function test_generate_rejects_more_than_4_images (line 39) | async def test_generate_rejects_more_than_4_images(client):
function test_generate_async_enqueues_work (line 45) | async def test_generate_async_enqueues_work(client, mock_worker):
function test_generate_validates_resolution_multiple_of_8 (line 52) | async def test_generate_validates_resolution_multiple_of_8(client):
function test_generate_rejects_empty_prompt (line 61) | async def test_generate_rejects_empty_prompt(client):
function test_generate_rejects_oversized_resolution (line 67) | async def test_generate_rejects_oversized_resolution(client):
function test_generate_default_values (line 76) | async def test_generate_default_values(client, mock_worker):
function test_serve_image_404_for_missing_job (line 82) | async def test_serve_image_404_for_missing_job(client):
function test_generate_stream_emits_done_for_completed_jobs (line 88) | async def test_generate_stream_emits_done_for_completed_jobs(client, db_...
function test_generate_stream_returns_404_for_unknown_job (line 115) | async def test_generate_stream_returns_404_for_unknown_job(client):
function test_generate_async_blocked_when_prompt_is_unsafe (line 126) | async def test_generate_async_blocked_when_prompt_is_unsafe(client):
function test_generate_async_proceeds_when_prompt_is_safe (line 139) | async def test_generate_async_proceeds_when_prompt_is_safe(client, mock_...
function test_generate_async_returns_503_when_moderation_service_fails (line 153) | async def test_generate_async_returns_503_when_moderation_service_fails(...
function test_generate_sync_blocked_when_prompt_is_unsafe (line 165) | async def test_generate_sync_blocked_when_prompt_is_unsafe(client):
function test_generate_sync_returns_503_when_moderation_service_fails (line 178) | async def test_generate_sync_returns_503_when_moderation_service_fails(c...
FILE: tests/test_api_jobs.py
function test_list_jobs_empty (line 7) | async def test_list_jobs_empty(client):
function test_get_job_after_create (line 16) | async def test_get_job_after_create(client, mock_worker):
function test_cancel_job (line 33) | async def test_cancel_job(client, mock_worker):
function test_list_jobs_with_status_filter (line 43) | async def test_list_jobs_with_status_filter(client, mock_worker):
function test_get_nonexistent_job (line 54) | async def test_get_nonexistent_job(client):
function test_health_endpoint (line 60) | async def test_health_endpoint(client):
FILE: tests/test_moderation.py
class TestModerationEngine (line 16) | class TestModerationEngine:
method test_is_loaded_false_before_load (line 17) | def test_is_loaded_false_before_load(self):
method test_load_failure_leaves_engine_unloaded (line 21) | def test_load_failure_leaves_engine_unloaded(self):
method test_parse_pipeline_output_safe (line 31) | def test_parse_pipeline_output_safe(self):
method test_parse_pipeline_output_unsafe (line 36) | def test_parse_pipeline_output_unsafe(self):
method test_check_with_local_model_raises_when_not_loaded (line 41) | def test_check_with_local_model_raises_when_not_loaded(self):
method test_check_with_local_model_safe (line 46) | def test_check_with_local_model_safe(self):
method test_check_with_local_model_unsafe (line 55) | def test_check_with_local_model_unsafe(self):
method test_check_with_agent_api_safe (line 65) | def test_check_with_agent_api_safe(self):
method test_check_with_agent_api_unsafe (line 83) | def test_check_with_agent_api_unsafe(self):
method test_check_uses_local_model_first (line 102) | def test_check_uses_local_model_first(self):
method test_check_falls_back_to_agent_api_when_local_fails (line 115) | def test_check_falls_back_to_agent_api_when_local_fails(self):
method test_check_fails_open_when_both_unavailable (line 127) | def test_check_fails_open_when_both_unavailable(self):
method test_check_falls_open_when_agent_api_also_fails (line 134) | def test_check_falls_open_when_agent_api_also_fails(self):
class TestCheckPrompt (line 148) | class TestCheckPrompt:
method test_allows_safe_prompt (line 149) | def test_allows_safe_prompt(self):
method test_rejects_unsafe_prompt (line 157) | def test_rejects_unsafe_prompt(self):
method test_skips_when_moderation_disabled (line 168) | def test_skips_when_moderation_disabled(self):
class TestImageRequestModeration (line 180) | class TestImageRequestModeration:
method test_unsafe_prompt_raises_validation_error (line 181) | def test_unsafe_prompt_raises_validation_error(self):
method test_safe_prompt_passes_validation (line 194) | def test_safe_prompt_passes_validation(self):
FILE: tests/test_moderation_service.py
function _make_moderation_response (line 18) | def _make_moderation_response(flagged: bool, categories: dict[str, bool]...
function test_check_prompt_safety_disabled_skips_api (line 43) | async def test_check_prompt_safety_disabled_skips_api():
function test_check_prompt_safety_no_api_key_skips_api (line 54) | async def test_check_prompt_safety_no_api_key_skips_api():
function test_check_prompt_safety_safe_prompt (line 71) | async def test_check_prompt_safety_safe_prompt():
function test_check_prompt_safety_flagged_prompt (line 100) | async def test_check_prompt_safety_flagged_prompt():
function test_check_prompt_safety_flagged_multiple_categories (line 128) | async def test_check_prompt_safety_flagged_multiple_categories():
function test_check_prompt_safety_flagged_no_categories (line 156) | async def test_check_prompt_safety_flagged_no_categories():
function test_check_prompt_safety_http_error_raises_moderation_error (line 185) | async def test_check_prompt_safety_http_error_raises_moderation_error():
function test_check_prompt_safety_http_status_error_raises_moderation_error (line 207) | async def test_check_prompt_safety_http_status_error_raises_moderation_e...
function test_check_prompt_safety_empty_results (line 237) | async def test_check_prompt_safety_empty_results():
FILE: tests/test_schemas.py
class TestImageRequest (line 10) | class TestImageRequest:
method test_defaults (line 11) | def test_defaults(self):
method test_num_images_up_to_4 (line 21) | def test_num_images_up_to_4(self):
method test_num_images_rejects_above_4 (line 25) | def test_num_images_rejects_above_4(self):
method test_num_images_rejects_zero (line 29) | def test_num_images_rejects_zero(self):
method test_width_must_be_multiple_of_8 (line 33) | def test_width_must_be_multiple_of_8(self):
method test_height_must_be_multiple_of_8 (line 37) | def test_height_must_be_multiple_of_8(self):
method test_valid_custom_resolution (line 41) | def test_valid_custom_resolution(self):
method test_rejects_empty_prompt (line 46) | def test_rejects_empty_prompt(self):
method test_rejects_oversized_resolution (line 50) | def test_rejects_oversized_resolution(self):
method test_seed_bounds (line 54) | def test_seed_bounds(self):
method test_all_formats (line 64) | def test_all_formats(self):
class TestBatchRequest (line 70) | class TestBatchRequest:
method test_valid_batch (line 71) | def test_valid_batch(self):
method test_empty_prompts_rejected (line 78) | def test_empty_prompts_rejected(self):
FILE: tests/test_worker.py
class TestQueuePriority (line 18) | class TestQueuePriority:
method test_realtime_before_batch (line 19) | def test_realtime_before_batch(self):
method test_same_priority_fifo (line 25) | def test_same_priority_fifo(self):
class TestWorkerEnqueue (line 32) | class TestWorkerEnqueue:
method test_enqueue_increases_depth (line 34) | async def test_enqueue_increases_depth(self):
method test_realtime_preempts_queued_batch_jobs (line 42) | async def test_realtime_preempts_queued_batch_jobs(self):
class TestWorkerStats (line 54) | class TestWorkerStats:
method test_initial_stats (line 55) | def test_initial_stats(self):
Condensed preview — 48 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (132K chars).
[
{
"path": ".gitignore",
"chars": 173,
"preview": "# Byte-compiled\n__pycache__/\n*.py[cod]\n\n# Virtual env\n.venv/\nvenv/\n\n# Output\noutput/\n\n# Logs\nlogs/\n\n# IDE\n.idea/\n.vscode"
},
{
"path": "Dockerfile",
"chars": 778,
"preview": "FROM nvidia/cuda:12.4.1-runtime-ubuntu22.04\n\nWORKDIR /app\n\n# System deps\nRUN apt-get update && apt-get install -y --no-i"
},
{
"path": "README.md",
"chars": 14176,
"preview": "# Flux Image Service\n\nLocal image generation API powered by **Flux.1 Schnell 12B** running on an H100 GPU with FP8 quant"
},
{
"path": "app/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "app/api/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "app/api/batch.py",
"chars": 10067,
"preview": "\"\"\"Batch / dataset generation endpoints with SSE progress streaming.\"\"\"\n\nfrom __future__ import annotations\n\nimport asyn"
},
{
"path": "app/api/generate.py",
"chars": 6655,
"preview": "\"\"\"Image generation endpoints — async and sync modes.\"\"\"\n\nfrom __future__ import annotations\n\nimport asyncio\nimport json"
},
{
"path": "app/api/health.py",
"chars": 685,
"preview": "\"\"\"Health and status endpoint.\"\"\"\n\nfrom __future__ import annotations\n\nimport time\n\nfrom fastapi import APIRouter, Reque"
},
{
"path": "app/api/jobs.py",
"chars": 1499,
"preview": "\"\"\"Job management endpoints.\"\"\"\n\nfrom __future__ import annotations\n\nfrom fastapi import APIRouter, Depends, HTTPExcepti"
},
{
"path": "app/config.py",
"chars": 1429,
"preview": "from pydantic_settings import BaseSettings\nfrom pathlib import Path\n\n\nclass Settings(BaseSettings):\n model_config = {"
},
{
"path": "app/database.py",
"chars": 677,
"preview": "from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine\n\nfrom app.config import setting"
},
{
"path": "app/engine/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "app/engine/engine_config.py",
"chars": 847,
"preview": "\"\"\"Resolution presets and engine configuration.\"\"\"\n\nPRESET_RESOLUTIONS: dict[str, tuple[int, int]] = {\n \"square_sm\": "
},
{
"path": "app/engine/flux_engine.py",
"chars": 5547,
"preview": "\"\"\"Flux.1 Schnell inference engine — model loading, generation, VRAM management.\"\"\"\n\nfrom __future__ import annotations\n"
},
{
"path": "app/engine/mock_engine.py",
"chars": 2567,
"preview": "\"\"\"Mock FluxEngine for local development without GPU or model weights.\"\"\"\n\nfrom __future__ import annotations\n\nimport lo"
},
{
"path": "app/main.py",
"chars": 3687,
"preview": "\"\"\"FastAPI application — lifespan manages model loading, worker, and DB.\"\"\"\n\nfrom __future__ import annotations\n\nimport "
},
{
"path": "app/models/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "app/models/enums.py",
"chars": 336,
"preview": "import enum\n\n\nclass JobStatus(str, enum.Enum):\n PENDING = \"pending\"\n PROCESSING = \"processing\"\n COMPLETED = \"co"
},
{
"path": "app/models/job.py",
"chars": 2424,
"preview": "import uuid\nfrom datetime import datetime, timezone\n\nfrom sqlalchemy import (\n Column,\n DateTime,\n Enum,\n Fl"
},
{
"path": "app/schemas/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "app/schemas/batch.py",
"chars": 507,
"preview": "from pydantic import BaseModel, Field\n\nfrom app.schemas.generate import ImageRequest\n\n\nclass BatchRequest(BaseModel):\n "
},
{
"path": "app/schemas/generate.py",
"chars": 1618,
"preview": "from pydantic import BaseModel, Field, field_validator, model_validator\n\nfrom app.config import settings\nfrom app.models"
},
{
"path": "app/schemas/job.py",
"chars": 553,
"preview": "from pydantic import BaseModel\n\n\nclass JobDetail(BaseModel):\n id: str\n prompt: str\n negative_prompt: str | None"
},
{
"path": "app/services/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "app/services/generation_service.py",
"chars": 2811,
"preview": "\"\"\"Orchestrates image generation: validate → persist → enqueue.\"\"\"\n\nfrom __future__ import annotations\n\nimport uuid\nfrom"
},
{
"path": "app/services/job_service.py",
"chars": 5233,
"preview": "\"\"\"Job and batch CRUD operations.\"\"\"\n\nfrom __future__ import annotations\n\nfrom datetime import datetime, timezone\n\nfrom "
},
{
"path": "app/services/moderation_service.py",
"chars": 3025,
"preview": "\"\"\"Content moderation service – checks prompt safety via a third-party API.\"\"\"\n\nfrom __future__ import annotations\n\nimpo"
},
{
"path": "app/utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "app/utils/image_utils.py",
"chars": 1069,
"preview": "\"\"\"Image saving, format conversion, and metadata embedding.\"\"\"\n\nfrom __future__ import annotations\n\nfrom pathlib import "
},
{
"path": "app/utils/moderation.py",
"chars": 7135,
"preview": "\"\"\"Content moderation using Gemma-3-1b-it (local) with HF Inference API fallback.\"\"\"\n\nfrom __future__ import annotations"
},
{
"path": "app/utils/storage.py",
"chars": 1250,
"preview": "\"\"\"Local filesystem storage management for generated images.\"\"\"\n\nfrom __future__ import annotations\n\nfrom datetime impor"
},
{
"path": "app/worker/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "app/worker/gpu_worker.py",
"chars": 9371,
"preview": "\"\"\"GPU worker — async priority queue consumer driving the Flux engine.\"\"\"\n\nfrom __future__ import annotations\n\nimport as"
},
{
"path": "flux_image_service.egg-info/PKG-INFO",
"chars": 900,
"preview": "Metadata-Version: 2.4\nName: flux-image-service\nVersion: 0.1.0\nSummary: Local image generation service using Flux.1 Schne"
},
{
"path": "flux_image_service.egg-info/SOURCES.txt",
"chars": 1046,
"preview": "README.md\npyproject.toml\napp/__init__.py\napp/config.py\napp/database.py\napp/main.py\napp/api/__init__.py\napp/api/batch.py\n"
},
{
"path": "flux_image_service.egg-info/dependency_links.txt",
"chars": 1,
"preview": "\n"
},
{
"path": "flux_image_service.egg-info/requires.txt",
"chars": 372,
"preview": "fastapi>=0.115.0\nuvicorn[standard]>=0.30.0\npydantic>=2.0\npydantic-settings>=2.0\nsqlalchemy[asyncio]>=2.0\naiosqlite>=0.20"
},
{
"path": "flux_image_service.egg-info/top_level.txt",
"chars": 4,
"preview": "app\n"
},
{
"path": "pyproject.toml",
"chars": 849,
"preview": "[project]\nname = \"flux-image-service\"\nversion = \"0.1.0\"\ndescription = \"Local image generation service using Flux.1 Schne"
},
{
"path": "tests/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "tests/conftest.py",
"chars": 2884,
"preview": "\"\"\"Shared test fixtures.\"\"\"\n\nfrom __future__ import annotations\n\nimport asyncio\nfrom unittest.mock import AsyncMock, Mag"
},
{
"path": "tests/test_api_batch.py",
"chars": 1469,
"preview": "\"\"\"Tests for batch API endpoints.\"\"\"\n\nimport pytest\n\n\n@pytest.mark.asyncio\nasync def test_create_batch(client, mock_work"
},
{
"path": "tests/test_api_generate.py",
"chars": 6424,
"preview": "\"\"\"Tests for generation API endpoints.\"\"\"\n\nimport pytest\nfrom unittest.mock import AsyncMock, patch\nfrom sqlalchemy impo"
},
{
"path": "tests/test_api_jobs.py",
"chars": 1917,
"preview": "\"\"\"Tests for job management endpoints.\"\"\"\n\nimport pytest\n\n\n@pytest.mark.asyncio\nasync def test_list_jobs_empty(client):\n"
},
{
"path": "tests/test_moderation.py",
"chars": 8221,
"preview": "\"\"\"Tests for the content moderation module.\"\"\"\n\nfrom __future__ import annotations\n\nfrom unittest.mock import MagicMock,"
},
{
"path": "tests/test_moderation_service.py",
"chars": 10353,
"preview": "\"\"\"Tests for the content moderation service.\"\"\"\n\nfrom __future__ import annotations\n\nfrom unittest.mock import AsyncMock"
},
{
"path": "tests/test_schemas.py",
"chars": 2623,
"preview": "\"\"\"Tests for schemas — resolution validation, format handling.\"\"\"\n\nimport pytest\nfrom pydantic import ValidationError\n\nf"
},
{
"path": "tests/test_worker.py",
"chars": 2105,
"preview": "\"\"\"Tests for the GPU worker queue and processing logic.\"\"\"\n\nimport asyncio\nfrom unittest.mock import AsyncMock, MagicMoc"
}
]
About this extraction
This page contains the full source code of the sgshubham98/image-service-be GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 48 files (120.4 KB), approximately 30.8k tokens, and a symbol index with 186 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.