Repository: sgshubham98/image-service-be Branch: main Commit: f9ddc6802cdf Files: 48 Total size: 120.4 KB Directory structure: gitextract_xtvq1k3n/ ├── .gitignore ├── Dockerfile ├── README.md ├── app/ │ ├── __init__.py │ ├── api/ │ │ ├── __init__.py │ │ ├── batch.py │ │ ├── generate.py │ │ ├── health.py │ │ └── jobs.py │ ├── config.py │ ├── database.py │ ├── engine/ │ │ ├── __init__.py │ │ ├── engine_config.py │ │ ├── flux_engine.py │ │ └── mock_engine.py │ ├── main.py │ ├── models/ │ │ ├── __init__.py │ │ ├── enums.py │ │ └── job.py │ ├── schemas/ │ │ ├── __init__.py │ │ ├── batch.py │ │ ├── generate.py │ │ └── job.py │ ├── services/ │ │ ├── __init__.py │ │ ├── generation_service.py │ │ ├── job_service.py │ │ └── moderation_service.py │ ├── utils/ │ │ ├── __init__.py │ │ ├── image_utils.py │ │ ├── moderation.py │ │ └── storage.py │ └── worker/ │ ├── __init__.py │ └── gpu_worker.py ├── flux_image_service.egg-info/ │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ ├── requires.txt │ └── top_level.txt ├── pyproject.toml └── tests/ ├── __init__.py ├── conftest.py ├── test_api_batch.py ├── test_api_generate.py ├── test_api_jobs.py ├── test_moderation.py ├── test_moderation_service.py ├── test_schemas.py └── test_worker.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled __pycache__/ *.py[cod] # Virtual env .venv/ venv/ # Output output/ # Logs logs/ # IDE .idea/ .vscode/ # Environment .env # SQLite *.db # OS .DS_Store ================================================ FILE: Dockerfile ================================================ FROM nvidia/cuda:12.4.1-runtime-ubuntu22.04 WORKDIR /app # System deps RUN apt-get update && apt-get install -y --no-install-recommends \ python3.11 python3.11-venv python3-pip git && \ rm -rf /var/lib/apt/lists/* # Create venv RUN python3.11 -m venv /opt/venv ENV PATH="/opt/venv/bin:$PATH" # Install Python deps COPY pyproject.toml . RUN pip install --no-cache-dir --upgrade pip && \ pip install --no-cache-dir ".[dev]" # Install PyTorch with CUDA 12.4 (override index for torch) RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cu124 # Copy application COPY app/ app/ COPY tests/ tests/ # Create output directory RUN mkdir -p /app/output EXPOSE 8000 CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] ================================================ FILE: README.md ================================================ # Flux Image Service Local image generation API powered by **Flux.1 Schnell 12B** running on an H100 GPU with FP8 quantization. ## Features - **Single image generation** — async (returns job IDs) or sync (waits for result), up to 4 images per request - **Batch/dataset generation** — submit up to 50K prompts via API or CSV/JSON file upload - **Priority queue** — real-time requests always jump ahead of batch jobs - **SSE progress streaming** — monitor batch progress in real-time - **Crash recovery** — pending jobs auto-resume on server restart (persisted in SQLite) - **Multiple formats** — PNG (with metadata), JPEG, WebP - **Custom resolutions** — any multiple of 8, up to 2048px ## Prerequisites - **Python** 3.10+ - **NVIDIA GPU** with CUDA 12.x (tested on H100 20GB) - **NVIDIA drivers** 535+ with CUDA toolkit - **System RAM** 32 GB recommended (16 GB minimum) - **Disk** ~12 GB for model weights + storage for generated images ## Setup ```bash # 1. Clone / navigate to the project cd flux-image-service # 2. Create a virtual environment python3 -m venv .venv source .venv/bin/activate # 3. Install dependencies (includes torch, diffusers, fastapi, etc.) pip install -e ".[dev]" # 4. Copy and configure environment variables cp .env.example .env # Edit .env if needed (model ID, storage path, etc.) ``` ## Running the Server ```bash # Activate venv (if not already) source .venv/bin/activate # Start the server uvicorn app.main:app --host 0.0.0.0 --port 8000 # Or with auto-reload for development (not recommended for GPU — reloads unload the model) uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload ``` The first startup will: 1. Download the Flux.1 Schnell model from HuggingFace (~12 GB) 2. Compile CUDA kernels via `torch.compile` (~30-60s one-time cost per restart) 3. Run a warm-up inference Once you see `Flux Image Service ready`, the server is accepting requests. ## Running Tests ```bash source .venv/bin/activate # Run all tests pytest tests/ -v # Run a specific test file pytest tests/test_api_generate.py -v # Run with coverage pip install pytest-cov pytest tests/ -v --cov=app --cov-report=term-missing ``` ## Docker ```bash # Build docker build -t flux-image-service . # Run (requires nvidia-container-toolkit) docker run --gpus all -p 8000:8000 -v ./output:/app/output flux-image-service # Run with custom env docker run --gpus all -p 8000:8000 \ -e MODEL_ID=black-forest-labs/FLUX.1-schnell \ -e STORAGE_DIR=/app/output \ -v ./output:/app/output \ flux-image-service ``` ## Linting ```bash # Check ruff check app/ tests/ # Auto-fix ruff check app/ tests/ --fix # Format ruff format app/ tests/ ``` --- ## API Reference ### `POST /generate` — Async Image Generation Submit a generation request (1–4 images). Returns immediately with job IDs. **Request:** ```bash curl -X POST http://localhost:8000/generate \ -H "Content-Type: application/json" \ -d '{ "prompt": "a majestic mountain landscape at sunset", "width": 1024, "height": 1024, "num_images": 2, "seed": 42 }' ``` **Response** `200 OK`: ```json { "job_ids": ["f47ac10b-58cc-4372-a567-0e02b2c3d479", "a1b2c3d4-5678-9abc-def0-1234567890ab"], "status": "pending", "image_urls": [] } ``` **Validation error** `422`: ```json { "detail": [ { "type": "value_error", "loc": ["body", "width"], "msg": "Value error, width must be a multiple of 8" } ] } ``` All request parameters for `ImageRequest`: | Field | Type | Default | Description | |-------|------|---------|-------------| | `prompt` | string | *required* | Text prompt (1–2000 chars) | | `negative_prompt` | string \| null | null | Negative prompt | | `width` | int | 1024 | Image width (64–2048, multiple of 8) | | `height` | int | 1024 | Image height (64–2048, multiple of 8) | | `num_steps` | int | 4 | Inference steps (1–50) | | `guidance_scale` | float | 0.0 | CFG scale (Schnell doesn't use CFG) | | `seed` | int \| null | null | Random seed (0–4294967295) | | `num_images` | int | 1 | Number of images (1–4) | | `format` | string | "png" | Output format: `png`, `jpeg`, `webp` | --- ### `GET /generate/stream` — SSE Progress Stream for Single Generate Stream status updates for one or more job IDs returned by `/generate` until all are terminal. **Request:** ```bash curl -N "http://localhost:8000/generate/stream?job_id=f47ac10b-58cc-4372-a567-0e02b2c3d479&job_id=a1b2c3d4-5678-9abc-def0-1234567890ab" ``` **Events:** - `progress` — emitted periodically with per-job status and aggregate counters - `done` — emitted once when all jobs are in `completed`, `failed`, or `cancelled` **Example event payload:** ```json { "total": 2, "completed": 1, "failed": 0, "cancelled": 0, "pending": 1, "jobs": [ { "id": "f47ac10b-58cc-4372-a567-0e02b2c3d479", "status": "completed", "image_url": "/jobs/f47ac10b-58cc-4372-a567-0e02b2c3d479/image" }, { "id": "a1b2c3d4-5678-9abc-def0-1234567890ab", "status": "processing", "image_url": null } ] } ``` --- ### `POST /generate/sync` — Synchronous Generation Same request body as `/generate`. Blocks until all images are generated (up to 60s timeout). **Request:** ```bash curl -X POST http://localhost:8000/generate/sync \ -H "Content-Type: application/json" \ -d '{"prompt": "a cute cat wearing a hat", "num_images": 1}' \ --max-time 60 ``` **Response** `200 OK`: ```json { "job_ids": ["f47ac10b-58cc-4372-a567-0e02b2c3d479"], "status": "completed", "image_urls": ["/jobs/f47ac10b-58cc-4372-a567-0e02b2c3d479/image"] } ``` **Timeout** `408`: ```json {"detail": "Generation timed out"} ``` --- ### `GET /jobs/{job_id}/image` — Serve Generated Image Returns the image file for a completed job. **Request:** ```bash curl http://localhost:8000/jobs/f47ac10b-58cc-4372-a567-0e02b2c3d479/image --output image.png ``` **Response:** Binary image file with appropriate `Content-Type` (`image/png`, `image/jpeg`, or `image/webp`). **Not found** `404`: ```json {"detail": "Image not available"} ``` --- ### `GET /jobs/{job_id}` — Get Job Details **Request:** ```bash curl http://localhost:8000/jobs/f47ac10b-58cc-4372-a567-0e02b2c3d479 ``` **Response** `200 OK`: ```json { "id": "f47ac10b-58cc-4372-a567-0e02b2c3d479", "prompt": "a majestic mountain landscape at sunset", "negative_prompt": null, "width": 1024, "height": 1024, "num_steps": 4, "guidance_scale": 0.0, "seed": 42, "status": "completed", "priority": 1, "format": "png", "file_path": "output/images/2026-03-08/f47ac10b-58cc-4372-a567-0e02b2c3d479.png", "image_url": "/jobs/f47ac10b-58cc-4372-a567-0e02b2c3d479/image", "error_message": null, "batch_id": null, "created_at": "2026-03-08T12:00:00+00:00", "started_at": "2026-03-08T12:00:01+00:00", "completed_at": "2026-03-08T12:00:03+00:00" } ``` --- ### `GET /jobs` — List Jobs **Request:** ```bash # List all jobs (paginated) curl "http://localhost:8000/jobs?page=1&page_size=20" # Filter by status curl "http://localhost:8000/jobs?status=completed" # Filter by batch curl "http://localhost:8000/jobs?batch_id=batch-uuid-here" ``` **Response** `200 OK`: ```json { "jobs": [ { "id": "f47ac10b-...", "prompt": "a mountain", "status": "completed", "width": 1024, "height": 1024, "image_url": "/jobs/f47ac10b-.../image", "...": "..." } ], "total": 150, "page": 1, "page_size": 20 } ``` --- ### `DELETE /jobs/{job_id}` — Cancel a Job Cancels a pending/processing job. Already completed/failed jobs are returned as-is. **Request:** ```bash curl -X DELETE http://localhost:8000/jobs/f47ac10b-58cc-4372-a567-0e02b2c3d479 ``` **Response** `200 OK`: ```json { "id": "f47ac10b-58cc-4372-a567-0e02b2c3d479", "status": "cancelled", "...": "..." } ``` --- ### `POST /batch` — Create Batch Job Submit 1–50,000 prompts as a batch. All jobs get BATCH priority (lower than real-time). **Request:** ```bash curl -X POST http://localhost:8000/batch \ -H "Content-Type: application/json" \ -d '{ "name": "training-dataset-v1", "prompts": [ {"prompt": "a red sports car", "width": 512, "height": 512}, {"prompt": "a blue ocean wave", "width": 1024, "height": 768}, {"prompt": "a forest path in autumn", "seed": 123} ] }' ``` **Response** `200 OK`: ```json { "batch_id": "b1c2d3e4-5678-9abc-def0-1234567890ab", "name": "training-dataset-v1", "total": 3, "completed": 0, "failed": 0, "cancelled": 0, "pending": 3, "status": "pending", "estimated_remaining_seconds": null, "created_at": "2026-03-08T12:00:00+00:00", "completed_at": null } ``` --- ### `POST /batch/from-file` — Create Batch from File Upload Upload a CSV or JSON file of prompts. **CSV upload:** ```bash curl -X POST http://localhost:8000/batch/from-file \ -F "file=@prompts.csv" \ -F "name=my-dataset" ``` CSV format (only `prompt` column is required): ```csv prompt,width,height,num_steps,seed,format a red car,512,512,4,42,png a blue house,1024,1024,,, a green tree,,,,, ``` **JSON upload:** ```bash curl -X POST http://localhost:8000/batch/from-file \ -F "file=@prompts.json" \ -F "name=my-dataset" ``` JSON format: ```json [ {"prompt": "a red car", "width": 512, "height": 512, "seed": 42}, {"prompt": "a blue house"}, {"prompt": "a green tree", "format": "jpeg"} ] ``` **Response:** Same as `POST /batch`. --- ### `GET /batch/{batch_id}` — Get Batch Progress **Request:** ```bash curl http://localhost:8000/batch/b1c2d3e4-5678-9abc-def0-1234567890ab ``` **Response** `200 OK`: ```json { "batch_id": "b1c2d3e4-5678-9abc-def0-1234567890ab", "name": "training-dataset-v1", "total": 1000, "completed": 342, "failed": 2, "cancelled": 0, "pending": 656, "status": "processing", "estimated_remaining_seconds": 1968.0, "created_at": "2026-03-08T12:00:00+00:00", "completed_at": null } ``` --- ### `GET /batch/{batch_id}/stream` — SSE Progress Stream Real-time Server-Sent Events stream of batch progress. Updates every 2 seconds. **Request:** ```bash curl -N http://localhost:8000/batch/b1c2d3e4-5678-9abc-def0-1234567890ab/stream ``` **Response** (event stream): ``` event: progress data: {"batch_id":"b1c2d3e4-...","total":1000,"completed":342,"failed":2,"pending":656,"status":"processing","estimated_remaining_seconds":1968.0} event: progress data: {"batch_id":"b1c2d3e4-...","total":1000,"completed":343,"failed":2,"pending":655,"status":"processing","estimated_remaining_seconds":1965.0} event: done data: {"batch_id":"b1c2d3e4-...","total":1000,"completed":998,"failed":2,"pending":0,"status":"failed"} ``` --- ### `DELETE /batch/{batch_id}` — Cancel Batch Cancels all remaining pending jobs in the batch. **Request:** ```bash curl -X DELETE http://localhost:8000/batch/b1c2d3e4-5678-9abc-def0-1234567890ab ``` **Response** `200 OK`: ```json { "batch_id": "b1c2d3e4-5678-9abc-def0-1234567890ab", "total": 1000, "completed": 342, "failed": 2, "cancelled": 656, "pending": 0, "status": "cancelled", "...": "..." } ``` --- ### `POST /batch/{batch_id}/retry-failed` — Retry Failed Jobs Re-enqueue only the failed jobs in a batch. **Request:** ```bash curl -X POST http://localhost:8000/batch/b1c2d3e4-5678-9abc-def0-1234567890ab/retry-failed ``` **Response** `200 OK`: ```json { "retried": 2, "job_ids": ["job-uuid-1", "job-uuid-2"] } ``` --- ### `GET /health` — Health Check **Request:** ```bash curl http://localhost:8000/health ``` **Response** `200 OK`: ```json { "status": "ok", "model_loaded": true, "vram": { "available": true, "allocated_gb": 12.34, "reserved_gb": 14.50, "max_allocated_gb": 16.78, "total_gb": 20.00, "device_name": "NVIDIA H100" }, "worker": { "queue_depth": 5, "jobs_processed": 142, "total_inference_time_s": 398.50, "avg_inference_time_s": 2.81 }, "uptime_seconds": 3600.0 } ``` --- ## Architecture ``` Client → FastAPI → asyncio.PriorityQueue → GPU Worker → Flux Engine ↓ ↓ SQLite Local FS (job state) (images) ``` - **Single GPU worker** — processes one image at a time (VRAM constraint) - **Priority queue** — REALTIME(1) always before BATCH(5) - **OOM recovery** — catches GPU OOM, clears cache, continues processing - **Crash recovery** — on startup, all PENDING/PROCESSING jobs in SQLite are automatically re-enqueued ### Queue Persistence The in-memory `asyncio.PriorityQueue` is **lost on server restart**. However, **no work is lost** because: 1. All jobs are persisted to SQLite the moment they're created (before being enqueued) 2. On startup, `resume_pending_jobs()` scans for any jobs in `PENDING` or `PROCESSING` state and re-enqueues them 3. Completed jobs and their images are already on disk This means a restart simply rebuilds the queue from the database. The only effect is a brief pause while the model reloads. ## Configuration All settings can be overridden via environment variables or `.env` file: | Variable | Default | Description | |----------|---------|-------------| | `MODEL_ID` | `black-forest-labs/FLUX.1-schnell` | HuggingFace model ID | | `DEVICE` | `cuda` | Torch device | | `DTYPE` | `float8_e4m3fn` | Model precision | | `ENABLE_TORCH_COMPILE` | `true` | Enable torch.compile optimization | | `DEFAULT_NUM_STEPS` | `4` | Default inference steps | | `DEFAULT_GUIDANCE_SCALE` | `0.0` | Default CFG scale | | `MAX_RESOLUTION` | `2048` | Maximum allowed width/height | | `MAX_IMMEDIATE_IMAGES` | `4` | Max images per /generate request | | `STORAGE_DIR` | `./output` | Where images and DB are stored | | `MAX_QUEUE_SIZE` | `10000` | Maximum queue capacity | | `SYNC_TIMEOUT` | `60.0` | Timeout for /generate/sync (seconds) | | `HOST` | `0.0.0.0` | Server bind address | | `PORT` | `8000` | Server port | ## Performance (H100 20GB, FP8) | Resolution | Time per image | 10K dataset | |-----------|---------------|-------------| | 512×512 | ~0.8–1.2s | ~2–3 hours | | 1024×1024 | ~2–4s | ~6–10 hours | | 1536×640 | ~2–3s | ~5–8 hours | ================================================ FILE: app/__init__.py ================================================ ================================================ FILE: app/api/__init__.py ================================================ ================================================ FILE: app/api/batch.py ================================================ """Batch / dataset generation endpoints with SSE progress streaming.""" from __future__ import annotations import asyncio import csv import io import json import os import tempfile import zipfile from pathlib import Path from typing import AsyncGenerator from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, File from fastapi.responses import StreamingResponse from sse_starlette.sse import EventSourceResponse from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.database import get_session from app.models.job import Job from app.models.enums import ImageFormat from app.schemas.batch import BatchProgress, BatchRequest from app.schemas.generate import ImageRequest from app.services.generation_service import create_batch from app.services.job_service import ( cancel_batch, get_batch_progress, retry_failed_in_batch, ) router = APIRouter(prefix="/batch", tags=["batch"]) def _get_worker(request: Request): return request.app.state.worker @router.post("", response_model=BatchProgress) async def create_batch_job( body: BatchRequest, request: Request, session: AsyncSession = Depends(get_session), ): """Create a batch of image generation jobs from a list of prompts.""" worker = _get_worker(request) batch_id = await create_batch(session, body.name, body.prompts, worker) progress = await get_batch_progress(session, batch_id) return BatchProgress(**progress) @router.post("/from-file", response_model=BatchProgress) async def create_batch_from_file( request: Request, file: UploadFile = File(...), name: str | None = None, session: AsyncSession = Depends(get_session), ): """Create a batch from a CSV or JSON file upload. CSV format: prompt,width,height,num_steps,seed,format JSON format: array of ImageRequest objects """ worker = _get_worker(request) content = await file.read() filename = file.filename or "" if filename.endswith(".json"): prompts = _parse_json(content) elif filename.endswith(".csv"): prompts = _parse_csv(content) else: raise HTTPException( status_code=400, detail="Unsupported file format. Use .csv or .json", ) if not prompts: raise HTTPException(status_code=400, detail="No valid prompts found in file") batch_id = await create_batch(session, name or filename, prompts, worker) progress = await get_batch_progress(session, batch_id) return BatchProgress(**progress) @router.get("/{batch_id}", response_model=BatchProgress) async def get_batch_status( batch_id: str, session: AsyncSession = Depends(get_session), ): progress = await get_batch_progress(session, batch_id) if not progress: raise HTTPException(status_code=404, detail="Batch not found") return BatchProgress(**progress) @router.get("/{batch_id}/stream") async def stream_batch_progress( batch_id: str, request: Request, session: AsyncSession = Depends(get_session), ): """SSE endpoint streaming batch progress updates until completion.""" progress = await get_batch_progress(session, batch_id) if not progress: raise HTTPException(status_code=404, detail="Batch not found") async def event_generator() -> AsyncGenerator[dict, None]: while True: if await request.is_disconnected(): break async with _fresh_session() as sess: p = await get_batch_progress(sess, batch_id) if p is None: break # Calculate ETA eta = _estimate_eta(p, request.app.state.worker) yield { "event": "progress", "data": json.dumps({**p, "estimated_remaining_seconds": eta}), } if p["status"] in ("completed", "failed", "cancelled"): yield {"event": "done", "data": json.dumps(p)} break await asyncio.sleep(2) return EventSourceResponse(event_generator()) @router.get("/{batch_id}/download") async def download_batch_images( batch_id: str, session: AsyncSession = Depends(get_session), ): """Download all completed images in a batch as a ZIP archive. Streams the ZIP incrementally so the full archive is never held in memory. File I/O runs in a thread pool to avoid blocking the async event loop. """ progress = await get_batch_progress(session, batch_id) if not progress: raise HTTPException(status_code=404, detail="Batch not found") result = await session.execute( select(Job).where(Job.batch_id == batch_id, Job.status == "completed") ) jobs = result.scalars().all() if not jobs: raise HTTPException(status_code=404, detail="No completed images in this batch") # Collect valid file paths up-front (lightweight check). file_entries: list[tuple[str, Path]] = [] for job in jobs: if not job.file_path: continue fp = Path(job.file_path) if fp.is_file(): file_entries.append((fp.name, fp)) if not file_entries: raise HTTPException(status_code=404, detail="No image files found on disk") batch_name = progress.get("name") or batch_id[:12] safe_name = "".join( c if c.isalnum() or c in (" ", "-", "_") else "_" for c in batch_name ).strip() filename = f"{safe_name}.zip" async def _stream_zip() -> AsyncGenerator[bytes, None]: """Build a valid ZIP in a temp file (off the event loop), then stream it in chunks so memory stays flat.""" loop = asyncio.get_running_loop() tmp_fd, tmp_path = tempfile.mkstemp(suffix=".zip") os.close(tmp_fd) def _build_zip() -> None: with zipfile.ZipFile(tmp_path, "w", zipfile.ZIP_STORED) as zf: for arc_name, file_path in file_entries: zf.write(str(file_path), arc_name) try: # Build the ZIP in a worker thread — no event-loop blocking. await loop.run_in_executor(None, _build_zip) # Stream the temp file in 256 KB chunks. chunk_size = 256 * 1024 with open(tmp_path, "rb") as f: while True: chunk = await loop.run_in_executor(None, f.read, chunk_size) if not chunk: break yield chunk finally: try: os.unlink(tmp_path) except OSError: pass return StreamingResponse( _stream_zip(), media_type="application/zip", headers={"Content-Disposition": f'attachment; filename="{filename}"'}, ) @router.delete("/{batch_id}", response_model=BatchProgress) async def cancel_batch_job( batch_id: str, session: AsyncSession = Depends(get_session), ): result = await cancel_batch(session, batch_id) if not result: raise HTTPException(status_code=404, detail="Batch not found") return BatchProgress(**result) @router.post("/{batch_id}/retry-failed") async def retry_failed_jobs( batch_id: str, request: Request, session: AsyncSession = Depends(get_session), ): """Re-enqueue all failed jobs in a batch.""" worker = _get_worker(request) job_ids = await retry_failed_in_batch(session, batch_id) if not job_ids: raise HTTPException(status_code=404, detail="No failed jobs found in batch") for jid in job_ids: await worker.enqueue(jid, priority=5) return {"retried": len(job_ids), "job_ids": job_ids} # ── File parsing helpers ───────────────────────────────────────── def _parse_json(content: bytes) -> list[ImageRequest]: try: data = json.loads(content) except json.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") if not isinstance(data, list): raise HTTPException(status_code=400, detail="JSON must be an array of objects") prompts = [] for i, item in enumerate(data): try: prompts.append(ImageRequest(**item)) except Exception as e: raise HTTPException(status_code=400, detail=f"Invalid entry at index {i}: {e}") return prompts def _parse_csv(content: bytes) -> list[ImageRequest]: try: text = content.decode("utf-8") except UnicodeDecodeError: raise HTTPException(status_code=400, detail="CSV must be UTF-8 encoded") reader = csv.DictReader(io.StringIO(text)) prompts = [] for i, row in enumerate(reader): try: kwargs: dict = {"prompt": row["prompt"]} if "width" in row and row["width"]: kwargs["width"] = int(row["width"]) if "height" in row and row["height"]: kwargs["height"] = int(row["height"]) if "num_steps" in row and row["num_steps"]: kwargs["num_steps"] = int(row["num_steps"]) if "seed" in row and row["seed"]: kwargs["seed"] = int(row["seed"]) if "format" in row and row["format"]: kwargs["format"] = ImageFormat(row["format"].strip().lower()) prompts.append(ImageRequest(**kwargs)) except KeyError: raise HTTPException( status_code=400, detail=f"Row {i + 1}: 'prompt' column is required" ) except Exception as e: raise HTTPException(status_code=400, detail=f"Row {i + 1}: {e}") return prompts def _estimate_eta(progress: dict, worker) -> float | None: """Estimate remaining seconds based on average inference time.""" stats = worker.get_stats() avg_time = stats.get("avg_inference_time_s", 0) if avg_time <= 0: return None remaining = progress.get("pending", 0) return round(remaining * avg_time, 1) # Helper to get fresh session for SSE generator from app.database import async_session_factory as _fresh_session # noqa: E402 ================================================ FILE: app/api/generate.py ================================================ """Image generation endpoints — async and sync modes.""" from __future__ import annotations import asyncio import json from pathlib import Path from typing import AsyncGenerator from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi.responses import FileResponse from sse_starlette.sse import EventSourceResponse from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.config import settings from app.database import get_session from app.models.job import Job from app.schemas.generate import ImageRequest, ImageResponse from app.services.generation_service import create_single_job from app.services.job_service import get_job from app.services.moderation_service import ModerationError, check_prompt_safety from app.utils.storage import image_url_path router = APIRouter(tags=["generate"]) _TERMINAL_JOB_STATUSES = {"completed", "failed", "cancelled"} def _get_worker(request: Request): return request.app.state.worker @router.post("/generate", response_model=ImageResponse) async def generate_async( body: ImageRequest, request: Request, session: AsyncSession = Depends(get_session), ): """Submit an image generation request (1–4 images). Returns immediately with job IDs.""" try: is_safe, reason = await check_prompt_safety(body.prompt) except ModerationError: raise HTTPException(status_code=503, detail="Content moderation service unavailable") if not is_safe: raise HTTPException( status_code=400, detail=f"Prompt contains unsafe or explicit content: {reason}", ) worker = _get_worker(request) job_ids = await create_single_job(session, body, worker) return ImageResponse(job_ids=job_ids, status="pending") @router.get("/generate/stream") async def stream_generate_progress( request: Request, job_id: list[str] = Query(..., description="One or more job IDs to monitor"), session: AsyncSession = Depends(get_session), ): """SSE stream for one or more generate job IDs until all reach terminal state.""" unique_job_ids = list(dict.fromkeys(job_id)) # Fail fast if any requested job ID does not exist. checks = [await get_job(session, jid) for jid in unique_job_ids] missing = [jid for jid, detail in zip(unique_job_ids, checks, strict=False) if detail is None] if missing: raise HTTPException( status_code=404, detail=f"Job not found: {', '.join(missing)}", ) async def event_generator() -> AsyncGenerator[dict, None]: while True: if await request.is_disconnected(): break # Expire identity map so concurrent worker updates are visible. session.expire_all() details = [await get_job(session, jid) for jid in unique_job_ids] if any(d is None for d in details): yield { "event": "error", "data": json.dumps({"detail": "Job no longer exists"}), } break jobs = [d.model_dump() for d in details if d is not None] total = len(jobs) completed = sum(1 for d in details if d and d.status == "completed") failed = sum(1 for d in details if d and d.status == "failed") cancelled = sum(1 for d in details if d and d.status == "cancelled") terminal = sum(1 for d in details if d and d.status in _TERMINAL_JOB_STATUSES) payload = { "total": total, "completed": completed, "failed": failed, "cancelled": cancelled, "pending": total - terminal, "jobs": jobs, } yield {"event": "progress", "data": json.dumps(payload)} if terminal == total: yield {"event": "done", "data": json.dumps(payload)} break await asyncio.sleep(1.0) return EventSourceResponse(event_generator()) @router.post("/generate/sync", response_model=ImageResponse) async def generate_sync( body: ImageRequest, request: Request, session: AsyncSession = Depends(get_session), ): """Submit and wait for image(s) to be generated (up to SYNC_TIMEOUT seconds).""" try: is_safe, reason = await check_prompt_safety(body.prompt) except ModerationError: raise HTTPException(status_code=503, detail="Content moderation service unavailable") if not is_safe: raise HTTPException( status_code=400, detail=f"Prompt contains unsafe or explicit content: {reason}", ) worker = _get_worker(request) job_ids = await create_single_job(session, body, worker) # Poll until all jobs complete or timeout deadline = asyncio.get_event_loop().time() + settings.SYNC_TIMEOUT while asyncio.get_event_loop().time() < deadline: await asyncio.sleep(0.3) details = [await get_job(session, jid) for jid in job_ids] if all(d and d.status in ("completed", "failed", "cancelled") for d in details): break else: raise HTTPException(status_code=408, detail="Generation timed out") failed = [d for d in details if d and d.status == "failed"] if failed and len(failed) == len(details): raise HTTPException( status_code=500, detail=failed[0].error_message or "Generation failed", ) return ImageResponse( job_ids=job_ids, status="completed", image_urls=[d.image_url if d and d.status == "completed" else None for d in details], ) @router.get("/jobs/{job_id}/image") async def serve_image( job_id: str, session: AsyncSession = Depends(get_session), ): """Serve the generated image file for a completed job.""" result = await session.execute(select(Job).where(Job.id == job_id)) job = result.scalar_one_or_none() if not job: raise HTTPException(status_code=404, detail="Job not found") if job.status != "completed" or not job.file_path: raise HTTPException(status_code=404, detail="Image not available") path = Path(job.file_path) if not path.is_file(): raise HTTPException(status_code=404, detail="Image file missing from disk") media_types = { "png": "image/png", "jpeg": "image/jpeg", "jpg": "image/jpeg", "webp": "image/webp", } ext = path.suffix.lstrip(".") media_type = media_types.get(ext, "application/octet-stream") return FileResponse(path, media_type=media_type, filename=path.name) ================================================ FILE: app/api/health.py ================================================ """Health and status endpoint.""" from __future__ import annotations import time from fastapi import APIRouter, Request router = APIRouter(tags=["health"]) _start_time: float = 0.0 def set_start_time() -> None: global _start_time _start_time = time.time() @router.get("/health") async def health_check(request: Request): engine = request.app.state.engine worker = request.app.state.worker uptime = time.time() - _start_time if _start_time else 0 return { "status": "ok", "model_loaded": engine.is_loaded, "vram": engine.get_vram_stats(), "worker": worker.get_stats(), "uptime_seconds": round(uptime, 1), } ================================================ FILE: app/api/jobs.py ================================================ """Job management endpoints.""" from __future__ import annotations from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy.ext.asyncio import AsyncSession from app.database import get_session from app.schemas.job import JobDetail, JobList from app.services.job_service import cancel_job, get_job, list_jobs router = APIRouter(prefix="/jobs", tags=["jobs"]) @router.get("", response_model=JobList) async def list_all_jobs( status: str | None = Query(None, description="Filter by status"), batch_id: str | None = Query(None, description="Filter by batch ID"), page: int = Query(1, ge=1), page_size: int = Query(50, ge=1, le=200), session: AsyncSession = Depends(get_session), ): jobs, total = await list_jobs(session, status=status, batch_id=batch_id, page=page, page_size=page_size) return JobList(jobs=jobs, total=total, page=page, page_size=page_size) @router.get("/{job_id}", response_model=JobDetail) async def get_job_detail( job_id: str, session: AsyncSession = Depends(get_session), ): job = await get_job(session, job_id) if not job: raise HTTPException(status_code=404, detail="Job not found") return job @router.delete("/{job_id}", response_model=JobDetail) async def cancel_job_endpoint( job_id: str, session: AsyncSession = Depends(get_session), ): job = await cancel_job(session, job_id) if not job: raise HTTPException(status_code=404, detail="Job not found") return job ================================================ FILE: app/config.py ================================================ from pydantic_settings import BaseSettings from pathlib import Path class Settings(BaseSettings): model_config = {"env_file": ".env", "env_file_encoding": "utf-8"} # ── Dev / mock mode ── MOCK_MODE: bool = False # ── Model ── MODEL_ID: str = "black-forest-labs/FLUX.1-schnell" DEVICE: str = "cuda" DTYPE: str = "float8_e4m3fn" ENABLE_TORCH_COMPILE: bool = True # ── Generation defaults ── DEFAULT_NUM_STEPS: int = 4 DEFAULT_GUIDANCE_SCALE: float = 0.0 MAX_RESOLUTION: int = 2048 MIN_RESOLUTION: int = 64 # ── Storage ── STORAGE_DIR: Path = Path("./output") # ── Logging ── LOG_DIR: Path = Path("./logs") LOG_LEVEL: str = "INFO" LOG_ROTATION_WHEN: str = "midnight" # midnight, h, m, s, etc. LOG_RETENTION_COUNT: int = 30 # number of rotated files to keep # ── Queue ── MAX_QUEUE_SIZE: int = 10000 # ── Server ── HOST: str = "0.0.0.0" PORT: int = 8000 # ── Immediate generation ── MAX_IMMEDIATE_IMAGES: int = 4 # ── Sync generation timeout (seconds) ── SYNC_TIMEOUT: float = 60.0 # ── Content moderation ── MODERATION_ENABLED: bool = True MODERATION_MODEL_ID: str = "google/gemma-3-1b-it" HF_API_TOKEN: str | None = None MODERATION_API_KEY: str = "" MODERATION_API_URL: str = "https://api.openai.com/v1/moderations" MODERATION_TIMEOUT: float = 10.0 settings = Settings() ================================================ FILE: app/database.py ================================================ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from app.config import settings DATABASE_URL = f"sqlite+aiosqlite:///{settings.STORAGE_DIR / 'flux_service.db'}" engine = create_async_engine(DATABASE_URL, echo=False) async_session_factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) async def init_db() -> None: from app.models.job import Base async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) async def get_session() -> AsyncSession: # type: ignore[misc] async with async_session_factory() as session: yield session # type: ignore[misc] ================================================ FILE: app/engine/__init__.py ================================================ ================================================ FILE: app/engine/engine_config.py ================================================ """Resolution presets and engine configuration.""" PRESET_RESOLUTIONS: dict[str, tuple[int, int]] = { "square_sm": (512, 512), "square": (1024, 1024), "landscape": (1344, 768), "portrait": (768, 1344), "wide": (1536, 640), } # Approximate peak VRAM (GB) by max dimension for FP8 Flux Schnell. # Used as a safety guard — not an exact model; conservative estimates. VRAM_ESTIMATES: dict[int, float] = { 512: 13.5, 768: 14.5, 1024: 16.0, 1536: 18.5, 2048: 20.0, } def estimated_vram_gb(width: int, height: int) -> float: """Return conservative VRAM estimate for a given resolution.""" max_dim = max(width, height) for threshold in sorted(VRAM_ESTIMATES.keys()): if max_dim <= threshold: return VRAM_ESTIMATES[threshold] return VRAM_ESTIMATES[max(VRAM_ESTIMATES.keys())] ================================================ FILE: app/engine/flux_engine.py ================================================ """Flux.1 Schnell inference engine — model loading, generation, VRAM management.""" from __future__ import annotations import logging import time from dataclasses import dataclass, field import torch from PIL import Image from app.config import settings logger = logging.getLogger(__name__) # Dtype mapping _DTYPE_MAP: dict[str, torch.dtype] = { "float16": torch.float16, "bfloat16": torch.bfloat16, "float8_e4m3fn": torch.float8_e4m3fn, "float32": torch.float32, } @dataclass class GenerationResult: image: Image.Image seed: int inference_time: float # seconds @dataclass class FluxEngine: """Singleton engine wrapping the Flux.1 Schnell diffusion pipeline.""" _pipeline: object | None = field(default=None, repr=False) _loaded: bool = False _device: str = "cuda" _dtype: torch.dtype = torch.bfloat16 def load_model(self) -> None: """Load the Flux pipeline onto the GPU.""" if self._loaded: logger.info("Model already loaded, skipping.") return from diffusers import FluxPipeline self._device = settings.DEVICE self._dtype = _DTYPE_MAP.get(settings.DTYPE, torch.bfloat16) logger.info( "Loading Flux.1 Schnell model=%s device=%s dtype=%s", settings.MODEL_ID, self._device, self._dtype, ) load_kwargs: dict = { "torch_dtype": self._dtype, } # For FP8 on H100, use the fp8 variant if available if self._dtype == torch.float8_e4m3fn: load_kwargs["variant"] = "fp8" self._pipeline = FluxPipeline.from_pretrained( settings.MODEL_ID, **load_kwargs, ) self._pipeline.to(self._device) self._pipeline.set_progress_bar_config(disable=True) # torch.compile for H100 Tensor Core acceleration if settings.ENABLE_TORCH_COMPILE and self._device == "cuda": logger.info("Compiling transformer with torch.compile (mode=max-autotune-no-cudagraphs)...") self._pipeline.transformer = torch.compile( self._pipeline.transformer, mode="max-autotune-no-cudagraphs", ) self._loaded = True logger.info("Model loaded. VRAM: %.2f GB", self.get_vram_usage_gb()) def warmup(self) -> None: """Run a single throwaway inference to trigger torch.compile tracing.""" if not self._loaded: raise RuntimeError("Model not loaded — call load_model() first") logger.info("Running warm-up inference (first run compiles CUDA kernels)...") start = time.perf_counter() self.generate(prompt="warmup", width=512, height=512, num_steps=1, seed=0) elapsed = time.perf_counter() - start logger.info("Warm-up completed in %.1fs", elapsed) @torch.inference_mode() def generate( self, prompt: str, width: int = 1024, height: int = 1024, num_steps: int = 4, guidance_scale: float = 0.0, seed: int | None = None, ) -> GenerationResult: """Generate a single image. Runs synchronously on the GPU.""" if not self._loaded: raise RuntimeError("Model not loaded — call load_model() first") if seed is None: seed = torch.randint(0, 2**32, (1,)).item() generator = torch.Generator(device=self._device).manual_seed(seed) start = time.perf_counter() output = self._pipeline( prompt=prompt, width=width, height=height, num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator, output_type="pil", ) elapsed = time.perf_counter() - start image: Image.Image = output.images[0] logger.info( "Generated %dx%d in %.2fs (steps=%d, seed=%d)", width, height, elapsed, num_steps, seed, ) return GenerationResult(image=image, seed=seed, inference_time=elapsed) def get_vram_usage_gb(self) -> float: """Return current GPU memory allocated in GB.""" if not torch.cuda.is_available(): return 0.0 return torch.cuda.memory_allocated(self._device) / (1024**3) def get_vram_stats(self) -> dict: """Return detailed VRAM statistics.""" if not torch.cuda.is_available(): return {"available": False} return { "available": True, "allocated_gb": round(torch.cuda.memory_allocated(self._device) / (1024**3), 2), "reserved_gb": round(torch.cuda.memory_reserved(self._device) / (1024**3), 2), "max_allocated_gb": round( torch.cuda.max_memory_allocated(self._device) / (1024**3), 2 ), "total_gb": round(torch.cuda.get_device_properties(0).total_mem / (1024**3), 2), "device_name": torch.cuda.get_device_name(0), } def unload(self) -> None: """Free the model and clear CUDA cache.""" if self._pipeline is not None: del self._pipeline self._pipeline = None self._loaded = False if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() logger.info("Model unloaded, CUDA cache cleared.") @property def is_loaded(self) -> bool: return self._loaded ================================================ FILE: app/engine/mock_engine.py ================================================ """Mock FluxEngine for local development without GPU or model weights.""" from __future__ import annotations import logging import random import time from dataclasses import dataclass from PIL import Image, ImageDraw logger = logging.getLogger(__name__) @dataclass class GenerationResult: image: Image.Image seed: int inference_time: float @dataclass class MockFluxEngine: """Drop-in replacement for FluxEngine that generates gradient placeholder images.""" _loaded: bool = False def load_model(self) -> None: logger.info("MockFluxEngine: loaded (no actual model).") self._loaded = True def warmup(self) -> None: logger.info("MockFluxEngine: warmup (no-op).") def generate( self, prompt: str, width: int = 1024, height: int = 1024, num_steps: int = 4, guidance_scale: float = 0.0, seed: int | None = None, ) -> GenerationResult: if seed is None: seed = random.randint(0, 2**32 - 1) start = time.perf_counter() rng = random.Random(seed) r1, g1, b1 = rng.randint(40, 180), rng.randint(40, 180), rng.randint(40, 180) r2, g2, b2 = rng.randint(40, 180), rng.randint(40, 180), rng.randint(40, 180) img = Image.new("RGB", (width, height)) pixels = img.load() for y in range(height): t = y / max(height - 1, 1) r = int(r1 + (r2 - r1) * t) g = int(g1 + (g2 - g1) * t) b = int(b1 + (b2 - b1) * t) for x in range(width): pixels[x, y] = (r, g, b) draw = ImageDraw.Draw(img) lines = [ f"[MOCK] {prompt[:60]}", f"Seed: {seed} | {width}x{height} | steps={num_steps}", ] draw.text((20, 20), "\n".join(lines), fill=(255, 255, 255)) # Small artificial delay to simulate inference time.sleep(2.05 + rng.random() * 0.15) elapsed = time.perf_counter() - start logger.info( "MockFluxEngine: generated %dx%d in %.2fs (seed=%d)", width, height, elapsed, seed, ) return GenerationResult(image=img, seed=seed, inference_time=elapsed) def get_vram_usage_gb(self) -> float: return 0.0 def get_vram_stats(self) -> dict: return {"available": False, "mock_mode": True} def unload(self) -> None: self._loaded = False logger.info("MockFluxEngine: unloaded.") @property def is_loaded(self) -> bool: return self._loaded ================================================ FILE: app/main.py ================================================ """FastAPI application — lifespan manages model loading, worker, and DB.""" from __future__ import annotations import logging from contextlib import asynccontextmanager from logging.handlers import TimedRotatingFileHandler from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from app.api import batch, generate, health, jobs from app.config import settings from app.database import init_db from app.utils.moderation import get_moderation_engine from app.utils.storage import ensure_storage_dirs from app.worker.gpu_worker import GPUWorker def _setup_logging() -> None: """Configure logging to both console and a time-rotating file.""" settings.LOG_DIR.mkdir(parents=True, exist_ok=True) log_format = "%(asctime)s %(levelname)-8s %(name)s — %(message)s" level = getattr(logging, settings.LOG_LEVEL.upper(), logging.INFO) root = logging.getLogger() root.setLevel(level) # Console handler console = logging.StreamHandler() console.setLevel(level) console.setFormatter(logging.Formatter(log_format)) root.addHandler(console) # Time-rotating file handler file_handler = TimedRotatingFileHandler( filename=settings.LOG_DIR / "flux-service.log", when=settings.LOG_ROTATION_WHEN, backupCount=settings.LOG_RETENTION_COUNT, encoding="utf-8", ) file_handler.setLevel(level) file_handler.setFormatter(logging.Formatter(log_format)) file_handler.suffix = "%Y-%m-%d" # rotated files: flux-service.log.2026-03-08 root.addHandler(file_handler) _setup_logging() logger = logging.getLogger(__name__) @asynccontextmanager async def lifespan(app: FastAPI): # ── Startup ────────────────────────────────────────────────── logger.info("Starting Flux Image Service...") # Storage directories ensure_storage_dirs() # Database await init_db() logger.info("Database initialized.") # Flux engine (mock or real) if settings.MOCK_MODE: from app.engine.mock_engine import MockFluxEngine engine = MockFluxEngine() else: from app.engine.flux_engine import FluxEngine engine = FluxEngine() engine.load_model() engine.warmup() app.state.engine = engine # Content moderation engine (skip model loading in mock mode) moderation_engine = get_moderation_engine() if not settings.MOCK_MODE: moderation_engine.load(settings.MODERATION_MODEL_ID) else: logger.info("Mock mode: skipping moderation model load.") app.state.moderation_engine = moderation_engine # GPU worker worker = GPUWorker(engine=engine) worker.start() resumed = await worker.resume_pending_jobs() if resumed: logger.info("Resumed %d pending jobs from previous session.", resumed) app.state.worker = worker # Start time for health endpoint health.set_start_time() logger.info("Flux Image Service ready. Queue depth: %d", worker.queue_depth) yield # ── Shutdown ───────────────────────────────────────────────── logger.info("Shutting down Flux Image Service...") await worker.shutdown() engine.unload() logger.info("Shutdown complete.") app = FastAPI( title="Flux Image Service", description="Local image generation API powered by Flux.1 Schnell 12B", version="0.1.0", lifespan=lifespan, ) # CORS — allow all for local development app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # Mount routers app.include_router(generate.router) app.include_router(batch.router) app.include_router(jobs.router) app.include_router(health.router) ================================================ FILE: app/models/__init__.py ================================================ ================================================ FILE: app/models/enums.py ================================================ import enum class JobStatus(str, enum.Enum): PENDING = "pending" PROCESSING = "processing" COMPLETED = "completed" FAILED = "failed" CANCELLED = "cancelled" class JobPriority(int, enum.Enum): REALTIME = 1 BATCH = 5 class ImageFormat(str, enum.Enum): PNG = "png" JPEG = "jpeg" WEBP = "webp" ================================================ FILE: app/models/job.py ================================================ import uuid from datetime import datetime, timezone from sqlalchemy import ( Column, DateTime, Enum, Float, ForeignKey, Integer, String, Text, ) from sqlalchemy.orm import DeclarativeBase, relationship class Base(DeclarativeBase): pass def _utcnow() -> datetime: return datetime.now(timezone.utc) class Job(Base): __tablename__ = "jobs" id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) prompt = Column(Text, nullable=False) negative_prompt = Column(Text, nullable=True) width = Column(Integer, nullable=False, default=1024) height = Column(Integer, nullable=False, default=1024) num_steps = Column(Integer, nullable=False, default=4) guidance_scale = Column(Float, nullable=False, default=0.0) seed = Column(Integer, nullable=True) status = Column( Enum("pending", "processing", "completed", "failed", "cancelled", name="job_status"), nullable=False, default="pending", ) priority = Column(Integer, nullable=False, default=1) format = Column( Enum("png", "jpeg", "webp", name="image_format"), nullable=False, default="png", ) file_path = Column(String(512), nullable=True) error_message = Column(Text, nullable=True) batch_id = Column(String(36), ForeignKey("batch_jobs.id"), nullable=True) created_at = Column(DateTime(timezone=True), nullable=False, default=_utcnow) started_at = Column(DateTime(timezone=True), nullable=True) completed_at = Column(DateTime(timezone=True), nullable=True) batch = relationship("BatchJob", back_populates="jobs") class BatchJob(Base): __tablename__ = "batch_jobs" id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) name = Column(String(255), nullable=True) total_count = Column(Integer, nullable=False, default=0) completed_count = Column(Integer, nullable=False, default=0) failed_count = Column(Integer, nullable=False, default=0) status = Column( Enum("pending", "processing", "completed", "failed", "cancelled", name="batch_status"), nullable=False, default="pending", ) created_at = Column(DateTime(timezone=True), nullable=False, default=_utcnow) completed_at = Column(DateTime(timezone=True), nullable=True) jobs = relationship("Job", back_populates="batch", lazy="selectin") ================================================ FILE: app/schemas/__init__.py ================================================ ================================================ FILE: app/schemas/batch.py ================================================ from pydantic import BaseModel, Field from app.schemas.generate import ImageRequest class BatchRequest(BaseModel): name: str | None = None prompts: list[ImageRequest] = Field(..., min_length=1, max_length=50000) class BatchProgress(BaseModel): batch_id: str name: str | None total: int completed: int failed: int cancelled: int pending: int status: str estimated_remaining_seconds: float | None = None created_at: str completed_at: str | None = None ================================================ FILE: app/schemas/generate.py ================================================ from pydantic import BaseModel, Field, field_validator, model_validator from app.config import settings from app.models.enums import ImageFormat class ImageRequest(BaseModel): prompt: str = Field(..., min_length=1, max_length=2000) negative_prompt: str | None = None width: int = Field(default=1024, ge=settings.MIN_RESOLUTION, le=settings.MAX_RESOLUTION) height: int = Field(default=1024, ge=settings.MIN_RESOLUTION, le=settings.MAX_RESOLUTION) num_steps: int = Field(default=settings.DEFAULT_NUM_STEPS, ge=1, le=50) guidance_scale: float = Field(default=settings.DEFAULT_GUIDANCE_SCALE, ge=0.0, le=20.0) seed: int | None = Field(default=None, ge=0, le=2**32 - 1) num_images: int = Field(default=1, ge=1, le=settings.MAX_IMMEDIATE_IMAGES) format: ImageFormat = ImageFormat.PNG @field_validator("prompt") @classmethod def check_content_safety(cls, v: str) -> str: from app.utils.moderation import check_prompt check_prompt(v) return v @model_validator(mode="after") def validate_resolution(self) -> "ImageRequest": if self.width % 8 != 0: raise ValueError("width must be a multiple of 8") if self.height % 8 != 0: raise ValueError("height must be a multiple of 8") return self class ImageResponse(BaseModel): job_ids: list[str] status: str image_urls: list[str | None] = [] PRESET_RESOLUTIONS: dict[str, tuple[int, int]] = { "square_sm": (512, 512), "square": (1024, 1024), "landscape": (1344, 768), "portrait": (768, 1344), "wide": (1536, 640), } ================================================ FILE: app/schemas/job.py ================================================ from pydantic import BaseModel class JobDetail(BaseModel): id: str prompt: str negative_prompt: str | None width: int height: int num_steps: int guidance_scale: float seed: int | None status: str priority: int format: str file_path: str | None image_url: str | None error_message: str | None batch_id: str | None created_at: str started_at: str | None completed_at: str | None class JobList(BaseModel): jobs: list[JobDetail] total: int page: int page_size: int ================================================ FILE: app/services/__init__.py ================================================ ================================================ FILE: app/services/generation_service.py ================================================ """Orchestrates image generation: validate → persist → enqueue.""" from __future__ import annotations import uuid from typing import TYPE_CHECKING from sqlalchemy.ext.asyncio import AsyncSession from app.models.enums import JobPriority from app.models.job import BatchJob, Job from app.schemas.generate import ImageRequest if TYPE_CHECKING: from app.worker.gpu_worker import GPUWorker async def create_single_job( session: AsyncSession, request: ImageRequest, worker: "GPUWorker", priority: int = JobPriority.REALTIME, ) -> list[str]: """Create generation job(s) for the request. Returns list of job_ids. Creates one job per requested image (num_images). Each job gets a unique seed derived from the base seed (if provided) to ensure varied outputs. """ job_ids: list[str] = [] queue_priority = int(priority) for i in range(request.num_images): # Derive per-image seed: if user provided a seed, offset it per image seed = None if request.seed is not None: seed = (request.seed + i) % (2**32) job = Job( id=str(uuid.uuid4()), prompt=request.prompt, negative_prompt=request.negative_prompt, width=request.width, height=request.height, num_steps=request.num_steps, guidance_scale=request.guidance_scale, seed=seed, format=request.format.value, priority=queue_priority, status="pending", ) session.add(job) job_ids.append(job.id) await session.commit() for jid in job_ids: await worker.enqueue(jid, queue_priority) return job_ids async def create_batch( session: AsyncSession, name: str | None, prompts: list[ImageRequest], worker: "GPUWorker", ) -> str: """Create a batch of generation jobs. Returns batch_id.""" batch_id = str(uuid.uuid4()) batch = BatchJob( id=batch_id, name=name, total_count=len(prompts), status="pending", ) session.add(batch) jobs: list[Job] = [] for req in prompts: job = Job( id=str(uuid.uuid4()), prompt=req.prompt, negative_prompt=req.negative_prompt, width=req.width, height=req.height, num_steps=req.num_steps, guidance_scale=req.guidance_scale, seed=req.seed, format=req.format.value, priority=int(JobPriority.BATCH), status="pending", batch_id=batch_id, ) jobs.append(job) session.add_all(jobs) await session.commit() # Enqueue all jobs for job in jobs: await worker.enqueue(job.id, int(JobPriority.BATCH)) return batch_id ================================================ FILE: app/services/job_service.py ================================================ """Job and batch CRUD operations.""" from __future__ import annotations from datetime import datetime, timezone from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from app.models.enums import JobStatus from app.models.job import BatchJob, Job from app.schemas.job import JobDetail from app.utils.storage import image_url_path def _job_to_detail(job: Job) -> JobDetail: return JobDetail( id=job.id, prompt=job.prompt, negative_prompt=job.negative_prompt, width=job.width, height=job.height, num_steps=job.num_steps, guidance_scale=job.guidance_scale, seed=job.seed, status=job.status, priority=job.priority, format=job.format, file_path=job.file_path, image_url=image_url_path(job.id) if job.status == "completed" else None, error_message=job.error_message, batch_id=job.batch_id, created_at=job.created_at.isoformat() if job.created_at else None, started_at=job.started_at.isoformat() if job.started_at else None, completed_at=job.completed_at.isoformat() if job.completed_at else None, ) async def get_job(session: AsyncSession, job_id: str) -> JobDetail | None: result = await session.execute(select(Job).where(Job.id == job_id)) job = result.scalar_one_or_none() if not job: return None return _job_to_detail(job) async def list_jobs( session: AsyncSession, status: str | None = None, batch_id: str | None = None, page: int = 1, page_size: int = 50, ) -> tuple[list[JobDetail], int]: query = select(Job).order_by(Job.created_at.desc()) count_query = select(func.count()).select_from(Job) if status: query = query.where(Job.status == status) count_query = count_query.where(Job.status == status) if batch_id: query = query.where(Job.batch_id == batch_id) count_query = count_query.where(Job.batch_id == batch_id) total = (await session.execute(count_query)).scalar() or 0 query = query.offset((page - 1) * page_size).limit(page_size) result = await session.execute(query) jobs = [_job_to_detail(j) for j in result.scalars().all()] return jobs, total async def cancel_job(session: AsyncSession, job_id: str) -> JobDetail | None: result = await session.execute(select(Job).where(Job.id == job_id)) job = result.scalar_one_or_none() if not job: return None if job.status in ("completed", "failed", "cancelled"): return _job_to_detail(job) job.status = "cancelled" job.completed_at = datetime.now(timezone.utc) await session.commit() return _job_to_detail(job) async def get_batch_progress(session: AsyncSession, batch_id: str) -> dict | None: result = await session.execute(select(BatchJob).where(BatchJob.id == batch_id)) batch = result.scalar_one_or_none() if not batch: return None cancelled = ( await session.execute( select(func.count()).where(Job.batch_id == batch_id, Job.status == "cancelled") ) ).scalar() or 0 pending = batch.total_count - batch.completed_count - batch.failed_count - cancelled return { "batch_id": batch.id, "name": batch.name, "total": batch.total_count, "completed": batch.completed_count, "failed": batch.failed_count, "cancelled": cancelled, "pending": pending, "status": batch.status, "created_at": batch.created_at.isoformat() if batch.created_at else None, "completed_at": batch.completed_at.isoformat() if batch.completed_at else None, } async def cancel_batch(session: AsyncSession, batch_id: str) -> dict | None: result = await session.execute(select(BatchJob).where(BatchJob.id == batch_id)) batch = result.scalar_one_or_none() if not batch: return None # Cancel all pending jobs in this batch pending_result = await session.execute( select(Job).where(Job.batch_id == batch_id, Job.status == "pending") ) now = datetime.now(timezone.utc) for job in pending_result.scalars().all(): job.status = "cancelled" job.completed_at = now batch.status = "cancelled" batch.completed_at = now await session.commit() return await get_batch_progress(session, batch_id) async def retry_failed_in_batch(session: AsyncSession, batch_id: str) -> list[str]: """Re-set failed jobs in a batch back to pending. Returns list of job IDs.""" result = await session.execute( select(Job).where(Job.batch_id == batch_id, Job.status == "failed") ) job_ids = [] for job in result.scalars().all(): job.status = "pending" job.error_message = None job.completed_at = None job_ids.append(job.id) if job_ids: # Reset batch status batch_result = await session.execute(select(BatchJob).where(BatchJob.id == batch_id)) batch = batch_result.scalar_one_or_none() if batch: batch.status = "processing" batch.completed_at = None batch.failed_count = 0 await session.commit() return job_ids ================================================ FILE: app/services/moderation_service.py ================================================ """Content moderation service – checks prompt safety via a third-party API.""" from __future__ import annotations import logging import httpx from app.config import settings logger = logging.getLogger(__name__) class ModerationError(Exception): """Raised when the moderation API call fails unexpectedly.""" async def check_prompt_safety(prompt: str) -> tuple[bool, str | None]: """Check whether a prompt is safe using the configured moderation API. Calls the API specified by ``MODERATION_API_URL`` (default: OpenAI moderation endpoint) and returns ``(is_safe, reason)``. When moderation is disabled or no API key is configured the function returns ``(True, None)`` without making any network request. Args: prompt: The user-supplied generation prompt to evaluate. Returns: A ``(is_safe, reason)`` tuple. ``is_safe`` is ``True`` when the prompt is considered safe. When flagged, ``reason`` contains the triggered content categories (e.g. ``"sexual, violence"``). Raises: ModerationError: If the API request fails and the failure cannot be handled gracefully. """ if not settings.MODERATION_ENABLED: return True, None if not settings.MODERATION_API_KEY: logger.warning( "Content moderation is enabled but MODERATION_API_KEY is not set; " "skipping moderation check." ) return True, None try: async with httpx.AsyncClient(timeout=settings.MODERATION_TIMEOUT) as client: response = await client.post( settings.MODERATION_API_URL, headers={ "Authorization": f"Bearer {settings.MODERATION_API_KEY}", "Content-Type": "application/json", }, json={"input": prompt}, ) response.raise_for_status() data = response.json() except httpx.HTTPStatusError as exc: logger.error( "Moderation API returned an error status %s: %s", exc.response.status_code, exc.response.text, ) raise ModerationError( f"Moderation API returned HTTP {exc.response.status_code}" ) from exc except httpx.HTTPError as exc: logger.error("Moderation API request failed: %s", exc) raise ModerationError(f"Moderation API request failed: {exc}") from exc # Parse OpenAI-compatible moderation response: # {"results": [{"flagged": bool, "categories": {category: bool, ...}}]} results = data.get("results", []) if not results: return True, None result = results[0] if result.get("flagged", False): categories: dict[str, bool] = result.get("categories", {}) flagged_categories = [cat for cat, triggered in categories.items() if triggered] reason = ", ".join(flagged_categories) if flagged_categories else "unsafe content" return False, reason return True, None ================================================ FILE: app/utils/__init__.py ================================================ ================================================ FILE: app/utils/image_utils.py ================================================ """Image saving, format conversion, and metadata embedding.""" from __future__ import annotations from pathlib import Path from PIL import Image, PngImagePlugin from app.models.enums import ImageFormat def save_image( image: Image.Image, path: Path, fmt: ImageFormat, quality: int = 95, prompt: str | None = None, seed: int | None = None, ) -> None: """Save a PIL Image in the requested format with embedded metadata.""" if fmt == ImageFormat.PNG: meta = PngImagePlugin.PngInfo() if prompt: meta.add_text("prompt", prompt) if seed is not None: meta.add_text("seed", str(seed)) image.save(path, format="PNG", pnginfo=meta) elif fmt == ImageFormat.JPEG: # JPEG doesn't support alpha if image.mode in ("RGBA", "LA"): image = image.convert("RGB") image.save(path, format="JPEG", quality=quality) elif fmt == ImageFormat.WEBP: image.save(path, format="WEBP", quality=quality) else: image.save(path, format="PNG") ================================================ FILE: app/utils/moderation.py ================================================ """Content moderation using Gemma-3-1b-it (local) with HF Inference API fallback.""" from __future__ import annotations import logging from dataclasses import dataclass logger = logging.getLogger(__name__) _SAFETY_SYSTEM_PROMPT = ( "You are a content safety classifier. " "Determine whether the following text prompt for image generation contains " "harmful, illegal, sexually explicit, or violent content. " "Respond with exactly one word: 'SAFE' if the prompt is acceptable, " "or 'UNSAFE' if it should be rejected. Do not include any other text." ) @dataclass class ModerationResult: is_safe: bool reason: str = "" source: str = "" class ModerationEngine: """Content moderation engine. Primary check: locally-loaded ``google/gemma-3-1b-it`` text-generation pipeline. Fallback check: HuggingFace Inference API (agent API) using the same model. If both are unavailable the engine fails open and logs a warning. """ def __init__(self) -> None: self._pipeline = None self._model_id: str | None = None # ── Loading ────────────────────────────────────────────────────────────── def load(self, model_id: str = "google/gemma-3-1b-it") -> None: """Load *model_id* locally for synchronous inference.""" try: import torch from transformers import pipeline logger.info("Loading moderation model %s …", model_id) self._pipeline = pipeline( "text-generation", model=model_id, device_map="auto", torch_dtype=torch.bfloat16, max_new_tokens=10, ) self._model_id = model_id logger.info("Moderation model %s loaded successfully.", model_id) except Exception as exc: # noqa: BLE001 logger.warning("Failed to load moderation model %s: %s", model_id, exc) self._pipeline = None @property def is_loaded(self) -> bool: """Return True when the local pipeline is ready.""" return self._pipeline is not None # ── Local-model check ──────────────────────────────────────────────────── def _parse_pipeline_output(self, raw: object) -> str: """Extract the assistant reply text from a text-generation pipeline result.""" # The pipeline returns a list of dicts like: # [{"generated_text": [{"role": "system", ...}, {"role": "assistant", "content": "SAFE"}]}] if isinstance(raw, list) and raw: inner = raw[0].get("generated_text", raw[0]) if isinstance(inner, list) and inner: last = inner[-1] if isinstance(last, dict): return last.get("content", "").strip().upper() return str(inner).strip().upper() return str(raw).strip().upper() def check_with_local_model(self, prompt: str) -> ModerationResult: """Run the locally-loaded Gemma pipeline and return a :class:`ModerationResult`.""" if not self.is_loaded: raise RuntimeError("Local moderation model is not loaded") try: messages = [ {"role": "system", "content": _SAFETY_SYSTEM_PROMPT}, {"role": "user", "content": f"Prompt: {prompt}"}, ] raw = self._pipeline(messages) response = self._parse_pipeline_output(raw) is_safe = "UNSAFE" not in response return ModerationResult( is_safe=is_safe, reason="" if is_safe else f"Flagged by local model: {response}", source="local", ) except Exception as exc: raise RuntimeError(f"Local model inference failed: {exc}") from exc # ── Agent-API fallback ─────────────────────────────────────────────────── def check_with_agent_api( self, prompt: str, *, model_id: str | None = None, token: str | None = None, ) -> ModerationResult: """Use the HuggingFace Inference API (agent API) as a fallback safety check.""" from huggingface_hub import InferenceClient _model = model_id or self._model_id or "google/gemma-3-1b-it" client = InferenceClient(model=_model, token=token) messages = [ {"role": "system", "content": _SAFETY_SYSTEM_PROMPT}, {"role": "user", "content": f"Prompt: {prompt}"}, ] result = client.chat_completion(messages=messages, max_tokens=10) response = result.choices[0].message.content.strip().upper() is_safe = "UNSAFE" not in response return ModerationResult( is_safe=is_safe, reason="" if is_safe else f"Flagged by agent API: {response}", source="agent_api", ) # ── Unified entry-point ────────────────────────────────────────────────── def check(self, prompt: str, *, hf_token: str | None = None) -> ModerationResult: """Check *prompt* safety. Tries the local Gemma model first; falls back to the HF Inference API; if both are unavailable the call succeeds with a warning (fail-open). """ # 1. Local model if self.is_loaded: try: return self.check_with_local_model(prompt) except Exception as exc: # noqa: BLE001 logger.warning( "Local moderation check failed — falling back to agent API: %s", exc ) # 2. Agent API fallback if hf_token or self._model_id: try: return self.check_with_agent_api( prompt, model_id=self._model_id, token=hf_token, ) except Exception as exc: # noqa: BLE001 logger.warning( "Agent API moderation check failed — allowing prompt through: %s", exc ) # 3. Both unavailable — fail open logger.warning( "No moderation backend available; prompt allowed through without safety check." ) return ModerationResult(is_safe=True, reason="Moderation unavailable", source="none") # ── Module-level singleton ──────────────────────────────────────────────────── _moderation_engine: ModerationEngine = ModerationEngine() def get_moderation_engine() -> ModerationEngine: """Return the global :class:`ModerationEngine` singleton.""" return _moderation_engine def check_prompt(prompt: str) -> None: """Validate *prompt* against the moderation engine. Raises :class:`ValueError` if the prompt is rejected. Does nothing when ``MODERATION_ENABLED`` is ``False`` or when no moderation backend is configured (fail-open to avoid breaking existing tests). """ from app.config import settings if not settings.MODERATION_ENABLED: return result = _moderation_engine.check(prompt, hf_token=settings.HF_API_TOKEN) if not result.is_safe: raise ValueError(result.reason or "Prompt rejected by content moderation") ================================================ FILE: app/utils/storage.py ================================================ """Local filesystem storage management for generated images.""" from __future__ import annotations from datetime import datetime, timezone from pathlib import Path from app.config import settings def ensure_storage_dirs() -> None: """Create the output directory tree if it doesn't exist.""" (settings.STORAGE_DIR / "images").mkdir(parents=True, exist_ok=True) (settings.STORAGE_DIR / "batches").mkdir(parents=True, exist_ok=True) def get_image_path(job_id: str, fmt: str, batch_id: str | None = None) -> Path: """Return the storage path for a generated image. Layout: output/batches/{batch_id}/{job_id}.{ext} (if part of a batch) output/images/{YYYY-MM-DD}/{job_id}.{ext} (standalone) """ ext = fmt.lower() if ext == "jpeg": ext = "jpg" if batch_id: directory = settings.STORAGE_DIR / "batches" / batch_id else: date_str = datetime.now(timezone.utc).strftime("%Y-%m-%d") directory = settings.STORAGE_DIR / "images" / date_str directory.mkdir(parents=True, exist_ok=True) return directory / f"{job_id}.{ext}" def image_url_path(job_id: str) -> str: """Return the API URL path to serve this image.""" return f"/jobs/{job_id}/image" ================================================ FILE: app/worker/__init__.py ================================================ ================================================ FILE: app/worker/gpu_worker.py ================================================ """GPU worker — async priority queue consumer driving the Flux engine.""" from __future__ import annotations import asyncio import logging import time from dataclasses import dataclass, field from datetime import datetime, timezone from typing import TYPE_CHECKING from sqlalchemy import select, func from app.database import async_session_factory from app.models.enums import ImageFormat, JobStatus from app.models.job import BatchJob, Job from app.utils.image_utils import save_image from app.utils.storage import get_image_path, image_url_path if TYPE_CHECKING: from app.engine.flux_engine import FluxEngine logger = logging.getLogger(__name__) @dataclass(order=True) class _QueueItem: """Priority queue item. Lower priority value = higher urgency.""" priority: int timestamp: float = field(compare=True) job_id: str = field(compare=False) @dataclass class GPUWorker: """Consumes jobs from an async priority queue and runs inference.""" engine: "FluxEngine" _queue: asyncio.PriorityQueue[_QueueItem] = field(default_factory=asyncio.PriorityQueue) _stop_event: asyncio.Event = field(default_factory=asyncio.Event) _task: asyncio.Task | None = field(default=None, repr=False) _jobs_processed: int = 0 _total_inference_time: float = 0.0 # ── Queue interface ────────────────────────────────────────────── @staticmethod def _normalize_priority(priority: int) -> int: """Convert enum/int-like priorities to a stable numeric queue value.""" try: value = int(priority) except (TypeError, ValueError): value = 1 return max(1, value) async def enqueue(self, job_id: str, priority: int = 1) -> None: normalized = self._normalize_priority(priority) item = _QueueItem(priority=normalized, timestamp=time.time(), job_id=job_id) await self._queue.put(item) logger.debug("Enqueued job=%s priority=%d queue_depth=%d", job_id, normalized, self.queue_depth) @property def queue_depth(self) -> int: return self._queue.qsize() # ── Lifecycle ──────────────────────────────────────────────────── def start(self) -> asyncio.Task: """Start the worker loop as a background task.""" self._stop_event.clear() self._task = asyncio.create_task(self._run(), name="gpu-worker") logger.info("GPU worker started.") return self._task async def shutdown(self) -> None: """Signal the worker to stop and wait for it to finish the current job.""" logger.info("Shutting down GPU worker (queue_depth=%d)...", self.queue_depth) self._stop_event.set() if self._task and not self._task.done(): # Unblock a pending queue.get() by pushing a sentinel await self._queue.put(_QueueItem(priority=999, timestamp=time.time(), job_id="__stop__")) await self._task logger.info("GPU worker stopped. Processed %d jobs total.", self._jobs_processed) # ── Resume support ─────────────────────────────────────────────── async def resume_pending_jobs(self) -> int: """Re-enqueue jobs left in PENDING or PROCESSING state (crash recovery).""" async with async_session_factory() as session: result = await session.execute( select(Job).where(Job.status.in_(["pending", "processing"])) ) stale_jobs = result.scalars().all() for job in stale_jobs: job.status = "pending" await self.enqueue(job.id, self._normalize_priority(job.priority)) await session.commit() if stale_jobs: logger.info("Resumed %d pending/stale jobs.", len(stale_jobs)) return len(stale_jobs) # ── Main loop ──────────────────────────────────────────────────── async def _run(self) -> None: logger.info("GPU worker loop running.") while not self._stop_event.is_set(): try: item = await asyncio.wait_for(self._queue.get(), timeout=1.0) except asyncio.TimeoutError: continue if item.job_id == "__stop__": break await self._process_job(item.job_id) self._queue.task_done() logger.info("GPU worker loop exited.") async def _process_job(self, job_id: str) -> None: async with async_session_factory() as session: result = await session.execute(select(Job).where(Job.id == job_id)) job = result.scalar_one_or_none() if job is None: logger.warning("Job %s not found in DB, skipping.", job_id) return # Skip cancelled jobs if job.status == "cancelled": logger.info("Job %s is cancelled, skipping.", job_id) return # Mark as processing job.status = "processing" job.started_at = datetime.now(timezone.utc) await session.commit() try: # Run inference in a thread to avoid blocking asyncio gen_result = await asyncio.to_thread( self.engine.generate, prompt=job.prompt, width=job.width, height=job.height, num_steps=job.num_steps, guidance_scale=job.guidance_scale, seed=job.seed, ) # Save image to disk fmt = ImageFormat(job.format) file_path = get_image_path(job.id, job.format, batch_id=job.batch_id) save_image( gen_result.image, file_path, fmt, prompt=job.prompt, seed=gen_result.seed, ) # Update job as completed async with async_session_factory() as session: result = await session.execute(select(Job).where(Job.id == job_id)) job = result.scalar_one() job.status = "completed" job.file_path = str(file_path) job.seed = gen_result.seed job.completed_at = datetime.now(timezone.utc) await session.commit() # Update batch counters if applicable if job.batch_id: await self._update_batch_count(session, job.batch_id) self._jobs_processed += 1 self._total_inference_time += gen_result.inference_time except Exception as exc: # Detect CUDA OOM without importing torch in mock/non-GPU environments. if exc.__class__.__name__ == "OutOfMemoryError": logger.error("OOM on job %s (%dx%d).", job_id, job.width, job.height) try: import torch torch.cuda.empty_cache() except Exception: pass await self._fail_job(job_id, "GPU out of memory for this resolution") return logger.exception("Error processing job %s", job_id) await self._fail_job(job_id, "Internal generation error") async def _fail_job(self, job_id: str, error: str) -> None: async with async_session_factory() as session: result = await session.execute(select(Job).where(Job.id == job_id)) job = result.scalar_one_or_none() if job: job.status = "failed" job.error_message = error job.completed_at = datetime.now(timezone.utc) await session.commit() if job.batch_id: await self._update_batch_count(session, job.batch_id) async def _update_batch_count(self, session, batch_id: str) -> None: """Recompute batch counters from individual jobs.""" result = await session.execute(select(BatchJob).where(BatchJob.id == batch_id)) batch = result.scalar_one_or_none() if not batch: return # Count completed and failed completed = ( await session.execute( select(func.count()).where(Job.batch_id == batch_id, Job.status == "completed") ) ).scalar() failed = ( await session.execute( select(func.count()).where(Job.batch_id == batch_id, Job.status == "failed") ) ).scalar() batch.completed_count = completed or 0 batch.failed_count = failed or 0 done = (completed or 0) + (failed or 0) if done >= batch.total_count: batch.status = "completed" if (failed or 0) == 0 else "failed" batch.completed_at = datetime.now(timezone.utc) elif done > 0: batch.status = "processing" await session.commit() # ── Stats ──────────────────────────────────────────────────────── def get_stats(self) -> dict: avg = ( self._total_inference_time / self._jobs_processed if self._jobs_processed > 0 else 0.0 ) return { "queue_depth": self.queue_depth, "jobs_processed": self._jobs_processed, "total_inference_time_s": round(self._total_inference_time, 2), "avg_inference_time_s": round(avg, 2), } ================================================ FILE: flux_image_service.egg-info/PKG-INFO ================================================ Metadata-Version: 2.4 Name: flux-image-service Version: 0.1.0 Summary: Local image generation service using Flux.1 Schnell 12B Requires-Python: >=3.10 Requires-Dist: fastapi>=0.115.0 Requires-Dist: uvicorn[standard]>=0.30.0 Requires-Dist: pydantic>=2.0 Requires-Dist: pydantic-settings>=2.0 Requires-Dist: sqlalchemy[asyncio]>=2.0 Requires-Dist: aiosqlite>=0.20.0 Requires-Dist: torch>=2.1 Requires-Dist: diffusers>=0.30.0 Requires-Dist: transformers>=4.40.0 Requires-Dist: accelerate>=0.30.0 Requires-Dist: sentencepiece>=0.2.0 Requires-Dist: protobuf>=4.25.0 Requires-Dist: Pillow>=10.0.0 Requires-Dist: python-multipart>=0.0.9 Requires-Dist: sse-starlette>=2.0.0 Requires-Dist: httpx>=0.27.0 Provides-Extra: dev Requires-Dist: pytest>=8.0; extra == "dev" Requires-Dist: pytest-asyncio>=0.23.0; extra == "dev" Requires-Dist: httpx>=0.27.0; extra == "dev" Requires-Dist: ruff>=0.4.0; extra == "dev" ================================================ FILE: flux_image_service.egg-info/SOURCES.txt ================================================ README.md pyproject.toml app/__init__.py app/config.py app/database.py app/main.py app/api/__init__.py app/api/batch.py app/api/generate.py app/api/health.py app/api/jobs.py app/engine/__init__.py app/engine/engine_config.py app/engine/flux_engine.py app/models/__init__.py app/models/enums.py app/models/job.py app/schemas/__init__.py app/schemas/batch.py app/schemas/generate.py app/schemas/job.py app/services/__init__.py app/services/generation_service.py app/services/job_service.py app/services/moderation_service.py app/utils/__init__.py app/utils/image_utils.py app/utils/moderation.py app/utils/storage.py app/worker/__init__.py app/worker/gpu_worker.py flux_image_service.egg-info/PKG-INFO flux_image_service.egg-info/SOURCES.txt flux_image_service.egg-info/dependency_links.txt flux_image_service.egg-info/requires.txt flux_image_service.egg-info/top_level.txt tests/test_api_batch.py tests/test_api_generate.py tests/test_api_jobs.py tests/test_moderation.py tests/test_moderation_service.py tests/test_schemas.py tests/test_worker.py ================================================ FILE: flux_image_service.egg-info/dependency_links.txt ================================================ ================================================ FILE: flux_image_service.egg-info/requires.txt ================================================ fastapi>=0.115.0 uvicorn[standard]>=0.30.0 pydantic>=2.0 pydantic-settings>=2.0 sqlalchemy[asyncio]>=2.0 aiosqlite>=0.20.0 torch>=2.1 diffusers>=0.30.0 transformers>=4.40.0 accelerate>=0.30.0 sentencepiece>=0.2.0 protobuf>=4.25.0 Pillow>=10.0.0 python-multipart>=0.0.9 sse-starlette>=2.0.0 httpx>=0.27.0 [dev] pytest>=8.0 pytest-asyncio>=0.23.0 httpx>=0.27.0 ruff>=0.4.0 ================================================ FILE: flux_image_service.egg-info/top_level.txt ================================================ app ================================================ FILE: pyproject.toml ================================================ [project] name = "flux-image-service" version = "0.1.0" description = "Local image generation service using Flux.1 Schnell 12B" requires-python = ">=3.10" dependencies = [ "fastapi>=0.115.0", "uvicorn[standard]>=0.30.0", "pydantic>=2.0", "pydantic-settings>=2.0", "sqlalchemy[asyncio]>=2.0", "aiosqlite>=0.20.0", "torch>=2.1", "diffusers>=0.30.0", "transformers>=4.40.0", "accelerate>=0.30.0", "sentencepiece>=0.2.0", "protobuf>=4.25.0", "Pillow>=10.0.0", "python-multipart>=0.0.9", "sse-starlette>=2.0.0", "httpx>=0.27.0", ] [project.optional-dependencies] dev = [ "pytest>=8.0", "pytest-asyncio>=0.23.0", "httpx>=0.27.0", "ruff>=0.4.0", ] [tool.pytest.ini_options] asyncio_mode = "auto" testpaths = ["tests"] [tool.ruff] target-version = "py310" line-length = 100 ================================================ FILE: tests/__init__.py ================================================ ================================================ FILE: tests/conftest.py ================================================ """Shared test fixtures.""" from __future__ import annotations import asyncio from unittest.mock import AsyncMock, MagicMock, patch from pathlib import Path import pytest from httpx import ASGITransport, AsyncClient from PIL import Image from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from app.engine.flux_engine import GenerationResult from app.models.job import Base @pytest.fixture(scope="session") def event_loop(): loop = asyncio.new_event_loop() yield loop loop.close() @pytest.fixture async def db_engine(tmp_path): """In-memory SQLite for tests.""" engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield engine await engine.dispose() @pytest.fixture async def db_session(db_engine): factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) async with factory() as session: yield session @pytest.fixture def mock_engine(): """A mock FluxEngine that returns a tiny test image.""" engine = MagicMock() engine.is_loaded = True engine.get_vram_stats.return_value = {"available": True, "allocated_gb": 12.0} engine.get_vram_usage_gb.return_value = 12.0 dummy_image = Image.new("RGB", (64, 64), color="red") engine.generate.return_value = GenerationResult( image=dummy_image, seed=42, inference_time=0.5 ) return engine @pytest.fixture def mock_worker(mock_engine): """A mock GPUWorker that tracks enqueued job IDs.""" worker = MagicMock() worker.queue_depth = 0 worker.enqueue = AsyncMock() worker.get_stats.return_value = { "queue_depth": 0, "jobs_processed": 0, "total_inference_time_s": 0.0, "avg_inference_time_s": 0.0, } return worker @pytest.fixture async def client(mock_engine, mock_worker, db_engine, tmp_path): """Test client with mocked engine/worker and in-memory DB.""" from app.database import get_session from app.models.job import Base factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) async def override_get_session(): async with factory() as session: yield session # Patch the database session factory used by the worker with patch("app.worker.gpu_worker.async_session_factory", factory), \ patch("app.database.async_session_factory", factory): from app.main import app app.dependency_overrides[get_session] = override_get_session app.state.engine = mock_engine app.state.worker = mock_worker transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as ac: yield ac app.dependency_overrides.clear() ================================================ FILE: tests/test_api_batch.py ================================================ """Tests for batch API endpoints.""" import pytest @pytest.mark.asyncio async def test_create_batch(client, mock_worker): resp = await client.post( "/batch", json={ "name": "test-batch", "prompts": [ {"prompt": "a red car"}, {"prompt": "a blue house"}, {"prompt": "a green tree"}, ], }, ) assert resp.status_code == 200 data = resp.json() assert data["total"] == 3 assert data["status"] == "pending" assert data["batch_id"] # Worker should have been called 3 times assert mock_worker.enqueue.call_count == 3 @pytest.mark.asyncio async def test_get_batch_404(client): resp = await client.get("/batch/nonexistent-id") assert resp.status_code == 404 @pytest.mark.asyncio async def test_cancel_batch(client, mock_worker): # Create batch first resp = await client.post( "/batch", json={ "name": "cancel-test", "prompts": [{"prompt": "p1"}, {"prompt": "p2"}], }, ) batch_id = resp.json()["batch_id"] # Cancel it resp = await client.delete(f"/batch/{batch_id}") assert resp.status_code == 200 data = resp.json() assert data["status"] == "cancelled" @pytest.mark.asyncio async def test_batch_requires_prompts(client): resp = await client.post("/batch", json={"name": "empty", "prompts": []}) assert resp.status_code == 422 ================================================ FILE: tests/test_api_generate.py ================================================ """Tests for generation API endpoints.""" import pytest from unittest.mock import AsyncMock, patch from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from app.models.job import Job from app.services.moderation_service import ModerationError @pytest.mark.asyncio async def test_generate_async_returns_job_ids(client): resp = await client.post("/generate", json={"prompt": "a beautiful sunset"}) assert resp.status_code == 200 data = resp.json() assert "job_ids" in data assert len(data["job_ids"]) == 1 assert data["status"] == "pending" @pytest.mark.asyncio async def test_generate_multiple_images(client, mock_worker): resp = await client.post("/generate", json={"prompt": "test", "num_images": 3}) assert resp.status_code == 200 data = resp.json() assert len(data["job_ids"]) == 3 assert mock_worker.enqueue.call_count == 3 @pytest.mark.asyncio async def test_generate_max_4_images(client): resp = await client.post("/generate", json={"prompt": "test", "num_images": 4}) assert resp.status_code == 200 assert len(resp.json()["job_ids"]) == 4 @pytest.mark.asyncio async def test_generate_rejects_more_than_4_images(client): resp = await client.post("/generate", json={"prompt": "test", "num_images": 5}) assert resp.status_code == 422 @pytest.mark.asyncio async def test_generate_async_enqueues_work(client, mock_worker): resp = await client.post("/generate", json={"prompt": "test prompt"}) assert resp.status_code == 200 mock_worker.enqueue.assert_called_once() @pytest.mark.asyncio async def test_generate_validates_resolution_multiple_of_8(client): resp = await client.post( "/generate", json={"prompt": "test", "width": 513, "height": 1024}, ) assert resp.status_code == 422 # validation error @pytest.mark.asyncio async def test_generate_rejects_empty_prompt(client): resp = await client.post("/generate", json={"prompt": ""}) assert resp.status_code == 422 @pytest.mark.asyncio async def test_generate_rejects_oversized_resolution(client): resp = await client.post( "/generate", json={"prompt": "test", "width": 4096, "height": 4096}, ) assert resp.status_code == 422 @pytest.mark.asyncio async def test_generate_default_values(client, mock_worker): resp = await client.post("/generate", json={"prompt": "a cat"}) assert resp.status_code == 200 @pytest.mark.asyncio async def test_serve_image_404_for_missing_job(client): resp = await client.get("/jobs/nonexistent-id/image") assert resp.status_code == 404 @pytest.mark.asyncio async def test_generate_stream_emits_done_for_completed_jobs(client, db_engine): create = await client.post("/generate", json={"prompt": "test", "num_images": 2}) assert create.status_code == 200 job_ids = create.json()["job_ids"] # Mark jobs as completed so the SSE stream terminates immediately. factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) async with factory() as session: result = await session.execute(select(Job).where(Job.id.in_(job_ids))) jobs = result.scalars().all() for job in jobs: job.status = "completed" job.file_path = f"output/images/{job.id}.png" await session.commit() resp = await client.get( "/generate/stream", params=[("job_id", job_ids[0]), ("job_id", job_ids[1])], ) assert resp.status_code == 200 text = resp.text assert "event: progress" in text assert "event: done" in text assert '"completed": 2' in text @pytest.mark.asyncio async def test_generate_stream_returns_404_for_unknown_job(client): resp = await client.get("/generate/stream", params={"job_id": "unknown-job"}) assert resp.status_code == 404 # --------------------------------------------------------------------------- # Content moderation integration tests # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_generate_async_blocked_when_prompt_is_unsafe(client): """POST /generate returns 400 when the moderation service flags the prompt.""" with patch( "app.api.generate.check_prompt_safety", new=AsyncMock(return_value=(False, "sexual")), ): resp = await client.post("/generate", json={"prompt": "explicit content"}) assert resp.status_code == 400 assert "unsafe" in resp.json()["detail"].lower() @pytest.mark.asyncio async def test_generate_async_proceeds_when_prompt_is_safe(client, mock_worker): """POST /generate proceeds normally when the moderation service approves the prompt.""" with patch( "app.api.generate.check_prompt_safety", new=AsyncMock(return_value=(True, None)), ): resp = await client.post("/generate", json={"prompt": "a beautiful landscape"}) assert resp.status_code == 200 assert "job_ids" in resp.json() mock_worker.enqueue.assert_called() @pytest.mark.asyncio async def test_generate_async_returns_503_when_moderation_service_fails(client): """POST /generate returns 503 when the moderation API is unreachable.""" with patch( "app.api.generate.check_prompt_safety", new=AsyncMock(side_effect=ModerationError("API unreachable")), ): resp = await client.post("/generate", json={"prompt": "a beautiful landscape"}) assert resp.status_code == 503 @pytest.mark.asyncio async def test_generate_sync_blocked_when_prompt_is_unsafe(client): """POST /generate/sync returns 400 when the moderation service flags the prompt.""" with patch( "app.api.generate.check_prompt_safety", new=AsyncMock(return_value=(False, "violence")), ): resp = await client.post("/generate/sync", json={"prompt": "violent content"}) assert resp.status_code == 400 assert "unsafe" in resp.json()["detail"].lower() @pytest.mark.asyncio async def test_generate_sync_returns_503_when_moderation_service_fails(client): """POST /generate/sync returns 503 when the moderation API is unreachable.""" with patch( "app.api.generate.check_prompt_safety", new=AsyncMock(side_effect=ModerationError("timeout")), ): resp = await client.post("/generate/sync", json={"prompt": "a sunset"}) assert resp.status_code == 503 ================================================ FILE: tests/test_api_jobs.py ================================================ """Tests for job management endpoints.""" import pytest @pytest.mark.asyncio async def test_list_jobs_empty(client): resp = await client.get("/jobs") assert resp.status_code == 200 data = resp.json() assert data["total"] == 0 assert data["jobs"] == [] @pytest.mark.asyncio async def test_get_job_after_create(client, mock_worker): # Create a job via generate gen_resp = await client.post("/generate", json={"prompt": "test job"}) job_id = gen_resp.json()["job_ids"][0] # Fetch it resp = await client.get(f"/jobs/{job_id}") assert resp.status_code == 200 data = resp.json() assert data["id"] == job_id assert data["prompt"] == "test job" assert data["status"] == "pending" assert data["width"] == 1024 assert data["height"] == 1024 @pytest.mark.asyncio async def test_cancel_job(client, mock_worker): gen_resp = await client.post("/generate", json={"prompt": "to cancel"}) job_id = gen_resp.json()["job_ids"][0] resp = await client.delete(f"/jobs/{job_id}") assert resp.status_code == 200 assert resp.json()["status"] == "cancelled" @pytest.mark.asyncio async def test_list_jobs_with_status_filter(client, mock_worker): await client.post("/generate", json={"prompt": "j1"}) await client.post("/generate", json={"prompt": "j2"}) resp = await client.get("/jobs?status=pending") assert resp.status_code == 200 data = resp.json() assert data["total"] == 2 @pytest.mark.asyncio async def test_get_nonexistent_job(client): resp = await client.get("/jobs/does-not-exist") assert resp.status_code == 404 @pytest.mark.asyncio async def test_health_endpoint(client): resp = await client.get("/health") assert resp.status_code == 200 data = resp.json() assert data["status"] == "ok" assert data["model_loaded"] is True assert "vram" in data assert "worker" in data ================================================ FILE: tests/test_moderation.py ================================================ """Tests for the content moderation module.""" from __future__ import annotations from unittest.mock import MagicMock, patch import pytest from pydantic import ValidationError from app.utils.moderation import ModerationEngine, ModerationResult, check_prompt # ── ModerationEngine unit tests ─────────────────────────────────────────────── class TestModerationEngine: def test_is_loaded_false_before_load(self): engine = ModerationEngine() assert not engine.is_loaded def test_load_failure_leaves_engine_unloaded(self): engine = ModerationEngine() with patch("app.utils.moderation.ModerationEngine.load") as mock_load: mock_load.side_effect = RuntimeError("no GPU") try: engine.load("google/gemma-3-1b-it") except RuntimeError: pass assert not engine.is_loaded def test_parse_pipeline_output_safe(self): engine = ModerationEngine() raw = [{"generated_text": [{"role": "assistant", "content": "SAFE"}]}] assert engine._parse_pipeline_output(raw) == "SAFE" def test_parse_pipeline_output_unsafe(self): engine = ModerationEngine() raw = [{"generated_text": [{"role": "assistant", "content": "UNSAFE"}]}] assert engine._parse_pipeline_output(raw) == "UNSAFE" def test_check_with_local_model_raises_when_not_loaded(self): engine = ModerationEngine() with pytest.raises(RuntimeError, match="not loaded"): engine.check_with_local_model("test prompt") def test_check_with_local_model_safe(self): engine = ModerationEngine() engine._pipeline = MagicMock( return_value=[{"generated_text": [{"role": "assistant", "content": "SAFE"}]}] ) result = engine.check_with_local_model("a beautiful sunset") assert result.is_safe is True assert result.source == "local" def test_check_with_local_model_unsafe(self): engine = ModerationEngine() engine._pipeline = MagicMock( return_value=[{"generated_text": [{"role": "assistant", "content": "UNSAFE"}]}] ) result = engine.check_with_local_model("explicit harmful content") assert result.is_safe is False assert result.source == "local" assert "Flagged by local model" in result.reason def test_check_with_agent_api_safe(self): engine = ModerationEngine() engine._model_id = "google/gemma-3-1b-it" mock_choice = MagicMock() mock_choice.message.content = "SAFE" mock_response = MagicMock() mock_response.choices = [mock_choice] mock_client = MagicMock() mock_client.chat_completion.return_value = mock_response with patch("huggingface_hub.InferenceClient", return_value=mock_client): result = engine.check_with_agent_api("a nice landscape", token="hf_test") assert result.is_safe is True assert result.source == "agent_api" def test_check_with_agent_api_unsafe(self): engine = ModerationEngine() engine._model_id = "google/gemma-3-1b-it" mock_choice = MagicMock() mock_choice.message.content = "UNSAFE" mock_response = MagicMock() mock_response.choices = [mock_choice] mock_client = MagicMock() mock_client.chat_completion.return_value = mock_response with patch("huggingface_hub.InferenceClient", return_value=mock_client): result = engine.check_with_agent_api("explicit content", token="hf_test") assert result.is_safe is False assert result.source == "agent_api" assert "Flagged by agent API" in result.reason def test_check_uses_local_model_first(self): engine = ModerationEngine() engine._pipeline = MagicMock( return_value=[{"generated_text": [{"role": "assistant", "content": "SAFE"}]}] ) with patch.object(engine, "check_with_agent_api") as mock_api: result = engine.check("safe prompt") assert result.is_safe is True assert result.source == "local" mock_api.assert_not_called() def test_check_falls_back_to_agent_api_when_local_fails(self): engine = ModerationEngine() engine._pipeline = MagicMock(side_effect=RuntimeError("inference error")) engine._model_id = "google/gemma-3-1b-it" fallback_result = ModerationResult(is_safe=True, source="agent_api") with patch.object(engine, "check_with_agent_api", return_value=fallback_result) as mock_api: result = engine.check("test prompt", hf_token="hf_test") assert result.source == "agent_api" mock_api.assert_called_once() def test_check_fails_open_when_both_unavailable(self): engine = ModerationEngine() # Local not loaded, no model_id so agent API branch skipped result = engine.check("test prompt") assert result.is_safe is True assert result.source == "none" def test_check_falls_open_when_agent_api_also_fails(self): engine = ModerationEngine() engine._model_id = "google/gemma-3-1b-it" with patch.object(engine, "check_with_agent_api", side_effect=Exception("API error")): result = engine.check("test prompt", hf_token="hf_test") assert result.is_safe is True assert result.source == "none" # ── check_prompt() integration tests ───────────────────────────────────────── class TestCheckPrompt: def test_allows_safe_prompt(self): with patch("app.config.settings") as mock_settings, \ patch("app.utils.moderation._moderation_engine") as mock_engine: mock_settings.MODERATION_ENABLED = True mock_settings.HF_API_TOKEN = None mock_engine.check.return_value = ModerationResult(is_safe=True, source="local") check_prompt("a beautiful sunset") # should not raise def test_rejects_unsafe_prompt(self): with patch("app.config.settings") as mock_settings, \ patch("app.utils.moderation._moderation_engine") as mock_engine: mock_settings.MODERATION_ENABLED = True mock_settings.HF_API_TOKEN = None mock_engine.check.return_value = ModerationResult( is_safe=False, reason="Flagged by local model: UNSAFE", source="local" ) with pytest.raises(ValueError, match="Flagged by local model"): check_prompt("explicit content") def test_skips_when_moderation_disabled(self): with patch("app.config.settings") as mock_settings, patch( "app.utils.moderation._moderation_engine" ) as mock_engine: mock_settings.MODERATION_ENABLED = False check_prompt("any content") mock_engine.check.assert_not_called() # ── ImageRequest schema integration ────────────────────────────────────────── class TestImageRequestModeration: def test_unsafe_prompt_raises_validation_error(self): with patch("app.config.settings") as mock_settings, \ patch("app.utils.moderation._moderation_engine") as mock_engine: mock_settings.MODERATION_ENABLED = True mock_settings.HF_API_TOKEN = None mock_engine.check.return_value = ModerationResult( is_safe=False, reason="Flagged by local model: UNSAFE", source="local" ) with pytest.raises(ValidationError, match="content moderation|Flagged"): from app.schemas.generate import ImageRequest ImageRequest(prompt="explicit content") def test_safe_prompt_passes_validation(self): with patch("app.config.settings") as mock_settings, \ patch("app.utils.moderation._moderation_engine") as mock_engine: mock_settings.MODERATION_ENABLED = True mock_settings.HF_API_TOKEN = None mock_engine.check.return_value = ModerationResult(is_safe=True, source="local") from app.schemas.generate import ImageRequest req = ImageRequest(prompt="a mountain landscape") assert req.prompt == "a mountain landscape" ================================================ FILE: tests/test_moderation_service.py ================================================ """Tests for the content moderation service.""" from __future__ import annotations from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest from app.services.moderation_service import ModerationError, check_prompt_safety # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_moderation_response(flagged: bool, categories: dict[str, bool] | None = None): """Build a mock httpx.Response that mimics the OpenAI moderation format.""" if categories is None: categories = {} payload = { "results": [ { "flagged": flagged, "categories": categories, "category_scores": {k: 0.9 if v else 0.01 for k, v in categories.items()}, } ] } response = MagicMock(spec=httpx.Response) response.json.return_value = payload response.raise_for_status = MagicMock() return response # --------------------------------------------------------------------------- # Tests: moderation disabled # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_check_prompt_safety_disabled_skips_api(): """When moderation is disabled the function returns safe without any API call.""" with patch("app.services.moderation_service.settings") as mock_settings: mock_settings.MODERATION_ENABLED = False is_safe, reason = await check_prompt_safety("a test prompt") assert is_safe is True assert reason is None @pytest.mark.asyncio async def test_check_prompt_safety_no_api_key_skips_api(): """When enabled but no API key is set, the function returns safe with a warning.""" with patch("app.services.moderation_service.settings") as mock_settings: mock_settings.MODERATION_ENABLED = True mock_settings.MODERATION_API_KEY = "" is_safe, reason = await check_prompt_safety("a test prompt") assert is_safe is True assert reason is None # --------------------------------------------------------------------------- # Tests: safe prompt # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_check_prompt_safety_safe_prompt(): """A prompt that the API says is safe returns (True, None).""" mock_response = _make_moderation_response(flagged=False) with patch("app.services.moderation_service.settings") as mock_settings, patch( "app.services.moderation_service.httpx.AsyncClient" ) as mock_client_cls: mock_settings.MODERATION_ENABLED = True mock_settings.MODERATION_API_KEY = "test-key" mock_settings.MODERATION_API_URL = "https://api.openai.com/v1/moderations" mock_settings.MODERATION_TIMEOUT = 10.0 mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_response) mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client) mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False) is_safe, reason = await check_prompt_safety("a beautiful sunset over the mountains") assert is_safe is True assert reason is None # --------------------------------------------------------------------------- # Tests: unsafe / flagged prompt # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_check_prompt_safety_flagged_prompt(): """A flagged prompt returns (False, reason) with the triggered categories.""" mock_response = _make_moderation_response( flagged=True, categories={"sexual": True, "violence": False, "hate": False}, ) with patch("app.services.moderation_service.settings") as mock_settings, patch( "app.services.moderation_service.httpx.AsyncClient" ) as mock_client_cls: mock_settings.MODERATION_ENABLED = True mock_settings.MODERATION_API_KEY = "test-key" mock_settings.MODERATION_API_URL = "https://api.openai.com/v1/moderations" mock_settings.MODERATION_TIMEOUT = 10.0 mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_response) mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client) mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False) is_safe, reason = await check_prompt_safety("explicit nsfw content here") assert is_safe is False assert reason is not None assert "sexual" in reason @pytest.mark.asyncio async def test_check_prompt_safety_flagged_multiple_categories(): """Multiple triggered categories are all included in the reason string.""" mock_response = _make_moderation_response( flagged=True, categories={"sexual": True, "violence": True, "hate": False}, ) with patch("app.services.moderation_service.settings") as mock_settings, patch( "app.services.moderation_service.httpx.AsyncClient" ) as mock_client_cls: mock_settings.MODERATION_ENABLED = True mock_settings.MODERATION_API_KEY = "test-key" mock_settings.MODERATION_API_URL = "https://api.openai.com/v1/moderations" mock_settings.MODERATION_TIMEOUT = 10.0 mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_response) mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client) mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False) is_safe, reason = await check_prompt_safety("violent explicit content") assert is_safe is False assert "sexual" in reason assert "violence" in reason @pytest.mark.asyncio async def test_check_prompt_safety_flagged_no_categories(): """If the API flags the prompt but provides no categories, reason is generic.""" mock_response = _make_moderation_response(flagged=True, categories={}) with patch("app.services.moderation_service.settings") as mock_settings, patch( "app.services.moderation_service.httpx.AsyncClient" ) as mock_client_cls: mock_settings.MODERATION_ENABLED = True mock_settings.MODERATION_API_KEY = "test-key" mock_settings.MODERATION_API_URL = "https://api.openai.com/v1/moderations" mock_settings.MODERATION_TIMEOUT = 10.0 mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_response) mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client) mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False) is_safe, reason = await check_prompt_safety("some flagged content") assert is_safe is False assert reason == "unsafe content" # --------------------------------------------------------------------------- # Tests: API failures # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_check_prompt_safety_http_error_raises_moderation_error(): """A network-level HTTP error raises ModerationError.""" with patch("app.services.moderation_service.settings") as mock_settings, patch( "app.services.moderation_service.httpx.AsyncClient" ) as mock_client_cls: mock_settings.MODERATION_ENABLED = True mock_settings.MODERATION_API_KEY = "test-key" mock_settings.MODERATION_API_URL = "https://api.openai.com/v1/moderations" mock_settings.MODERATION_TIMEOUT = 10.0 mock_client = AsyncMock() mock_client.post = AsyncMock( side_effect=httpx.ConnectError("Connection refused") ) mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client) mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False) with pytest.raises(ModerationError): await check_prompt_safety("test prompt") @pytest.mark.asyncio async def test_check_prompt_safety_http_status_error_raises_moderation_error(): """A non-2xx API response raises ModerationError.""" mock_response = MagicMock(spec=httpx.Response) mock_response.status_code = 401 mock_response.text = "Unauthorized" http_error = httpx.HTTPStatusError( "401 Unauthorized", request=MagicMock(), response=mock_response, ) mock_response.raise_for_status = MagicMock(side_effect=http_error) with patch("app.services.moderation_service.settings") as mock_settings, patch( "app.services.moderation_service.httpx.AsyncClient" ) as mock_client_cls: mock_settings.MODERATION_ENABLED = True mock_settings.MODERATION_API_KEY = "bad-key" mock_settings.MODERATION_API_URL = "https://api.openai.com/v1/moderations" mock_settings.MODERATION_TIMEOUT = 10.0 mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_response) mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client) mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False) with pytest.raises(ModerationError): await check_prompt_safety("test prompt") @pytest.mark.asyncio async def test_check_prompt_safety_empty_results(): """An API response with no results array is treated as safe.""" mock_response = MagicMock(spec=httpx.Response) mock_response.json.return_value = {"results": []} mock_response.raise_for_status = MagicMock() with patch("app.services.moderation_service.settings") as mock_settings, patch( "app.services.moderation_service.httpx.AsyncClient" ) as mock_client_cls: mock_settings.MODERATION_ENABLED = True mock_settings.MODERATION_API_KEY = "test-key" mock_settings.MODERATION_API_URL = "https://api.openai.com/v1/moderations" mock_settings.MODERATION_TIMEOUT = 10.0 mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_response) mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client) mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False) is_safe, reason = await check_prompt_safety("a normal prompt") assert is_safe is True assert reason is None ================================================ FILE: tests/test_schemas.py ================================================ """Tests for schemas — resolution validation, format handling.""" import pytest from pydantic import ValidationError from app.schemas.generate import ImageRequest from app.schemas.batch import BatchRequest class TestImageRequest: def test_defaults(self): req = ImageRequest(prompt="test") assert req.width == 1024 assert req.height == 1024 assert req.num_steps == 4 assert req.guidance_scale == 0.0 assert req.format.value == "png" assert req.seed is None assert req.num_images == 1 def test_num_images_up_to_4(self): req = ImageRequest(prompt="test", num_images=4) assert req.num_images == 4 def test_num_images_rejects_above_4(self): with pytest.raises(ValidationError): ImageRequest(prompt="test", num_images=5) def test_num_images_rejects_zero(self): with pytest.raises(ValidationError): ImageRequest(prompt="test", num_images=0) def test_width_must_be_multiple_of_8(self): with pytest.raises(ValidationError, match="multiple of 8"): ImageRequest(prompt="test", width=513) def test_height_must_be_multiple_of_8(self): with pytest.raises(ValidationError, match="multiple of 8"): ImageRequest(prompt="test", height=100 + 3) def test_valid_custom_resolution(self): req = ImageRequest(prompt="test", width=768, height=1344) assert req.width == 768 assert req.height == 1344 def test_rejects_empty_prompt(self): with pytest.raises(ValidationError): ImageRequest(prompt="") def test_rejects_oversized_resolution(self): with pytest.raises(ValidationError): ImageRequest(prompt="test", width=4096) def test_seed_bounds(self): req = ImageRequest(prompt="test", seed=0) assert req.seed == 0 req = ImageRequest(prompt="test", seed=2**32 - 1) assert req.seed == 2**32 - 1 with pytest.raises(ValidationError): ImageRequest(prompt="test", seed=-1) def test_all_formats(self): for fmt in ("png", "jpeg", "webp"): req = ImageRequest(prompt="test", format=fmt) assert req.format.value == fmt class TestBatchRequest: def test_valid_batch(self): batch = BatchRequest( name="test", prompts=[ImageRequest(prompt="a"), ImageRequest(prompt="b")], ) assert len(batch.prompts) == 2 def test_empty_prompts_rejected(self): with pytest.raises(ValidationError): BatchRequest(name="empty", prompts=[]) ================================================ FILE: tests/test_worker.py ================================================ """Tests for the GPU worker queue and processing logic.""" import asyncio from unittest.mock import AsyncMock, MagicMock, patch from datetime import datetime, timezone import pytest from PIL import Image from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from app.engine.flux_engine import GenerationResult from app.models.enums import JobPriority from app.models.job import Base, Job from app.worker.gpu_worker import GPUWorker, _QueueItem class TestQueuePriority: def test_realtime_before_batch(self): """REALTIME (priority=1) items should sort before BATCH (priority=5).""" rt = _QueueItem(priority=1, timestamp=100.0, job_id="rt") batch = _QueueItem(priority=5, timestamp=99.0, job_id="batch") assert rt < batch def test_same_priority_fifo(self): """Same priority: earlier timestamp wins.""" a = _QueueItem(priority=1, timestamp=100.0, job_id="a") b = _QueueItem(priority=1, timestamp=200.0, job_id="b") assert a < b class TestWorkerEnqueue: @pytest.mark.asyncio async def test_enqueue_increases_depth(self): engine = MagicMock() worker = GPUWorker(engine=engine) await worker.enqueue("job-1", priority=1) await worker.enqueue("job-2", priority=5) assert worker.queue_depth == 2 @pytest.mark.asyncio async def test_realtime_preempts_queued_batch_jobs(self): engine = MagicMock() worker = GPUWorker(engine=engine) await worker.enqueue("batch-1", priority=JobPriority.BATCH) await worker.enqueue("batch-2", priority=JobPriority.BATCH) await worker.enqueue("rt-1", priority=JobPriority.REALTIME) first = await worker._queue.get() assert first.job_id == "rt-1" class TestWorkerStats: def test_initial_stats(self): engine = MagicMock() worker = GPUWorker(engine=engine) stats = worker.get_stats() assert stats["jobs_processed"] == 0 assert stats["queue_depth"] == 0 assert stats["avg_inference_time_s"] == 0.0