Repository: jvdillon/netv Branch: main Commit: 9fde1364c4ff Files: 67 Total size: 1.0 MB Directory structure: gitextract_lipmxb5x/ ├── .dockerignore ├── .github/ │ └── workflows/ │ ├── ai-upscale.yml │ ├── ci.yml │ ├── ffmpeg-base.yml │ └── release.yml ├── .gitignore ├── Dockerfile ├── Dockerfile.ai_upscale ├── Dockerfile.ffmpeg ├── LICENSE ├── README.md ├── __init__.py ├── auth.py ├── auth_test.py ├── cache.py ├── cache_test.py ├── docker-compose.yml ├── entrypoint-ai_upscale.sh ├── entrypoint.sh ├── epg.py ├── epg_test.py ├── ffmpeg_command.py ├── ffmpeg_command_test.py ├── ffmpeg_session.py ├── ffmpeg_session_test.py ├── m3u.py ├── m3u_test.py ├── main.py ├── main_test.py ├── pyproject.toml ├── static/ │ └── js/ │ ├── app.js │ ├── favorites-grid.js │ ├── player.js │ ├── settings.js │ └── virtual-guide.js ├── templates/ │ ├── base.html │ ├── error.html │ ├── guide.html │ ├── login.html │ ├── movie_detail.html │ ├── player.html │ ├── search.html │ ├── series.html │ ├── series_detail.html │ ├── settings.html │ ├── setup.html │ └── vod.html ├── testing.py ├── tools/ │ ├── alignm3u.py │ ├── export-tensorrt.py │ ├── install-ai_upscale.sh │ ├── install-ffmpeg.sh │ ├── install-letsencrypt.sh │ ├── install-netv.sh │ ├── install-prereqs.sh │ ├── patches/ │ │ ├── dnn_backend_tensorrt.cpp │ │ ├── dnn_backend_torch.cpp │ │ ├── dnn_cuda_kernels.cu │ │ ├── dnn_cuda_kernels.h │ │ └── vf_dnn_processing.c │ ├── uninstall-netv.sh │ ├── xtream2m3u.py │ └── zap2xml.py ├── util.py ├── util_test.py ├── xtream.py └── xtream_test.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .dockerignore ================================================ # Git .git/ .gitignore # Python __pycache__/ *.py[cod] .venv/ .ruff_cache/ .pytest_cache/ uv.lock # App data (user-specific) cache/ .cache/ settings.json *.pem # Tools (not needed in container, except specific scripts) tools/ !tools/install-ffmpeg.sh !tools/patches/ !tools/install-ai_upscale.sh !tools/export-tensorrt.py # Tests *_test.py conftest.py # Docs *.md !README.md screenshots/ LICENSE # Docker Dockerfile docker-compose.yml .dockerignore ================================================ FILE: .github/workflows/ai-upscale.yml ================================================ name: AI Upscale Image on: workflow_dispatch: # Manual trigger push: branches: [main] paths: - "Dockerfile.ai_upscale" - "entrypoint-ai_upscale.sh" - ".github/workflows/ai-upscale.yml" workflow_run: # Also trigger after ffmpeg-base completes to pick up new ffmpeg image workflows: ["FFmpeg Base Image"] types: [completed] branches: [main] env: REGISTRY: ghcr.io IMAGE_NAME: ${{ github.repository }}-ai-upscale jobs: build: runs-on: ubuntu-latest # Skip if triggered by failed ffmpeg workflow if: ${{ github.event_name != 'workflow_run' || github.event.workflow_run.conclusion == 'success' }} permissions: contents: read packages: write strategy: matrix: include: - nvidia: 'cuda:12.4' base_image: 'ubuntu:22.04' - nvidia: 'cuda:12.6' base_image: 'ubuntu:24.04' - nvidia: 'cuda:12.8' base_image: 'ubuntu:24.04' - nvidia: 'cuda:13.0' base_image: 'ubuntu:24.04' latest: true steps: - name: Free disk space run: | # Remove large unnecessary packages to free up space for torch/tensorrt sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache sudo apt-get clean df -h - uses: actions/checkout@v4 with: ref: ${{ github.event.workflow_run.head_sha || github.sha }} - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - name: Log in to Container Registry uses: docker/login-action@v3 with: registry: ${{ env.REGISTRY }} username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} - name: Parse CUDA version id: cuda run: | # Extract "12.4" from "cuda:12.4" CUDA_VER="${{ matrix.nvidia }}" CUDA_VER="${CUDA_VER#cuda:}" echo "version=$CUDA_VER" >> $GITHUB_OUTPUT - name: Extract metadata id: meta uses: docker/metadata-action@v5 with: images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} tags: | type=raw,value=cuda${{ steps.cuda.outputs.version }} type=raw,value=latest,enable=${{ matrix.latest == true }} - name: Build and push uses: docker/build-push-action@v6 with: context: . file: Dockerfile.ai_upscale push: true tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} build-args: | FFMPEG_IMAGE=ghcr.io/${{ github.repository }}-ffmpeg:cuda${{ steps.cuda.outputs.version }} NVIDIA=${{ matrix.nvidia }} BASE_IMAGE=${{ matrix.base_image }} no-cache: true - name: Verify image run: | # Free space: BuildKit uses separate storage, prune it before pulling docker buildx prune -af || true docker system prune -af || true df -h docker pull ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:cuda${{ steps.cuda.outputs.version }} # Verify everything that should be enabled is actually linked # Note: --entrypoint bypasses the default entrypoint which tries to build TensorRT engines (requires GPU) docker run --rm --entrypoint sh ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:cuda${{ steps.cuda.outputs.version }} -c " echo '=== Verifying AI Upscale build ===' LDD_OUTPUT=\$(ldd /usr/local/bin/ffmpeg) # Verify ffmpeg binaries exist for bin in ffmpeg ffprobe ffplay; do test -x /usr/local/bin/\$bin || { echo \"ERROR: \$bin not found\"; exit 1; } echo \"OK: \$bin exists\" done # Verify non-NVIDIA dependencies are satisfied # Note: All NVIDIA libraries (libnvinfer, libcuda, libcudart) are loaded via dlopen, # so ffmpeg works even without NVIDIA GPU/drivers installed MISSING=\$(echo \"\$LDD_OUTPUT\" | grep 'not found' | grep -v -E 'libnvinfer|libnvonnxparser|libcudart|libcuda' || true) if [ -n \"\$MISSING\" ]; then echo 'ERROR: Missing non-NVIDIA libraries:' echo \"\$MISSING\" exit 1 fi echo 'OK: All non-NVIDIA dependencies satisfied' # Verify ffmpeg has NO hard CUDA dependency (all CUDA/TensorRT loaded via dlopen) if echo \"\$LDD_OUTPUT\" | grep -q 'libcudart'; then echo 'ERROR: libcudart linked at compile time (should use dlopen)' exit 1 fi echo 'OK: No libcudart dependency (CUDA Driver API via dlopen)' echo 'OK: libnvinfer loaded via dlopen (not in ldd output)' # Verify Python AI packages installed (use pip show, not import - import needs CUDA runtime) echo '=== Verifying Python packages ===' pip3 show torch >/dev/null 2>&1 || { echo 'ERROR: torch not installed'; exit 1; } echo 'OK: torch installed' pip3 show onnx >/dev/null 2>&1 || { echo 'ERROR: onnx not installed'; exit 1; } echo 'OK: onnx installed' pip3 show tensorrt >/dev/null 2>&1 || { echo 'ERROR: tensorrt not installed'; exit 1; } echo 'OK: tensorrt installed' # Verify AI upscale scripts exist echo '=== Verifying AI upscale scripts ===' test -x /app/tools/install-ai_upscale.sh || { echo 'ERROR: install-ai_upscale.sh not found'; exit 1; } echo 'OK: install-ai_upscale.sh exists' test -f /app/tools/export-tensorrt.py || { echo 'ERROR: export-tensorrt.py not found'; exit 1; } echo 'OK: export-tensorrt.py exists' echo '' echo '=== All verifications passed ===' " ================================================ FILE: .github/workflows/ci.yml ================================================ name: CI on: workflow_dispatch: # Manual trigger push: branches: [main] paths: # Only trigger on files used by main Dockerfile - "Dockerfile" - "*.py" - "pyproject.toml" - "templates/**" - "static/**" - "entrypoint.sh" - ".github/workflows/ci.yml" pull_request: branches: [main] workflow_run: # Trigger after ffmpeg-base completes to pick up new ffmpeg image workflows: ["FFmpeg Base Image"] types: [completed] branches: [main] env: REGISTRY: ghcr.io IMAGE_NAME: ${{ github.repository }} FFMPEG_IMAGE: ghcr.io/${{ github.repository }}-ffmpeg:latest jobs: test: runs-on: ubuntu-latest # Skip if triggered by failed ffmpeg workflow if: ${{ github.event_name != 'workflow_run' || github.event.workflow_run.conclusion == 'success' }} steps: - uses: actions/checkout@v4 with: # For workflow_run, checkout the commit that triggered ffmpeg build ref: ${{ github.event.workflow_run.head_sha || github.sha }} - name: Set up Python uses: actions/setup-python@v5 with: python-version: "3.11" - name: Install uv uses: astral-sh/setup-uv@v4 - name: Install dependencies run: uv sync --group dev - name: Lint with ruff run: uv run ruff check . - name: Type check with basedpyright run: uv run basedpyright - name: Run tests run: uv run pytest build: runs-on: ubuntu-latest needs: test permissions: contents: read packages: write steps: - uses: actions/checkout@v4 with: ref: ${{ github.event.workflow_run.head_sha || github.sha }} - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - name: Log in to Container Registry if: github.event_name != 'pull_request' uses: docker/login-action@v3 with: registry: ${{ env.REGISTRY }} username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} - name: Verify base image exists run: | if ! docker manifest inspect ${{ env.FFMPEG_IMAGE }} > /dev/null 2>&1; then echo "ERROR: Base image ${{ env.FFMPEG_IMAGE }} not found" echo "Run the 'FFmpeg Base Image' workflow first" exit 1 fi echo "Base image verified: ${{ env.FFMPEG_IMAGE }}" - name: Extract metadata id: meta uses: docker/metadata-action@v5 with: images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} tags: | type=raw,value=latest,enable={{is_default_branch}} type=ref,event=branch type=sha,prefix= - name: Build and push uses: docker/build-push-action@v6 with: context: . push: ${{ github.event_name != 'pull_request' }} tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} build-args: | FFMPEG_IMAGE=${{ env.FFMPEG_IMAGE }} cache-from: type=gha cache-to: type=gha,mode=max ================================================ FILE: .github/workflows/ffmpeg-base.yml ================================================ name: FFmpeg Base Image on: schedule: # Build daily at 3 AM UTC - cron: "0 3 * * *" push: branches: [main] paths: - "Dockerfile.ffmpeg" - "tools/install-ffmpeg.sh" - ".github/workflows/ffmpeg-base.yml" workflow_dispatch: inputs: ffmpeg_version: description: "FFmpeg version (e.g., 7.1 or snapshot)" required: false default: "snapshot" env: REGISTRY: ghcr.io IMAGE_NAME: ${{ github.repository }}-ffmpeg jobs: build: runs-on: ubuntu-latest permissions: contents: read packages: write strategy: matrix: include: - nvidia: 'cuda:12.4' base_image: 'ubuntu:22.04' - nvidia: 'cuda:12.6' base_image: 'ubuntu:24.04' - nvidia: 'cuda:12.8' base_image: 'ubuntu:24.04' - nvidia: 'cuda:13.0' base_image: 'ubuntu:24.04' steps: - uses: actions/checkout@v4 - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - name: Log in to Container Registry uses: docker/login-action@v3 with: registry: ${{ env.REGISTRY }} username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} - name: Parse CUDA version id: cuda run: | # Extract "12.4" from "cuda:12.4" CUDA_VER="${{ matrix.nvidia }}" CUDA_VER="${CUDA_VER#cuda:}" echo "version=$CUDA_VER" >> $GITHUB_OUTPUT - name: Generate build date id: date run: | echo "date=$(date -u +'%Y-%m-%d')" >> $GITHUB_OUTPUT echo "datetime=$(date -u +'%Y-%m-%dT%H:%M:%SZ')" >> $GITHUB_OUTPUT - name: Determine FFmpeg version id: version run: | VERSION="${{ github.event.inputs.ffmpeg_version || 'snapshot' }}" echo "version=$VERSION" >> $GITHUB_OUTPUT if [ "$VERSION" = "snapshot" ]; then echo "tag=${{ steps.date.outputs.date }}" >> $GITHUB_OUTPUT else echo "tag=$VERSION" >> $GITHUB_OUTPUT fi - name: Extract metadata id: meta uses: docker/metadata-action@v5 with: images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} tags: | type=raw,value=cuda${{ steps.cuda.outputs.version }} type=raw,value=latest,enable=${{ steps.cuda.outputs.version == '13.0' }} type=raw,value=${{ steps.date.outputs.date }}-cuda${{ steps.cuda.outputs.version }},enable=${{ steps.version.outputs.version == 'snapshot' }} - name: Build and push uses: docker/build-push-action@v6 with: context: . file: Dockerfile.ffmpeg push: true tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} build-args: | BUILD_DATE=${{ steps.date.outputs.datetime }} FFMPEG_VERSION=${{ steps.version.outputs.version }} NVIDIA=${{ matrix.nvidia }} FFMPEG_BASE_IMAGE=${{ matrix.base_image }} no-cache: true - name: Verify image run: | docker pull ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:cuda${{ steps.cuda.outputs.version }} # Verify ffmpeg build - check binaries and shared library dependencies # Note: Static libraries (libx264, libx265, etc.) are compiled into the binary # and won't appear in ldd output. If they were missing, the build would have failed. docker run --rm ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:cuda${{ steps.cuda.outputs.version }} sh -c " echo '=== Verifying ffmpeg build ===' # Verify all binaries exist for bin in ffmpeg ffprobe ffplay; do test -x /usr/local/bin/\$bin || { echo \"ERROR: \$bin not found\"; exit 1; } echo \"OK: \$bin exists\" done # Get ldd output for shared library checks LDD_OUTPUT=\$(ldd /usr/local/bin/ffmpeg) # Verify non-NVIDIA shared dependencies are satisfied MISSING=\$(echo \"\$LDD_OUTPUT\" | grep 'not found' | grep -v -E 'libnvinfer|libnvonnxparser|libcudart|libcuda' || true) if [ -n \"\$MISSING\" ]; then echo 'ERROR: Missing non-NVIDIA libraries:' echo \"\$MISSING\" exit 1 fi echo 'OK: All non-NVIDIA shared dependencies satisfied' # Verify NVIDIA libraries (all loaded via dlopen - no hard dependencies) # - CUDA Driver API (libcuda): loaded via dlopen when TensorRT backend used # - TensorRT (libnvinfer): loaded via dlopen when TensorRT backend used echo '=== Verifying NVIDIA libraries ===' # Verify ffmpeg has NO hard CUDA dependency if echo \"\$LDD_OUTPUT\" | grep -q 'libcudart'; then echo 'ERROR: libcudart linked at compile time (should use dlopen)' exit 1 fi echo 'OK: No libcudart dependency (uses CUDA Driver API via dlopen)' ls /usr/lib/x86_64-linux-gnu/libnvinfer.so* >/dev/null 2>&1 || { echo 'ERROR: TensorRT (libnvinfer) not installed'; exit 1; } echo 'OK: libnvinfer installed (loaded via dlopen)' # Verify libva is linked (we build it as shared for runtime) echo \"\$LDD_OUTPUT\" | grep -q libva || { echo 'ERROR: libva not linked'; exit 1; } echo 'OK: libva linked' echo '' echo '=== All verifications passed ===' " ================================================ FILE: .github/workflows/release.yml ================================================ name: Release on: push: tags: - "v*" env: REGISTRY: ghcr.io IMAGE_NAME: ${{ github.repository }} FFMPEG_IMAGE: ghcr.io/${{ github.repository }}-ffmpeg:latest jobs: release: runs-on: ubuntu-latest permissions: contents: write packages: write steps: - uses: actions/checkout@v4 - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - name: Log in to Container Registry uses: docker/login-action@v3 with: registry: ${{ env.REGISTRY }} username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} - name: Extract metadata id: meta uses: docker/metadata-action@v5 with: images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} tags: | type=semver,pattern={{version}} type=semver,pattern={{major}}.{{minor}} type=semver,pattern={{major}} type=raw,value=latest - name: Build and push uses: docker/build-push-action@v6 with: context: . push: true tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} build-args: | FFMPEG_IMAGE=${{ env.FFMPEG_IMAGE }} cache-from: type=gha cache-to: type=gha,mode=max - name: Generate release notes id: notes run: | echo "## Docker Image" >> notes.md echo "" >> notes.md echo "Pull the image:" >> notes.md echo "\`\`\`bash" >> notes.md echo "docker pull ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ github.ref_name }}" >> notes.md echo "\`\`\`" >> notes.md echo "" >> notes.md echo "## What's Changed" >> notes.md git log $(git describe --tags --abbrev=0 HEAD^ 2>/dev/null || git rev-list --max-parents=0 HEAD)..HEAD --pretty=format:"- %s" >> notes.md || true - name: Create GitHub Release uses: softprops/action-gh-release@v2 with: body_path: notes.md generate_release_notes: true ================================================ FILE: .gitignore ================================================ # Python __pycache__/ *.py[cod] .venv/ .ruff_cache/ .pytest_cache/ # Debugging **/*.out **/*.log # UV uv.lock # App data .cache/ cache/ # Certificates *.pem # Tools - user data and generated files **/*.gz **/*.json **/*.m3u **/*.xml tools/.zap2xml/ ================================================ FILE: Dockerfile ================================================ # netv application image # # Default build uses pre-built FFmpeg with full hardware support: # docker compose build # # Alternative: use apt FFmpeg (fewer codecs, no NVENC/QSV): # FFMPEG_IMAGE=ubuntu:24.04 docker compose build # # The optimized FFmpeg base image includes: # - NVENC (NVIDIA hardware encoding) # - VAAPI (Intel/AMD hardware encoding) # - QSV/VPL (Intel QuickSync) # - All major codecs (x264, x265, VP9, AV1, etc.) ARG FFMPEG_IMAGE=ghcr.io/jvdillon/netv-ffmpeg:latest FROM ${FFMPEG_IMAGE} ENV DEBIAN_FRONTEND=noninteractive # Install dependencies # - If using apt ffmpeg (ubuntu base): install ffmpeg + python # - If using compiled ffmpeg (netv-ffmpeg base): ffmpeg already present, just install python # Note: The conditional must be evaluated in shell, not in Dockerfile syntax RUN apt-get update && \ apt-get install -y --no-install-recommends \ gosu \ python3 \ python3-pip && \ # Conditionally install ffmpeg if not present from base image if [ ! -x /usr/local/bin/ffmpeg ] && [ ! -x /usr/bin/ffmpeg ]; then \ apt-get install -y --no-install-recommends ffmpeg; \ fi && \ rm -rf /var/lib/apt/lists/* # App setup WORKDIR /app # Copy application files with verification COPY pyproject.toml README.md ./ COPY *.py ./ COPY templates/ templates/ COPY static/ static/ # Verify critical files exist RUN test -f pyproject.toml || { echo "ERROR: pyproject.toml not found"; exit 1; } # Install Python dependencies # --ignore-installed: avoids "Cannot uninstall X, RECORD file not found" for apt packages # --break-system-packages: required for PEP 668 (Ubuntu 24.04+), doesn't exist in pip 22.0 (Ubuntu 22.04) # Using try-fallback approach for maximum compatibility RUN if python3 -m pip install --help 2>&1 | grep -q -- '--break-system-packages'; then \ python3 -m pip install --no-cache-dir --ignore-installed --break-system-packages .; \ else \ python3 -m pip install --no-cache-dir --ignore-installed .; \ fi # Runtime config EXPOSE 8000 # Environment variables (see README for details) ENV NETV_PORT=8000 ENV NETV_HTTPS="" ENV LOG_LEVEL=INFO # Create non-root user (entrypoint handles permissions and group membership) RUN useradd -m netv # Copy entrypoint and set permissions with validation COPY entrypoint.sh /app/ RUN chmod +x /app/entrypoint.sh && \ test -x /app/entrypoint.sh || { echo "ERROR: entrypoint.sh not executable"; exit 1; } # Healthcheck with improved error handling # Note: start-period allows time for application startup HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \ CMD python3 -c "import urllib.request; r=urllib.request.urlopen('http://localhost:8000/', timeout=5); exit(0 if r.status==200 else 1)" 2>/dev/null || exit 1 ENTRYPOINT ["/app/entrypoint.sh"] ================================================ FILE: Dockerfile.ai_upscale ================================================ # netv with AI Upscale (TensorRT super-resolution) # # This image includes everything needed for AI upscaling: # - FFmpeg with TensorRT DNN backend # - Python + torch + tensorrt for building engines # - Auto-builds TensorRT engines on first start (GPU-specific) # # REQUIREMENTS: # - Docker BuildKit (DOCKER_BUILDKIT=1 or use docker buildx) # - NVIDIA GPU with 8GB+ VRAM recommended # - nvidia-container-toolkit installed on host # # Build: # docker build -f Dockerfile.ai_upscale -t netv-ai-upscale . # # Run (engines are cached in volume): # docker run --gpus all -v netv-models:/models -p 8000:8000 netv-ai-upscale # # Note: First start takes ~2-3 minutes to build TensorRT engines for your GPU. # Subsequent starts are instant (engines cached in /models volume). ARG FFMPEG_IMAGE=ghcr.io/jvdillon/netv-ffmpeg:latest FROM ${FFMPEG_IMAGE} # Build metadata (passed from workflow, used for documentation/debugging) ARG NVIDIA=cuda:13.0 ARG BASE_IMAGE=ubuntu:24.04 # Store build info as labels LABEL org.opencontainers.image.description="netv with AI Upscale (TensorRT)" LABEL ai.netv.cuda="${NVIDIA}" LABEL ai.netv.base="${BASE_IMAGE}" ENV DEBIAN_FRONTEND=noninteractive # Install dependencies RUN apt-get update && apt-get install -y --no-install-recommends \ gosu \ python3 \ python3-pip \ && rm -rf /var/lib/apt/lists/* # App setup WORKDIR /app # Copy application files with verification COPY pyproject.toml README.md ./ COPY *.py ./ COPY templates/ templates/ COPY static/ static/ COPY tools/ tools/ # Verify critical files exist RUN test -f pyproject.toml || { echo "ERROR: pyproject.toml not found"; exit 1; } && \ test -f tools/export-tensorrt.py || { echo "ERROR: export-tensorrt.py not found"; exit 1; } # Install Python dependencies with version constraints for stability # --ignore-installed: avoids "Cannot uninstall X, RECORD file not found" for apt packages # --break-system-packages: required for PEP 668 (Ubuntu 24.04+), doesn't exist in pip 22.0 (Ubuntu 22.04) # Version constraints: use compatible versions that work together RUN if python3 -m pip install --help 2>&1 | grep -q -- '--break-system-packages'; then \ PIP_OPTS="--no-cache-dir --ignore-installed --break-system-packages"; \ else \ PIP_OPTS="--no-cache-dir --ignore-installed"; \ fi && \ # Install with minimum version constraints for compatibility python3 -m pip install $PIP_OPTS \ 'torch>=2.1.0' \ 'onnx>=1.14.0' \ 'tensorrt>=9.0' \ . && \ # Verify packages installed correctly python3 -c "import torch; import onnx; import tensorrt; print(f'Installed: torch={torch.__version__}, onnx={onnx.__version__}, tensorrt={tensorrt.__version__}')" && \ # Remove Windows-only TensorRT libraries to save ~500MB # Log what we're deleting for debugging echo "Removing Windows-only TensorRT libraries..." && \ find /usr -name '*_win_*.so*' -type f 2>/dev/null | head -5 | xargs -I{} echo " Removing: {}" && \ find /usr -name '*_win_*.so*' -type f -delete 2>/dev/null || true && \ find /usr -name '*_win.so*' -type f -delete 2>/dev/null || true && \ echo "Cleanup complete" # Runtime config EXPOSE 8000 # Environment variables (see README for details) ENV NETV_PORT=8000 ENV NETV_HTTPS="" ENV LOG_LEVEL=INFO ENV SR_ENGINE_DIR=/models # Create non-root user RUN useradd -m netv # Copy entrypoint and set permissions with validation COPY entrypoint-ai_upscale.sh /app/entrypoint.sh RUN chmod +x /app/entrypoint.sh && \ test -x /app/entrypoint.sh || { echo "ERROR: entrypoint.sh not executable"; exit 1; } # Create models directory with proper permissions (will be a volume mount point) RUN mkdir -p /models && \ chown netv:netv /models && \ chmod 755 /models && \ test -d /models && test -w /models || { echo "ERROR: /models not writable"; exit 1; } VOLUME /models # Healthcheck with improved error handling # Note: start-period=60s allows time for TensorRT engine compilation on first start HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \ CMD python3 -c "import urllib.request; r=urllib.request.urlopen('http://localhost:8000/', timeout=5); exit(0 if r.status==200 else 1)" 2>/dev/null || exit 1 ENTRYPOINT ["/app/entrypoint.sh"] ================================================ FILE: Dockerfile.ffmpeg ================================================ # FFmpeg Docker image with hardware acceleration (NVENC, VAAPI, QSV, AMF) # # REQUIREMENTS: # - Docker BuildKit (DOCKER_BUILDKIT=1 or use docker buildx) # - 20GB+ disk space for build # # Uses install-ffmpeg.sh as single source of truth for build configuration. # # NVIDIA GPU compatibility: # FFmpeg compiles CUDA code to PTX (parallel thread execution) assembly, # NOT to GPU-specific SASS binary code. PTX is forward-compatible: the # NVIDIA driver JIT-compiles PTX to the actual GPU at runtime. # # This means a binary built with -arch=sm_52 (Maxwell) runs on ALL GPUs # from Maxwell through Blackwell and beyond. The only cost is JIT compilation # on first run (cached by driver for subsequent runs). # # For Docker builds, we use NVCC_GENCODE=minimum (sm_52 for CUDA <13, sm_75 for CUDA 13+) # to maximize GPU compatibility. For local builds, install-ffmpeg.sh defaults to # NVCC_GENCODE=native for best performance on the build machine. # ============================================================================= # Builder stage: compile FFmpeg using install-ffmpeg.sh # ============================================================================= ARG FFMPEG_BASE_IMAGE=ubuntu:24.04 FROM ${FFMPEG_BASE_IMAGE} AS builder # Build configuration - these are passed to install-ffmpeg.sh via environment # Hardware acceleration ARG ENABLE_NVIDIA_CUDA=1 ARG ENABLE_AMD_AMF=1 ARG ENABLE_TENSORRT=1 ARG ENABLE_LIBTORCH=0 ARG LIBTORCH_VERSION=2.5.0 ARG LIBTORCH_VARIANT=cu124 # Library builds ARG BUILD_LIBPLACEBO=1 ARG LIBPLACEBO_GIT_REF= ARG BUILD_LIBX265=1 ARG BUILD_LIBAOM=1 ARG BUILD_LIBWEBP=1 ARG BUILD_LIBVPL=1 ARG BUILD_LIBDAV1D=1 ARG BUILD_LIBSVTAV1=1 ARG BUILD_LIBVMAF=1 ARG BUILD_LIBVA=1 ARG BUILD_LIBJXL=1 ARG BUILD_LIBX264=1 ARG FFMPEG_VERSION=snapshot ARG NVIDIA=cuda:12.8 ARG NVCC_GENCODE=minimum ENV DEBIAN_FRONTEND=noninteractive # Override install-ffmpeg.sh paths for container ENV SRC_DIR=/src ENV BUILD_DIR=/opt/ffmpeg_build ENV BIN_DIR=/opt/bin ENV LIB_DIR=/opt/lib # Pass build args to script via env ENV ENABLE_NVIDIA_CUDA=${ENABLE_NVIDIA_CUDA} ENV ENABLE_AMD_AMF=${ENABLE_AMD_AMF} ENV ENABLE_TENSORRT=${ENABLE_TENSORRT} ENV ENABLE_LIBTORCH=${ENABLE_LIBTORCH} ENV LIBTORCH_VERSION=${LIBTORCH_VERSION} ENV LIBTORCH_VARIANT=${LIBTORCH_VARIANT} ENV BUILD_LIBPLACEBO=${BUILD_LIBPLACEBO} ENV LIBPLACEBO_GIT_REF=${LIBPLACEBO_GIT_REF} ENV BUILD_LIBX265=${BUILD_LIBX265} ENV BUILD_LIBAOM=${BUILD_LIBAOM} ENV BUILD_LIBWEBP=${BUILD_LIBWEBP} ENV BUILD_LIBVPL=${BUILD_LIBVPL} ENV BUILD_LIBDAV1D=${BUILD_LIBDAV1D} ENV BUILD_LIBSVTAV1=${BUILD_LIBSVTAV1} ENV BUILD_LIBVMAF=${BUILD_LIBVMAF} ENV BUILD_LIBVA=${BUILD_LIBVA} ENV BUILD_LIBJXL=${BUILD_LIBJXL} ENV BUILD_LIBX264=${BUILD_LIBX264} ENV FFMPEG_VERSION=${FFMPEG_VERSION} ENV NVIDIA=${NVIDIA} ENV NVCC_GENCODE=${NVCC_GENCODE} # Pre-configure timezone to prevent tzdata interactive prompts ENV DEBIAN_FRONTEND=noninteractive ENV TZ=Etc/UTC RUN ln -fs /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone # Install sudo (install-ffmpeg.sh uses it, works as no-op when root) RUN apt-get update && apt-get install -y sudo # Copy build script and patches COPY tools/install-ffmpeg.sh /tmp/ COPY tools/patches /tmp/patches # Run the build script # Extract CUDA version from NVIDIA arg (e.g., "cuda:13.0" -> "13.0") # Note: echo BUILD_ARGS forces cache invalidation when any build arg changes RUN echo "BUILD_ARGS: NVIDIA=${NVIDIA} FFMPEG_VERSION=${FFMPEG_VERSION} NVCC_GENCODE=${NVCC_GENCODE} \ ENABLE_NVIDIA_CUDA=${ENABLE_NVIDIA_CUDA} ENABLE_AMD_AMF=${ENABLE_AMD_AMF} ENABLE_TENSORRT=${ENABLE_TENSORRT} \ ENABLE_LIBTORCH=${ENABLE_LIBTORCH} BUILD_LIBPLACEBO=${BUILD_LIBPLACEBO} BUILD_LIBX265=${BUILD_LIBX265} \ BUILD_LIBAOM=${BUILD_LIBAOM} BUILD_LIBWEBP=${BUILD_LIBWEBP} BUILD_LIBVPL=${BUILD_LIBVPL} \ BUILD_LIBDAV1D=${BUILD_LIBDAV1D} BUILD_LIBSVTAV1=${BUILD_LIBSVTAV1} BUILD_LIBVMAF=${BUILD_LIBVMAF} \ BUILD_LIBVA=${BUILD_LIBVA} BUILD_LIBJXL=${BUILD_LIBJXL} BUILD_LIBX264=${BUILD_LIBX264}" && \ chmod +x /tmp/install-ffmpeg.sh && \ CUDA_VERSION="${NVIDIA#cuda:}" && \ export CUDA_VERSION && \ /tmp/install-ffmpeg.sh # ============================================================================= # Runtime stage: minimal image with just FFmpeg binaries # ============================================================================= ARG FFMPEG_BASE_IMAGE=ubuntu:24.04 FROM ${FFMPEG_BASE_IMAGE} ARG BUILD_DATE ARG FFMPEG_VERSION=snapshot ARG ENABLE_NVIDIA_CUDA=1 ARG ENABLE_AMD_AMF=1 ARG ENABLE_LIBTORCH=0 ARG BUILD_LIBPLACEBO=1 ARG BUILD_LIBX265=1 ARG BUILD_LIBAOM=1 ARG BUILD_LIBWEBP=1 ARG BUILD_LIBVPL=1 ARG BUILD_LIBDAV1D=1 ARG BUILD_LIBSVTAV1=1 ARG BUILD_LIBVMAF=1 ARG BUILD_LIBVA=1 ARG BUILD_LIBJXL=1 ARG BUILD_LIBX264=1 LABEL org.opencontainers.image.created="${BUILD_DATE}" LABEL org.opencontainers.image.title="netv-ffmpeg" LABEL org.opencontainers.image.description="FFmpeg with NVENC, VAAPI, QSV, AMF hardware acceleration" LABEL org.opencontainers.image.version="${FFMPEG_VERSION}" ENV DEBIAN_FRONTEND=noninteractive # Add Intel graphics PPA for newer Xe driver support (Ubuntu 24.04+ only) RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ --mount=type=cache,target=/var/lib/apt,sharing=locked \ . /etc/os-release && \ if [ "$VERSION_CODENAME" != "jammy" ]; then \ apt-get update && apt-get install -y --no-install-recommends software-properties-common && \ add-apt-repository -y ppa:kobuk-team/intel-graphics && \ apt-get update; \ fi # Add NVIDIA repo and install TensorRT runtime (for TensorRT DNN backend) # Note: TensorRT is loaded via dlopen, so ffmpeg works even if these aren't installed, # but including them means the Docker image works out of the box with --gpus all ARG ENABLE_TENSORRT=1 RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ --mount=type=cache,target=/var/lib/apt,sharing=locked \ if [ "$ENABLE_TENSORRT" = "1" ]; then \ . /etc/os-release && \ UBUNTU_VER=$(echo "$VERSION_ID" | tr -d '.') && \ apt-get update && apt-get install -y --no-install-recommends ca-certificates curl && \ curl -fsSL "https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VER}/x86_64/cuda-keyring_1.1-1_all.deb" \ -o /tmp/cuda-keyring.deb && \ dpkg -i /tmp/cuda-keyring.deb && rm /tmp/cuda-keyring.deb && \ apt-get update && \ if [ "$VERSION_CODENAME" = "jammy" ]; then \ apt-get install -y --no-install-recommends libnvinfer8 libnvinfer-plugin8; \ else \ apt-get install -y --no-install-recommends libnvinfer10 libnvinfer-plugin10; \ fi && \ rm -rf /var/lib/apt/lists/*; \ fi # Runtime libraries for FFmpeg # Note: x265, libaom, libwebp, libvpl, libdav1d are statically linked when built from source RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ --mount=type=cache,target=/var/lib/apt,sharing=locked \ . /etc/os-release && \ if [ "$VERSION_CODENAME" = "jammy" ]; then \ LIBVPX=libvpx7; \ LIBSRT=libsrt1.4-openssl; \ LIBUNISTRING=libunistring2; \ LIBASOUND=libasound2; \ LIBSNDIO=libsndio7.0; \ else \ LIBVPX=libvpx9; \ LIBSRT=libsrt1.5-openssl; \ LIBUNISTRING=libunistring5; \ LIBASOUND=libasound2t64; \ LIBSNDIO=libsndio7.0; \ fi && \ apt-get update && apt-get install -y --no-install-recommends \ # Core codec libs libass9 \ libbluray2 \ libfdk-aac2 \ libmp3lame0 \ libopus0 \ libvorbis0a \ libvorbisenc2 \ $LIBVPX \ # Text/font rendering libfontconfig1 \ libfreetype6 \ libfribidi0 \ libharfbuzz0b \ # Audio/video processing librubberband2 \ libsoxr0 \ libvidstab1.1 \ libzimg2 \ libnuma1 \ # Network/crypto $LIBSRT \ libssl3 \ # X11/display libxcb1 \ libxcb-shm0 \ libxcb-shape0 \ libxcb-xfixes0 \ libxv1 \ libx11-6 \ libxext6 \ # Hardware accel: # - NVENC/CUDA: provided by nvidia-container-toolkit from host (no pkg needed) # - VAAPI: intel-media-va-driver-non-free (Intel), mesa-va-drivers (AMD), libva from source (below) # - OpenCL: ocl-icd-libopencl1 (ICD loader), backend from host (NVIDIA) or mesa-opencl-icd (AMD) # - Vulkan: libvulkan1 (conditional below), driver from host libvdpau1 \ intel-media-va-driver-non-free \ mesa-va-drivers \ ocl-icd-libopencl1 \ # Intel oneVPL/QSV runtime for Intel GPU hardware encoding (modern Intel CPUs) libmfx-gen1.2 \ # Other deps zlib1g \ $LIBUNISTRING \ liblzma5 \ liblzo2-2 \ $LIBASOUND \ libdrm2 \ $LIBSNDIO \ libsdl2-2.0-0 \ libpulse0 # Conditional runtime libs for apt-based packages (when not built from source) RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ --mount=type=cache,target=/var/lib/apt,sharing=locked \ . /etc/os-release && \ HWY_PKG="libhwy1t64" && \ if [ "$VERSION_CODENAME" = "jammy" ]; then HWY_PKG="libhwy0"; fi && \ APT_PKGS="" && \ [ "$BUILD_LIBX265" != "1" ] && APT_PKGS="$APT_PKGS libx265-199" ; \ [ "$BUILD_LIBAOM" != "1" ] && APT_PKGS="$APT_PKGS libaom3" ; \ [ "$BUILD_LIBWEBP" != "1" ] && APT_PKGS="$APT_PKGS libwebp7 libwebpmux3" ; \ [ "$BUILD_LIBVPL" != "1" ] && APT_PKGS="$APT_PKGS libvpl2" ; \ [ "$BUILD_LIBDAV1D" != "1" ] && APT_PKGS="$APT_PKGS libdav1d7" ; \ [ "$BUILD_LIBSVTAV1" != "1" ] && APT_PKGS="$APT_PKGS libsvtav1enc1d1" ; \ [ "$BUILD_LIBVMAF" != "1" ] && APT_PKGS="$APT_PKGS libvmaf3" ; \ [ "$BUILD_LIBPLACEBO" = "1" ] && APT_PKGS="$APT_PKGS libvulkan1" ; \ [ "$BUILD_LIBVA" != "1" ] && APT_PKGS="$APT_PKGS libva2 libva-drm2 libva-x11-2" ; \ [ "$BUILD_LIBJXL" = "1" ] && APT_PKGS="$APT_PKGS libbrotli1 $HWY_PKG" ; \ [ "$BUILD_LIBJXL" != "1" ] && APT_PKGS="$APT_PKGS libjxl0.7" ; \ [ "$BUILD_LIBX264" != "1" ] && APT_PKGS="$APT_PKGS libx264-164" ; \ if [ -n "$APT_PKGS" ]; then apt-get update && apt-get install -y --no-install-recommends $APT_PKGS; fi # Tell libva where to find system drivers and which driver to use # Our libva is built with prefix=/opt/ffmpeg_build but drivers are system-installed # Default LIBVA_DRIVER_NAME=iHD (Intel iHD driver, supports Xe kernel driver, Gen8+) # Override at runtime for other GPUs: # - AMD: LIBVA_DRIVER_NAME=radeonsi # - Older Intel: LIBVA_DRIVER_NAME=i965 # - Auto-detect: unset LIBVA_DRIVER_NAME (let libva auto-select) ENV LIBVA_DRIVERS_PATH=/usr/lib/x86_64-linux-gnu/dri # Note: Set to empty to allow auto-detection, or override at container runtime ARG LIBVA_DRIVER_NAME_DEFAULT=iHD ENV LIBVA_DRIVER_NAME=${LIBVA_DRIVER_NAME_DEFAULT} # Copy FFmpeg binaries from builder and verify they exist COPY --from=builder /opt/bin/ffmpeg /usr/local/bin/ COPY --from=builder /opt/bin/ffprobe /usr/local/bin/ COPY --from=builder /opt/bin/ffplay /usr/local/bin/ # Verify FFmpeg binaries are executable and have expected size (not empty) RUN for bin in ffmpeg ffprobe ffplay; do \ if [ ! -x "/usr/local/bin/$bin" ]; then \ echo "ERROR: $bin not executable or not found"; exit 1; \ fi; \ SIZE=$(stat -c%s "/usr/local/bin/$bin" 2>/dev/null || echo 0); \ if [ "$SIZE" -lt 1000000 ]; then \ echo "ERROR: $bin seems too small (${SIZE} bytes), may be corrupt"; exit 1; \ fi; \ done && echo "FFmpeg binaries verified" # Copy built libva if compiled from source (for Intel Xe kernel driver support) # libva needs to be in /opt/lib to match the rpath embedded in ffmpeg binary RUN --mount=type=bind,from=builder,source=/opt/lib,target=/tmp/ffmpeg_libs \ if [ "$BUILD_LIBVA" = "1" ] && [ -f /tmp/ffmpeg_libs/libva.so ]; then \ mkdir -p /opt/lib && \ cp -a /tmp/ffmpeg_libs/libva*.so* /opt/lib/ && \ echo "/opt/lib" > /etc/ld.so.conf.d/ffmpeg.conf && \ ldconfig; \ fi # Copy VMAF model files if built from source (needed for -vf libvmaf filter) RUN --mount=type=bind,from=builder,source=/opt/ffmpeg_build/share,target=/tmp/ffmpeg_share \ if [ "$BUILD_LIBVMAF" = "1" ] && [ -d /tmp/ffmpeg_share/libvmaf ]; then \ mkdir -p /usr/local/share && \ cp -a /tmp/ffmpeg_share/libvmaf /usr/local/share/; \ fi # Copy LibTorch shared libraries if torch enabled (needed for DNN filters) RUN --mount=type=bind,from=builder,source=/src,target=/tmp/src \ if [ "$ENABLE_LIBTORCH" = "1" ] && [ -d /tmp/src/libtorch/lib ]; then \ mkdir -p /opt/lib && \ cp -a /tmp/src/libtorch/lib/*.so* /opt/lib/ && \ echo "/opt/lib" > /etc/ld.so.conf.d/libtorch.conf && \ ldconfig; \ fi # Verify all dependencies are satisfied (exclude NVIDIA libs - provided at runtime by nvidia-container-toolkit) # NVIDIA libraries excluded: libnvinfer, libnvonnxparser, libcudart, libcuda, libnvcuvid, libnvrtc RUN set -e && \ echo "=== Checking library dependencies ===" && \ LDD_OUTPUT=$(ldd /usr/local/bin/ffmpeg 2>&1) && \ ALL_MISSING=$(echo "$LDD_OUTPUT" | grep "not found" || true) && \ if [ -n "$ALL_MISSING" ]; then \ echo "Libraries reported as 'not found':" && \ echo "$ALL_MISSING" && \ # Filter out NVIDIA libraries (provided by nvidia-container-toolkit at runtime) MISSING=$(echo "$ALL_MISSING" | grep -v -E "libnv|libcuda|libcublas|libcurand|libcufft" || true) && \ if [ -n "$MISSING" ]; then \ echo "ERROR: Non-NVIDIA libraries missing:" && \ echo "$MISSING" && \ exit 1; \ fi && \ echo "All missing libs are NVIDIA (provided at runtime)"; \ fi && \ echo "All FFmpeg dependencies satisfied" # Capture ffmpeg capabilities (may fail if TensorRT enabled - NVIDIA libs only available at runtime) RUN ffmpeg -version > /ffmpeg-version.txt && \ ffmpeg -hide_banner -encoders > /ffmpeg-encoders.txt && \ ffmpeg -hide_banner -decoders > /ffmpeg-decoders.txt && \ ffmpeg -hide_banner -filters > /ffmpeg-filters.txt || \ echo "Skipped (NVIDIA libs not available during build)" CMD ["ffmpeg", "-version"] ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to the Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS Copyright 2024 Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # neTV A minimal, self-hosted web interface for IPTV streams. ![EPG Guide](screenshots/epg.png) ![Player](screenshots/player.png) ![VOD](screenshots/vod.png) ![Series](screenshots/series.png) ![Settings](screenshots/settings.png) ## Why This Exists We built neTV because we couldn't find a clean, lightweight interface for Xtream IPTV services. Existing solutions were either bloated media centers or clunky apps that didn't work well across devices. **neTV is intentionally minimal.** It does one thing: play your IPTV streams with a clean UI that works on desktop, tablet, mobile, and Chromecast. We also prioritize **keyboard navigation** throughout (though still rough around the edges). The entire app is theoretically usable with just arrow keys, Enter, and Escape -- perfect for media PCs, HTPCs, or anyone who prefers keeping hands on the keyboard (like me). ### Disclaimer This is a **player only** -- it does not provide any content. You must have your own IPTV subscription that provides Xtream Codes API access or M3U playlists. Users are responsible for ensuring they have legal rights to access any content through their IPTV providers. ## Features - **Live TV** with EPG grid guide - **Movies & Series** with metadata, seasons, episodes - **AI Upscale** - Real-time 4x upscaling via TensorRT (720p → 4K @ 85fps) - **Chromecast** support (HTTPS required) - **Closed captions** with style customization - **Search** across all content (supports regex) - **Favorites** with drag-and-drop ordering - **Resume playback** for VOD content - **Responsive** - works on desktop, tablet, mobile - **Keyboard navigation** - 10-foot UI friendly ### Transcoding Extensively optimized for minimal latency and CPU usage: - **Smart passthrough** - h264+aac streams remux without re-encoding (zero CPU) - **Full GPU pipeline** - NVDEC decode → NVENC/VAAPI encode, CPU stays idle - **Probe caching** - Streams probed once, series episodes share probe data - **Interlace detection** - Auto-deinterlaces OTA/cable, skips progressive - **Smart seeking** - Reuses segments for backward seeks, only transcodes gaps - **Session recovery** - VOD sessions survive restarts, resume where you left off - **HTTPS passthrough** - Auto-proxies HTTP streams when behind HTTPS ### 4K AI Upscaling Real-time 4x upscaling using Real-ESRGAN via TensorRT. Transforms 480p/720p/1080p content to pristine 4K at 85fps (RTX 5090). Perfect for older shows and low-bitrate streams. | Before (720p source) | After (4K AI Upscale) | |---|---| | ![Before](screenshots/ai-upscale_price-is-right_disabled.png) | ![After](screenshots/ai-upscale_price-is-right_enabled.png) | | ![Before](screenshots/ai-upscale_cleopatra_disabled.png) | ![After](screenshots/ai-upscale_cleopatra_enabled.png) | | ![Before](screenshots/ai-upscale_batman_disabled.png) | ![After](screenshots/ai-upscale_batman_enabled.png) | Requires Nvidia GPU and the [AI Upscale Docker image](#ai-upscale-image-nvidia-gpu). The Settings page shows AI Upscale options when TensorRT engines are available. ## Alternatives If you want a full-featured media center, you might be happier with: - **[Jellyfin](https://jellyfin.org/)** - Free, open-source media system - **[Emby](https://emby.media/)** - Media server with IPTV support - **[Plex](https://plex.tv/)** - Popular media platform with live TV These are excellent, mature projects with large communities. neTV exists for users who find them overkill and just want a simple IPTV player. | | neTV | [nodecast-tv] | [Jellyfin] | [Emby] | [Plex] | |---|---|---|---|---|---| | **Focus** | IPTV | IPTV | General media | General media | General media | | **Xtream Codes** | ✅ | ✅ | ❌ | ❌ | ❌ | | **M3U playlists** | ✅ | ✅ | ✅ | ✅ | ⚠️ Via [xTeVe] | | **XMLTV EPG** | ✅ | ⚠️ Via provider | ✅ | ✅ | ✅ | | **Local media** | ❌ | ❌ | ✅ | ✅ | ✅ | | **Live TV** | ✅ | ✅ | ✅ | ✅ | ✅ | | **VOD (movies/series)** | ✅ | ✅ | ✅ | ✅ | ✅ | | **DVR recording** | ❌ | ❌ | ✅ | ✅ | ⚠️ Pass | | **Catchup/timeshift** | ❌ | ❌ | ⚠️ Plugin | ⚠️ Plugin | ❌ | | **Live rewind buffer** | ✅ | ❌ | ⚠️ Via DVR | ⚠️ Via DVR | ⚠️ Via DVR | | **Resume playback** | ✅ | ❌ | ✅ | ✅ | ✅ | | **Multi-user** | ✅ | ✅ | ✅ | ✅ | ✅ | | **User roles** | ⚠️ Admin/viewer | ⚠️ Admin/viewer | ✅ Granular | ✅ Granular | ✅ Granular | | **Stream limits** | ✅ Per-user, per-source | ❌ | ⚠️ Per-user | ⚠️ Per-user | ⚠️ Per-user | | **Library permissions** | N/A | N/A | ✅ Per-library | ✅ Per-library | ✅ Per-library | | **Favorites** | ✅ Drag-and-drop | ✅ | ✅ | ✅ | ✅ | | **Search** | ✅ Regex | ✅ Basic | ✅ Basic | ✅ Basic | ✅ Basic | | **Video transcoding** | ✅ | ❌ | ✅ | ✅ | ✅ | | **Audio transcoding** | ✅ | ✅ | ✅ | ✅ | ✅ | | **Transcode only if needed** | ✅ Auto mode | ❌ | ⚠️ Per-library | ⚠️ Per-library | ⚠️ Per-client | | **NVENC** | ✅ | ❌ | ✅ | ✅ | ⚠️ Pass | | **VAAPI** | ✅ | ❌ | ✅ | ✅ | ⚠️ Pass | | **QSV** | ✅ | ❌ | ✅ | ✅ | ⚠️ Pass | | **AI Upscale (4x)** | ✅ TensorRT | ❌ | ⚠️ Plugin | ❌ | ❌ | | **Software fallback** | ✅ | ❌ Browser | ✅ | ✅ | ✅ | | **Legacy GPU** | ✅ Any | ❌ No (browser) | ✅ Any | ✅ Any | ⚠️ Driver 450+ | | **ffprobe caching** | ✅ Dynamic | ❌ None | ⚠️ Offline | ⚠️ Offline | ⚠️ Offline | | **Episode probe reuse** | ✅ MRU | ❌ No | ⚠️ Per-file | ⚠️ Per-file | ⚠️ Per-file | | **Session recovery** | ✅ Yes | ❌ No | ⚠️ Via DB | ⚠️ Via DB | ⚠️ Via DB | | **Auto deinterlace** | ✅ Yes | ❌ No | ⚠️ Manual | ⚠️ Manual | ⚠️ Manual | | **Subtitles** | ⚠️ WebVTT | ❌ No | ✅ Full | ✅ Full | ✅ Full | | **Chromecast** | ✅ Yes | ❌ No | ✅ Yes | ✅ Yes | ✅ Yes | | **Keyboard/remote** | ✅ 10-foot UI | ⚠️ Basic | ✅ 10-foot UI | ✅ 10-foot UI | ✅ 10-foot UI | | **Mobile apps** | ⚠️ Web only | ⚠️ Web only | ✅ Native | ✅ Native | ✅ Native | | **Subscription** | ✅ Free | ✅ Free | ✅ Free | ⚠️ Premiere | ⚠️ Pass | | **Setup complexity** | ✅ Minimal | ✅ Minimal | ⚠️ Moderate | ⚠️ Moderate | ⚠️ Moderate | | **License** | Apache 2.0 | GPL v3 | GPL v2 | GPL v2 | Proprietary | | **Stack** | Python, FFmpeg | Node.js | .NET, FFmpeg | .NET, FFmpeg | Proprietary | *Corrections welcome — [open an issue](https://github.com/jvdillon/netv/issues).* [nodecast-tv]: https://github.com/technomancer702/nodecast-tv [Jellyfin]: https://jellyfin.org [Emby]: https://emby.media [Plex]: https://plex.tv [xTeVe]: https://github.com/xteve-project/xTeVe ## Installation ### Docker Create a `docker-compose.yml`: ```yaml services: netv: image: ghcr.io/jvdillon/netv:latest ports: - "8000:8000" volumes: - ./cache:/app/cache - /etc/localtime:/etc/localtime:ro devices: - /dev/dri:/dev/dri # for hardware transcoding (remove if no GPU) restart: unless-stopped ``` Then run: ```bash docker compose up -d ``` Open http://localhost:8000. To update: `docker compose pull && docker compose up -d` #### Optional: Nonfree (proprietary) FFMPEG optimized for Nvidia or AMD and/or Intel GPU We provide a custom built ffmpeg with Nvidia, AMD, and Intel _proprietary support_ for GPUs. Notably, essential packages are built from source and often _significantly_ newer than what is baked into Ubuntu 2024 (LTS). The custom built ffmpeg is not required unless you want: - best possible GPU performance, - bleeding edge capability, - to use AMD discrete GPU, - realtime AI upscaling (Nvidia only). Note: the custom built ffmpeg will generally work even if a dependency is not available. In such cases the specific capability will not be available but other capabilities will still work. In this sense the custom built ffmpeg is a "kitchen sink" build. | | Custom ffmpeg | Ubuntu ffmpeg | |---|---|---| | Intel or AMD Integrated GPU (VAAPI) | ✅ | ✅ | | Intel Integrated GPU (QSV QuickSync) | ✅ | ✅ | | Nvidia Discrete GPU (NVENC via LLVM) | ❌ | ✅ | | Nvidia Discrete GPU (NVENC via nvcc) | ✅ | ❌ | | AMD Discrete GPU (AMF) | ✅ | ❌ | | Fraunhofer FDK AAC | ✅ | ❌ | | Realtime AI Upscale (Nvidia TensorRT/Cuda) | ✅ | ❌ | | AV1 Vulkan | ✅ | ❌ | | Torch (Nvidia Cuda) | ⚠️ Optional | ❌ | For Nvidia, you will need the [nvidia-container-toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html). To determine which ffmpeg build for Cuda, check your driver and compute capability: ```bash nvidia-smi --query-gpu=driver_version,compute_cap --format=csv,noheader # Example: 580.87.02, 8.6 → Driver 580, compute ≥7.5 → use cuda13.0 ``` Find your CUDA version ([source](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html)): | Driver | < 7.5 (Maxwell/Pascal/Volta) | ≥ 7.5 (Turing+) | |--------|------------------------------|-----------------| | 550 | cuda12.4 | cuda12.4 | | 560 | cuda12.6 | cuda12.6 | | 570 | cuda12.8 | cuda12.8 | | 580+ | cuda12.8 | cuda13.0 | Then run: ```bash FFMPEG_IMAGE=ghcr.io/jvdillon/netv-ffmpeg: docker compose --profile nvidia up -d ``` For AMD or Intel, it does not matter which version you choose nor do you need Cuda installed. #### Optional: AI Upscaling (Nvidia GPU only) For real-time 2x or 4x AI upscaling (4x: 720p → 4K at ~39fps or 480p → 4K at ~85fps on RTX 5090): ```bash git clone https://github.com/jvdillon/netv.git cd netv docker build -f Dockerfile.ai_upscale -t netv-ai-upscale . docker run --gpus all -v netv-models:/models -v ./cache:/app/cache -p 8000:8000 netv-ai-upscale ``` First start builds TensorRT engines for your GPU (~2-3 min). Engines are cached in the `netv-models` volume for instant subsequent starts. Requirements: - Nvidia GPU (RTX 20xx or newer recommended) - [nvidia-container-toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) - Driver 535+ (CUDA 12.x) #### Docker Custom Builds For customization or development: ```bash git clone https://github.com/jvdillon/netv.git cd netv docker compose build # optimized FFmpeg (default) # FFMPEG_IMAGE=ubuntu:24.04 docker compose build # or stock FFmpeg docker compose up -d ``` To update: `git pull && docker compose build && docker compose up -d` #### Options ```bash NETV_PORT=9000 docker compose up -d # custom port NETV_HTTPS=1 docker compose up -d # enable HTTPS (mount certs first) ``` ### Debian/Ubuntu (`systemd`) For peak FFMPEG performance, Chromecast (requires HTTPS), and auto-start: ```bash # 1. Install prerequisites (uv, Python) ./tools/install-prereqs.sh # 2. (Optional) Get HTTPS certificates (required for Chromecast) ./tools/install-letsencrypt.sh yourdomain.com # 3. (Optional) Build FFmpeg (required for optimal Nvidia encoding efficiency) ./tools/install-ffmpeg.sh # 4. (Optional) Build AI Upscale engines (requires Nvidia GPU + TensorRT) uv sync --group ai_upscale ./tools/install-ai_upscale.sh # 5. Install systemd service sudo ./tools/install-netv.sh # default port=8000 or --port 9000 ``` Manage with: ```bash sudo systemctl status netv # Check status sudo systemctl restart netv # Restart after updates journalctl -u netv -f # View logs sudo systemctl edit netv --full # Change port or other settings sudo ./tools/uninstall-netv.sh # Uninstall ``` ### Development/Testing Requires Python 3.11+ and [uv](https://docs.astral.sh/uv/): ```bash git clone https://github.com/jvdillon/netv.git cd netv uv run ./main.py --port 8000 # --https ``` Or with pip: ```bash pip install . ./main.py --port 8000 ``` Open http://localhost:8000, create an admin account, and add your IPTV source. ### Additional Gems There's also some useful applications in `tools/`: - `zap2xml.py`: Scrape guide data into XML (I `crontab` this at 5am daily). - `alignm3u.py`: Useful for reworking your HDHomeRun m3u to align with guide. - `xtream2m3u.py`: Dump xtream to m3u, useful for making Emby work with IPTV. ## Troubleshooting ### Debug Logging Enable verbose logs to diagnose EPG, M3U parsing, or other issues. **Docker:** In `docker-compose.yml`, change `LOG_LEVEL=INFO` to `LOG_LEVEL=DEBUG`, then restart: ```bash docker compose down && docker compose up -d docker compose logs -f ``` **Systemd:** ```bash sudo systemctl edit netv ``` Add: ```ini [Service] Environment="LOG_LEVEL=DEBUG" ``` Then restart and view logs: ```bash sudo systemctl restart netv journalctl -u netv -f ``` **Manual / Development:** ```bash LOG_LEVEL=DEBUG ./main.py # or ./main.py --debug ``` ## Q&A ### Where can I get free IPTV? Check out [iptv-org/iptv](https://github.com/iptv-org/iptv) -- a community-maintained collection of publicly available IPTV channels from around the world. ### Where can I get TV guide data? The free choice is [iptv-org/epg](https://github.com/iptv-org/epg), but this has never worked reliably for me. For a more robust solution, consider [Schedules Direct](https://schedulesdirect.org/) -- your membership helps fund Open Source projects. Alternatively you can use `tools/zap2xml.py`. I've used this for over a year and found it to be very reliable -- it scrapes guide data from zap2it/gracenote. ### How do I set up HDHomeRun? HDHomeRun devices provide an M3U playlist, but it lacks EPG channel IDs. Use the `tools/` to fetch guide data and align it: ```bash # 1. Get your HDHomeRun lineup (replace IP with your device's IP) wget http://192.168.1.87/lineup.m3u -O tools/lineup.m3u # 2. Fetch TV guide data for your area ./tools/zap2xml.py --zip 90210 # 3. Align the M3U with the guide (adds tvg-id for EPG matching) ./tools/alignm3u.py --input tools/lineup.m3u --xmltv tools/xmltv.xml --output tools/ota.m3u ``` Then add `tools/ota.m3u` as an M3U source in neTV settings. And set up a cron job to refresh the guide daily (e.g., `0 5 * * * /usr/bin/python3 /path/to/netv/tools/zap2xml.py --zip 90210 && cp /path/to/netv/tools/xmltv.xml /var/www/html/`). ### How do I enable hardware transcoding? Hardware transcoding is auto-detected. Check Settings to see available encoders. - **Intel/AMD (VAAPI)**: Works automatically if `/dev/dri` exists. - **Nvidia**: Requires [nvidia-container-toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html). See [Nvidia GPU (NVENC)](#nvidia-gpu-nvenc) installation section for driver/compute compatibility table. - **No GPU / VPS**: If `/dev/dri` doesn't exist, comment out the `devices` section in `docker-compose.yml` or compose will fail to start ### How do I install CUDA on Ubuntu? Tested on Ubuntu 24.04 LTS, 25.04, and 25.10. ```bash # Step 1: Remove existing Nvidia packages sudo apt purge -y '^nv.*' '^libnv.*' '^cuda-.*' '^libcuda-.*' '^cudnn[0-9]*-.*' '^libcudnn[0-9]*-.*' sudo apt autoremove -y # Step 2: Add Nvidia CUDA repository wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb sudo dpkg -i cuda-keyring_1.1-1_all.deb sudo apt modernize-sources || true sudo apt update # Step 3: Install driver and CUDA toolkit # For Turing+ GPUs (RTX 20 series and newer, compute >=7.5): sudo apt install -y nvidia-open cuda-toolkit-13 cudnn9-cuda-13 libcudnn9-dev-cuda-13 libnvinfer-bin # For Maxwell/Pascal GPUs (GTX 900/1000 series, compute <7.5): # Driver 590 dropped support. Pin to 580 and use CUDA 12.8. # Note: Maxwell/Pascal requires nvidia-driver (proprietary), not nvidia-open. # sudo apt install -y nvidia-driver-pinning-580 # sudo apt install -y nvidia-driver-580 cuda-toolkit-12-8 cudnn9-cuda-12-8 libcudnn9-dev-cuda-12 libnvinfer-bin # sudo update-alternatives --set cuda /usr/local/cuda-12.8 # Step 4: Configure environment tee -a ~/.bashrc << 'EOF' export CUDA_HOME=/usr/local/cuda if [ -d $CUDA_HOME ]; then export PATH="${CUDA_HOME}/bin${PATH:+:${PATH}}" export LD_LIBRARY_PATH="${CUDA_HOME}/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" fi unset CUDA_HOME EOF source ~/.bashrc # Step 5: Verify installation nvidia-smi --query-gpu=name,compute_cap,driver_version --format=csv,noheader nvcc --version ``` ### What are the keyboard shortcuts? | Key | Action | |-----|--------| | `Space` / `k` | Play/pause | | `f` | Fullscreen | | `m` | Mute | | `c` | Toggle captions | | `i` | Toggle info overlay | | `←` / `→` | Seek ±10s | | `↑` / `↓` | Volume | | `j` | Jump to time | | `Esc` | Back / close | ### What Does "neTV" Mean? Yes. We leave pronunciation and meaning as an exercise for your idiom: - **N-E-T-V** -- "Any TV", say it out loud - **≠TV** -- "Not Equals TV", because we're `!=` traditional cable - **Net-V** -- "Net Vision", because it streams video over your network - **Ni!-TV** -- For the [Knights who say Ni](https://www.youtube.com/watch?v=zIV4poUZAQo) We will also accept a shrubbery. One that looks nice. And not too expensive. ## Support If you find neTV useful, consider buying me a coffee: [![Buy Me a Coffee](https://img.shields.io/badge/Buy%20Me%20a%20Coffee-ffdd00?style=flat&logo=buy-me-a-coffee&logoColor=black)](https://buymeacoffee.com/jvdillon) ## License Apache License 2.0 ================================================ FILE: __init__.py ================================================ ================================================ FILE: auth.py ================================================ """Authentication: users, passwords, tokens, JWT.""" from __future__ import annotations from typing import Any import hashlib import hmac import json import pathlib import secrets import time APP_DIR = pathlib.Path(__file__).parent # Use old "cache" if it exists (backwards compat), otherwise ".cache" _OLD_CACHE = APP_DIR / "cache" CACHE_DIR = _OLD_CACHE if _OLD_CACHE.exists() else APP_DIR / ".cache" SERVER_SETTINGS_FILE = CACHE_DIR / "server_settings.json" USERS_DIR = CACHE_DIR / "users" TOKEN_EXPIRY = 86400 * 7 # 7 days def _get_settings_file() -> pathlib.Path: """Get the settings file.""" return SERVER_SETTINGS_FILE def _get_secret_key() -> str: """Get or generate secret key (persisted in settings).""" settings_file = _get_settings_file() settings = {} if settings_file.exists(): settings = json.loads(settings_file.read_text()) if "secret_key" not in settings: settings["secret_key"] = secrets.token_hex(32) settings_file.write_text(json.dumps(settings, indent=2)) return settings["secret_key"] def _hash_password(password: str, salt: str | None = None) -> str: """Hash password with salt using PBKDF2.""" if salt is None: salt = secrets.token_hex(16) key = hashlib.pbkdf2_hmac("sha256", password.encode(), salt.encode(), 100000) return f"{salt}:{key.hex()}" def _verify_hashed_password(password: str, hashed: str) -> bool: """Verify password against hash.""" if ":" not in hashed: return False # Invalid hash format salt, _ = hashed.split(":", 1) return hmac.compare_digest(_hash_password(password, salt), hashed) def _get_users() -> dict[str, dict[str, Any]]: """Get users from settings. Returns empty dict if no users configured. User format: {username: {password: str, admin: bool}} """ settings_file = _get_settings_file() if settings_file.exists(): settings = json.loads(settings_file.read_text()) return settings.get("users", {}) return {} def get_all_usernames() -> list[str]: """Get list of all usernames.""" return list(_get_users().keys()) def is_setup_required() -> bool: """Check if initial setup is required (no users configured).""" return len(_get_users()) == 0 def create_user(username: str, password: str, admin: bool = False) -> None: """Create a new user with hashed password.""" settings_file = _get_settings_file() settings = {} if settings_file.exists(): settings = json.loads(settings_file.read_text()) users = settings.get("users", {}) # First user is always admin if len(users) == 0: admin = True users[username] = {"password": _hash_password(password), "admin": admin} settings["users"] = users settings_file.write_text(json.dumps(settings, indent=2)) # Create user directory for per-user settings user_dir = USERS_DIR / username user_dir.mkdir(parents=True, exist_ok=True) def _ensure_one_admin(users: dict[str, dict[str, Any]]) -> None: """Ensure at least one user is admin. Promotes first user if needed.""" if not users or any(u.get("admin") for u in users.values()): return next(iter(users.values()))["admin"] = True def delete_user(username: str) -> bool: """Delete a user. Returns True if deleted, False if not found.""" settings_file = _get_settings_file() if not settings_file.exists(): return False settings = json.loads(settings_file.read_text()) users = settings.get("users", {}) if username not in users: return False del users[username] _ensure_one_admin(users) settings["users"] = users settings_file.write_text(json.dumps(settings, indent=2)) return True def verify_password(username: str, password: str) -> bool: """Verify username and password.""" users = _get_users() user_data = users.get(username, {"password": _hash_password("dummy")}) stored = user_data["password"] valid = _verify_hashed_password(password, stored) return valid and username in users def change_password(username: str, new_password: str) -> bool: """Change a user's password. Returns True if successful.""" settings_file = _get_settings_file() if not settings_file.exists(): return False settings = json.loads(settings_file.read_text()) users = settings.get("users", {}) if username not in users: return False users[username]["password"] = _hash_password(new_password) settings["users"] = users settings_file.write_text(json.dumps(settings, indent=2)) return True def is_admin(username: str) -> bool: """Check if user is an admin.""" users = _get_users() user_data = users.get(username, {}) return user_data.get("admin", False) def set_admin(username: str, admin: bool) -> bool: """Set admin status for a user. Returns True if successful.""" settings_file = _get_settings_file() if not settings_file.exists(): return False settings = json.loads(settings_file.read_text()) users = settings.get("users", {}) if username not in users: return False users[username]["admin"] = admin _ensure_one_admin(users) settings["users"] = users settings_file.write_text(json.dumps(settings, indent=2)) return True def get_users_with_admin() -> list[dict[str, Any]]: """Get list of users with their admin status and limits.""" users = _get_users() return [ { "username": u, "admin": d.get("admin", False), "max_streams_per_source": d.get("max_streams_per_source", {}), "unavailable_groups": d.get("unavailable_groups", []), } for u, d in users.items() ] def get_user_limits(username: str) -> dict[str, Any]: """Get user's stream limits and group restrictions.""" users = _get_users() user_data = users.get(username, {}) return { "max_streams_per_source": user_data.get("max_streams_per_source", {}), "unavailable_groups": user_data.get("unavailable_groups", []), } def set_user_limits( username: str, max_streams_per_source: dict[str, int] | None = None, unavailable_groups: list[str] | None = None, ) -> bool: """Set user's stream limits and/or group restrictions. Returns True if successful.""" settings_file = _get_settings_file() if not settings_file.exists(): return False settings = json.loads(settings_file.read_text()) users = settings.get("users", {}) if username not in users: return False if max_streams_per_source is not None: users[username]["max_streams_per_source"] = max_streams_per_source if unavailable_groups is not None: users[username]["unavailable_groups"] = unavailable_groups settings["users"] = users settings_file.write_text(json.dumps(settings, indent=2)) return True def create_token(payload: dict[str, Any]) -> str: """Create a signed JWT-like token.""" payload = {**payload, "exp": int(time.time()) + TOKEN_EXPIRY} data = json.dumps(payload, separators=(",", ":")).encode() sig = hmac.new(_get_secret_key().encode(), data, hashlib.sha256).hexdigest() return f"{data.hex()}.{sig}" def verify_token(token: str) -> dict[str, Any] | None: """Verify token and return payload, or None if invalid/expired.""" try: data_hex, sig = token.split(".") data = bytes.fromhex(data_hex) expected = hmac.new(_get_secret_key().encode(), data, hashlib.sha256).hexdigest() if not hmac.compare_digest(sig, expected): return None payload = json.loads(data) if payload.get("exp", 0) < time.time(): return None return payload except Exception: return None ================================================ FILE: auth_test.py ================================================ """Tests for auth.py.""" from __future__ import annotations from pathlib import Path from unittest import mock import json import pytest @pytest.fixture def auth_module(tmp_path: Path): """Import auth module with temp settings file.""" import auth # Patch settings files to temp location original_server = auth.SERVER_SETTINGS_FILE original_users = auth.USERS_DIR auth.SERVER_SETTINGS_FILE = tmp_path / "server_settings.json" auth.USERS_DIR = tmp_path / "users" auth.USERS_DIR.mkdir(exist_ok=True) yield auth auth.SERVER_SETTINGS_FILE = original_server auth.USERS_DIR = original_users class TestPasswordHashing: def test_hash_password_creates_salt(self, auth_module): hashed = auth_module._hash_password("mypassword") assert ":" in hashed salt, key = hashed.split(":") assert len(salt) == 32 # 16 bytes hex assert len(key) == 64 # 32 bytes hex def test_hash_password_with_salt_deterministic(self, auth_module): salt = "a" * 32 h1 = auth_module._hash_password("test", salt) h2 = auth_module._hash_password("test", salt) assert h1 == h2 def test_verify_hashed_password_correct(self, auth_module): hashed = auth_module._hash_password("secret") assert auth_module._verify_hashed_password("secret", hashed) def test_verify_hashed_password_wrong(self, auth_module): hashed = auth_module._hash_password("secret") assert not auth_module._verify_hashed_password("wrong", hashed) class TestUserManagement: def test_is_setup_required_no_users(self, auth_module): assert auth_module.is_setup_required() def test_create_user_and_verify(self, auth_module): auth_module.create_user("admin", "password123") assert not auth_module.is_setup_required() assert auth_module.verify_password("admin", "password123") assert not auth_module.verify_password("admin", "wrongpass") assert not auth_module.verify_password("nobody", "password123") class TestTokens: def test_create_and_verify_token(self, auth_module): payload = {"user": "admin", "role": "admin"} token = auth_module.create_token(payload) result = auth_module.verify_token(token) assert result is not None assert result["user"] == "admin" assert result["role"] == "admin" assert "exp" in result def test_token_format(self, auth_module): token = auth_module.create_token({"test": 1}) assert "." in token _, sig = token.split(".") assert len(sig) == 64 # sha256 hex def test_invalid_token_rejected(self, auth_module): assert auth_module.verify_token("invalid") is None assert auth_module.verify_token("abc.def") is None assert auth_module.verify_token("") is None def test_tampered_token_rejected(self, auth_module): token = auth_module.create_token({"user": "admin"}) # Tamper with signature data, _ = token.split(".") tampered = f"{data}.{'0' * 64}" assert auth_module.verify_token(tampered) is None def test_expired_token_rejected(self, auth_module): # Create token with expired time with mock.patch.object(auth_module, "TOKEN_EXPIRY", -1): token = auth_module.create_token({"user": "admin"}) assert auth_module.verify_token(token) is None class TestSecretKey: def test_get_secret_key_generates_and_persists(self, auth_module): key1 = auth_module._get_secret_key() assert len(key1) == 64 # 32 bytes hex # Should return same key key2 = auth_module._get_secret_key() assert key1 == key2 # Should be persisted (uses _get_settings_file which returns legacy if server doesn't exist) settings_file = auth_module._get_settings_file() settings = json.loads(settings_file.read_text()) assert settings["secret_key"] == key1 if __name__ == "__main__": from testing import run_tests run_tests(__file__) ================================================ FILE: cache.py ================================================ """File cache, settings, sources management.""" from __future__ import annotations from collections.abc import Callable from dataclasses import dataclass, field from typing import Any import hashlib import json import logging import pathlib import subprocess import threading import time import urllib.parse log = logging.getLogger(__name__) # =========================================================================== # VAAPI Auto-Detection # =========================================================================== def _get_gpu_vendor() -> str | None: """Detect GPU vendor ID via lspci or sysfs. Returns '8086' (Intel) or '1002' (AMD).""" # Try lspci first (works on bare metal) try: result = subprocess.run(["lspci", "-nn"], capture_output=True, text=True, timeout=5) for line in result.stdout.splitlines(): if "VGA" in line or "Display" in line or "3D" in line: if "[8086:" in line: return "8086" if "[1002:" in line: return "1002" except Exception: pass # Fallback: check sysfs (works in containers) drm_path = pathlib.Path("/sys/class/drm") if drm_path.exists(): for card in drm_path.iterdir(): if card.name.startswith("card") and card.name[4:].isdigit(): vendor_file = card / "device" / "vendor" if vendor_file.exists(): vendor = vendor_file.read_text().strip().replace("0x", "") if vendor in ("8086", "1002"): return vendor return None def _detect_vaapi_device() -> str | None: """Auto-detect the VAAPI render device. Returns '/dev/dri/renderD128' or None.""" render = pathlib.Path("/dev/dri/renderD128") return str(render) if render.exists() else None def _detect_libva_driver() -> str | None: """Auto-detect LIBVA driver name. Returns 'iHD', 'i965', 'radeonsi', or None.""" vendor = _get_gpu_vendor() if vendor == "8086": # iHD for Intel Gen8+ (Broadwell 2014+), supports Xe driver # Fall back to i965 for older Intel GPUs dri_path = _detect_dri_path() if dri_path and pathlib.Path(f"{dri_path}/iHD_drv_video.so").exists(): return "iHD" return "i965" if vendor == "1002": return "radeonsi" return None def _detect_dri_path() -> str | None: """Auto-detect the system DRI drivers path. Returns path like '/usr/lib/x86_64-linux-gnu/dri' or None. """ # Check common locations in order of preference candidates = [ "/usr/lib/x86_64-linux-gnu/dri", # Debian/Ubuntu "/usr/lib64/dri", # Fedora/RHEL "/usr/lib/dri", # Arch ] for path in candidates: if pathlib.Path(path).is_dir(): return path return None # Cached detection results (computed once at import) VAAPI_DEVICE = _detect_vaapi_device() LIBVA_DRIVER = _detect_libva_driver() DRI_PATH = _detect_dri_path() APP_DIR = pathlib.Path(__file__).parent # Use old "cache" if it exists (backwards compat), otherwise ".cache" _OLD_CACHE = APP_DIR / "cache" CACHE_DIR = _OLD_CACHE if _OLD_CACHE.exists() else APP_DIR / ".cache" CACHE_DIR.mkdir(exist_ok=True) SERVER_SETTINGS_FILE = CACHE_DIR / "server_settings.json" USERS_DIR = CACHE_DIR / "users" USERS_DIR.mkdir(exist_ok=True) LOGOS_DIR = CACHE_DIR / "logos" LOGOS_DIR.mkdir(exist_ok=True) # Cache TTLs in seconds LIVE_CACHE_TTL = 2 * 3600 # 2 hours EPG_CACHE_TTL = 6 * 3600 # 6 hours VOD_CACHE_TTL = 12 * 3600 # 12 hours SERIES_CACHE_TTL = 12 * 3600 # 12 hours INFO_CACHE_TTL = 7 * 24 * 3600 # 7 days max for series/movie info INFO_CACHE_STALE = 24 * 3600 # Refresh in background after 24 hours LOGO_CACHE_TTL = 7 * 24 * 3600 # 7 days for logos (server-side) LOGO_BROWSER_TTL = 24 * 3600 # 1 day for browser cache (re-validates before server expires) LOGO_MAX_SIZE = 1024 * 1024 # 1MB max logo size # In-memory cache _cache: dict[str, Any] = {} _cache_lock = threading.Lock() def _parse_json_file(path: str) -> tuple[Any, float] | None: """Parse JSON file - runs in separate process to avoid GIL blocking.""" try: with open(path) as f: data = json.load(f) return data.get("data"), data.get("timestamp", 0) except Exception: return None def load_file_cache(name: str, use_process: bool = False) -> tuple[Any, float] | None: """Load cached data from file. Returns (data, timestamp) or None. Args: name: Cache file name (without .json extension) use_process: If True, parse in separate process to avoid GIL blocking """ path = CACHE_DIR / f"{name}.json" if not path.exists(): return None if use_process: import concurrent.futures with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: future = executor.submit(_parse_json_file, str(path)) return future.result(timeout=60) try: data = json.loads(path.read_text()) return data.get("data"), data.get("timestamp", 0) except Exception: return None def save_file_cache(name: str, data: Any) -> None: """Save data to cache file with current timestamp.""" path = CACHE_DIR / f"{name}.json" path.write_text(json.dumps({"data": data, "timestamp": time.time()})) def clear_all_caches() -> None: """Clear memory cache except EPG (file cache preserved for restart).""" with _cache_lock: epg = _cache.get("epg") _cache.clear() if epg: _cache["epg"] = epg def clear_all_file_caches() -> int: """Clear all data file caches (live, vod, series). Returns count deleted.""" cache_files = ["live_data.json", "vod_data.json", "series_data.json"] deleted = 0 for name in cache_files: path = CACHE_DIR / name if path.exists(): path.unlink() deleted += 1 # Also clear memory cache clear_all_caches() return deleted def get_cache() -> dict[str, Any]: """Get reference to memory cache.""" return _cache def get_cache_lock() -> threading.Lock: """Get cache lock.""" return _cache_lock def _sanitize_name(name: str) -> str: """Sanitize a name for use as a directory/file name.""" # Remove path traversal and special chars name = name.replace("..", "").replace("/", "_").replace("\\", "_") name = "".join(c for c in name if c.isalnum() or c in "-_ ") return name[:224] or "default" def _url_to_filename(url: str) -> str: """Derive a readable filename from URL with hash suffix to avoid collisions.""" # Always include hash suffix to avoid collisions url_hash = hashlib.md5(url.encode()).hexdigest()[:8] parsed = urllib.parse.urlparse(url) path = parsed.path.rstrip("/") if path: # Get last path component name = path.split("/")[-1] # Strip extension, we'll add our own if "." in name: name = name.rsplit(".", 1)[0] name = _sanitize_name(name) if name and len(name) >= 2: return f"{name}_{url_hash}" return url_hash def get_cached_logo(source_name: str, url: str) -> pathlib.Path | None: """Get cached logo path if valid and not expired. Returns None if not cached.""" safe_source = _sanitize_name(source_name) filename = _url_to_filename(url) source_dir = LOGOS_DIR / safe_source if not source_dir.exists(): return None # Look for file with any extension for ext in ("png", "jpg", "jpeg", "gif", "webp", "svg"): path = source_dir / f"{filename}.{ext}" if path.exists(): age = time.time() - path.stat().st_mtime if age < LOGO_CACHE_TTL: return path # Expired, delete it path.unlink(missing_ok=True) return None def save_logo(source_name: str, url: str, data: bytes, content_type: str) -> pathlib.Path: """Save logo to cache. Returns the saved path.""" safe_source = _sanitize_name(source_name) filename = _url_to_filename(url) source_dir = LOGOS_DIR / safe_source source_dir.mkdir(parents=True, exist_ok=True) # Determine extension from content-type ext_map = { "image/png": "png", "image/jpeg": "jpg", "image/gif": "gif", "image/webp": "webp", "image/svg+xml": "svg", } ext = ext_map.get(content_type.split(";")[0].strip(), "png") path = source_dir / f"{filename}.{ext}" # Atomic write: write to temp file then rename tmp = path.with_suffix(".tmp") tmp.write_bytes(data) tmp.rename(path) return path def get_cached_info(cache_key: str, fetch_fn: Callable[[], Any], force: bool = False) -> Any: """Get info from memory cache, file cache, or fetch. Stale-while-revalidate.""" cached = load_file_cache(cache_key) cached_data, cached_ts = cached if cached else (None, 0) age = time.time() - cached_ts if force and cached_data: _cache.pop(cache_key, None) cached_data = None if cache_key in _cache and not force: if cached_ts and age > INFO_CACHE_STALE: def bg_refresh() -> None: try: data = fetch_fn() _cache[cache_key] = data save_file_cache(cache_key, data) log.info("Background refreshed %s", cache_key) except Exception as e: log.warning("Background refresh failed for %s: %s", cache_key, e) threading.Thread(target=bg_refresh, daemon=True).start() return _cache[cache_key] if cached_data and age < INFO_CACHE_TTL: _cache[cache_key] = cached_data if age > INFO_CACHE_STALE: def bg_refresh() -> None: try: data = fetch_fn() _cache[cache_key] = data save_file_cache(cache_key, data) log.info("Background refreshed %s", cache_key) except Exception as e: log.warning("Background refresh failed for %s: %s", cache_key, e) threading.Thread(target=bg_refresh, daemon=True).start() return cached_data data = fetch_fn() _cache[cache_key] = data save_file_cache(cache_key, data) return data def _test_encoder(cmd: list[str], timeout: int = 5, env: dict | None = None) -> tuple[bool, str]: """Test if an encoder works. Returns (success, error_message).""" try: run_env = None if env: import os run_env = os.environ.copy() run_env.update(env) result = subprocess.run(cmd, capture_output=True, timeout=timeout, env=run_env) if result.returncode == 0: return True, "" stderr = result.stderr.decode(errors="replace").strip() # Extract the most relevant error line for line in stderr.split("\n"): if line and not line.startswith("["): return False, line return False, stderr if stderr else "unknown error" except subprocess.TimeoutExpired: return False, "timeout" except FileNotFoundError: return False, "ffmpeg not found" except Exception as e: return False, str(e) def detect_encoders() -> dict[str, bool]: """Detect available FFmpeg H.264 encoders by testing actual hardware.""" log.info("Detecting hardware encoders...") encoders = { "nvenc": False, "amf": False, "qsv": False, "vaapi": False, } # Test input: 1 frame of 256x256 black (64x64 is below NVENC minimum on newer GPUs) test_input = ["-f", "lavfi", "-i", "color=black:s=256x256:d=0.04", "-frames:v", "1"] base_cmd = ["ffmpeg", "-hide_banner", "-loglevel", "error", "-y"] null_out = ["-f", "null", "-"] # NVENC: try nvenc directly ok, err = _test_encoder(base_cmd + test_input + ["-c:v", "h264_nvenc"] + null_out) encoders["nvenc"] = ok if ok: log.info(" NVENC (h264_nvenc): available") else: log.info(" NVENC (h264_nvenc): unavailable - %s", err) # AMF: try amf directly ok, err = _test_encoder(base_cmd + test_input + ["-c:v", "h264_amf"] + null_out) encoders["amf"] = ok if ok: log.info(" AMF (h264_amf): available") else: log.info(" AMF (h264_amf): unavailable - %s", err) # QSV: needs hwaccel init ok, err = _test_encoder( base_cmd + ["-hwaccel", "qsv", "-hwaccel_output_format", "qsv"] + test_input + ["-c:v", "h264_qsv"] + null_out ) encoders["qsv"] = ok if ok: log.info(" QSV (h264_qsv): available") else: log.info(" QSV (h264_qsv): unavailable - %s", err) # VA-API: needs device, hwupload, and driver env vars for hybrid GPU systems vaapi_baseline_only = False if VAAPI_DEVICE and LIBVA_DRIVER and DRI_PATH: vaapi_env = { "LIBVA_DRIVER_NAME": LIBVA_DRIVER, "LIBVA_DRIVERS_PATH": DRI_PATH, } # Try high profile first, fall back to constrained_baseline for older GPUs ok, err = _test_encoder( base_cmd + ["-init_hw_device", f"vaapi=va:{VAAPI_DEVICE}"] + test_input + ["-vf", "format=nv12,hwupload", "-c:v", "h264_vaapi"] + null_out, env=vaapi_env, ) if not ok: # Some older AMD GPUs (GCN 1.0) only support baseline profile ok, err = _test_encoder( base_cmd + ["-init_hw_device", f"vaapi=va:{VAAPI_DEVICE}"] + test_input + [ "-vf", "format=nv12,hwupload", "-c:v", "h264_vaapi", "-profile:v", "constrained_baseline", ] + null_out, env=vaapi_env, ) if ok: vaapi_baseline_only = True encoders["vaapi"] = ok encoders["vaapi_baseline_only"] = vaapi_baseline_only if ok: profile_note = " (baseline only)" if vaapi_baseline_only else "" log.info( " VAAPI (h264_vaapi): available%s (device=%s, driver=%s)", profile_note, VAAPI_DEVICE, LIBVA_DRIVER, ) else: log.info(" VAAPI (h264_vaapi): unavailable - %s", err) else: log.info(" VAAPI (h264_vaapi): unavailable - no Intel/AMD GPU detected") return encoders AVAILABLE_ENCODERS = detect_encoders() def refresh_encoders() -> dict[str, bool]: """Re-detect available encoders and update the cache.""" global AVAILABLE_ENCODERS AVAILABLE_ENCODERS = detect_encoders() return AVAILABLE_ENCODERS def _default_encoder() -> str: """Return first available encoder option. Preference order: nvenc > amf > qsv > vaapi > software For nvenc/amf, prefer +vaapi fallback if VAAPI is available. """ if AVAILABLE_ENCODERS.get("nvenc"): return "nvenc+vaapi" if AVAILABLE_ENCODERS.get("vaapi") else "nvenc+software" if AVAILABLE_ENCODERS.get("amf"): return "amf+vaapi" if AVAILABLE_ENCODERS.get("vaapi") else "amf+software" if AVAILABLE_ENCODERS.get("qsv"): return "qsv" if AVAILABLE_ENCODERS.get("vaapi"): return "vaapi" return "software" @dataclass(slots=True) class Source: id: str name: str type: str # "xtream", "m3u", or "epg" url: str username: str = "" password: str = "" epg_timeout: int = 120 # seconds epg_schedule: list[str] = field(default_factory=list) # ["03:00", "15:00"] epg_enabled: bool = True # Whether to fetch EPG from this source epg_url: str = "" # EPG URL (auto-detected from M3U/Xtream, or manual override) deinterlace_fallback: bool = True # Deinterlace when probe is skipped (for OTA/HDHomeRun) max_streams: int = 0 # Max concurrent streams from this source (0 = unlimited) def load_server_settings() -> dict[str, Any]: """Load server-wide settings.""" if SERVER_SETTINGS_FILE.exists(): data: dict[str, Any] = json.loads(SERVER_SETTINGS_FILE.read_text()) else: data = {} data.setdefault("transcode_mode", "auto") # Migrate old transcode_hw values to new format old_hw = data.get("transcode_hw", "") if old_hw == "nvidia": data["transcode_hw"] = ( "nvenc+vaapi" if AVAILABLE_ENCODERS.get("vaapi") else "nvenc+software" ) elif old_hw == "intel": data["transcode_hw"] = "qsv" # "vaapi" and "software" remain unchanged data.setdefault("transcode_hw", _default_encoder()) data.setdefault("vod_transcode_cache_mins", 60) # 0 = no caching (dead sessions cleaned immediately) data.setdefault("live_transcode_cache_secs", 0) data.setdefault("live_dvr_mins", 0) # 0 = disabled (default 30 sec buffer) data.setdefault("transcode_dir", "") # Empty = system temp dir data.setdefault("probe_live", True) data.setdefault("probe_movies", True) data.setdefault("probe_series", False) data.setdefault("sources", []) data.setdefault("users", {}) data.setdefault("user_agent_preset", "tivimate") data.setdefault("user_agent_custom", "") return data def save_server_settings(settings: dict[str, Any]) -> None: """Save server-wide settings.""" SERVER_SETTINGS_FILE.write_text(json.dumps(settings, indent=2)) def _validate_username(username: str) -> None: """Validate username to prevent path traversal and length attacks.""" if ( not username or len(username) > 64 or ".." in username or "/" in username or "\\" in username ): raise ValueError("Invalid username") def load_user_settings(username: str) -> dict[str, Any]: """Load per-user settings.""" _validate_username(username) user_file = USERS_DIR / username / "settings.json" if user_file.exists(): data = json.loads(user_file.read_text()) else: data = {} data.setdefault("guide_filter", []) data.setdefault("captions_enabled", True) data.setdefault("watch_history", {}) data.setdefault("favorites", {"series": {}, "movies": {}}) data.setdefault("cc_lang", "") data.setdefault("cc_style", {}) data.setdefault("cast_host", "") return data def save_user_settings(username: str, settings: dict[str, Any]) -> None: """Save per-user settings.""" _validate_username(username) user_dir = USERS_DIR / username user_dir.mkdir(exist_ok=True) (user_dir / "settings.json").write_text(json.dumps(settings, indent=2)) def get_watch_position(username: str, stream_url: str) -> dict[str, Any] | None: """Get saved watch position for a stream. Returns None if not found or >=95% watched.""" settings = load_user_settings(username) history = settings.get("watch_history", {}) entry = history.get(stream_url) if not entry: return None # Reset if >=95% watched if entry.get("duration", 0) > 0: pct = entry.get("position", 0) / entry["duration"] if pct >= 0.95: return None return entry def save_watch_position(username: str, stream_url: str, position: float, duration: float) -> None: """Save watch position for a stream.""" settings = load_user_settings(username) history = settings.setdefault("watch_history", {}) history[stream_url] = { "position": position, "duration": duration, "updated": time.time(), } # Keep only last 200 entries if len(history) > 200: sorted_entries = sorted(history.items(), key=lambda x: x[1].get("updated", 0), reverse=True) settings["watch_history"] = dict(sorted_entries[:200]) save_user_settings(username, settings) def get_sources() -> list[Source]: """Get list of configured sources.""" settings = load_server_settings() return [Source(**s) for s in settings.get("sources", [])] def update_source_epg_url(source_id: str, epg_url: str) -> None: """Update a source's epg_url in settings (only if currently empty).""" if not epg_url: return settings = load_server_settings() for s in settings.get("sources", []): if s["id"] == source_id and not s.get("epg_url"): s["epg_url"] = epg_url save_server_settings(settings) log.info("Saved EPG URL for source %s: %s", source_id, epg_url) break ================================================ FILE: cache_test.py ================================================ """Tests for cache.py.""" from __future__ import annotations from pathlib import Path from unittest import mock import subprocess import pytest import cache @pytest.fixture def cache_module(tmp_path: Path): """Import cache module with temp directories.""" # Patch paths to temp locations original_server_settings = cache.SERVER_SETTINGS_FILE original_users_dir = cache.USERS_DIR original_cache_dir = cache.CACHE_DIR cache.SERVER_SETTINGS_FILE = tmp_path / "server_settings.json" cache.USERS_DIR = tmp_path / "users" cache.USERS_DIR.mkdir(exist_ok=True) cache.CACHE_DIR = tmp_path / "cache" cache.CACHE_DIR.mkdir(exist_ok=True) # Clear memory cache cache._cache.clear() yield cache cache.SERVER_SETTINGS_FILE = original_server_settings cache.USERS_DIR = original_users_dir cache.CACHE_DIR = original_cache_dir cache._cache.clear() class TestFileCache: def test_save_and_load_file_cache(self, cache_module): cache_module.save_file_cache("test", {"key": "value"}) result = cache_module.load_file_cache("test") assert result is not None data, ts = result assert data == {"key": "value"} assert ts > 0 def test_load_nonexistent_cache(self, cache_module): assert cache_module.load_file_cache("nonexistent") is None def test_load_corrupted_cache(self, cache_module): path = cache_module.CACHE_DIR / "corrupted.json" path.write_text("not valid json") assert cache_module.load_file_cache("corrupted") is None class TestMemoryCache: def test_get_cache_returns_reference(self, cache_module): cache = cache_module.get_cache() cache["test"] = 123 assert cache_module.get_cache()["test"] == 123 def test_clear_all_caches_preserves_epg(self, cache_module): cache = cache_module.get_cache() cache["epg"] = {"data": "epg"} cache["live"] = {"data": "live"} cache_module.clear_all_caches() assert "epg" in cache assert "live" not in cache class TestCachedInfo: def test_get_cached_info_calls_fetch(self, cache_module): fetch_fn = mock.Mock(return_value={"result": 42}) result = cache_module.get_cached_info("test_key", fetch_fn) assert result == {"result": 42} fetch_fn.assert_called_once() def test_get_cached_info_uses_memory_cache(self, cache_module): fetch_fn = mock.Mock(return_value={"result": 1}) cache_module.get_cached_info("key1", fetch_fn) cache_module.get_cached_info("key1", fetch_fn) # Only called once - second call uses memory cache fetch_fn.assert_called_once() def test_get_cached_info_force_bypasses_memory(self, cache_module): fetch_fn = mock.Mock(return_value={"result": 1}) cache_module.get_cached_info("key2", fetch_fn) cache_module.get_cached_info("key2", fetch_fn, force=True) assert fetch_fn.call_count == 2 class TestSettings: def test_load_settings_defaults(self, cache_module): settings = cache_module.load_server_settings() assert settings["sources"] == [] assert settings["transcode_mode"] == "auto" assert settings["transcode_hw"] in ( "nvenc+vaapi", "nvenc+software", "amf+vaapi", "amf+software", "qsv", "vaapi", "software", ) assert settings["probe_movies"] is True def test_save_and_load_settings(self, cache_module): settings = {"sources": [{"id": "s1", "name": "Test"}], "custom": True} cache_module.save_server_settings(settings) loaded = cache_module.load_server_settings() assert loaded["sources"] == [{"id": "s1", "name": "Test"}] assert loaded["custom"] is True class TestUserSettings: def test_load_user_settings_defaults(self, cache_module): settings = cache_module.load_user_settings("testuser") assert settings["guide_filter"] == [] assert settings["captions_enabled"] is True assert settings["watch_history"] == {} def test_save_and_load_user_settings(self, cache_module): settings = {"guide_filter": ["cat1", "cat2"], "captions_enabled": False} cache_module.save_user_settings("testuser", settings) loaded = cache_module.load_user_settings("testuser") assert loaded["guide_filter"] == ["cat1", "cat2"] assert loaded["captions_enabled"] is False def test_watch_position_save_and_get(self, cache_module): cache_module.save_watch_position("user1", "http://video.url", 120.5, 3600.0) entry = cache_module.get_watch_position("user1", "http://video.url") assert entry is not None assert entry["position"] == 120.5 assert entry["duration"] == 3600.0 def test_watch_position_resets_at_95_percent(self, cache_module): # Save at 96% watched cache_module.save_watch_position("user1", "http://video.url", 960.0, 1000.0) entry = cache_module.get_watch_position("user1", "http://video.url") assert entry is None # Should be reset class TestSource: def test_source_dataclass(self, cache_module): source = cache_module.Source( id="test", name="Test Source", type="xtream", url="http://example.com", ) assert source.id == "test" assert source.username == "" assert source.epg_timeout == 120 assert source.epg_enabled is True def test_get_sources_empty(self, cache_module): sources = cache_module.get_sources() assert sources == [] def test_get_sources_from_settings(self, cache_module): settings = { "sources": [ { "id": "s1", "name": "Source 1", "type": "m3u", "url": "http://example.com/playlist.m3u", } ] } cache_module.save_server_settings(settings) sources = cache_module.get_sources() assert len(sources) == 1 assert sources[0].id == "s1" assert sources[0].type == "m3u" class TestUpdateSourceEpgUrl: def test_update_source_epg_url(self, cache_module): settings = {"sources": [{"id": "s1", "name": "S1", "type": "m3u", "url": "http://x"}]} cache_module.save_server_settings(settings) cache_module.update_source_epg_url("s1", "http://epg.example.com") loaded = cache_module.load_server_settings() assert loaded["sources"][0]["epg_url"] == "http://epg.example.com" def test_update_source_epg_url_not_overwrite(self, cache_module): settings = { "sources": [ { "id": "s1", "name": "S1", "type": "m3u", "url": "http://x", "epg_url": "http://existing", } ] } cache_module.save_server_settings(settings) cache_module.update_source_epg_url("s1", "http://new") loaded = cache_module.load_server_settings() assert loaded["sources"][0]["epg_url"] == "http://existing" def test_update_source_epg_url_empty_noop(self, cache_module): settings = {"sources": [{"id": "s1", "name": "S1", "type": "m3u", "url": "http://x"}]} cache_module.save_server_settings(settings) cache_module.update_source_epg_url("s1", "") loaded = cache_module.load_server_settings() assert "epg_url" not in loaded["sources"][0] class TestEncoderDetection: """Tests for encoder detection functions.""" def test_test_encoder_success(self): """Test _test_encoder returns (True, '') on successful command.""" with mock.patch("subprocess.run") as mock_run: mock_run.return_value = mock.Mock(returncode=0) ok, err = cache._test_encoder(["echo", "test"]) assert ok is True assert err == "" mock_run.assert_called_once() def test_test_encoder_failure(self): """Test _test_encoder returns (False, error) on non-zero return code.""" with mock.patch("subprocess.run") as mock_run: mock_run.return_value = mock.Mock(returncode=1, stderr=b"encoder not found") ok, err = cache._test_encoder(["false"]) assert ok is False assert "encoder not found" in err def test_test_encoder_timeout(self): """Test _test_encoder returns (False, 'timeout') on timeout.""" with mock.patch("subprocess.run") as mock_run: mock_run.side_effect = subprocess.TimeoutExpired(cmd=["test"], timeout=5) ok, err = cache._test_encoder(["sleep", "100"], timeout=5) assert ok is False assert err == "timeout" def test_test_encoder_exception(self): """Test _test_encoder returns (False, error) on exception.""" with mock.patch("subprocess.run") as mock_run: mock_run.side_effect = FileNotFoundError("ffmpeg not found") ok, err = cache._test_encoder(["nonexistent_command"]) assert ok is False assert err == "ffmpeg not found" def test_detect_encoders_all_available(self): """Test detect_encoders when all hardware is available.""" with ( mock.patch.object(cache, "_test_encoder", return_value=(True, "")), mock.patch.object(cache, "VAAPI_DEVICE", "/dev/dri/renderD128"), mock.patch.object(cache, "LIBVA_DRIVER", "i965"), mock.patch.object(cache, "DRI_PATH", "/usr/lib/x86_64-linux-gnu/dri"), ): result = cache.detect_encoders() assert result == { "nvenc": True, "amf": True, "qsv": True, "vaapi": True, "vaapi_baseline_only": False, } def test_detect_encoders_none_available(self): """Test detect_encoders when no hardware is available.""" with mock.patch.object(cache, "_test_encoder", return_value=(False, "not found")): result = cache.detect_encoders() assert result == { "nvenc": False, "amf": False, "qsv": False, "vaapi": False, } def test_detect_encoders_partial(self): """Test detect_encoders with mixed hardware availability.""" def mock_test(cmd, timeout=5, env=None): # Return True only for VAAPI if "h264_vaapi" in cmd: return True, "" return False, "not available" with ( mock.patch.object(cache, "_test_encoder", side_effect=mock_test), mock.patch.object(cache, "VAAPI_DEVICE", "/dev/dri/renderD128"), mock.patch.object(cache, "LIBVA_DRIVER", "i965"), mock.patch.object(cache, "DRI_PATH", "/usr/lib/x86_64-linux-gnu/dri"), ): result = cache.detect_encoders() assert result["nvenc"] is False assert result["amf"] is False assert result["qsv"] is False assert result["vaapi"] is True def test_detect_encoders_nvenc_only(self): """Test detect_encoders when only NVENC is available.""" def mock_test(cmd, timeout=5, env=None): if "h264_nvenc" in cmd: return True, "" return False, "not available" with mock.patch.object(cache, "_test_encoder", side_effect=mock_test): result = cache.detect_encoders() assert result["nvenc"] is True assert result["amf"] is False assert result["qsv"] is False assert result["vaapi"] is False def test_detect_encoders_vaapi_command_structure(self): """Test detect_encoders passes correct VAAPI command structure when GPU detected.""" captured_cmds = [] captured_envs = [] def capture_cmd(cmd, timeout=5, env=None): captured_cmds.append(cmd) captured_envs.append(env) return False, "test" # Mock auto-detected VAAPI device with ( mock.patch.object(cache, "_test_encoder", side_effect=capture_cmd), mock.patch.object(cache, "VAAPI_DEVICE", "/dev/dri/renderD128"), mock.patch.object(cache, "LIBVA_DRIVER", "i965"), mock.patch.object(cache, "DRI_PATH", "/usr/lib/x86_64-linux-gnu/dri"), ): cache.detect_encoders() # Find VAAPI commands (now 2: High profile first, then baseline fallback) vaapi_cmds = [c for c in captured_cmds if "h264_vaapi" in c] assert len(vaapi_cmds) == 2 # First command: High profile (default, no -profile:v) assert "-init_hw_device" in vaapi_cmds[0] assert "hwupload" in " ".join(vaapi_cmds[0]) assert "constrained_baseline" not in " ".join(vaapi_cmds[0]) # Second command: constrained_baseline fallback assert "constrained_baseline" in " ".join(vaapi_cmds[1]) def test_detect_encoders_qsv_command_structure(self): """Test detect_encoders passes correct QSV command structure.""" captured_cmds = [] def capture_cmd(cmd, timeout=5, env=None): captured_cmds.append(cmd) return False, "test" with mock.patch.object(cache, "_test_encoder", side_effect=capture_cmd): cache.detect_encoders() # Find QSV command qsv_cmd = [c for c in captured_cmds if "h264_qsv" in c][0] assert "-hwaccel" in qsv_cmd assert "qsv" in qsv_cmd assert "-hwaccel_output_format" in qsv_cmd def test_refresh_encoders_updates_global(self): """Test refresh_encoders updates AVAILABLE_ENCODERS.""" original = cache.AVAILABLE_ENCODERS.copy() with mock.patch.object( cache, "detect_encoders", return_value={"nvenc": True, "amf": True, "qsv": True, "vaapi": True}, ): result = cache.refresh_encoders() assert cache.AVAILABLE_ENCODERS == { "nvenc": True, "amf": True, "qsv": True, "vaapi": True, } assert result == cache.AVAILABLE_ENCODERS # Restore original cache.AVAILABLE_ENCODERS = original def test_default_encoder_prefers_nvenc_with_vaapi(self): """Test _default_encoder prefers NVENC+VAAPI when both available.""" original = cache.AVAILABLE_ENCODERS.copy() cache.AVAILABLE_ENCODERS = { "nvenc": True, "amf": True, "qsv": True, "vaapi": True, } try: assert cache._default_encoder() == "nvenc+vaapi" finally: cache.AVAILABLE_ENCODERS = original def test_default_encoder_nvenc_without_vaapi(self): """Test _default_encoder uses NVENC+software when VAAPI unavailable.""" original = cache.AVAILABLE_ENCODERS.copy() cache.AVAILABLE_ENCODERS = { "nvenc": True, "amf": False, "qsv": False, "vaapi": False, } try: assert cache._default_encoder() == "nvenc+software" finally: cache.AVAILABLE_ENCODERS = original def test_default_encoder_falls_back_to_amf(self): """Test _default_encoder falls back to AMF when NVENC unavailable.""" original = cache.AVAILABLE_ENCODERS.copy() cache.AVAILABLE_ENCODERS = { "nvenc": False, "amf": True, "qsv": True, "vaapi": True, } try: assert cache._default_encoder() == "amf+vaapi" finally: cache.AVAILABLE_ENCODERS = original def test_default_encoder_falls_back_to_qsv(self): """Test _default_encoder falls back to QSV when NVENC/AMF unavailable.""" original = cache.AVAILABLE_ENCODERS.copy() cache.AVAILABLE_ENCODERS = { "nvenc": False, "amf": False, "qsv": True, "vaapi": True, } try: assert cache._default_encoder() == "qsv" finally: cache.AVAILABLE_ENCODERS = original def test_default_encoder_falls_back_to_vaapi(self): """Test _default_encoder falls back to VAAPI when NVENC/AMF/QSV unavailable.""" original = cache.AVAILABLE_ENCODERS.copy() cache.AVAILABLE_ENCODERS = { "nvenc": False, "amf": False, "qsv": False, "vaapi": True, } try: assert cache._default_encoder() == "vaapi" finally: cache.AVAILABLE_ENCODERS = original def test_default_encoder_falls_back_to_software(self): """Test _default_encoder falls back to software as last resort.""" original = cache.AVAILABLE_ENCODERS.copy() cache.AVAILABLE_ENCODERS = { "nvenc": False, "amf": False, "qsv": False, "vaapi": False, } try: assert cache._default_encoder() == "software" finally: cache.AVAILABLE_ENCODERS = original class TestLogoCache: """Tests for logo caching functions.""" def test_sanitize_name_removes_path_traversal(self): assert ".." not in cache._sanitize_name("../../../etc/passwd") assert "/" not in cache._sanitize_name("foo/bar") assert "\\" not in cache._sanitize_name("foo\\bar") def test_sanitize_name_keeps_safe_chars(self): assert cache._sanitize_name("my-source_123") == "my-source_123" assert cache._sanitize_name("Source Name") == "Source Name" def test_sanitize_name_truncates_long_names(self): long_name = "a" * 300 result = cache._sanitize_name(long_name) assert len(result) == 224 def test_sanitize_name_empty_returns_default(self): assert cache._sanitize_name("") == "default" assert cache._sanitize_name("!!!") == "default" def test_url_to_filename_extracts_name(self): result = cache._url_to_filename("http://example.com/logos/channel1.png") assert result.startswith("channel1_") assert len(result) == len("channel1_") + 8 # name + underscore + 8 char hash def test_url_to_filename_strips_extension(self): result = cache._url_to_filename("http://example.com/logo.png") assert not result.endswith(".png") assert result.startswith("logo_") def test_url_to_filename_hash_differs_by_url(self): r1 = cache._url_to_filename("http://example.com/a/logo.png") r2 = cache._url_to_filename("http://example.com/b/logo.png") # Same base name but different hashes assert r1.startswith("logo_") assert r2.startswith("logo_") assert r1 != r2 def test_url_to_filename_fallback_to_hash(self): result = cache._url_to_filename("http://example.com/") assert len(result) == 8 # Just the hash def test_save_and_get_cached_logo(self, cache_module, tmp_path): cache_module.LOGOS_DIR = tmp_path / "logos" cache_module.LOGOS_DIR.mkdir() # Save a logo data = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100 # Fake PNG path = cache_module.save_logo( "TestSource", "http://example.com/logo.png", data, "image/png" ) assert path.exists() assert path.suffix == ".png" assert path.read_bytes() == data # Get cached logo cached = cache_module.get_cached_logo("TestSource", "http://example.com/logo.png") assert cached == path def test_get_cached_logo_returns_none_when_missing(self, cache_module, tmp_path): cache_module.LOGOS_DIR = tmp_path / "logos" cache_module.LOGOS_DIR.mkdir() cached = cache_module.get_cached_logo("NoSource", "http://missing.com/logo.png") assert cached is None def test_get_cached_logo_expires(self, cache_module, tmp_path): import time cache_module.LOGOS_DIR = tmp_path / "logos" cache_module.LOGOS_DIR.mkdir() # Save a logo data = b"\x89PNG" + b"\x00" * 100 path = cache_module.save_logo("TestSource", "http://example.com/old.png", data, "image/png") # Backdate the file old_time = time.time() - cache_module.LOGO_CACHE_TTL - 100 import os os.utime(path, (old_time, old_time)) # Should be expired cached = cache_module.get_cached_logo("TestSource", "http://example.com/old.png") assert cached is None assert not path.exists() # Should be deleted def test_save_logo_content_type_mapping(self, cache_module, tmp_path): cache_module.LOGOS_DIR = tmp_path / "logos" cache_module.LOGOS_DIR.mkdir() data = b"test" assert cache_module.save_logo("s", "http://a.com/1", data, "image/jpeg").suffix == ".jpg" assert cache_module.save_logo("s", "http://a.com/2", data, "image/gif").suffix == ".gif" assert cache_module.save_logo("s", "http://a.com/3", data, "image/webp").suffix == ".webp" assert cache_module.save_logo("s", "http://a.com/4", data, "image/svg+xml").suffix == ".svg" assert cache_module.save_logo("s", "http://a.com/5", data, "unknown/type").suffix == ".png" if __name__ == "__main__": from testing import run_tests run_tests(__file__) ================================================ FILE: docker-compose.yml ================================================ # Docker Compose for neTV # # Build: # docker compose build # Default (optimized FFmpeg) # FFMPEG_IMAGE=ubuntu:24.04 docker compose build # Alternative (apt FFmpeg) # # Run: # docker compose up -d # Auto-detects hardware (Intel/AMD) # docker compose --profile nvidia up -d # NVIDIA GPU (driver 580+, Turing+) # # NVIDIA with older drivers/GPUs (see README for driver/compute compatibility): # FFMPEG_IMAGE=ghcr.io/jvdillon/netv-ffmpeg:cuda12.8 docker compose --profile nvidia up -d # # Local CUDA 12.4 FFmpeg build (from Dockerfile.ffmpeg): # docker build --progress plain --build-arg NVIDIA=cuda:12.4 --build-arg FFMPEG_BASE_IMAGE=ubuntu:22.04 -f Dockerfile.ffmpeg -t netv-ffmpeg:cuda12.4 . # FFMPEG_IMAGE=netv-ffmpeg:cuda12.4 docker compose --profile nvidia build # FFMPEG_IMAGE=netv-ffmpeg:cuda12.4 docker compose --profile nvidia up -d # # No GPU or /dev/dri? Comment out the 'devices' section below. services: netv: build: context: . args: FFMPEG_IMAGE: ${FFMPEG_IMAGE:-ghcr.io/jvdillon/netv-ffmpeg:latest} image: netv container_name: netv ports: - "${NETV_PORT:-8000}:8000" environment: - NETV_PORT=8000 - NETV_HTTPS=${NETV_HTTPS:-} - LOG_LEVEL=INFO # DEBUG for verbose logging volumes: - ./cache:/app/cache - /etc/localtime:/etc/localtime:ro # Use host timezone for EPG # For HTTPS, also mount your certificates: # - /etc/letsencrypt:/etc/letsencrypt:ro # Use system RAM instead of using local storage to store transcodes # - /dev/shm:/tmp restart: unless-stopped # Hardware acceleration (Intel/AMD) - comment out if no /dev/dri devices: - /dev/dri:/dev/dri # NVIDIA GPU: docker compose --profile nvidia up -d netv-nvidia: extends: service: netv container_name: netv profiles: - nvidia deploy: resources: reservations: devices: - driver: nvidia count: all capabilities: [gpu] ================================================ FILE: entrypoint-ai_upscale.sh ================================================ #!/bin/sh set -e # Entrypoint for AI Upscale image # # Same as base entrypoint, plus: # - Auto-builds TensorRT engines on first start if missing # Fix cache directory ownership mkdir -p /app/cache if [ "$(stat -c '%U' /app/cache)" != "netv" ]; then chown -R netv:netv /app/cache 2>/dev/null || true fi # Ensure writable even on filesystems that ignore chown (e.g., some NAS mounts) if ! gosu netv sh -c "touch /app/cache/.perm_test && rm /app/cache/.perm_test" 2>/dev/null; then chmod -R u+rwX,g+rwX /app/cache 2>/dev/null || true chmod g+s /app/cache 2>/dev/null || true fi # Final verification - warn if still not writable if ! gosu netv sh -c "touch /app/cache/.perm_test && rm /app/cache/.perm_test" 2>/dev/null; then echo "WARNING: /app/cache is not writable by netv user" echo "Cache operations may fail. Check volume permissions." fi # Fix models directory ownership mkdir -p /models if [ "$(stat -c '%U' /models)" != "netv" ]; then chown -R netv:netv /models 2>/dev/null || true fi if ! gosu netv sh -c "touch /models/.perm_test && rm /models/.perm_test" 2>/dev/null; then chmod -R u+rwX,g+rwX /models 2>/dev/null || true fi if ! gosu netv sh -c "touch /models/.perm_test && rm /models/.perm_test" 2>/dev/null; then echo "WARNING: /models is not writable by netv user" echo "TensorRT engine caching may fail. Check volume permissions." fi # Add netv user to render device group (for VAAPI hardware encoding) if [ -e /dev/dri/renderD128 ]; then RENDER_GID=$(stat -c '%g' /dev/dri/renderD128) RENDER_ADDED=false if groupadd --gid "$RENDER_GID" hostrender 2>/dev/null; then : # Created new group fi if usermod -aG hostrender netv 2>/dev/null; then RENDER_ADDED=true fi if [ "$RENDER_ADDED" = "false" ]; then echo "WARNING: Could not add netv to render group (GID $RENDER_GID)" if [ "$RENDER_GID" = "65534" ]; then echo " GID 65534 (nogroup) indicates Docker user namespace mapping issue." echo " This is usually harmless - VAAPI may still work if container has device access." echo " To fix: ensure 'render' group exists on host and user is in it, or use --privileged" else echo " VAAPI hardware encoding may not be available." echo " To fix on host: sudo usermod -aG render \$USER (then restart Docker)" fi fi fi # Build TensorRT engines if missing (first run only) # Builds both recommended models: 4x-compact (quality) and 2x-liveaction-span (fast) if ! ls /models/4x-compact_*p_fp16.engine >/dev/null 2>&1; then echo "========================================" echo "AI Upscale: First start detected" echo "========================================" echo "Building TensorRT engines for your GPU..." echo "Models: 4x-compact (quality), 2x-liveaction-span (fast)" echo "This only happens once (cached in /models volume)." echo "" # Run as netv user so files have correct ownership if ! gosu netv env MODEL_DIR=/models MODEL="recommended" /app/tools/install-ai_upscale.sh; then echo "ERROR: Failed to build TensorRT engines" echo "Check GPU compatibility and CUDA installation" exit 1 fi # Verify engines were created if ! ls /models/4x-compact_*p_fp16.engine >/dev/null 2>&1; then echo "ERROR: TensorRT engines not found after build" echo "Build may have succeeded but produced no output" exit 1 fi fi # Drop to netv user and run the app exec gosu netv python3 main.py --port "${NETV_PORT:-8000}" ${NETV_HTTPS:+--https} ================================================ FILE: entrypoint.sh ================================================ #!/bin/sh set -e # Entrypoint: fix permissions and drop to netv user # # Handles two common Docker issues: # 1. Bind-mounted ./cache owned by host user (permission denied) # 2. /dev/dri/renderD128 GID mismatch (VAAPI unavailable) # Fix cache directory ownership (skip if already correct to avoid slow recursive chown) # Build/runtime note: this only applies to bind-mounted cache (e.g., NAS), # not to image layers, so it does not affect build reproducibility. mkdir -p /app/cache if [ "$(stat -c '%U' /app/cache)" != "netv" ]; then chown -R netv:netv /app/cache 2>/dev/null || true fi # Ensure writable even on filesystems that ignore chown (e.g., some NAS mounts) if ! gosu netv sh -c "touch /app/cache/.perm_test && rm /app/cache/.perm_test" 2>/dev/null; then chmod -R u+rwX,g+rwX /app/cache 2>/dev/null || true chmod g+s /app/cache 2>/dev/null || true fi # Final verification - warn if still not writable if ! gosu netv sh -c "touch /app/cache/.perm_test && rm /app/cache/.perm_test" 2>/dev/null; then echo "WARNING: /app/cache is not writable by netv user" echo "Cache operations may fail. Check volume permissions." fi mkdir -p /app/cache/users if [ "$(stat -c '%U' /app/cache/users)" != "netv" ]; then chown -R netv:netv /app/cache/users 2>/dev/null || true fi # Ensure writable even on filesystems that ignore chown (e.g., some NAS mounts) if ! gosu netv sh -c "touch /app/cache/users/.perm_test && rm /app/cache/users/.perm_test" 2>/dev/null; then chmod -R u+rwX,g+rwX /app/cache/users 2>/dev/null || true chmod g+s /app/cache/users 2>/dev/null || true fi # Final verification - warn if still not writable if ! gosu netv sh -c "touch /app/cache/users/.perm_test && rm /app/cache/users/.perm_test" 2>/dev/null; then echo "WARNING: /app/cache/users is not writable by netv user" echo "Cache operations may fail. Check volume permissions." fi # Add netv user to render device group (for VAAPI hardware encoding) if [ -e /dev/dri/renderD128 ]; then RENDER_GID=$(stat -c '%g' /dev/dri/renderD128) RENDER_ADDED=false if groupadd --gid "$RENDER_GID" hostrender 2>/dev/null; then : # Created new group fi if usermod -aG hostrender netv 2>/dev/null; then RENDER_ADDED=true fi if [ "$RENDER_ADDED" = "false" ]; then echo "WARNING: Could not add netv to render group (GID $RENDER_GID)" if [ "$RENDER_GID" = "65534" ]; then echo " GID 65534 (nogroup) indicates Docker user namespace mapping issue." echo " This is usually harmless - VAAPI may still work if container has device access." echo " To fix: ensure 'render' group exists on host and user is in it, or use --privileged" else echo " VAAPI hardware encoding may not be available." echo " To fix on host: sudo usermod -aG render \$USER (then restart Docker)" fi fi fi # Drop to netv user and run the app exec gosu netv python3 main.py --port "${NETV_PORT:-8000}" ${NETV_HTTPS:+--https} ================================================ FILE: epg.py ================================================ """EPG storage and XMLTV parsing.""" from __future__ import annotations from dataclasses import dataclass from datetime import UTC, datetime, timedelta, timezone from pathlib import Path import contextlib import gzip import logging import re import sqlite3 import threading import time import defusedxml.ElementTree as ET # Safe XML parsing from util import safe_urlopen log = logging.getLogger(__name__) # ============================================================================= # Data Types # ============================================================================= @dataclass(slots=True) class Program: channel_id: str title: str start: datetime stop: datetime desc: str = "" source_id: str = "" # ============================================================================= # SQLite Storage # ============================================================================= _DB_PATH: Path | None = None _local = threading.local() def init(cache_dir: Path) -> None: """Initialize EPG database.""" global _DB_PATH _DB_PATH = cache_dir / "epg.db" conn = _get_conn() conn.executescript(""" CREATE TABLE IF NOT EXISTS channels ( id TEXT PRIMARY KEY, name TEXT, source_id TEXT ); CREATE TABLE IF NOT EXISTS icons ( channel_id TEXT PRIMARY KEY, url TEXT ); CREATE TABLE IF NOT EXISTS programs ( id INTEGER PRIMARY KEY, channel_id TEXT, title TEXT, start_ts REAL, stop_ts REAL, desc TEXT, source_id TEXT ); CREATE INDEX IF NOT EXISTS idx_programs_channel_time ON programs(channel_id, start_ts, stop_ts); CREATE INDEX IF NOT EXISTS idx_programs_time ON programs(start_ts); """) conn.commit() def _get_conn() -> sqlite3.Connection: """Get thread-local database connection.""" if not hasattr(_local, "conn") or _local.conn is None: if _DB_PATH is None: raise RuntimeError("EPG database not initialized") _local.conn = sqlite3.connect(_DB_PATH, timeout=30.0) _local.conn.row_factory = sqlite3.Row _local.conn.execute("PRAGMA journal_mode=WAL") return _local.conn def clear() -> None: """Clear all EPG data.""" conn = _get_conn() conn.executescript("DELETE FROM programs; DELETE FROM channels; DELETE FROM icons;") conn.commit() def clear_source(source_id: str) -> None: """Clear EPG data for a specific source.""" conn = _get_conn() conn.execute("DELETE FROM programs WHERE source_id = ?", (source_id,)) conn.execute("DELETE FROM channels WHERE source_id = ?", (source_id,)) conn.commit() def insert_channel(channel_id: str, name: str, source_id: str) -> None: """Insert or update a channel.""" conn = _get_conn() conn.execute( "INSERT OR REPLACE INTO channels (id, name, source_id) VALUES (?, ?, ?)", (channel_id, name, source_id), ) def insert_icon(channel_id: str, url: str) -> None: """Insert or update a channel icon.""" conn = _get_conn() conn.execute( "INSERT OR REPLACE INTO icons (channel_id, url) VALUES (?, ?)", (channel_id, url), ) def insert_programs(programs: list[tuple[str, str, float, float, str, str]]) -> None: """Bulk insert programs. Each tuple: (channel_id, title, start_ts, stop_ts, desc, source_id).""" conn = _get_conn() conn.executemany( "INSERT INTO programs (channel_id, title, start_ts, stop_ts, desc, source_id) VALUES (?, ?, ?, ?, ?, ?)", programs, ) def commit() -> None: """Commit current transaction.""" _get_conn().commit() def get_icon(channel_id: str) -> str: """Get icon URL for a channel.""" conn = _get_conn() row = conn.execute("SELECT url FROM icons WHERE channel_id = ?", (channel_id,)).fetchone() return row["url"] if row else "" def get_programs_in_range( channel_id: str, start: datetime, end: datetime, preferred_source_id: str = "", ) -> list[Program]: """Get programs for a channel within a time range.""" conn = _get_conn() start_ts = start.timestamp() end_ts = end.timestamp() rows = conn.execute( """ SELECT channel_id, title, start_ts, stop_ts, desc, source_id FROM programs WHERE channel_id = ? AND stop_ts > ? AND start_ts < ? ORDER BY start_ts """, (channel_id, start_ts, end_ts), ).fetchall() programs = [ Program( channel_id=row["channel_id"], title=row["title"], start=datetime.fromtimestamp(row["start_ts"], tz=UTC), stop=datetime.fromtimestamp(row["stop_ts"], tz=UTC), desc=row["desc"] or "", source_id=row["source_id"] or "", ) for row in rows ] if not preferred_source_id or len(programs) <= 1: return programs # Deduplicate overlapping programs, preferring the preferred source result: list[Program] = [] for p in programs: dominated = False for i, existing in enumerate(result): if p.start < existing.stop and p.stop > existing.start: if p.source_id == preferred_source_id and existing.source_id != preferred_source_id: result[i] = p dominated = True break if not dominated: result.append(p) return sorted(result, key=lambda p: p.start) _MAX_IN_CLAUSE = 500 # SQLite limit is 999, stay well below def _dedupe_programs(programs: list[Program], preferred_source_id: str) -> list[Program]: """Deduplicate overlapping programs, preferring the preferred source.""" if not preferred_source_id or len(programs) <= 1: return programs result: list[Program] = [] for p in programs: dominated = False for i, existing in enumerate(result): # Check for overlap if p.start < existing.stop and p.stop > existing.start: # Prefer the preferred source if p.source_id == preferred_source_id and existing.source_id != preferred_source_id: result[i] = p dominated = True break if not dominated: result.append(p) return sorted(result, key=lambda p: p.start) def get_programs_batch( channel_ids: list[str], start: datetime, end: datetime, preferred_sources: dict[str, str] | None = None, ) -> dict[str, list[Program]]: """Get programs for multiple channels in a single query. Args: channel_ids: List of EPG channel IDs to query start: Start of time window end: End of time window preferred_sources: Optional dict mapping channel_id -> preferred source_id for deduplication. If provided, overlapping programs from the preferred source will be kept over programs from other sources. """ if not channel_ids: return {} conn = _get_conn() start_ts = start.timestamp() end_ts = end.timestamp() result: dict[str, list[Program]] = {ch: [] for ch in channel_ids} # Process in chunks to avoid huge IN clauses for i in range(0, len(channel_ids), _MAX_IN_CLAUSE): chunk = channel_ids[i : i + _MAX_IN_CLAUSE] placeholders = ",".join("?" * len(chunk)) rows = conn.execute( f""" SELECT channel_id, title, start_ts, stop_ts, desc, source_id FROM programs WHERE channel_id IN ({placeholders}) AND stop_ts > ? AND start_ts < ? ORDER BY channel_id, start_ts """, [*chunk, start_ts, end_ts], ).fetchall() for row in rows: result[row["channel_id"]].append( Program( channel_id=row["channel_id"], title=row["title"], start=datetime.fromtimestamp(row["start_ts"], tz=UTC), stop=datetime.fromtimestamp(row["stop_ts"], tz=UTC), desc=row["desc"] or "", source_id=row["source_id"] or "", ) ) # Deduplicate overlapping programs if preferred_sources provided if preferred_sources: for ch_id in result: if ch_id in preferred_sources and result[ch_id]: result[ch_id] = _dedupe_programs(result[ch_id], preferred_sources[ch_id]) channels_with_programs = sum(1 for progs in result.values() if progs) log.debug( "EPG batch query: requested %d channel IDs, found programs for %d", len(channel_ids), channels_with_programs, ) return result def get_icons_batch(channel_ids: list[str]) -> dict[str, str]: """Get icons for multiple channels in a single query.""" if not channel_ids: return {} conn = _get_conn() result: dict[str, str] = {} for i in range(0, len(channel_ids), _MAX_IN_CLAUSE): chunk = channel_ids[i : i + _MAX_IN_CLAUSE] placeholders = ",".join("?" * len(chunk)) rows = conn.execute( f"SELECT channel_id, url FROM icons WHERE channel_id IN ({placeholders})", chunk, ).fetchall() for row in rows: result[row["channel_id"]] = row["url"] return result def has_programs() -> bool: """Check if there are any programs in the database.""" conn = _get_conn() row = conn.execute("SELECT 1 FROM programs LIMIT 1").fetchone() return row is not None def get_program_count() -> int: """Get total program count.""" conn = _get_conn() row = conn.execute("SELECT COUNT(*) FROM programs").fetchone() return row[0] if row else 0 def get_channel_count() -> int: """Get total channel count.""" conn = _get_conn() row = conn.execute("SELECT COUNT(*) FROM channels").fetchone() return row[0] if row else 0 def prune_old_programs(before: datetime) -> int: """Delete programs that ended before the given time. Returns count deleted.""" conn = _get_conn() cursor = conn.execute("DELETE FROM programs WHERE stop_ts < ?", (before.timestamp(),)) conn.commit() return cursor.rowcount # ============================================================================= # XMLTV Parsing # ============================================================================= def _parse_epg_time(s: str) -> datetime: """Parse XMLTV time format: 20241130120000 +0000 or 20241130120000+0530.""" s = s.replace(" ", "") if len(s) >= 14: dt = datetime.strptime(s[:14], "%Y%m%d%H%M%S") if len(s) > 14: tz_str = s[14:] sign = -1 if tz_str[0] == "-" else 1 tz_hours = int(tz_str[1:3]) if len(tz_str) >= 3 else 0 tz_mins = int(tz_str[3:5]) if len(tz_str) >= 5 else 0 offset = timedelta(hours=tz_hours, minutes=tz_mins) dt = dt.replace(tzinfo=timezone(sign * offset)) return dt return datetime.now(UTC) def _sanitize_epg_xml(xml_str: str) -> str: """Try to fix corrupted EPG XML by extracting valid elements.""" channels = re.findall(r"]*>.*?", xml_str, re.DOTALL) programmes = re.findall( r']+"\s+stop="[^"<>]+"\s+channel="[^"<>]+"[^>]*>.*?', xml_str, re.DOTALL, ) log.info("Sanitized EPG: extracted %d channels, %d programmes", len(channels), len(programmes)) return '\n\n' + "\n".join(channels) + "\n".join(programmes) + "\n" def fetch_epg( epg_url: str, cache_dir: Path, timeout: int = 120, source_id: str = "", user_agent: str | None = None, ) -> int: """Fetch and parse XMLTV EPG data directly into sqlite. Args: epg_url: URL of the XMLTV EPG feed cache_dir: Directory for debug files if parsing fails timeout: Request timeout in seconds source_id: Source identifier for multi-source support user_agent: User-Agent header to send. If None, uses default. Returns: Number of programs inserted. """ with safe_urlopen(epg_url, timeout=timeout, user_agent=user_agent) as resp: content = resp.read() with contextlib.suppress(Exception): content = gzip.decompress(content) xml_str = content.decode("utf-8") try: root = ET.fromstring(xml_str) except ET.ParseError as e: debug_file = cache_dir / f"epg_debug_{int(time.time())}.xml" debug_file.write_text(xml_str) log.warning("EPG parse failed (%s), attempting sanitization...", e) try: sanitized = _sanitize_epg_xml(xml_str) root = ET.fromstring(sanitized) log.info("Sanitized EPG parsed successfully") except ET.ParseError as e2: log.error("Sanitized EPG also failed: %s", e2) raise # Parse channels directly into sqlite channel_ids: set[str] = set() for ch in root.findall("channel"): ch_id = ch.get("id", "") channel_ids.add(ch_id) name_el = ch.find("display-name") name = name_el.text if name_el is not None and name_el.text else ch_id insert_channel(ch_id, name, source_id) icon_el = ch.find("icon") if icon_el is not None: insert_icon(ch_id, icon_el.get("src", "")) # Parse programs in batches batch: list[tuple[str, str, float, float, str, str]] = [] batch_size = 10000 program_count = 0 program_channel_ids: set[str] = set() for prog in root.findall("programme"): ch_id = prog.get("channel", "") program_channel_ids.add(ch_id) start_str = prog.get("start", "") stop_str = prog.get("stop", "") title_el = prog.find("title") title = title_el.text if title_el is not None and title_el.text else "Unknown" desc_el = prog.find("desc") desc = desc_el.text if desc_el is not None and desc_el.text else "" try: start = _parse_epg_time(start_str) stop = _parse_epg_time(stop_str) except Exception: continue batch.append((ch_id, title, start.timestamp(), stop.timestamp(), desc, source_id)) program_count += 1 if len(batch) >= batch_size: insert_programs(batch) batch.clear() if batch: insert_programs(batch) commit() log.debug( "EPG parsed: %d channels, %d unique program channel IDs, %d programs", len(channel_ids), len(program_channel_ids), program_count, ) return program_count ================================================ FILE: epg_test.py ================================================ """Tests for epg.py - EPG storage and parsing.""" from __future__ import annotations from datetime import UTC, datetime, timedelta from pathlib import Path import pytest from epg import Program import epg @pytest.fixture def db(tmp_path: Path): """Initialize EPG database in temp directory.""" epg.init(tmp_path) yield epg # Clear thread-local connection if hasattr(epg._local, "conn"): epg._local.conn.close() epg._local.conn = None class TestInit: """Tests for database initialization.""" def test_init_creates_tables(self, db): conn = db._get_conn() tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall() table_names = {t["name"] for t in tables} assert "channels" in table_names assert "icons" in table_names assert "programs" in table_names class TestChannels: """Tests for channel operations.""" def test_insert_channel(self, db): db.insert_channel("ch1", "Channel One", "src1") db.commit() conn = db._get_conn() row = conn.execute("SELECT * FROM channels WHERE id = ?", ("ch1",)).fetchone() assert row["name"] == "Channel One" assert row["source_id"] == "src1" def test_insert_channel_upsert(self, db): db.insert_channel("ch1", "Old Name", "src1") db.insert_channel("ch1", "New Name", "src1") db.commit() conn = db._get_conn() rows = conn.execute("SELECT * FROM channels WHERE id = ?", ("ch1",)).fetchall() assert len(rows) == 1 assert rows[0]["name"] == "New Name" class TestIcons: """Tests for icon operations.""" def test_insert_icon(self, db): db.insert_icon("ch1", "http://example.com/icon.png") db.commit() result = db.get_icon("ch1") assert result == "http://example.com/icon.png" def test_get_icon_not_found(self, db): result = db.get_icon("nonexistent") assert result == "" def test_get_icons_batch(self, db): db.insert_icon("ch1", "http://example.com/1.png") db.insert_icon("ch2", "http://example.com/2.png") db.insert_icon("ch3", "http://example.com/3.png") db.commit() result = db.get_icons_batch(["ch1", "ch3"]) assert result == { "ch1": "http://example.com/1.png", "ch3": "http://example.com/3.png", } def test_get_icons_batch_empty(self, db): result = db.get_icons_batch([]) assert result == {} class TestPrograms: """Tests for program operations.""" def test_insert_programs(self, db): now = datetime.now(UTC) programs = [ ( "ch1", "Show 1", now.timestamp(), (now + timedelta(hours=1)).timestamp(), "Desc 1", "src1", ), ( "ch1", "Show 2", (now + timedelta(hours=1)).timestamp(), (now + timedelta(hours=2)).timestamp(), "Desc 2", "src1", ), ] db.insert_programs(programs) db.commit() count = db.get_program_count() assert count == 2 def test_get_programs_in_range(self, db): now = datetime.now(UTC).replace(minute=0, second=0, microsecond=0) programs = [ ( "ch1", "Show 1", now.timestamp(), (now + timedelta(hours=1)).timestamp(), "Desc 1", "src1", ), ( "ch1", "Show 2", (now + timedelta(hours=1)).timestamp(), (now + timedelta(hours=2)).timestamp(), "Desc 2", "src1", ), ( "ch1", "Show 3", (now + timedelta(hours=2)).timestamp(), (now + timedelta(hours=3)).timestamp(), "Desc 3", "src1", ), ] db.insert_programs(programs) db.commit() # Query for middle hour result = db.get_programs_in_range( "ch1", now + timedelta(minutes=30), now + timedelta(hours=1, minutes=30), ) assert len(result) == 2 assert result[0].title == "Show 1" assert result[1].title == "Show 2" def test_get_programs_in_range_empty(self, db): now = datetime.now(UTC) result = db.get_programs_in_range("ch1", now, now + timedelta(hours=1)) assert result == [] def test_get_programs_batch(self, db): now = datetime.now(UTC).replace(minute=0, second=0, microsecond=0) programs = [ ("ch1", "Show A", now.timestamp(), (now + timedelta(hours=1)).timestamp(), "", "src1"), ("ch2", "Show B", now.timestamp(), (now + timedelta(hours=1)).timestamp(), "", "src1"), ] db.insert_programs(programs) db.commit() result = db.get_programs_batch( ["ch1", "ch2", "ch3"], now, now + timedelta(hours=1), ) assert len(result["ch1"]) == 1 assert len(result["ch2"]) == 1 assert len(result["ch3"]) == 0 assert result["ch1"][0].title == "Show A" assert result["ch2"][0].title == "Show B" def test_get_programs_batch_empty_channels(self, db): result = db.get_programs_batch([], datetime.now(UTC), datetime.now(UTC)) assert result == {} def test_has_programs_false(self, db): assert db.has_programs() is False def test_has_programs_true(self, db): now = datetime.now(UTC) db.insert_programs( [("ch1", "Show", now.timestamp(), (now + timedelta(hours=1)).timestamp(), "", "src1")] ) db.commit() assert db.has_programs() is True def test_get_program_count(self, db): now = datetime.now(UTC) assert db.get_program_count() == 0 db.insert_programs( [ ( "ch1", "Show 1", now.timestamp(), (now + timedelta(hours=1)).timestamp(), "", "src1", ), ( "ch1", "Show 2", (now + timedelta(hours=1)).timestamp(), (now + timedelta(hours=2)).timestamp(), "", "src1", ), ] ) db.commit() assert db.get_program_count() == 2 def test_get_channel_count(self, db): assert db.get_channel_count() == 0 db.insert_channel("ch1", "Channel 1", "src1") db.insert_channel("ch2", "Channel 2", "src1") db.commit() assert db.get_channel_count() == 2 class TestClear: """Tests for clear operations.""" def test_clear_all(self, db): now = datetime.now(UTC) db.insert_channel("ch1", "Channel 1", "src1") db.insert_icon("ch1", "http://example.com/icon.png") db.insert_programs( [("ch1", "Show", now.timestamp(), (now + timedelta(hours=1)).timestamp(), "", "src1")] ) db.commit() db.clear() assert db.get_channel_count() == 0 assert db.get_program_count() == 0 assert db.get_icon("ch1") == "" def test_clear_source(self, db): now = datetime.now(UTC) db.insert_channel("ch1", "Channel 1", "src1") db.insert_channel("ch2", "Channel 2", "src2") db.insert_programs( [ ( "ch1", "Show 1", now.timestamp(), (now + timedelta(hours=1)).timestamp(), "", "src1", ), ( "ch2", "Show 2", now.timestamp(), (now + timedelta(hours=1)).timestamp(), "", "src2", ), ] ) db.commit() db.clear_source("src1") assert db.get_channel_count() == 1 assert db.get_program_count() == 1 class TestPrune: """Tests for prune operations.""" def test_prune_old_programs(self, db): now = datetime.now(UTC) old = now - timedelta(days=2) db.insert_programs( [ ( "ch1", "Old Show", old.timestamp(), (old + timedelta(hours=1)).timestamp(), "", "src1", ), ( "ch1", "New Show", now.timestamp(), (now + timedelta(hours=1)).timestamp(), "", "src1", ), ] ) db.commit() deleted = db.prune_old_programs(now - timedelta(days=1)) assert deleted == 1 assert db.get_program_count() == 1 class TestPreferredSource: """Tests for preferred source deduplication.""" def test_prefer_source_in_range(self, db): now = datetime.now(UTC).replace(minute=0, second=0, microsecond=0) # Two overlapping programs from different sources db.insert_programs( [ ( "ch1", "From Src1", now.timestamp(), (now + timedelta(hours=1)).timestamp(), "", "src1", ), ( "ch1", "From Src2", now.timestamp(), (now + timedelta(hours=1)).timestamp(), "", "src2", ), ] ) db.commit() # Prefer src2 result = db.get_programs_in_range( "ch1", now, now + timedelta(hours=1), preferred_source_id="src2" ) assert len(result) == 1 assert result[0].title == "From Src2" # Prefer src1 result = db.get_programs_in_range( "ch1", now, now + timedelta(hours=1), preferred_source_id="src1" ) assert len(result) == 1 assert result[0].title == "From Src1" class TestProgram: """Tests for Program dataclass.""" def test_program_dataclass(self): now = datetime.now(UTC) p = Program( channel_id="ch1", title="Test Show", start=now, stop=now + timedelta(hours=1), desc="Description", source_id="src1", ) assert p.channel_id == "ch1" assert p.title == "Test Show" assert p.desc == "Description" assert p.source_id == "src1" def test_program_defaults(self): now = datetime.now(UTC) p = Program(channel_id="ch1", title="Test", start=now, stop=now + timedelta(hours=1)) assert p.desc == "" assert p.source_id == "" if __name__ == "__main__": from testing import run_tests run_tests(__file__) ================================================ FILE: ffmpeg_command.py ================================================ """FFmpeg command building and media probing.""" from __future__ import annotations from collections.abc import Callable from contextlib import suppress from dataclasses import dataclass from typing import Any, Literal import json import logging import pathlib import re import subprocess import tempfile import threading import time # Import VAAPI auto-detection results (avoid circular import by importing constants only) from cache import AVAILABLE_ENCODERS, VAAPI_DEVICE log = logging.getLogger(__name__) HwAccel = Literal[ "nvenc+vaapi", "nvenc+software", "amf+vaapi", "amf+software", "qsv", "vaapi", "software" ] def _parse_hw(hw: HwAccel) -> tuple[str, str]: """Parse hw into (encoder, fallback). e.g. 'nvenc+vaapi' -> ('nvenc', 'vaapi')""" if "+" in hw: encoder, fallback = hw.split("+", 1) return encoder, fallback return hw, "software" # standalone options fallback to software # Timing constants _HLS_SEGMENT_DURATION_SEC = 3.0 # Short segments for faster startup/seeking _PROBE_CACHE_TTL_SEC = 3_600 _SERIES_PROBE_CACHE_TTL_SEC = 7 * 24 * 3_600 # 7 days _PROBE_TIMEOUT_SEC = 30 # Segment file naming SEG_PREFIX = "seg" # Segment files are named seg000.ts, seg001.ts, etc. DEFAULT_LIVE_BUFFER_SECS = 30.0 # Default live buffer when DVR disabled TEXT_SUBTITLE_CODECS = { "subrip", "ass", "ssa", "mov_text", "webvtt", "srt", } # User-Agent presets _USER_AGENT_PRESETS = { "vlc": "VLC/3.0.20 LibVLC/3.0.20", "chrome": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", "tivimate": "TiviMate/4.7.0", } # NVDEC capabilities by minimum compute capability # https://developer.nvidia.com/video-encode-and-decode-gpu-support-matrix-new _NVDEC_MIN_COMPUTE: dict[str, float] = { "h264": 5.0, # Maxwell+ "hevc": 6.0, # Pascal+ (HEVC 10-bit requires Pascal; Maxwell GM206 is edge case we ignore) "av1": 8.0, # Ampere+ } # VAAPI/QSV: static conservative lists (unlike NVIDIA, no clean runtime probe available). # Could parse `vainfo` output, but format varies by driver (i965 vs iHD vs radeonsi). # These codecs are nearly universal on any GPU from the last decade. _VAAPI_SAFE_CODECS = {"h264", "hevc", "mpeg2video", "vp8", "vp9", "vc1", "av1"} _QSV_SAFE_CODECS = {"h264", "hevc", "mpeg2video", "vp9", "vc1", "av1"} # Max resolution height by setting _MAX_RES_HEIGHT: dict[str, int] = { "4k": 2160, "1080p": 1080, "720p": 720, "480p": 480, } # Quality presets -> QP/CRF values (lower = higher quality) _QUALITY_QP: dict[str, int] = {"high": 20, "medium": 28, "low": 35} _QUALITY_CRF: dict[str, int] = {"high": 20, "medium": 26, "low": 32} # Module state _probe_lock = threading.Lock() _probe_cache: dict[str, tuple[float, MediaInfo | None, list[SubtitleStream]]] = {} _series_probe_cache: dict[int, dict[str, Any]] = {} _gpu_nvdec_codecs: set[str] | None = None # None = not probed yet _has_libplacebo: bool | None = None # None = not probed yet _load_settings: Callable[[], dict[str, Any]] = dict # Super-resolution configuration (set by init()) # Directory containing TensorRT engines: {model}_{height}p_fp16.engine _sr_engine_dir: str = "" # Use old "cache" if it exists (backwards compat), otherwise ".cache" _OLD_CACHE = pathlib.Path(__file__).parent / "cache" _CACHE_DIR = _OLD_CACHE if _OLD_CACHE.exists() else pathlib.Path(__file__).parent / ".cache" _SERIES_PROBE_CACHE_FILE = _CACHE_DIR / "series_probe_cache.json" _LANG_NAMES = { "eng": "English", "spa": "Spanish", "fre": "French", "ger": "German", "por": "Portuguese", "ita": "Italian", "jpn": "Japanese", "kor": "Korean", "chi": "Chinese", "ara": "Arabic", "rus": "Russian", "und": "Unknown", } @dataclass(slots=True) class SubtitleStream: index: int lang: str name: str @dataclass(slots=True) class MediaInfo: video_codec: str audio_codec: str pix_fmt: str audio_channels: int = 0 audio_sample_rate: int = 0 audio_profile: str = "" # e.g. "LC", "HE-AAC", "HE-AACv2" subtitle_codecs: list[str] | None = None duration: float = 0.0 height: int = 0 video_bitrate: int = 0 # bits per second, 0 if unknown interlaced: bool = False # True if field_order indicates interlaced is_10bit: bool = False # True if pix_fmt indicates 10-bit color is_hdr: bool = False # True if color transfer indicates HDR is_hls: bool = False # True if format is HLS (for input options) def init( load_settings: Callable[[], dict[str, Any]], sr_engine_dir: str = "", ) -> None: """Initialize module with settings loader and optional AI Upscale config.""" global _load_settings, _sr_engine_dir _load_settings = load_settings _sr_engine_dir = sr_engine_dir _load_series_probe_cache() if _sr_engine_dir: log.info("AI Upscale enabled: engine_dir=%s", _sr_engine_dir) def get_settings() -> dict[str, Any]: """Get current settings.""" return _load_settings() def get_ffmpeg_env() -> dict[str, str] | None: """Get environment for ffmpeg subprocess. Returns None (ffmpeg has libtorch via rpath).""" # ffmpeg is built with -Wl,-rpath pointing to libtorch, so no LD_LIBRARY_PATH needed return None def _find_sr_engine(model_name: str, source_height: int) -> tuple[str, int, int, int] | None: """Find the best matching SR engine file for the given model and resolution. Returns (engine_path, input_height, input_width, scale_factor) or None if not found. """ import pathlib engine_dir = pathlib.Path(_sr_engine_dir) if not engine_dir.exists(): return None # Find all engines for this model # Engine naming: {model}_{height}p_fp16.engine engines: list[tuple[int, pathlib.Path]] = [] for engine in engine_dir.glob(f"{model_name}_*p_fp16.engine"): # Extract height from filename name = engine.stem # e.g., "2x-liveaction-span_1080p_fp16" parts = name.rsplit("_", 2) if len(parts) >= 3: height_str = parts[1].rstrip("p") if height_str.isdigit(): engines.append((int(height_str), engine)) if not engines: return None # Determine scale factor from model name prefix (e.g., "2x-", "4x-") scale_match = re.match(r"^(\d+)x-", model_name) if scale_match: scale = int(scale_match.group(1)) elif model_name == "realesrgan": # Legacy model name - was 4x scale = 4 else: log.error( "SR: cannot determine scale from model name: %s (expected Nx- prefix)", model_name ) return None # Sort by height ascending engines.sort(key=lambda x: x[0]) # Select appropriate engine based on source height if source_height <= 0: # Probe failed - use highest resolution engine engine_height, engine_path = engines[-1] log.warning("SR: probe failed, using %dp engine", engine_height) else: # Find engine closest to but >= source height, or use largest if source is bigger engine_height, engine_path = engines[-1] # default to largest for h, p in engines: if h >= source_height: engine_height, engine_path = h, p break # Calculate width assuming 16:9 aspect ratio, rounded to multiple of 8 engine_width = ((engine_height * 16 // 9) + 7) // 8 * 8 return str(engine_path), engine_height, engine_width, scale def _build_sr_filter(source_height: int, target_height: int) -> str: """Build AI Upscale filter string if needed. Returns empty string if disabled. SR is controlled by sr_model setting - if a model is selected, SR is applied when source height < target height. """ if not _sr_engine_dir: return "" # Get selected model from settings settings = _load_settings() model_name = settings.get("sr_model", "") if not model_name: return "" # SR disabled (Off selected) # Find engine for this model and resolution engine_info = _find_sr_engine(model_name, source_height) if not engine_info: log.warning("SR: no engine found for model=%s, source=%dp", model_name, source_height) return "" engine_path, engine_height, engine_width, scale = engine_info # Apply SR when source resolution is below target (upscaling scenario) if target_height and source_height >= target_height: log.info( "SR: skipping %s - source %dp >= target %dp", model_name, source_height, target_height ) return "" log.info( "SR: applying %s (%dx) to %dp -> %dp", model_name, scale, source_height, target_height or (source_height * scale), ) # Build SR filter chain: # 1. Scale to engine's expected input size (preserving aspect with padding if needed) # 2. Convert to RGB (model expects 3-channel RGB input) # 3. hwupload to GPU (critical for performance - keeps data on GPU) # 4. Apply SR via TensorRT dnn_processing (outputs Nx resolution on GPU) # 5. Scale down on GPU to target resolution sr_filter = ( f"scale={engine_width}:{engine_height}:force_original_aspect_ratio=decrease," f"pad={engine_width}:{engine_height}:(ow-iw)/2:(oh-ih)/2," f"format=rgb24," f"hwupload," f"dnn_processing=dnn_backend=tensorrt:model={engine_path}" ) if target_height: # After dnn_processing, data is on GPU - use scale_cuda with explicit params sr_filter += f",scale_cuda=w=-2:h={target_height}" return sr_filter def get_hls_segment_duration() -> float: """Get HLS segment duration in seconds.""" return _HLS_SEGMENT_DURATION_SEC # =========================================================================== # GPU Detection # =========================================================================== def _get_gpu_nvdec_codecs() -> set[str]: """Get supported NVDEC codecs, probing GPU on first call.""" global _gpu_nvdec_codecs if _gpu_nvdec_codecs is not None: return _gpu_nvdec_codecs _gpu_nvdec_codecs = set() try: result = subprocess.run( ["nvidia-smi", "--query-gpu=name,compute_cap", "--format=csv,noheader"], capture_output=True, text=True, timeout=5, ) if result.returncode != 0: log.info("No NVIDIA GPU detected") return _gpu_nvdec_codecs # Parse "NVIDIA GeForce GTX TITAN X, 5.2" line = result.stdout.strip().split("\n")[0] parts = line.rsplit(",", 1) if len(parts) != 2: return _gpu_nvdec_codecs gpu_name = parts[0].strip() compute_cap = float(parts[1].strip()) _gpu_nvdec_codecs = { codec for codec, min_cap in _NVDEC_MIN_COMPUTE.items() if compute_cap >= min_cap } log.info( "GPU: %s (compute %.1f) NVDEC: %s", gpu_name, compute_cap, _gpu_nvdec_codecs or "none", ) except Exception as e: log.debug("GPU probe failed: %s", e) return _gpu_nvdec_codecs def _has_libplacebo_filter() -> bool: """Check if FFmpeg has libplacebo filter available (for GPU HDR tone mapping).""" global _has_libplacebo if _has_libplacebo is not None: return _has_libplacebo _has_libplacebo = False try: result = subprocess.run( ["ffmpeg", "-filters"], capture_output=True, text=True, timeout=5, ) _has_libplacebo = "libplacebo" in result.stdout log.info("libplacebo filter available: %s", _has_libplacebo) except Exception as e: log.debug("libplacebo probe failed: %s", e) return _has_libplacebo # =========================================================================== # User-Agent # =========================================================================== def get_user_agent() -> str | None: """Get user-agent string from settings, or None to use FFmpeg default.""" settings = _load_settings() preset = settings.get("user_agent_preset", "default") if preset == "default": return None if preset == "custom": return settings.get("user_agent_custom") or None return _USER_AGENT_PRESETS.get(preset) # =========================================================================== # Transcode Directory # =========================================================================== def get_transcode_dir() -> pathlib.Path: """Get the transcode output directory. Falls back to system temp if not set or inaccessible.""" custom_dir = _load_settings().get("transcode_dir", "") if custom_dir: path = pathlib.Path(custom_dir) try: path.mkdir(parents=True, exist_ok=True) return path except (PermissionError, OSError) as e: log.warning("Transcode dir %s inaccessible (%s), using temp dir", custom_dir, e) return pathlib.Path(tempfile.gettempdir()) # =========================================================================== # Series Probe Cache Persistence # =========================================================================== def _load_series_probe_cache() -> None: """Load series probe cache from disk.""" if not _SERIES_PROBE_CACHE_FILE.exists(): return try: data = json.loads(_SERIES_PROBE_CACHE_FILE.read_text()) count = 0 with _probe_lock: for sid_str, series_data in data.items(): sid = int(sid_str) if sid not in _series_probe_cache: _series_probe_cache[sid] = { "name": series_data.get("name", ""), "mru": series_data.get("mru"), "episodes": {}, } else: _series_probe_cache[sid].setdefault("name", series_data.get("name", "")) _series_probe_cache[sid].setdefault("mru", series_data.get("mru")) _series_probe_cache[sid].setdefault("episodes", {}) for eid_str, entry in series_data.get("episodes", {}).items(): eid = int(eid_str) if eid in _series_probe_cache[sid]["episodes"]: continue # Use .get() for all fields to handle corrupt/incomplete cache video_codec = entry.get("video_codec", "") if not video_codec: continue # Skip entries without video codec media_info = MediaInfo( video_codec=video_codec, audio_codec=entry.get("audio_codec", ""), pix_fmt=entry.get("pix_fmt", ""), audio_channels=entry.get("audio_channels", 0), audio_sample_rate=entry.get("audio_sample_rate", 0), subtitle_codecs=entry.get("subtitle_codecs"), duration=entry.get("duration", 0), height=entry.get("height", 0), video_bitrate=entry.get("video_bitrate", 0), interlaced=entry.get("interlaced", False), is_10bit=entry.get("is_10bit", False), is_hdr=entry.get("is_hdr", False), is_hls=entry.get("is_hls", False), ) subs = [ SubtitleStream(s["index"], s.get("lang", "und"), s.get("name", "")) for s in entry.get("subtitles", []) ] _series_probe_cache[sid]["episodes"][eid] = ( entry.get("time", 0), media_info, subs, ) count += 1 log.info("Loaded %d series probe cache entries", count) except Exception as e: log.warning("Failed to load series probe cache: %s", e) def _save_series_probe_cache() -> None: """Save series probe cache to disk.""" with _probe_lock: data: dict[str, dict[str, Any]] = {} for sid, series_data in _series_probe_cache.items(): episodes = series_data.get("episodes", {}) data[str(sid)] = { "name": series_data.get("name", ""), "mru": series_data.get("mru"), "episodes": {}, } for eid, (cache_time, media_info, subs) in episodes.items(): if media_info is None: continue data[str(sid)]["episodes"][str(eid)] = { "time": cache_time, "video_codec": media_info.video_codec, "audio_codec": media_info.audio_codec, "pix_fmt": media_info.pix_fmt, "audio_channels": media_info.audio_channels, "audio_sample_rate": media_info.audio_sample_rate, "subtitle_codecs": media_info.subtitle_codecs, "duration": media_info.duration, "height": media_info.height, "video_bitrate": media_info.video_bitrate, "interlaced": media_info.interlaced, "is_10bit": media_info.is_10bit, "is_hdr": media_info.is_hdr, "subtitles": [{"index": s.index, "lang": s.lang, "name": s.name} for s in subs], } try: _SERIES_PROBE_CACHE_FILE.parent.mkdir(parents=True, exist_ok=True) _SERIES_PROBE_CACHE_FILE.write_text(json.dumps(data, indent=2)) except Exception as e: log.warning("Failed to save series probe cache: %s", e) # =========================================================================== # Probe Cache Management # =========================================================================== def get_series_probe_cache_stats() -> list[dict[str, Any]]: """Get stats about cached series probes for settings UI.""" with _probe_lock: log.info( "get_series_probe_cache_stats: cache has %d series: %s", len(_series_probe_cache), list(_series_probe_cache.keys()), ) result = [] for series_id, series_data in _series_probe_cache.items(): episodes = series_data.get("episodes", {}) if not episodes: continue # Get most recent entry for display info most_recent = max(episodes.values(), key=lambda x: x[0]) _, media_info, subs = most_recent if media_info is None: continue # Build episode list episode_list = [] for eid, (_, emedia, esubs) in episodes.items(): if emedia: episode_list.append( { "episode_id": eid, "duration": emedia.duration, "subtitle_count": len(esubs), } ) result.append( { "series_id": series_id, "name": series_data.get("name", ""), "mru": series_data.get("mru"), "episode_count": len(episodes), "video_codec": media_info.video_codec, "audio_codec": media_info.audio_codec, "subtitle_count": len(subs), "episodes": sorted(episode_list, key=lambda x: x["episode_id"]), } ) return sorted(result, key=lambda x: x.get("name") or str(x["series_id"])) def clear_all_probe_cache() -> int: """Clear all probe caches. Returns count of entries cleared.""" with _probe_lock: url_count = len(_probe_cache) series_count = sum(len(s.get("episodes", {})) for s in _series_probe_cache.values()) _probe_cache.clear() _series_probe_cache.clear() _save_series_probe_cache() log.info("Cleared probe cache: %d URL entries, %d series entries", url_count, series_count) return url_count + series_count def invalidate_series_probe_cache(series_id: int, episode_id: int | None = None) -> None: """Invalidate cached probe for series/episode. If episode_id is None, clears entire series. Otherwise clears just that episode. """ with _probe_lock: if series_id not in _series_probe_cache: return if episode_id is None: del _series_probe_cache[series_id] log.info("Cleared probe cache for series=%d", series_id) else: series_data = _series_probe_cache[series_id] episodes = series_data.get("episodes", {}) if episode_id in episodes: del episodes[episode_id] log.info( "Cleared probe cache for series=%d episode=%d", series_id, episode_id, ) _save_series_probe_cache() def clear_series_mru(series_id: int) -> None: """Clear only the MRU for a series, keeping episode cache intact.""" with _probe_lock: if series_id not in _series_probe_cache: return if "mru" in _series_probe_cache[series_id]: del _series_probe_cache[series_id]["mru"] log.info("Cleared MRU for series=%d", series_id) _save_series_probe_cache() def restore_probe_cache_entry( url: str, media_info: MediaInfo, subs: list[SubtitleStream], series_id: int | None = None, episode_id: int | None = None, ) -> None: """Restore a probe cache entry (used during session recovery).""" now = time.time() with _probe_lock: if url not in _probe_cache: _probe_cache[url] = (now, media_info, subs) if series_id is not None: if series_id not in _series_probe_cache: _series_probe_cache[series_id] = {"name": "", "episodes": {}} _series_probe_cache[series_id].setdefault("episodes", {}) eid = episode_id or 0 if eid not in _series_probe_cache[series_id]["episodes"]: _series_probe_cache[series_id]["episodes"][eid] = (now, media_info, subs) # =========================================================================== # Media Probing # =========================================================================== def _lang_display_name(code: str) -> str: return _LANG_NAMES.get(code, code.upper()) def resolve_hls_master_playlist(url: str) -> str: """Resolve HLS master playlist to highest bandwidth variant URL. If the URL points to an HLS master playlist (contains #EXT-X-STREAM-INF), this fetches and parses it to find the variant with the highest bandwidth. Returns the resolved variant URL, or the original URL if not a master playlist or on any error. """ from urllib.parse import urljoin import urllib.request if not url.endswith(".m3u8") and ".m3u8?" not in url: return url # Not an m3u8, return as-is try: req = urllib.request.Request(url) user_agent = get_user_agent() if user_agent: req.add_header("User-Agent", user_agent) with urllib.request.urlopen(req, timeout=10) as response: content = response.read().decode("utf-8", errors="replace") # Check if this is a master playlist (has #EXT-X-STREAM-INF) if "#EXT-X-STREAM-INF" not in content: return url # Not a master playlist # Parse variants: each #EXT-X-STREAM-INF is followed by a URL line lines = content.strip().split("\n") variants: list[tuple[int, str]] = [] for i, line in enumerate(lines): if line.startswith("#EXT-X-STREAM-INF:"): # Extract BANDWIDTH from the tag bandwidth = 0 for attr in line.split(":")[1].split(","): if attr.startswith("BANDWIDTH="): with suppress(ValueError): bandwidth = int(attr.split("=")[1]) break # Next non-comment line is the variant URL for j in range(i + 1, len(lines)): variant_line = lines[j].strip() if variant_line and not variant_line.startswith("#"): # Resolve relative URL variant_url = urljoin(url, variant_line) variants.append((bandwidth, variant_url)) break if not variants: log.warning("HLS master playlist has no variants: %s", url[:80]) return url # Select highest bandwidth variant variants.sort(key=lambda x: x[0], reverse=True) best_bandwidth, best_url = variants[0] log.info( "HLS master playlist resolved: %d variants, selected %d bps: %s", len(variants), best_bandwidth, best_url[:80], ) return best_url except Exception as e: log.warning("Failed to resolve HLS master playlist %s: %s", url[:80], e) return url def probe_media( url: str, series_id: int | None = None, episode_id: int | None = None, series_name: str = "", ) -> tuple[MediaInfo | None, list[SubtitleStream]]: """Probe media, returns (media_info, subtitles).""" # Check series/episode cache first cache_hit_result: tuple[MediaInfo, list[SubtitleStream]] | None = None save_mru = False if series_id is not None: with _probe_lock: series_data = _series_probe_cache.get(series_id) if series_data: episodes = series_data.get("episodes", {}) mru_eid = series_data.get("mru") # Try exact episode first if episode_id is not None and episode_id in episodes: cache_time, media_info, subtitles = episodes[episode_id] if time.time() - cache_time < _SERIES_PROBE_CACHE_TTL_SEC: # Update MRU to this episode if series_data.get("mru") != episode_id: series_data["mru"] = episode_id save_mru = True log.info( "Probe cache hit for series=%d episode=%d", series_id, episode_id, ) cache_hit_result = (media_info, subtitles) # Fall back to MRU if set elif mru_eid is not None and mru_eid in episodes: cache_time, media_info, subtitles = episodes[mru_eid] if time.time() - cache_time < _SERIES_PROBE_CACHE_TTL_SEC: log.info( "Probe cache hit for series=%d (fallback from mru=%d)", series_id, mru_eid, ) cache_hit_result = (media_info, subtitles) # Save MRU update outside the lock to avoid deadlock if save_mru: _save_series_probe_cache() if cache_hit_result: return cache_hit_result # Check URL cache (for movies, or series cache miss) with _probe_lock: cached = _probe_cache.get(url) if cached: cache_time, media_info, subtitles = cached if time.time() - cache_time < _PROBE_CACHE_TTL_SEC: log.info("Probe cache hit for %s", url[:50]) return media_info, subtitles log.info( "Probe cache miss for %s (series=%s, episode=%s)", url[:50], series_id, episode_id, ) # Build base probe command # MPEG-TS streams (HDHomeRun, live TV) need ~1MB to reach first keyframe # which contains the sequence header with dimensions. GOP at 15Mbps = ~1-2MB. base_cmd = [ "ffprobe", "-probesize", "1000000", # Had to increase for HDHomerun; was 50000. "-analyzeduration", "1500000", # Had to increase for HDHomerun; was 500000. "-v", "quiet", "-print_format", "json", "-show_streams", "-show_format", ] user_agent = get_user_agent() if user_agent: base_cmd.extend(["-user_agent", user_agent]) # Try probe without forcing HLS first, retry with HLS options if it fails is_hls = False data = None for force_hls in (False, True): try: cmd = base_cmd.copy() if force_hls: cmd.extend(["-f", "hls", "-extension_picky", "0"]) cmd.append(url) log.info("Probing%s: %s", " (HLS mode)" if force_hls else "", " ".join(cmd)) result = subprocess.run( cmd, check=False, capture_output=True, text=True, timeout=_PROBE_TIMEOUT_SEC, ) if result.returncode == 0: data = json.loads(result.stdout) # Check detected format or if we forced HLS format_name = data.get("format", {}).get("format_name", "").lower() is_hls = force_hls or "hls" in format_name break except Exception as e: log.warning("Probe failed%s: %s", " (HLS mode)" if force_hls else "", e) continue if data is None: return None, [] video_codec = audio_codec = pix_fmt = audio_profile = "" audio_channels = audio_sample_rate = 0 subtitle_codecs: list[str] = [] subtitles: list[SubtitleStream] = [] height = 0 video_bitrate = 0 interlaced = False is_10bit = False is_hdr = False for stream in data.get("streams", []): codec = stream.get("codec_name", "").lower() codec_type = stream.get("codec_type", "") if codec_type == "video" and not video_codec: video_codec = codec pix_fmt = stream.get("pix_fmt", "") height = stream.get("height", 0) or 0 # Detect interlacing from field_order (tt, bb, tb, bt = interlaced) field_order = stream.get("field_order", "").lower() interlaced = field_order in ("tt", "bb", "tb", "bt") # Detect 10-bit from pix_fmt (e.g. yuv420p10le, p010le) # Check for "p10" or "10le/10be" to avoid false positive on yuv410p is_10bit = "p10" in pix_fmt or "10le" in pix_fmt or "10be" in pix_fmt # Detect HDR from color_transfer (PQ = smpte2084, HLG = arib-std-b67) color_transfer = stream.get("color_transfer", "").lower() is_hdr = color_transfer in ("smpte2084", "arib-std-b67") # Try to get bitrate from stream, fall back to format with suppress(ValueError, TypeError): video_bitrate = int(stream.get("bit_rate", 0) or 0) elif codec_type == "audio" and not audio_codec: audio_codec = codec audio_channels = stream.get("channels", 0) audio_sample_rate = int(stream.get("sample_rate", 0) or 0) audio_profile = stream.get("profile", "") elif codec_type == "subtitle": subtitle_codecs.append(codec) if codec in TEXT_SUBTITLE_CODECS: idx = stream.get("index") if idx is not None: tags = stream.get("tags", {}) lang = tags.get("language", "und").lower() name = tags.get("name") or tags.get("title") or _lang_display_name(lang) subtitles.append( SubtitleStream( index=idx, lang=lang, name=name, ) ) duration = 0.0 fmt = data.get("format", {}) if fmt.get("duration"): with suppress(ValueError, TypeError): duration = float(fmt["duration"]) # Fall back to format bitrate if stream bitrate unavailable (common for MKV) if not video_bitrate and fmt.get("bit_rate"): with suppress(ValueError, TypeError): video_bitrate = int(fmt["bit_rate"]) if not video_codec: return None, [] media_info = MediaInfo( video_codec=video_codec, audio_codec=audio_codec, pix_fmt=pix_fmt, audio_channels=audio_channels, audio_sample_rate=audio_sample_rate, audio_profile=audio_profile, subtitle_codecs=subtitle_codecs or None, duration=duration, height=height, video_bitrate=video_bitrate, interlaced=interlaced, is_10bit=is_10bit, is_hdr=is_hdr, is_hls=is_hls, ) # Only cache if we got valid video info (height > 0) if height <= 0: log.warning("Probe returned invalid height=%d, not caching: %s", height, url[:80]) return media_info, subtitles with _probe_lock: _probe_cache[url] = (time.time(), media_info, subtitles) # Cache by series_id/episode_id if provided if series_id is not None: if series_id not in _series_probe_cache: _series_probe_cache[series_id] = {"name": series_name, "episodes": {}} elif not _series_probe_cache[series_id].get("name") and series_name: _series_probe_cache[series_id]["name"] = series_name eid = episode_id if episode_id is not None else 0 _series_probe_cache[series_id].setdefault("episodes", {})[eid] = ( time.time(), media_info, subtitles, ) # Set MRU to this episode old_mru = _series_probe_cache[series_id].get("mru") _series_probe_cache[series_id]["mru"] = eid log.info( "Probe cached: series=%s episode=%s, mru changed from %s to %s", series_id, eid, old_mru, eid, ) if series_id is not None: _save_series_probe_cache() return media_info, subtitles # =========================================================================== # FFmpeg Command Building # =========================================================================== def _build_video_args( *, copy_video: bool, hw: HwAccel, deinterlace: bool, use_hw_pipeline: bool, max_resolution: str, quality: str, is_hdr: bool = False, source_height: int = 0, ) -> tuple[list[str], list[str]]: """Build video args. Returns (pre_input_args, post_input_args).""" if copy_video: return [], ["-c:v", "copy"] # Parse hw into encoder and fallback enc_type, fallback = _parse_hw(hw) max_h = _MAX_RES_HEIGHT.get(max_resolution) # Check if SR should be applied (discrete GPUs only) sr_filter = "" sr_model = _load_settings().get("sr_model", "") if sr_model and enc_type in ("nvenc", "amf") and _sr_engine_dir: sr_filter = _build_sr_filter(source_height, max_h or 0) # SR requires CPU frames, so disable hw pipeline when SR active if sr_filter: use_hw_pipeline = False # Fall back gracefully if VAAPI is needed but no device was detected needs_vaapi = enc_type == "vaapi" or fallback == "vaapi" if needs_vaapi and not VAAPI_DEVICE: if enc_type == "vaapi": # Pure VAAPI encoder requested but not available - fall back to software log.warning("VAAPI unavailable (no Intel/AMD GPU), falling back to software encoding") enc_type = "software" else: # VAAPI fallback requested but not available - use software decode instead log.warning("VAAPI fallback unavailable (no Intel/AMD GPU), using software decode") fallback = "software" # Height expr for scale filter (scale down only, -2 keeps width divisible by 2) h = f"min(ih\\,{max_h})" if max_h else None qp = _QUALITY_QP.get(quality, 28) if enc_type == "nvenc": if use_hw_pipeline: # CUDA decode path pre = [ "-hwaccel", "cuda", "-hwaccel_output_format", "cuda", "-extra_hw_frames", "3", ] scale = f"scale_cuda=-2:{h}:format=nv12" if h else "scale_cuda=format=nv12" deint = "yadif_cuda=0," if deinterlace else "" # mode=0 keeps original framerate # HDR tone mapping: prefer libplacebo (Vulkan GPU), fall back to CPU zscale+tonemap if is_hdr: if _has_libplacebo_filter(): tonemap = "hwdownload,format=p010le,libplacebo=tonemapping=hable:colorspace=bt709:color_primaries=bt709:color_trc=bt709,format=nv12,hwupload_cuda," else: tonemap = "hwdownload,format=p010le,zscale=t=linear:npl=100,format=gbrpf32le,zscale=p=bt709,tonemap=hable:desat=0,zscale=t=bt709:m=bt709:r=tv,format=nv12,hwupload_cuda," else: tonemap = "" vf = f"{deint}{tonemap}{scale}" elif fallback == "vaapi": # VAAPI decode + VAAPI filters + hwdownload + hwupload_cuda for NVENC pre = [ "-hwaccel", "vaapi", "-hwaccel_output_format", "vaapi", "-hwaccel_device", VAAPI_DEVICE, ] scale = f"scale_vaapi=w=-2:h={h}:format=nv12" if h else "scale_vaapi=format=nv12" tonemap = "tonemap_vaapi=format=nv12:t=bt709:m=bt709:p=bt709," if is_hdr else "" deint = "deinterlace_vaapi," if deinterlace else "" vf = f"{deint}{tonemap}{scale},hwdownload,format=nv12,hwupload_cuda" else: # Software decode, upload to GPU for scaling/encoding pre = [] scale = f"scale_cuda=-2:{h}:format=nv12" if h else "scale_cuda=format=nv12" # HDR tone mapping: prefer libplacebo (Vulkan GPU), fall back to CPU zscale+tonemap # Deinterlace before tonemap (CPU yadif) for consistency with hw decode path if sr_filter: # SR path: CPU decode -> deinterlace -> SR (GPU) -> encode # SR filter ends with scale_cuda, outputs CUDA frames ready for nvenc # Need init_hw_device for TensorRT dnn_processing to use GPU pre = ["-init_hw_device", "cuda=cu", "-filter_hw_device", "cu"] deint = "yadif=0," if deinterlace else "" vf = f"{deint}{sr_filter}" elif is_hdr: deint = "yadif=0," if deinterlace else "" # CPU deinterlace before tonemap if _has_libplacebo_filter(): tonemap = "libplacebo=tonemapping=hable:colorspace=bt709:color_primaries=bt709:color_trc=bt709,format=nv12,hwupload_cuda," else: tonemap = "zscale=t=linear:npl=100,format=gbrpf32le,zscale=p=bt709,tonemap=hable:desat=0,zscale=t=bt709:m=bt709:r=tv,format=nv12,hwupload_cuda," vf = f"{deint}{tonemap}{scale}" else: deint = "yadif_cuda=0," if deinterlace else "" # GPU deinterlace after upload tonemap = "format=nv12,hwupload_cuda," vf = f"{tonemap}{deint}{scale}" preset = "p4" if deinterlace or sr_filter else "p2" encoder = "h264_nvenc" # Lookahead for better quality, B-frames for compression, AQ for adaptive quantization enc_opts = [ "-preset", preset, "-rc", "constqp", "-qp", str(qp), "-rc-lookahead", "32", "-bf", "3", "-spatial-aq", "1", "-temporal-aq", "1", ] elif enc_type == "amf": # AMF has no hardware decode - always uses fallback for decode/filter if fallback == "vaapi": # VAAPI decode + VAAPI filters + hwdownload for AMF encode pre = [ "-hwaccel", "vaapi", "-hwaccel_output_format", "vaapi", "-hwaccel_device", VAAPI_DEVICE, ] scale = f"scale_vaapi=w=-2:h={h}:format=nv12" if h else "scale_vaapi=format=nv12" tonemap = "tonemap_vaapi=format=nv12:t=bt709:m=bt709:p=bt709," if is_hdr else "" deint = "deinterlace_vaapi," if deinterlace else "" vf = f"{deint}{tonemap}{scale},hwdownload,format=nv12" else: # Software decode + software filters pre = [] if is_hdr: tonemap = "zscale=t=linear:npl=100,format=gbrpf32le,zscale=p=bt709,tonemap=hable:desat=0,zscale=t=bt709:m=bt709:r=tv,format=nv12," else: tonemap = "" deint = "yadif=0," if deinterlace else "" scale = f"scale=-2:{h}" if h else "" vf = f"{deint}{tonemap}{scale},format=nv12".strip(",").replace(",,", ",") encoder = "h264_amf" enc_opts = [ "-rc", "cqp", "-qp_i", str(qp), "-qp_p", str(qp), "-quality", "balanced", ] elif enc_type == "vaapi": if use_hw_pipeline: pre = [ "-hwaccel", "vaapi", "-hwaccel_output_format", "vaapi", "-hwaccel_device", VAAPI_DEVICE, "-extra_hw_frames", "3", ] scale = f"scale_vaapi=w=-2:h={h}:format=nv12" if h else "scale_vaapi=format=nv12" # HDR tone mapping on VAAPI tonemap = "tonemap_vaapi=format=nv12:t=bt709:m=bt709:p=bt709," if is_hdr else "" vf = f"deinterlace_vaapi,{tonemap}{scale}" if deinterlace else f"{tonemap}{scale}" else: # Software decode, upload to GPU for scaling/encoding pre = ["-vaapi_device", VAAPI_DEVICE] scale = f"scale_vaapi=w=-2:h={h}:format=nv12" if h else "scale_vaapi=format=nv12" tonemap = "tonemap_vaapi=format=nv12:t=bt709:m=bt709:p=bt709," if is_hdr else "" deint = "deinterlace_vaapi," if deinterlace else "" vf = f"format=nv12,hwupload,{deint}{tonemap}{scale}" encoder = "h264_vaapi" # Use baseline profile for older GPUs (e.g., AMD GCN 1.0) that don't support High profile if AVAILABLE_ENCODERS.get("vaapi_baseline_only"): # Baseline doesn't support B-frames enc_opts = ["-rc_mode", "CQP", "-qp", str(qp), "-profile:v", "constrained_baseline"] else: enc_opts = ["-rc_mode", "CQP", "-qp", str(qp), "-bf", "3"] elif enc_type == "qsv": if use_hw_pipeline: pre = ["-hwaccel", "qsv", "-hwaccel_output_format", "qsv"] scale = f"scale_qsv=w=-2:h={h}:format=nv12" if h else "scale_qsv=format=nv12" # Combine deinterlace and tonemap into single vpp_qsv call when possible if deinterlace and is_hdr: vf = f"vpp_qsv=deinterlace=2:tonemap=1:format=nv12,{scale}" elif deinterlace: vf = f"vpp_qsv=deinterlace=2,{scale}" elif is_hdr: vf = f"vpp_qsv=tonemap=1:format=nv12,{scale}" else: vf = scale else: # Software decode, upload to GPU for scaling/encoding pre = ["-init_hw_device", "qsv=hw", "-filter_hw_device", "hw"] scale = f"scale_qsv=w=-2:h={h}:format=nv12" if h else "scale_qsv=format=nv12" # Combine deinterlace and tonemap into single vpp_qsv call when possible if deinterlace and is_hdr: vf = f"format=nv12,hwupload=extra_hw_frames=64,vpp_qsv=deinterlace=2:tonemap=1:format=nv12,{scale}" elif deinterlace: vf = f"format=nv12,hwupload=extra_hw_frames=64,vpp_qsv=deinterlace=2,{scale}" elif is_hdr: vf = ( f"format=nv12,hwupload=extra_hw_frames=64,vpp_qsv=tonemap=1:format=nv12,{scale}" ) else: vf = f"format=nv12,hwupload=extra_hw_frames=64,{scale}" encoder = "h264_qsv" enc_opts = [ "-global_quality", str(qp), "-bf", "3", "-look_ahead", "1", "-look_ahead_depth", "40", ] elif enc_type == "software": pre = [] # HDR tone mapping on CPU if is_hdr: tonemap = "zscale=t=linear:npl=100,format=gbrpf32le,zscale=p=bt709,tonemap=hable:desat=0,zscale=t=bt709:m=bt709:r=tv,format=yuv420p," else: tonemap = "" deint = "yadif=0," if deinterlace else "" # mode=0 keeps original framerate if h: vf = f"{deint}{tonemap}scale=-2:{h},format=yuv420p" else: vf = f"{deint}{tonemap}format=yuv420p".rstrip(",") crf = _QUALITY_CRF.get(quality, 26) encoder = "libx264" enc_opts = ["-preset", "veryfast", "-crf", str(crf), "-bf", "3"] else: raise ValueError(f"Unrecognized hardware encoder: '{enc_type}'.") post = ["-vf", vf, "-c:v", encoder, *enc_opts, "-g", "60"] return pre, post def _build_audio_args(*, copy_audio: bool, audio_sample_rate: int) -> list[str]: """Build audio args.""" if copy_audio: return ["-c:a", "copy"] rate = str(audio_sample_rate) if audio_sample_rate in (44100, 48000) else "48000" return ["-c:a", "aac", "-ac", "2", "-ar", rate, "-b:a", "192k", "-profile:a", "aac_low"] def get_live_hls_list_size() -> int: """Get hls_list_size for live streams based on DVR setting.""" dvr_mins = _load_settings().get("live_dvr_mins", 0) if dvr_mins <= 0: # Default buffer when DVR disabled return int(DEFAULT_LIVE_BUFFER_SECS / _HLS_SEGMENT_DURATION_SEC) # DVR enabled: calculate segments from minutes return int(dvr_mins * 60 / _HLS_SEGMENT_DURATION_SEC) def build_hls_ffmpeg_cmd( input_url: str, hw: HwAccel, output_dir: str, is_vod: bool = False, subtitles: list[SubtitleStream] | None = None, media_info: MediaInfo | None = None, max_resolution: str = "1080p", quality: str = "high", user_agent: str | None = None, deinterlace_fallback: bool | None = None, ) -> list[str]: """Build ffmpeg command for HLS transcoding.""" # Check if we can copy streams directly (compatible codecs, no processing needed) max_h = _MAX_RES_HEIGHT.get(max_resolution, 9999) needs_scale = media_info and media_info.height > max_h # SR requires re-encode (can't copy video when SR is active) sr_active = bool(_sr_engine_dir and _load_settings().get("sr_model", "")) copy_video = bool( media_info and media_info.video_codec == "h264" and media_info.pix_fmt == "yuv420p" and not needs_scale and not sr_active and not media_info.interlaced # Can't copy if deinterlacing needed ) copy_audio = bool( media_info and media_info.audio_codec == "aac" and media_info.audio_channels <= 2 and media_info.audio_sample_rate in (44100, 48000) # HE-AAC has browser compatibility issues - only copy LC-AAC and "HE" not in media_info.audio_profile ) # Full hardware pipeline if GPU supports the codec # Parse hw to get encoder type enc_type, _ = _parse_hw(hw) codec = media_info.video_codec if media_info else "" use_hw_pipeline = bool( not copy_video and media_info and ( (enc_type == "nvenc" and codec in _get_gpu_nvdec_codecs()) or (enc_type == "vaapi" and codec in _VAAPI_SAFE_CODECS) or (enc_type == "qsv" and codec in _QSV_SAFE_CODECS) # AMF never has hw decode pipeline - always False ) ) # Deinterlace: use probe result if available, else use fallback setting # (fallback defaults to True for live, False for VOD when not explicitly set) fallback = deinterlace_fallback if deinterlace_fallback is not None else (not is_vod) # If probe failed (height=0), don't trust interlaced flag - use fallback probe_valid = media_info is not None and media_info.height > 0 deinterlace = media_info.interlaced if probe_valid and media_info else fallback # Build component arg lists video_pre, video_post = _build_video_args( copy_video=copy_video, hw=hw, deinterlace=deinterlace, use_hw_pipeline=use_hw_pipeline, max_resolution=max_resolution, quality=quality, is_hdr=media_info.is_hdr if media_info else False, source_height=media_info.height if media_info else 0, ) audio_args = _build_audio_args( copy_audio=copy_audio, audio_sample_rate=media_info.audio_sample_rate if media_info else 0, ) # Base args cmd = [ "ffmpeg", "-hide_banner", "-loglevel", "error", "-noautorotate", ] # Hwaccel args (before -i) cmd.extend(video_pre) # Probe args (only when no media_info, since we already probed) if media_info is None: probe_size = "50000" if is_vod else "5000000" analyze_dur = "500000" if is_vod else "5000000" cmd.extend(["-probesize", probe_size, "-analyzeduration", analyze_dur]) # Input args cmd.extend( [ "-fflags", "+discardcorrupt+genpts", "-err_detect", "ignore_err", "-reconnect", "1", "-reconnect_streamed", "1", "-reconnect_on_network_error", "1", "-reconnect_on_http_error", "4xx,5xx", "-reconnect_delay_max", "30", ] ) if user_agent: cmd.extend(["-user_agent", user_agent]) # Use HLS demuxer options if probe detected HLS format if media_info and media_info.is_hls: cmd.extend(["-f", "hls", "-extension_picky", "0"]) cmd.extend(["-i", input_url]) # Subtitle extraction for i, sub in enumerate(subtitles or []): cmd.extend( [ "-map", f"0:{sub.index}", "-c:s", "webvtt", "-flush_packets", "1", f"{output_dir}/sub{i}.vtt", ] ) # Stream mapping + video + audio cmd.extend(["-map", "0:v:0", "-map", "0:a:0"]) cmd.extend(video_post) cmd.extend(audio_args) # HLS output args cmd.extend( [ "-max_delay", "5000000", "-f", "hls", "-hls_time", str(int(_HLS_SEGMENT_DURATION_SEC)), "-hls_list_size", "0" if is_vod else str(get_live_hls_list_size()), "-hls_segment_filename", f"{output_dir}/{SEG_PREFIX}%03d.ts", ] ) if is_vod: cmd.extend( [ "-hls_init_time", "2", "-hls_flags", "independent_segments", "-hls_playlist_type", "event", ] ) else: cmd.extend(["-hls_flags", "delete_segments"]) cmd.append(f"{output_dir}/stream.m3u8") return cmd ================================================ FILE: ffmpeg_command_test.py ================================================ """Tests for ffmpeg command generation and media probing.""" from pathlib import Path from unittest.mock import MagicMock, patch import json import tempfile import pytest from ffmpeg_command import ( _MAX_RES_HEIGHT, HwAccel, MediaInfo, SubtitleStream, _build_audio_args, _build_video_args, _get_gpu_nvdec_codecs, build_hls_ffmpeg_cmd, clear_all_probe_cache, clear_series_mru, get_live_hls_list_size, get_series_probe_cache_stats, get_transcode_dir, get_user_agent, invalidate_series_probe_cache, probe_media, restore_probe_cache_entry, ) @pytest.fixture(autouse=True) def mock_vaapi_device(): """Mock VAAPI_DEVICE for all tests to allow VAAPI tests on CI without hardware.""" with patch("ffmpeg_command.VAAPI_DEVICE", "/dev/dri/renderD128"): yield class FakeMediaInfo: """Fake media info for testing.""" def __init__( self, video_codec: str = "h264", audio_codec: str = "aac", pix_fmt: str = "yuv420p", audio_channels: int = 2, audio_sample_rate: int = 48000, audio_profile: str = "LC", height: int = 1080, interlaced: bool = False, is_10bit: bool = False, is_hdr: bool = False, is_hls: bool = False, ): self.video_codec = video_codec self.audio_codec = audio_codec self.pix_fmt = pix_fmt self.audio_channels = audio_channels self.audio_sample_rate = audio_sample_rate self.audio_profile = audio_profile self.height = height self.interlaced = interlaced self.is_10bit = is_10bit self.is_hdr = is_hdr self.is_hls = is_hls # ============================================================================= # Video Args Tests # ============================================================================= class TestBuildVideoArgs: """Tests for _build_video_args.""" @pytest.mark.parametrize( "hw", ["nvenc+vaapi", "nvenc+software", "amf+vaapi", "amf+software", "qsv", "vaapi", "software"], ) @pytest.mark.parametrize("deinterlace", [True, False]) @pytest.mark.parametrize("max_resolution", ["1080p", "720p", "4k"]) def test_all_hw_combinations(self, hw: HwAccel, deinterlace: bool, max_resolution: str): """Test all hardware/deinterlace/resolution combinations produce valid args.""" pre, post = _build_video_args( copy_video=False, hw=hw, deinterlace=deinterlace, use_hw_pipeline=(hw not in ("software", "nvenc+software", "amf+software")), max_resolution=max_resolution, quality="high", ) if hw in ("nvenc+vaapi", "nvenc+software") or hw in ("amf+vaapi", "amf+software"): assert pre == [] or "-hwaccel" in pre elif hw == "qsv": assert "-hwaccel" in pre assert "qsv" in pre elif hw == "vaapi": assert "-hwaccel" in pre assert "vaapi" in pre else: assert pre == [] assert "-vf" in post assert "-c:v" in post assert "-g" in post assert "60" in post @pytest.mark.parametrize( "hw", ["nvenc+vaapi", "nvenc+software", "amf+vaapi", "amf+software", "qsv", "vaapi", "software"], ) def test_copy_video(self, hw: HwAccel): """Test copy_video returns minimal args.""" pre, post = _build_video_args( copy_video=True, hw=hw, deinterlace=False, use_hw_pipeline=False, max_resolution="1080p", quality="high", ) assert pre == [] assert post == ["-c:v", "copy"] def test_nvenc_hw_pipeline_filters(self): """Test NVENC with hw pipeline uses CUDA filters.""" pre, post = _build_video_args( copy_video=False, hw="nvenc+software", deinterlace=True, use_hw_pipeline=True, max_resolution="1080p", quality="high", ) assert "-hwaccel" in pre vf = post[post.index("-vf") + 1] assert "yadif_cuda" in vf assert "scale_cuda" in vf def test_nvenc_sw_fallback_filters(self): """Test NVENC without hw pipeline uses SW decode + GPU processing.""" pre, post = _build_video_args( copy_video=False, hw="nvenc+software", deinterlace=True, use_hw_pipeline=False, max_resolution="1080p", quality="high", ) assert pre == [] vf = post[post.index("-vf") + 1] # Upload to GPU, then deinterlace (mode=0 for original framerate) and scale on GPU assert "hwupload_cuda" in vf assert "yadif_cuda=0" in vf assert "scale_cuda" in vf def test_vaapi_filters(self): """Test VAAPI uses VAAPI filters.""" pre, post = _build_video_args( copy_video=False, hw="vaapi", deinterlace=True, use_hw_pipeline=True, max_resolution="1080p", quality="high", ) vf = post[post.index("-vf") + 1] assert "deinterlace_vaapi" in vf assert "scale_vaapi" in vf def test_qsv_filters(self): """Test QSV uses QSV filters.""" pre, post = _build_video_args( copy_video=False, hw="qsv", deinterlace=True, use_hw_pipeline=True, max_resolution="1080p", quality="high", ) vf = post[post.index("-vf") + 1] assert "vpp_qsv" in vf assert "scale_qsv" in vf def test_software_filters(self): """Test software uses yadif (mode=0 for original framerate) and scale.""" pre, post = _build_video_args( copy_video=False, hw="software", deinterlace=True, use_hw_pipeline=False, max_resolution="1080p", quality="high", ) assert pre == [] vf = post[post.index("-vf") + 1] assert "yadif=0" in vf @pytest.mark.parametrize( "quality,expected_qp", [("high", "20"), ("medium", "28"), ("low", "35")] ) def test_quality_presets(self, quality: str, expected_qp: str): """Test quality presets map to correct QP values.""" _, post = _build_video_args( copy_video=False, hw="vaapi", deinterlace=False, use_hw_pipeline=True, max_resolution="1080p", quality=quality, ) assert expected_qp in post def test_invalid_hw_raises(self): """Test invalid hardware raises ValueError.""" with pytest.raises(ValueError, match="Unrecognized hardware"): _build_video_args( copy_video=False, hw="invalid", # type: ignore deinterlace=False, use_hw_pipeline=False, max_resolution="1080p", quality="high", ) @patch("ffmpeg_command._has_libplacebo_filter", return_value=True) def test_nvenc_hdr_with_libplacebo(self, mock_placebo): """Test NVENC HDR uses libplacebo when available.""" _, post = _build_video_args( copy_video=False, hw="nvenc+software", deinterlace=False, use_hw_pipeline=True, max_resolution="1080p", quality="high", is_hdr=True, ) vf = post[post.index("-vf") + 1] assert "libplacebo" in vf assert "tonemapping=hable" in vf # Should download from CUDA, process, re-upload assert "hwdownload" in vf assert "hwupload_cuda" in vf @patch("ffmpeg_command._has_libplacebo_filter", return_value=False) def test_nvenc_hdr_zscale_fallback(self, mock_placebo): """Test NVENC HDR falls back to zscale when libplacebo unavailable.""" _, post = _build_video_args( copy_video=False, hw="nvenc+software", deinterlace=False, use_hw_pipeline=True, max_resolution="1080p", quality="high", is_hdr=True, ) vf = post[post.index("-vf") + 1] assert "zscale" in vf assert "tonemap=hable" in vf assert "libplacebo" not in vf @patch("ffmpeg_command._has_libplacebo_filter", return_value=True) def test_nvenc_hdr_deinterlace_order(self, mock_placebo): """Test NVENC HDR hw decode deinterlaces BEFORE tonemap.""" _, post = _build_video_args( copy_video=False, hw="nvenc+software", deinterlace=True, use_hw_pipeline=True, max_resolution="1080p", quality="high", is_hdr=True, ) vf = post[post.index("-vf") + 1] # Deinterlace should come before tonemap in hw decode path deint_pos = vf.find("yadif_cuda") tonemap_pos = vf.find("libplacebo") assert deint_pos < tonemap_pos, f"deinterlace should come before tonemap: {vf}" @patch("ffmpeg_command._has_libplacebo_filter", return_value=True) def test_nvenc_sw_hdr_deinterlace_order(self, mock_placebo): """Test NVENC HDR sw decode uses CPU deinterlace before tonemap.""" _, post = _build_video_args( copy_video=False, hw="nvenc+software", deinterlace=True, use_hw_pipeline=False, max_resolution="1080p", quality="high", is_hdr=True, ) vf = post[post.index("-vf") + 1] # SW decode HDR should use CPU yadif before tonemap assert "yadif=0" in vf # CPU deinterlace, not yadif_cuda deint_pos = vf.find("yadif=0") tonemap_pos = vf.find("libplacebo") assert deint_pos < tonemap_pos, f"CPU deinterlace should come before tonemap: {vf}" def test_vaapi_hdr_tonemap(self): """Test VAAPI HDR uses tonemap_vaapi filter.""" _, post = _build_video_args( copy_video=False, hw="vaapi", deinterlace=False, use_hw_pipeline=True, max_resolution="1080p", quality="high", is_hdr=True, ) vf = post[post.index("-vf") + 1] assert "tonemap_vaapi" in vf # ============================================================================= # Audio Args Tests # ============================================================================= class TestBuildAudioArgs: """Tests for _build_audio_args.""" def test_copy_audio(self): """Test copy_audio returns copy args.""" args = _build_audio_args(copy_audio=True, audio_sample_rate=48000) assert args == ["-c:a", "copy"] @pytest.mark.parametrize( "sample_rate,expected", [ (44100, "44100"), (48000, "48000"), (96000, "48000"), (0, "48000"), ], ) def test_sample_rates(self, sample_rate: int, expected: str): """Test sample rate handling.""" args = _build_audio_args(copy_audio=False, audio_sample_rate=sample_rate) assert "-ar" in args assert expected in args # ============================================================================= # HLS Command Tests # ============================================================================= class TestBuildHlsFfmpegCmd: """Tests for build_hls_ffmpeg_cmd.""" @pytest.mark.parametrize( "hw", ["nvenc+vaapi", "nvenc+software", "amf+vaapi", "amf+software", "qsv", "vaapi", "software"], ) @pytest.mark.parametrize("is_vod", [True, False]) def test_command_structure(self, hw: HwAccel, is_vod: bool): """Test command has correct structure for all hw/vod combinations.""" cmd = build_hls_ffmpeg_cmd( "http://test/stream", hw, "/tmp/output", is_vod=is_vod, ) assert cmd[0] == "ffmpeg" assert "-i" in cmd assert "-map" in cmd assert "-c:v" in cmd assert "-c:a" in cmd assert "-f" in cmd assert "hls" in cmd i_idx = cmd.index("-i") if "-hwaccel" in cmd: hwaccel_idx = cmd.index("-hwaccel") assert hwaccel_idx < i_idx, "hwaccel must come before -i" if "-vf" in cmd: vf_idx = cmd.index("-vf") assert vf_idx > i_idx, "-vf must come after -i" def test_vod_hls_flags(self): """Test VOD has correct HLS flags.""" cmd = build_hls_ffmpeg_cmd("http://test", "software", "/tmp", is_vod=True) assert "-hls_playlist_type" in cmd assert "event" in cmd assert "-hls_list_size" in cmd assert cmd[cmd.index("-hls_list_size") + 1] == "0" def test_live_hls_flags(self): """Test live has correct HLS flags.""" cmd = build_hls_ffmpeg_cmd("http://test", "software", "/tmp", is_vod=False) assert "delete_segments" in cmd assert "-hls_list_size" in cmd assert cmd[cmd.index("-hls_list_size") + 1] == "10" def test_copy_video_with_compatible_media(self): """Test copy_video is used for compatible VOD media.""" media = FakeMediaInfo(video_codec="h264", pix_fmt="yuv420p", height=1080) cmd = build_hls_ffmpeg_cmd( "http://test", "vaapi", "/tmp", is_vod=True, media_info=media, # type: ignore max_resolution="1080p", ) assert "-c:v" in cmd assert cmd[cmd.index("-c:v") + 1] == "copy" assert "-hwaccel" not in cmd def test_no_copy_for_10bit(self): """Test 10-bit content is transcoded, not copied.""" media = FakeMediaInfo(video_codec="h264", pix_fmt="yuv420p10le", height=1080) cmd = build_hls_ffmpeg_cmd( "http://test", "vaapi", "/tmp", is_vod=True, media_info=media, # type: ignore ) assert cmd[cmd.index("-c:v") + 1] != "copy" assert "-hwaccel" in cmd def test_no_copy_when_scaling_needed(self): """Test scaling requirement prevents copy.""" media = FakeMediaInfo(video_codec="h264", pix_fmt="yuv420p", height=2160) cmd = build_hls_ffmpeg_cmd( "http://test", "vaapi", "/tmp", is_vod=True, media_info=media, # type: ignore max_resolution="1080p", ) assert cmd[cmd.index("-c:v") + 1] != "copy" def test_user_agent(self): """Test user agent is included when provided.""" cmd = build_hls_ffmpeg_cmd( "http://test", "software", "/tmp", user_agent="TestAgent/1.0", ) assert "-user_agent" in cmd assert "TestAgent/1.0" in cmd def test_probe_args_without_media_info(self): """Test probe args are added when no media_info.""" cmd = build_hls_ffmpeg_cmd("http://test", "software", "/tmp", media_info=None) assert "-probesize" in cmd assert "-analyzeduration" in cmd def test_no_probe_args_with_media_info(self): """Test probe args are skipped when media_info provided.""" media = FakeMediaInfo() cmd = build_hls_ffmpeg_cmd("http://test", "software", "/tmp", media_info=media) # type: ignore assert "-probesize" not in cmd def test_subtitle_extraction(self): """Test subtitle streams are extracted.""" subs = [ SubtitleStream(index=2, lang="eng", name="English"), SubtitleStream(index=3, lang="spa", name="Spanish"), ] cmd = build_hls_ffmpeg_cmd("http://test", "software", "/tmp/out", subtitles=subs) assert "-map" in cmd assert "0:2" in cmd assert "0:3" in cmd assert "/tmp/out/sub0.vtt" in cmd assert "/tmp/out/sub1.vtt" in cmd # ============================================================================= # Aspect Ratio Tests # ============================================================================= class TestAspectRatioHandling: """Tests for various aspect ratio content.""" @pytest.mark.parametrize( "input_height,max_res,should_scale", [ (1080, "1080p", False), (1080, "720p", True), (720, "1080p", False), (2160, "1080p", True), (1600, "1080p", True), (1600, "4k", False), ], ) def test_scaling_decisions(self, input_height: int, max_res: str, should_scale: bool): """Test correct scaling decisions for various input heights.""" media = FakeMediaInfo(height=input_height, pix_fmt="yuv420p10le") cmd = build_hls_ffmpeg_cmd( "http://test", "vaapi", "/tmp", is_vod=True, media_info=media, # type: ignore max_resolution=max_res, ) vf = cmd[cmd.index("-vf") + 1] max_h = _MAX_RES_HEIGHT.get(max_res, 9999) # Comma is escaped in FFmpeg filter expressions height_expr = f"min(ih\\,{max_h})" assert height_expr in vf, f"Expected {height_expr} in {vf}" # ============================================================================= # GPU Detection Tests # ============================================================================= class TestGpuDetection: """Tests for GPU/NVDEC detection.""" def test_nvidia_gpu_detected(self): """Test NVIDIA GPU detection parses compute capability.""" import ffmpeg_command ffmpeg_command._gpu_nvdec_codecs = None # Reset cache mock_result = MagicMock() mock_result.returncode = 0 mock_result.stdout = "NVIDIA GeForce RTX 3080, 8.6\n" with patch("subprocess.run", return_value=mock_result): codecs = _get_gpu_nvdec_codecs() assert "h264" in codecs assert "hevc" in codecs assert "av1" in codecs def test_no_nvidia_gpu(self): """Test handling when no NVIDIA GPU present.""" import ffmpeg_command ffmpeg_command._gpu_nvdec_codecs = None mock_result = MagicMock() mock_result.returncode = 1 with patch("subprocess.run", return_value=mock_result): codecs = _get_gpu_nvdec_codecs() assert codecs == set() def test_older_nvidia_gpu(self): """Test older GPU with limited NVDEC support.""" import ffmpeg_command ffmpeg_command._gpu_nvdec_codecs = None mock_result = MagicMock() mock_result.returncode = 0 mock_result.stdout = "NVIDIA GeForce GTX 960, 5.2\n" with patch("subprocess.run", return_value=mock_result): codecs = _get_gpu_nvdec_codecs() assert "h264" in codecs assert "hevc" not in codecs assert "av1" not in codecs # ============================================================================= # User Agent Tests # ============================================================================= class TestUserAgent: """Tests for user agent handling.""" def test_default_user_agent(self): """Test default preset returns None.""" with patch("ffmpeg_command._load_settings", return_value={"user_agent_preset": "default"}): assert get_user_agent() is None def test_vlc_user_agent(self): """Test VLC preset.""" with patch("ffmpeg_command._load_settings", return_value={"user_agent_preset": "vlc"}): ua = get_user_agent() assert ua is not None assert "VLC" in ua def test_chrome_user_agent(self): """Test Chrome preset.""" with patch("ffmpeg_command._load_settings", return_value={"user_agent_preset": "chrome"}): ua = get_user_agent() assert ua is not None assert "Chrome" in ua def test_custom_user_agent(self): """Test custom user agent.""" with patch( "ffmpeg_command._load_settings", return_value={"user_agent_preset": "custom", "user_agent_custom": "MyAgent/1.0"}, ): assert get_user_agent() == "MyAgent/1.0" def test_custom_empty_returns_none(self): """Test empty custom user agent returns None.""" with patch( "ffmpeg_command._load_settings", return_value={"user_agent_preset": "custom", "user_agent_custom": ""}, ): assert get_user_agent() is None # ============================================================================= # Transcode Directory Tests # ============================================================================= class TestTranscodeDir: """Tests for transcode directory handling.""" def test_default_transcode_dir(self): """Test default uses system temp.""" with patch("ffmpeg_command._load_settings", return_value={}): path = get_transcode_dir() assert path == Path(tempfile.gettempdir()) def test_custom_transcode_dir(self, tmp_path): """Test custom directory is used and created.""" custom_dir = tmp_path / "custom_transcode" with patch( "ffmpeg_command._load_settings", return_value={"transcode_dir": str(custom_dir)} ): path = get_transcode_dir() assert path == custom_dir assert custom_dir.exists() # ============================================================================= # HLS List Size Tests # ============================================================================= class TestHlsListSize: """Tests for HLS list size calculation.""" def test_default_list_size(self): """Test default (DVR disabled) uses 10 segments.""" with patch("ffmpeg_command._load_settings", return_value={}): assert get_live_hls_list_size() == 10 def test_dvr_enabled_list_size(self): """Test DVR enabled calculates segments from minutes.""" with patch("ffmpeg_command._load_settings", return_value={"live_dvr_mins": 5}): # 5 min = 300 sec / 3 sec per segment = 100 segments assert get_live_hls_list_size() == 100 # ============================================================================= # Probe Media Tests # ============================================================================= class TestProbeMedia: """Tests for media probing.""" def test_probe_success(self): """Test successful probe parses media info.""" import ffmpeg_command ffmpeg_command._probe_cache.clear() probe_output = { "streams": [ { "codec_type": "video", "codec_name": "h264", "pix_fmt": "yuv420p", "height": 1080, "field_order": "progressive", }, { "codec_type": "audio", "codec_name": "aac", "channels": 2, "sample_rate": "48000", }, ], "format": {"duration": "3600.0"}, } mock_result = MagicMock() mock_result.returncode = 0 mock_result.stdout = json.dumps(probe_output) with ( patch("subprocess.run", return_value=mock_result), patch("ffmpeg_command._load_settings", return_value={}), ): media_info, subs = probe_media("http://test/video.mp4") assert media_info is not None assert media_info.video_codec == "h264" assert media_info.audio_codec == "aac" assert media_info.height == 1080 assert media_info.duration == 3600.0 assert not media_info.interlaced def test_probe_interlaced_detection(self): """Test interlaced content detection.""" import ffmpeg_command ffmpeg_command._probe_cache.clear() probe_output = { "streams": [ { "codec_type": "video", "codec_name": "mpeg2video", "pix_fmt": "yuv420p", "height": 1080, "field_order": "tt", # Top field first = interlaced }, ], "format": {}, } mock_result = MagicMock() mock_result.returncode = 0 mock_result.stdout = json.dumps(probe_output) with ( patch("subprocess.run", return_value=mock_result), patch("ffmpeg_command._load_settings", return_value={}), ): media_info, _ = probe_media("http://test/interlaced.ts") assert media_info is not None assert media_info.interlaced is True @pytest.mark.parametrize( "pix_fmt,expected", [ ("yuv420p10le", True), # 10-bit little endian ("yuv420p10be", True), # 10-bit big endian ("yuv422p10le", True), # 10-bit 4:2:2 ("p010le", True), # CUDA/VAAPI 10-bit format ("yuv420p", False), # 8-bit ("yuv410p", False), # 4:1:0 chroma, NOT 10-bit (was a false positive) ("nv12", False), # 8-bit NV12 ], ) def test_probe_10bit_detection(self, pix_fmt: str, expected: bool): """Test 10-bit content detection from pix_fmt.""" import ffmpeg_command ffmpeg_command._probe_cache.clear() probe_output = { "streams": [ { "codec_type": "video", "codec_name": "hevc", "pix_fmt": pix_fmt, "height": 2160, }, ], "format": {}, } mock_result = MagicMock() mock_result.returncode = 0 mock_result.stdout = json.dumps(probe_output) with ( patch("subprocess.run", return_value=mock_result), patch("ffmpeg_command._load_settings", return_value={}), ): media_info, _ = probe_media(f"http://test/{pix_fmt}.mkv") assert media_info is not None assert media_info.is_10bit is expected, f"pix_fmt={pix_fmt} should be is_10bit={expected}" @pytest.mark.parametrize( "color_transfer,expected", [ ("smpte2084", True), # PQ (HDR10, HDR10+, Dolby Vision) ("arib-std-b67", True), # HLG ("bt709", False), # SDR ("bt2020-10", False), # Wide gamut but not HDR transfer ("", False), # Unknown/missing ], ) def test_probe_hdr_detection(self, color_transfer: str, expected: bool): """Test HDR content detection from color_transfer.""" import ffmpeg_command ffmpeg_command._probe_cache.clear() probe_output = { "streams": [ { "codec_type": "video", "codec_name": "hevc", "pix_fmt": "yuv420p10le", "height": 2160, "color_transfer": color_transfer, }, ], "format": {}, } mock_result = MagicMock() mock_result.returncode = 0 mock_result.stdout = json.dumps(probe_output) with ( patch("subprocess.run", return_value=mock_result), patch("ffmpeg_command._load_settings", return_value={}), ): media_info, _ = probe_media(f"http://test/{color_transfer or 'unknown'}.mkv") assert media_info is not None assert media_info.is_hdr is expected, ( f"color_transfer={color_transfer} should be is_hdr={expected}" ) def test_probe_failure(self): """Test probe failure returns None.""" import ffmpeg_command ffmpeg_command._probe_cache.clear() mock_result = MagicMock() mock_result.returncode = 1 with ( patch("subprocess.run", return_value=mock_result), patch("ffmpeg_command._load_settings", return_value={}), ): media_info, subs = probe_media("http://test/bad.mp4") assert media_info is None assert subs == [] def test_probe_cache_hit(self): """Test probe cache returns cached result.""" import time import ffmpeg_command ffmpeg_command._probe_cache.clear() cached_info = MediaInfo( video_codec="h264", audio_codec="aac", pix_fmt="yuv420p", ) ffmpeg_command._probe_cache["http://cached"] = (time.time(), cached_info, []) with ( patch("subprocess.run") as mock_run, patch("ffmpeg_command._load_settings", return_value={}), ): media_info, _ = probe_media("http://cached") mock_run.assert_not_called() assert media_info == cached_info def test_probe_extracts_subtitles(self): """Test subtitle stream extraction.""" import ffmpeg_command ffmpeg_command._probe_cache.clear() probe_output = { "streams": [ {"codec_type": "video", "codec_name": "h264", "pix_fmt": "yuv420p"}, { "codec_type": "subtitle", "codec_name": "subrip", "index": 2, "tags": {"language": "eng", "title": "English"}, }, { "codec_type": "subtitle", "codec_name": "ass", "index": 3, "tags": {"language": "jpn"}, }, ], "format": {}, } mock_result = MagicMock() mock_result.returncode = 0 mock_result.stdout = json.dumps(probe_output) with ( patch("subprocess.run", return_value=mock_result), patch("ffmpeg_command._load_settings", return_value={}), ): _, subs = probe_media("http://test/subs.mkv") assert len(subs) == 2 assert subs[0].index == 2 assert subs[0].lang == "eng" assert subs[0].name == "English" assert subs[1].index == 3 assert subs[1].lang == "jpn" # ============================================================================= # Probe Cache Management Tests # ============================================================================= class TestProbeCacheManagement: """Tests for probe cache management functions.""" def test_clear_all_probe_cache(self): """Test clearing all probe caches.""" import time import ffmpeg_command # Clear first to ensure known state ffmpeg_command._probe_cache.clear() ffmpeg_command._series_probe_cache.clear() ffmpeg_command._probe_cache["url1"] = (time.time(), None, []) ffmpeg_command._probe_cache["url2"] = (time.time(), None, []) ffmpeg_command._series_probe_cache[123] = {"episodes": {1: (time.time(), None, [])}} with patch("ffmpeg_command._save_series_probe_cache"): count = clear_all_probe_cache() assert count == 3 assert len(ffmpeg_command._probe_cache) == 0 assert len(ffmpeg_command._series_probe_cache) == 0 def test_invalidate_series_probe_cache_entire_series(self): """Test invalidating entire series cache.""" import ffmpeg_command ffmpeg_command._series_probe_cache[123] = { "name": "Test", "episodes": {1: (0, None, []), 2: (0, None, [])}, } with patch("ffmpeg_command._save_series_probe_cache"): invalidate_series_probe_cache(123) assert 123 not in ffmpeg_command._series_probe_cache def test_invalidate_series_probe_cache_single_episode(self): """Test invalidating single episode cache.""" import ffmpeg_command ffmpeg_command._series_probe_cache[123] = { "name": "Test", "episodes": {1: (0, None, []), 2: (0, None, [])}, } with patch("ffmpeg_command._save_series_probe_cache"): invalidate_series_probe_cache(123, episode_id=1) assert 123 in ffmpeg_command._series_probe_cache assert 1 not in ffmpeg_command._series_probe_cache[123]["episodes"] assert 2 in ffmpeg_command._series_probe_cache[123]["episodes"] def test_clear_series_mru(self): """Test clearing series MRU.""" import ffmpeg_command ffmpeg_command._series_probe_cache[123] = { "name": "Test", "mru": 5, "episodes": {1: (0, None, [])}, } with patch("ffmpeg_command._save_series_probe_cache"): clear_series_mru(123) assert "mru" not in ffmpeg_command._series_probe_cache[123] assert "episodes" in ffmpeg_command._series_probe_cache[123] def test_restore_probe_cache_entry(self): """Test restoring probe cache entry.""" import ffmpeg_command ffmpeg_command._probe_cache.clear() ffmpeg_command._series_probe_cache.clear() media_info = MediaInfo(video_codec="h264", audio_codec="aac", pix_fmt="yuv420p") subs = [SubtitleStream(index=2, lang="eng", name="English")] restore_probe_cache_entry("http://test", media_info, subs, series_id=123, episode_id=5) assert "http://test" in ffmpeg_command._probe_cache assert 123 in ffmpeg_command._series_probe_cache assert 5 in ffmpeg_command._series_probe_cache[123]["episodes"] def test_get_series_probe_cache_stats(self): """Test getting cache stats for UI.""" import time import ffmpeg_command ffmpeg_command._series_probe_cache.clear() ffmpeg_command._series_probe_cache[123] = { "name": "Test Series", "mru": 2, "episodes": { 1: (time.time(), MediaInfo("h264", "aac", "yuv420p"), []), 2: (time.time(), MediaInfo("h264", "aac", "yuv420p"), []), }, } stats = get_series_probe_cache_stats() assert len(stats) == 1 assert stats[0]["series_id"] == 123 assert stats[0]["name"] == "Test Series" assert stats[0]["episode_count"] == 2 class TestResolveHlsMasterPlaylist: """Tests for resolve_hls_master_playlist function.""" def test_non_m3u8_url_returns_unchanged(self): """Non-m3u8 URLs should be returned unchanged.""" from ffmpeg_command import resolve_hls_master_playlist url = "http://example.com/video.mp4" assert resolve_hls_master_playlist(url) == url def test_master_playlist_selects_highest_bandwidth(self): """Master playlist should resolve to highest bandwidth variant.""" from unittest.mock import MagicMock, patch from ffmpeg_command import resolve_hls_master_playlist master_playlist = """#EXTM3U #EXT-X-STREAM-INF:BANDWIDTH=500000,RESOLUTION=640x360 low.m3u8 #EXT-X-STREAM-INF:BANDWIDTH=1500000,RESOLUTION=1280x720 mid.m3u8 #EXT-X-STREAM-INF:BANDWIDTH=3000000,RESOLUTION=1920x1080 high.m3u8 """ mock_response = MagicMock() mock_response.read.return_value = master_playlist.encode("utf-8") mock_response.__enter__ = MagicMock(return_value=mock_response) mock_response.__exit__ = MagicMock(return_value=False) with patch("urllib.request.urlopen", return_value=mock_response): result = resolve_hls_master_playlist("http://example.com/master.m3u8") assert result == "http://example.com/high.m3u8" def test_media_playlist_returns_unchanged(self): """Media playlist (not master) should return original URL.""" from unittest.mock import MagicMock, patch from ffmpeg_command import resolve_hls_master_playlist media_playlist = """#EXTM3U #EXT-X-VERSION:3 #EXT-X-TARGETDURATION:10 #EXTINF:10.0, segment0.ts #EXTINF:10.0, segment1.ts """ mock_response = MagicMock() mock_response.read.return_value = media_playlist.encode("utf-8") mock_response.__enter__ = MagicMock(return_value=mock_response) mock_response.__exit__ = MagicMock(return_value=False) with patch("urllib.request.urlopen", return_value=mock_response): result = resolve_hls_master_playlist("http://example.com/stream.m3u8") assert result == "http://example.com/stream.m3u8" def test_fetch_error_returns_original_url(self): """On fetch error, should return original URL.""" from unittest.mock import patch from ffmpeg_command import resolve_hls_master_playlist with patch("urllib.request.urlopen", side_effect=Exception("Network error")): result = resolve_hls_master_playlist("http://example.com/master.m3u8") assert result == "http://example.com/master.m3u8" def test_relative_url_resolved_correctly(self): """Relative variant URLs should be resolved against base URL.""" from unittest.mock import MagicMock, patch from ffmpeg_command import resolve_hls_master_playlist master_playlist = """#EXTM3U #EXT-X-STREAM-INF:BANDWIDTH=3000000 ../streams/1080p/index.m3u8 """ mock_response = MagicMock() mock_response.read.return_value = master_playlist.encode("utf-8") mock_response.__enter__ = MagicMock(return_value=mock_response) mock_response.__exit__ = MagicMock(return_value=False) with patch("urllib.request.urlopen", return_value=mock_response): result = resolve_hls_master_playlist("http://example.com/hls/master.m3u8") assert result == "http://example.com/streams/1080p/index.m3u8" if __name__ == "__main__": from testing import run_tests run_tests(__file__) ================================================ FILE: ffmpeg_session.py ================================================ """FFmpeg session lifecycle management.""" from __future__ import annotations from dataclasses import dataclass from typing import Any import asyncio import contextlib import json import logging import pathlib import re import shutil import tempfile import threading import time import uuid from fastapi import HTTPException from ffmpeg_command import ( SEG_PREFIX, HwAccel, MediaInfo, SubtitleStream, build_hls_ffmpeg_cmd, get_ffmpeg_env, get_hls_segment_duration, get_settings, get_transcode_dir, get_user_agent, invalidate_series_probe_cache, probe_media, resolve_hls_master_playlist, restore_probe_cache_entry, ) log = logging.getLogger(__name__) # Timing constants _POLL_INTERVAL_SEC = 0.2 _QUICK_FAILURE_THRESHOLD_SEC = 10.0 _HEARTBEAT_TIMEOUT_SEC = 30.0 # 30 sec without progress poll = dead # Wait timeouts (seconds) _PLAYLIST_WAIT_TIMEOUT_SEC = 30.0 _PLAYLIST_WAIT_SEEK_TIMEOUT_SEC = 40.0 _REUSE_ACTIVE_WAIT_TIMEOUT_SEC = 15.0 _RESUME_WAIT_TIMEOUT_SEC = 10.0 _RESUME_SEGMENT_WAIT_TIMEOUT_SEC = 5.0 # Size thresholds _MIN_SEGMENT_SIZE_BYTES = 1_000 # Module state _transcode_sessions: dict[str, dict[str, Any]] = {} _url_to_session: dict[str, str] = {} # URL -> session_id (all content types) _transcode_lock = threading.Lock() _background_tasks: set[asyncio.Task[None]] = set() class _DeadProcess: """Placeholder for dead/recovered processes.""" returncode = -1 def terminate(self) -> None: pass def kill(self) -> None: pass # =========================================================================== # Cache Timeout Helpers # =========================================================================== def get_vod_cache_timeout() -> int: """Get VOD session cache timeout in seconds.""" return get_settings().get("vod_transcode_cache_mins", 60) * 60 def get_live_cache_timeout() -> int: """Get live session cache timeout in seconds.""" return get_settings().get("live_transcode_cache_secs", 0) # =========================================================================== # Session Validity # =========================================================================== def _is_process_alive(proc: Any) -> bool: """Check if process is still running.""" if proc is None: return False if isinstance(proc, _DeadProcess): return False if hasattr(proc, "returncode"): return proc.returncode is None return False def is_session_valid(session: dict[str, Any]) -> bool: """Check if session is still valid (not expired). A session is valid if: - Has received a heartbeat (progress poll) within timeout, AND - Process is still running, OR process is dead but within cache timeout """ last_access = session.get("last_access", session["started"]) time_since_heartbeat = time.time() - last_access # No heartbeat in 30 sec = dead regardless of process state if time_since_heartbeat > _HEARTBEAT_TIMEOUT_SEC: return False # Active process with recent heartbeat = valid if _is_process_alive(session.get("process")): return True # Dead process: check cache timeout is_vod = session.get("is_vod", False) cache_timeout = get_vod_cache_timeout() if is_vod else get_live_cache_timeout() if cache_timeout <= 0: return False # No caching of dead sessions return time_since_heartbeat < cache_timeout def _kill_process(proc: Any) -> bool: """Kill process gracefully (SIGTERM then SIGKILL), return True if killed.""" try: # Try graceful termination first (lets ffmpeg flush buffers) proc.terminate() # Give it a moment to exit cleanly for _ in range(10): # 100ms total if proc.returncode is not None: return True time.sleep(0.01) # Force kill if still running proc.kill() return True except (ProcessLookupError, OSError): return False # =========================================================================== # Session Start/Stop # =========================================================================== def stop_session(session_id: str, force: bool = False) -> None: """Stop a transcode session.""" with _transcode_lock: session = _transcode_sessions.get(session_id) if not session: return # Skip stop if session was accessed recently (race with seeking/resume, # or multiple users watching same stream) if not force and time.time() - session.get("last_access", 0) < 5.0: log.info("Ignoring stop for recently-accessed session %s", session_id) return if _kill_process(session["process"]): log.info("Killed ffmpeg for session %s", session_id) # Cache session if timeout > 0 is_vod = session.get("is_vod", False) cache_timeout = get_vod_cache_timeout() if is_vod else get_live_cache_timeout() if not force and cache_timeout > 0: session["last_access"] = time.time() log.info( "Session %s cached (vod=%s, ffmpeg stopped, segments kept)", session_id, is_vod, ) return _transcode_sessions.pop(session_id, None) url = session.get("url") if url: _url_to_session.pop(url, None) dir_to_remove = session["dir"] shutil.rmtree(dir_to_remove, ignore_errors=True) log.info("Stopped transcode session %s", session_id) def cleanup_expired_sessions() -> None: """Clean up all expired sessions (VOD and live).""" with _transcode_lock: expired = [ sid for sid, session in list(_transcode_sessions.items()) if not is_session_valid(session) ] for session_id in expired: stop_session(session_id, force=True) def shutdown() -> None: """Kill all running ffmpeg processes for clean shutdown.""" with _transcode_lock: for session_id, session in list(_transcode_sessions.items()): proc = session.get("process") if proc and _kill_process(proc): log.info("Shutdown: killed ffmpeg for session %s", session_id) _transcode_sessions.clear() # =========================================================================== # Stream Limits # =========================================================================== def get_user_sessions(username: str) -> list[tuple[str, dict[str, Any]]]: """Get all active sessions for a user, sorted by start time (oldest first).""" with _transcode_lock: sessions = [ (sid, s) for sid, s in _transcode_sessions.items() if s.get("username") == username ] return sorted(sessions, key=lambda x: x[1].get("started", 0)) def get_source_sessions(source_id: str) -> list[tuple[str, dict[str, Any]]]: """Get all active sessions for a source, sorted by start time (oldest first).""" with _transcode_lock: sessions = [ (sid, s) for sid, s in _transcode_sessions.items() if s.get("source_id") == source_id ] return sorted(sessions, key=lambda x: x[1].get("started", 0)) def enforce_stream_limits( username: str, source_id: str | None, user_max: int, source_max: int, ) -> str | None: """Enforce stream limits, stopping oldest sessions if needed. Returns error message if source is at capacity and user can't reclaim, or None if limits are satisfied. """ # Check source limit first (hard limit - can only reclaim own slots) if source_id and source_max > 0: source_sessions = get_source_sessions(source_id) if len(source_sessions) >= source_max: user_source_sessions = [ (sid, s) for sid, s in source_sessions if s.get("username") == username ] if user_source_sessions: oldest_sid, _ = user_source_sessions[0] log.info( "Source %s at limit (%d), stopping user %s's oldest session %s", source_id, source_max, username, oldest_sid, ) stop_session(oldest_sid, force=True) else: return f"Source at capacity ({source_max} streams)" # Check user limit (soft limit - auto-rotate oldest) if user_max > 0: user_sessions = get_user_sessions(username) if len(user_sessions) >= user_max: oldest_sid, _ = user_sessions[0] log.info( "User %s at limit (%d), stopping oldest session %s", username, user_max, oldest_sid, ) stop_session(oldest_sid, force=True) return None # =========================================================================== # Session Recovery (Startup) # =========================================================================== def cleanup_and_recover_sessions() -> None: """Clean up orphaned transcode dirs and recover valid VOD sessions. Called on startup to: 1. Remove all orphaned dirs (no session.json - leftover live sessions) 2. Remove expired VOD dirs (older than cache timeout) 3. Recover valid VOD sessions for resume """ cache_timeout = get_vod_cache_timeout() now = time.time() removed = recovered = 0 for d in get_transcode_dir().glob("netv_transcode_*"): if not d.is_dir(): continue info_file = d / "session.json" try: mtime = d.stat().st_mtime except OSError: shutil.rmtree(d, ignore_errors=True) removed += 1 continue # No session.json = orphaned (live session or failed VOD) if not info_file.exists(): shutil.rmtree(d, ignore_errors=True) removed += 1 continue # Expired VOD session if now - mtime > cache_timeout: shutil.rmtree(d, ignore_errors=True) removed += 1 continue # No segments = nothing to recover if not list(d.glob(f"{SEG_PREFIX}*.ts")): shutil.rmtree(d, ignore_errors=True) removed += 1 continue # Try to recover VOD session try: info = json.loads(info_file.read_text()) if not (info.get("is_vod") and info.get("url")): shutil.rmtree(d, ignore_errors=True) removed += 1 continue session_id = info["session_id"] url = info["url"] new_seek = info.get("seek_offset", 0) with _transcode_lock: _transcode_sessions[session_id] = { "dir": str(d), "process": _DeadProcess(), "started": info.get("started", mtime), "url": url, "is_vod": True, "last_access": now, # Use current time, not mtime, to avoid immediate expiration "subtitles": info.get("subtitles") or info.get("subtitle_indices"), "duration": info.get("duration", 0), "seek_offset": new_seek, "series_id": info.get("series_id"), "episode_id": info.get("episode_id"), "username": info.get("username", ""), "source_id": info.get("source_id", ""), } # Prefer session with seek_offset or more recent mtime existing_id = _url_to_session.get(url) if existing_id: existing = _transcode_sessions.get(existing_id, {}) existing_seek = existing.get("seek_offset", 0) existing_mtime = existing.get("last_access", 0) if (new_seek > 0 and existing_seek == 0) or ( existing_seek == 0 and new_seek == 0 and mtime > existing_mtime ): _url_to_session[url] = session_id else: _url_to_session[url] = session_id # Restore probe cache if p := info.get("probe"): media_info = MediaInfo( video_codec=p.get("video_codec", ""), audio_codec=p.get("audio_codec", ""), pix_fmt=p.get("pix_fmt", ""), audio_channels=p.get("audio_channels", 0), audio_sample_rate=p.get("audio_sample_rate", 0), subtitle_codecs=p.get("subtitle_codecs"), duration=info.get("duration", 0), height=p.get("height", 0), video_bitrate=p.get("video_bitrate", 0), interlaced=p.get("interlaced", False), ) subs = [ SubtitleStream(s["index"], s.get("lang", "und"), s.get("name", "")) for s in (info.get("subtitles") or []) if isinstance(s, dict) and "index" in s ] restore_probe_cache_entry( url, media_info, subs, info.get("series_id"), info.get("episode_id"), ) recovered += 1 log.debug("Recovered VOD session %s for %s", session_id, url[:50]) except Exception as e: log.warning("Failed to recover session from %s: %s", d, e) shutil.rmtree(d, ignore_errors=True) removed += 1 if removed or recovered: log.info( "Startup cleanup: removed %d orphaned dirs, recovered %d VOD sessions", removed, recovered, ) # =========================================================================== # FFmpeg Monitoring # =========================================================================== async def _monitor_ffmpeg_stderr( process: asyncio.subprocess.Process, session_id: str, stderr_lines: list[str] | None = None, ) -> None: assert process.stderr is not None while True: line = await process.stderr.readline() if not line: break text = line.decode().rstrip() if stderr_lines is not None: stderr_lines.append(text) is_fatal = "fatal" in text.lower() or "aborting" in text.lower() level = logging.WARNING if is_fatal else logging.DEBUG log.log(level, "ffmpeg:%s %s", session_id, text) async def _monitor_resume_ffmpeg( process: asyncio.subprocess.Process, session_id: str, url: str, ) -> None: start_time = time.time() await _monitor_ffmpeg_stderr(process, session_id) await process.wait() if process.returncode != 0: log.warning( "Resume ffmpeg exited with code %s for session %s", process.returncode, session_id, ) if time.time() - start_time < _QUICK_FAILURE_THRESHOLD_SEC: log.info("Resume failed quickly, invalidating session %s", session_id) with _transcode_lock: _url_to_session.pop(url, None) session = _transcode_sessions.pop(session_id, None) # Clean up output directory if session: shutil.rmtree(session["dir"], ignore_errors=True) async def _monitor_seek_ffmpeg( process: asyncio.subprocess.Process, session_id: str, ) -> None: await _monitor_ffmpeg_stderr(process, session_id) await process.wait() if process.returncode != 0: log.warning( "Seek ffmpeg exited with code %s for session %s", process.returncode, session_id, ) def _spawn_background_task(coro: Any) -> None: task = asyncio.create_task(coro) _background_tasks.add(task) task.add_done_callback(_background_tasks.discard) # =========================================================================== # Playlist Helpers # =========================================================================== async def _wait_for_playlist( playlist_path: pathlib.Path, process: asyncio.subprocess.Process, min_segments: int = 1, timeout_sec: float = _PLAYLIST_WAIT_TIMEOUT_SEC, ) -> bool: """Wait for playlist with min_segments, checking process health.""" output_dir = playlist_path.parent deadline = time.monotonic() + timeout_sec while time.monotonic() < deadline: if process.returncode is not None: return False if playlist_path.exists(): content = playlist_path.read_text() seg_count = content.count("#EXTINF") if seg_count >= min_segments: seg_files = list(output_dir.glob(f"{SEG_PREFIX}*.ts")) if len(seg_files) >= min_segments: first_seg = min(seg_files, key=lambda f: f.name) if ( first_seg.stat().st_size > _MIN_SEGMENT_SIZE_BYTES and process.returncode is None ): return True await asyncio.sleep(_POLL_INTERVAL_SEC) return False def _calc_hls_duration(playlist_path: pathlib.Path, segment_count: int) -> float: """Calculate HLS duration from playlist or estimate from segment count.""" if playlist_path.exists(): durations = re.findall(r"#EXTINF:([\d.]+)", playlist_path.read_text()) if durations: return sum(float(d) for d in durations) return segment_count * get_hls_segment_duration() def _build_subtitle_tracks( session_id: str, sub_info: list[dict[str, Any]], ) -> list[dict[str, Any]]: if not sub_info or not isinstance(sub_info[0], dict): return [] return [ { "url": f"/subs/{session_id}/sub{i}.vtt", "lang": s["lang"], "label": s["name"], "default": i == 0, } for i, s in enumerate(sub_info) ] def _regenerate_playlist(output_dir: pathlib.Path, start_segment: int) -> None: """Regenerate HLS playlist starting from a specific segment (for smart seek).""" playlist_path = output_dir / "stream.m3u8" seg_duration = get_hls_segment_duration() # Find all existing segments from start_segment onwards segments = [] for seg_file in sorted(output_dir.glob(f"{SEG_PREFIX}*.ts")): try: seg_num = int(seg_file.stem[len(SEG_PREFIX) :]) if seg_num >= start_segment and seg_file.stat().st_size > _MIN_SEGMENT_SIZE_BYTES: segments.append((seg_num, seg_file.name)) except ValueError: pass if not segments: return # Build playlist lines = [ "#EXTM3U", "#EXT-X-VERSION:3", f"#EXT-X-TARGETDURATION:{int(seg_duration) + 1}", f"#EXT-X-MEDIA-SEQUENCE:{start_segment}", "#EXT-X-PLAYLIST-TYPE:EVENT", ] for _, seg_name in segments: lines.append(f"#EXTINF:{seg_duration:.6f},") lines.append(seg_name) playlist_path.write_text("\n".join(lines) + "\n") log.debug("Regenerated playlist with %d segments starting at %d", len(segments), start_segment) # =========================================================================== # Session Snapshots # =========================================================================== @dataclass(slots=True) class _SessionSnapshot: """Immutable snapshot of session state for lock-free access.""" output_dir: str process: Any seek_offset: float subtitles: list[dict[str, Any]] duration: float def _get_session_snapshot(session_id: str) -> _SessionSnapshot | None: """Get atomic snapshot of session state under lock.""" with _transcode_lock: session = _transcode_sessions.get(session_id) if not session: return None session["last_access"] = time.time() return _SessionSnapshot( output_dir=session["dir"], process=session["process"], seek_offset=session.get("seek_offset", 0), subtitles=session.get("subtitles") or [], duration=session.get("duration", 0), ) def _update_session_process(session_id: str, process: Any) -> bool: """Atomically update session process. Returns False if session gone.""" with _transcode_lock: session = _transcode_sessions.get(session_id) if not session: return False session["process"] = process return True def _build_session_response( session_id: str, snap: _SessionSnapshot, playlist_path: pathlib.Path, ) -> dict[str, Any]: """Build response dict for existing session, recalculating duration.""" segments = list(playlist_path.parent.glob(f"{SEG_PREFIX}*.ts")) return { "session_id": session_id, "playlist": f"/transcode/{session_id}/stream.m3u8", "subtitles": _build_subtitle_tracks(session_id, snap.subtitles), "duration": snap.duration, "seek_offset": snap.seek_offset, "transcoded_duration": _calc_hls_duration(playlist_path, len(segments)), } # =========================================================================== # Existing Session Handling # =========================================================================== def _get_existing_session(url: str) -> tuple[str | None, bool, float]: """Get existing session info atomically. Returns (session_id, is_valid, seek_offset).""" with _transcode_lock: existing_id = _url_to_session.get(url) if not existing_id: return None, False, 0.0 session = _transcode_sessions.get(existing_id) if not session: return None, False, 0.0 return ( existing_id, is_session_valid(session), session.get("seek_offset", 0), ) async def _handle_existing_vod_session( existing_id: str, url: str, hw: HwAccel, do_probe: bool, max_resolution: str = "1080p", quality: str = "high", ) -> dict[str, Any] | None: """Handle existing VOD session: reuse active, return cached, or append. Returns None to trigger fresh start if session is invalid. """ snap = _get_session_snapshot(existing_id) if not snap: return None playlist_path = pathlib.Path(snap.output_dir) / "stream.m3u8" segments = sorted(pathlib.Path(snap.output_dir).glob(f"{SEG_PREFIX}*.ts")) # Case 1: Active session - reuse it if snap.process.returncode is None: log.info("Reusing active session %s", existing_id) await _wait_for_playlist( playlist_path, snap.process, min_segments=1, timeout_sec=_REUSE_ACTIVE_WAIT_TIMEOUT_SEC, ) return _build_session_response(existing_id, snap, playlist_path) # Case 2: Dead session with no segments - invalid if not segments: stop_session(existing_id, force=True) with _transcode_lock: _url_to_session.pop(url, None) return None # Case 3: Dead session with seek_offset - return cached content if snap.seek_offset > 0: log.info( "Returning cached session %s (seek_offset=%.1f)", existing_id, snap.seek_offset, ) return _build_session_response(existing_id, snap, playlist_path) # Case 4: Dead session, no seek_offset - append new content hls_duration = _calc_hls_duration(playlist_path, len(segments)) log.info("Resuming session %s from %.1fs", existing_id, hls_duration) # Resolve HLS master playlist to highest bandwidth variant url = await asyncio.to_thread(resolve_hls_master_playlist, url) media_info = ( (await asyncio.to_thread(probe_media, url, None, None, ""))[0] if do_probe else None ) cmd = build_hls_ffmpeg_cmd( url, hw, snap.output_dir, True, None, media_info, max_resolution, quality, get_user_agent(), None, ) i_idx = cmd.index("-i") cmd.insert(i_idx, str(hls_duration)) cmd.insert(i_idx, "-ss") try: hls_flags_idx = cmd.index("-hls_flags") cmd[hls_flags_idx + 1] += "+append_list" except ValueError: cmd.extend(["-hls_flags", "append_list"]) cmd.extend(["-start_number", str(len(segments))]) process = await asyncio.create_subprocess_exec( *cmd, stdin=asyncio.subprocess.DEVNULL, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, env=get_ffmpeg_env(), ) if not _update_session_process(existing_id, process): _kill_process(process) return None _spawn_background_task(_monitor_resume_ffmpeg(process, existing_id, url)) log.info("Started resume ffmpeg pid=%s for %s", process.pid, existing_id) deadline = time.monotonic() + _RESUME_SEGMENT_WAIT_TIMEOUT_SEC next_seg = f"{SEG_PREFIX}{len(segments):03d}.ts" while time.monotonic() < deadline: if process.returncode is not None: log.warning("Resume ffmpeg died immediately for %s", existing_id) return None if (pathlib.Path(snap.output_dir) / next_seg).exists(): break await asyncio.sleep(_POLL_INTERVAL_SEC) await _wait_for_playlist( playlist_path, process, min_segments=1, timeout_sec=_RESUME_WAIT_TIMEOUT_SEC, ) return _build_session_response(existing_id, snap, playlist_path) async def _try_reuse_session( existing_id: str, url: str, is_vod: bool, content_type: str, ) -> dict[str, Any] | None: """Try to reuse an existing valid session. Returns response or None if can't reuse.""" if is_vod: settings = get_settings() return await _handle_existing_vod_session( existing_id, url, settings.get("transcode_hw", "software"), settings.get( {"movie": "probe_movies", "series": "probe_series"}.get(content_type, ""), False ), settings.get("max_resolution", "1080p"), settings.get("quality", "high"), ) # Live: return existing session if snapshot available snap = _get_session_snapshot(existing_id) if not snap: return None playlist_path = pathlib.Path(snap.output_dir) / "stream.m3u8" return _build_session_response(existing_id, snap, playlist_path) def _cleanup_invalid_session(url: str, session_id: str) -> None: """Clean up an invalid/expired session.""" with _transcode_lock: _url_to_session.pop(url, None) stop_session(session_id, force=True) # =========================================================================== # Core Transcode Logic # =========================================================================== async def _do_start_transcode( url: str, content_type: str, series_id: int | None, episode_id: int | None, old_seek_offset: float, series_name: str = "", deinterlace_fallback: bool = True, username: str = "", source_id: str = "", ) -> dict[str, Any]: """Core transcode logic. Raises HTTPException on failure.""" # Resolve HLS master playlist to highest bandwidth variant url = await asyncio.to_thread(resolve_hls_master_playlist, url) settings = get_settings() hw = settings.get("transcode_hw", "software") max_resolution = settings.get("max_resolution", "1080p") quality = settings.get("quality", "high") is_vod = content_type in ("movie", "series") probe_key = {"movie": "probe_movies", "series": "probe_series", "live": "probe_live"} do_probe = settings.get(probe_key.get(content_type, ""), False) session_id = str(uuid.uuid4()) output_dir = tempfile.mkdtemp( prefix=f"netv_transcode_{session_id}_", dir=get_transcode_dir(), ) playlist_path = pathlib.Path(output_dir) / "stream.m3u8" media_info: MediaInfo | None = None subtitles: list[SubtitleStream] = [] if do_probe: media_info, subtitles = await asyncio.to_thread( probe_media, url, series_id, episode_id, series_name ) if media_info: subs_str = ( ",".join(media_info.subtitle_codecs) if media_info.subtitle_codecs else "none" ) if subtitles: subs_str += f" [extract:{','.join(s.lang for s in subtitles)}]" bitrate_str = ( f"{media_info.video_bitrate / 1_000_000:.1f}Mbps" if media_info.video_bitrate else "?" ) log.info( "Probe: video=%s/%s/%dp/%s%s audio=%s/%dch/%dHz duration=%.0fs subs=%s", media_info.video_codec, media_info.pix_fmt, media_info.height, bitrate_str, "/interlaced" if media_info.interlaced else "", media_info.audio_codec, media_info.audio_channels, media_info.audio_sample_rate, media_info.duration, subs_str, ) cmd = build_hls_ffmpeg_cmd( url, hw, output_dir, is_vod, subtitles, media_info, max_resolution, quality, get_user_agent(), deinterlace_fallback, ) if old_seek_offset > 0: i_idx = cmd.index("-i") cmd.insert(i_idx, str(old_seek_offset)) cmd.insert(i_idx, "-ss") log.info("Applying seek_offset=%.1f from previous session", old_seek_offset) log.info( "Starting transcode session %s (vod=%s): %s", session_id, is_vod, " ".join(cmd), ) process = await asyncio.create_subprocess_exec( *cmd, stdin=asyncio.subprocess.DEVNULL, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, env=get_ffmpeg_env(), ) stderr_lines: list[str] = [] _spawn_background_task(_monitor_ffmpeg_stderr(process, session_id, stderr_lines)) sub_info = [{"index": s.index, "lang": s.lang, "name": s.name} for s in subtitles] total_duration = media_info.duration if media_info else 0.0 with _transcode_lock: _transcode_sessions[session_id] = { "dir": output_dir, "process": process, "started": time.time(), "url": url, "is_vod": is_vod, "last_access": time.time(), "subtitles": sub_info, "duration": total_duration, "seek_offset": old_seek_offset, "series_id": series_id, "episode_id": episode_id, "username": username, "source_id": source_id, } _url_to_session[url] = session_id if is_vod: session_info: dict[str, Any] = { "session_id": session_id, "url": url, "is_vod": True, "started": time.time(), "subtitles": sub_info, "duration": total_duration, "seek_offset": old_seek_offset, "series_id": series_id, "episode_id": episode_id, "username": username, "source_id": source_id, } if media_info: session_info["probe"] = { "video_codec": media_info.video_codec, "audio_codec": media_info.audio_codec, "pix_fmt": media_info.pix_fmt, "audio_channels": media_info.audio_channels, "audio_sample_rate": media_info.audio_sample_rate, "subtitle_codecs": media_info.subtitle_codecs, "height": media_info.height, "video_bitrate": media_info.video_bitrate, "interlaced": media_info.interlaced, } (pathlib.Path(output_dir) / "session.json").write_text(json.dumps(session_info)) timeout = _PLAYLIST_WAIT_SEEK_TIMEOUT_SEC if old_seek_offset > 0 else _PLAYLIST_WAIT_TIMEOUT_SEC if not await _wait_for_playlist( playlist_path, process, min_segments=2, timeout_sec=timeout, ): # Wait for process to fully exit and stderr to be captured with contextlib.suppress(TimeoutError): await asyncio.wait_for(process.wait(), timeout=1.0) # Give stderr monitor time to process final output await asyncio.sleep(0.1) error_msg = "\n".join(stderr_lines[-10:]) if stderr_lines else "unknown" log.error( "ffmpeg:%s failed (exit %d): %s", session_id, process.returncode or -1, error_msg, ) stop_session(session_id) raise HTTPException(500, "Transcode failed - check server logs for details") return { "session_id": session_id, "playlist": f"/transcode/{session_id}/stream.m3u8", "subtitles": _build_subtitle_tracks(session_id, sub_info), "duration": total_duration, "seek_offset": old_seek_offset, } async def start_transcode( url: str, content_type: str = "live", series_id: int | None = None, episode_id: int | None = None, series_name: str = "", deinterlace_fallback: bool = True, username: str = "", source_id: str = "", user_max_streams: int = 0, source_max_streams: int = 0, ) -> dict[str, Any]: """Start or reuse a transcode session.""" # Enforce stream limits if username: error = enforce_stream_limits(username, source_id, user_max_streams, source_max_streams) if error: raise HTTPException(status_code=429, detail=error) is_vod = content_type in ("movie", "series") existing_id, is_valid, old_seek_offset = _get_existing_session(url) # Try to reuse existing valid session if existing_id and is_valid: log.info("Found valid existing session %s (vod=%s)", existing_id, is_vod) result = await _try_reuse_session(existing_id, url, is_vod, content_type) if result: return result # Clean up any existing invalid session if existing_id: log.info("Cleaning up invalid session %s", existing_id) _cleanup_invalid_session(url, existing_id) # Start fresh transcode (with retry for series probe cache staleness) try: return await _do_start_transcode( url, content_type, series_id, episode_id, old_seek_offset, series_name, deinterlace_fallback, username, source_id, ) except HTTPException: if series_id is None: raise log.info("Transcode failed, clearing probe cache and retrying") invalidate_series_probe_cache(series_id, episode_id) return await _do_start_transcode( url, content_type, series_id, episode_id, old_seek_offset, series_name, deinterlace_fallback, username, source_id, ) # =========================================================================== # Session Query/Update # =========================================================================== def get_session(session_id: str) -> dict[str, Any] | None: """Get a copy of session dict (safe to use outside lock).""" with _transcode_lock: session = _transcode_sessions.get(session_id) return dict(session) if session else None def touch_session(session_id: str) -> bool: """Update session last_access timestamp (heartbeat). Returns True if session exists.""" with _transcode_lock: session = _transcode_sessions.get(session_id) if session: session["last_access"] = time.time() return True return False def get_session_progress(session_id: str) -> dict[str, Any] | None: """Get transcode progress for a session.""" touch_session(session_id) session = get_session(session_id) if not session: return None playlist_path = pathlib.Path(session["dir"]) / "stream.m3u8" if not playlist_path.exists(): return {"segment_count": 0, "duration": 0.0} durations = re.findall(r"#EXTINF:([\d.]+)", playlist_path.read_text()) return { "segment_count": len(durations), "duration": sum(float(d) for d in durations), } def clear_url_session(url: str) -> str | None: """Clear URL-to-session mapping.""" with _transcode_lock: return _url_to_session.pop(url, None) # =========================================================================== # Seek # =========================================================================== @dataclass(slots=True) class _SeekSessionInfo: """Snapshot of session info needed for seek.""" url: str output_dir: str process: Any subtitles: list[dict[str, Any]] series_id: int | None episode_id: int | None def _get_seek_session_info(session_id: str) -> _SeekSessionInfo | None: """Get session info for seek atomically. Returns None if not VOD.""" with _transcode_lock: session = _transcode_sessions.get(session_id) if not session or not session.get("is_vod"): return None return _SeekSessionInfo( url=session["url"], output_dir=session["dir"], process=session["process"], subtitles=session.get("subtitles") or [], series_id=session.get("series_id"), episode_id=session.get("episode_id"), ) def _update_seek_session( session_id: str, url: str, process: Any, seek_time: float, ) -> bool: """Update session after seek. Returns False if session gone.""" with _transcode_lock: session = _transcode_sessions.get(session_id) if not session: return False session["process"] = process session["seek_offset"] = seek_time if url: _url_to_session[url] = session_id return True async def seek_transcode(session_id: str, seek_time: float) -> dict[str, Any]: """Seek to a specific time in a VOD session.""" info = _get_seek_session_info(session_id) if not info: raise HTTPException(404, "Session not found or not VOD") settings = get_settings() hw = settings.get("transcode_hw", "software") max_resolution = settings.get("max_resolution", "1080p") quality = settings.get("quality", "high") seg_duration = get_hls_segment_duration() segment_num = int(seek_time / seg_duration) output_path = pathlib.Path(info.output_dir) target_segment = output_path / f"{SEG_PREFIX}{segment_num:03d}.ts" # Smart seek: if target segment exists, no need to restart ffmpeg if target_segment.exists() and target_segment.stat().st_size > _MIN_SEGMENT_SIZE_BYTES: log.info( "Smart seek: segment %d exists for time %.1fs, skipping ffmpeg restart", segment_num, seek_time, ) with _transcode_lock: session = _transcode_sessions.get(session_id) if session: session["seek_offset"] = seek_time _regenerate_playlist(output_path, segment_num) return {"session_id": session_id, "playlist": f"/transcode/{session_id}/stream.m3u8"} # Kill existing process if _kill_process(info.process): log.info("Killed ffmpeg for seek in session %s", session_id) # Clear playlist but keep segments (for backward seeks later) playlist_file = output_path / "stream.m3u8" playlist_file.unlink(missing_ok=True) # Only clear segments AFTER target (we might seek back to earlier ones) for seg_file in output_path.glob(f"{SEG_PREFIX}*.ts"): try: seg_num = int(seg_file.stem[len(SEG_PREFIX) :]) if seg_num >= segment_num: seg_file.unlink(missing_ok=True) except ValueError: pass for vtt_file in output_path.glob("sub*.vtt"): vtt_file.unlink(missing_ok=True) # Resolve HLS master playlist to highest bandwidth variant url = await asyncio.to_thread(resolve_hls_master_playlist, info.url) # Use probe_series if series_id, else probe_movies probe_setting = "probe_series" if info.series_id else "probe_movies" do_probe = settings.get(probe_setting, False) if do_probe: media_info = ( await asyncio.to_thread( probe_media, url, info.series_id, info.episode_id, ) )[0] else: media_info = None subtitles: list[SubtitleStream] = [] for s in info.subtitles: if isinstance(s, dict) and "index" in s: subtitles.append( SubtitleStream( index=s["index"], lang=s.get("lang", "und"), name=s.get("name", "Unknown"), ) ) cmd = build_hls_ffmpeg_cmd( url, hw, info.output_dir, True, subtitles or None, media_info, max_resolution, quality, get_user_agent(), None, ) i_idx = cmd.index("-i") cmd.insert(i_idx, str(seek_time)) cmd.insert(i_idx, "-ss") # Shift output timestamps so subtitles start at 0 after seek f_idx = cmd.index("-f") cmd.insert(f_idx, str(-seek_time)) cmd.insert(f_idx, "-output_ts_offset") cmd.extend(["-start_number", str(segment_num)]) log.info( "Seek transcode %s to %.1fs (seg %d): %s", session_id, seek_time, segment_num, " ".join(cmd), ) process = await asyncio.create_subprocess_exec( *cmd, stdin=asyncio.subprocess.DEVNULL, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, env=get_ffmpeg_env(), ) if not _update_seek_session(session_id, info.url, process, seek_time): _kill_process(process) raise HTTPException(404, "Session disappeared during seek") # Persist seek_offset session_json = output_path / "session.json" if session_json.exists(): try: data = json.loads(session_json.read_text()) data["seek_offset"] = seek_time session_json.write_text(json.dumps(data)) except Exception as e: log.warning("Failed to update session.json for %s: %s", session_id, e) _spawn_background_task(_monitor_seek_ffmpeg(process, session_id)) if not await _wait_for_playlist( playlist_file, process, min_segments=2, timeout_sec=_PLAYLIST_WAIT_TIMEOUT_SEC, ): raise HTTPException(500, "Seek transcode timed out waiting for playlist") log.info("Seek ready: %s", playlist_file) return { "ok": True, "segment": segment_num, "time": seek_time, } ================================================ FILE: ffmpeg_session_test.py ================================================ """Tests for ffmpeg session management.""" from unittest.mock import patch import json import pathlib import tempfile import time from ffmpeg_session import ( _HEARTBEAT_TIMEOUT_SEC, _build_subtitle_tracks, _calc_hls_duration, _DeadProcess, _is_process_alive, _kill_process, _regenerate_playlist, _transcode_lock, _transcode_sessions, _url_to_session, cleanup_and_recover_sessions, cleanup_expired_sessions, clear_url_session, enforce_stream_limits, get_live_cache_timeout, get_session, get_session_progress, get_source_sessions, get_user_sessions, get_vod_cache_timeout, is_session_valid, shutdown, stop_session, touch_session, ) class FakeProcess: """Fake async process for testing.""" def __init__(self, alive: bool = True, killed: bool = False): self.returncode = None if alive else 0 self._killed = killed def terminate(self) -> None: if self._killed: raise ProcessLookupError("No such process") self.returncode = -15 # SIGTERM def kill(self) -> None: if self._killed: raise ProcessLookupError("No such process") self.returncode = -9 # SIGKILL def _clear_session_state(): """Clear all session state for test isolation.""" with _transcode_lock: _transcode_sessions.clear() _url_to_session.clear() # ============================================================================= # Process Lifecycle Tests # ============================================================================= class TestIsProcessAlive: """Tests for _is_process_alive.""" def test_none_is_dead(self): assert _is_process_alive(None) is False def test_dead_process_placeholder(self): assert _is_process_alive(_DeadProcess()) is False def test_alive_process(self): proc = FakeProcess(alive=True) assert _is_process_alive(proc) is True def test_dead_process(self): proc = FakeProcess(alive=False) assert _is_process_alive(proc) is False class TestKillProcess: """Tests for _kill_process.""" def test_kill_alive_process(self): proc = FakeProcess(alive=True) assert _kill_process(proc) is True # SIGTERM (-15) is used first; if process exits, SIGKILL (-9) isn't needed assert proc.returncode == -15 def test_kill_already_dead(self): proc = FakeProcess(alive=True, killed=True) assert _kill_process(proc) is False # ============================================================================= # Session Validity Tests # ============================================================================= class TestIsSessionValid: """Tests for is_session_valid with heartbeat timeout.""" def test_active_process_with_recent_heartbeat(self): """Active process + recent heartbeat = valid.""" session = { "process": FakeProcess(alive=True), "started": time.time(), "last_access": time.time(), "is_vod": False, } assert is_session_valid(session) is True def test_active_process_stale_heartbeat(self): """Active process but no heartbeat in 5+ min = invalid.""" session = { "process": FakeProcess(alive=True), "started": time.time() - 400, "last_access": time.time() - 400, # 6+ min ago "is_vod": False, } assert is_session_valid(session) is False def test_dead_process_live_session_no_cache(self): """Dead process, live session, cache=0 = invalid.""" with patch("ffmpeg_session.get_live_cache_timeout", return_value=0): session = { "process": FakeProcess(alive=False), "started": time.time(), "last_access": time.time(), "is_vod": False, } assert is_session_valid(session) is False def test_dead_process_vod_session_within_cache(self): """Dead process, VOD session, within cache timeout = valid.""" with patch("ffmpeg_session.get_vod_cache_timeout", return_value=3600): session = { "process": FakeProcess(alive=False), "started": time.time() - 10, "last_access": time.time() - 10, # 10 sec ago (within 30 sec heartbeat) "is_vod": True, } assert is_session_valid(session) is True def test_dead_process_vod_session_expired_cache(self): """Dead process, VOD session, past cache timeout = invalid.""" with patch("ffmpeg_session.get_vod_cache_timeout", return_value=60): session = { "process": FakeProcess(alive=False), "started": time.time() - 120, "last_access": time.time() - 120, # 2 min ago, cache is 1 min "is_vod": True, } assert is_session_valid(session) is False def test_heartbeat_timeout_boundary(self): """Test exactly at heartbeat timeout boundary.""" # Just under timeout = valid (if process alive) session = { "process": FakeProcess(alive=True), "started": time.time() - (_HEARTBEAT_TIMEOUT_SEC - 1), "last_access": time.time() - (_HEARTBEAT_TIMEOUT_SEC - 1), "is_vod": False, } assert is_session_valid(session) is True # Just over timeout = invalid session["last_access"] = time.time() - (_HEARTBEAT_TIMEOUT_SEC + 1) assert is_session_valid(session) is False def test_missing_last_access_uses_started(self): """If last_access missing, falls back to started time.""" session = { "process": FakeProcess(alive=True), "started": time.time(), "is_vod": False, } assert is_session_valid(session) is True # ============================================================================= # Cache Timeout Tests # ============================================================================= class TestCacheTimeouts: """Tests for cache timeout getters.""" def test_vod_cache_timeout_default(self): """VOD cache default is 60 min = 3600 sec.""" with patch("ffmpeg_session.get_settings", return_value={}): assert get_vod_cache_timeout() == 3600 def test_vod_cache_timeout_custom(self): """VOD cache from settings.""" with patch("ffmpeg_session.get_settings", return_value={"vod_transcode_cache_mins": 30}): assert get_vod_cache_timeout() == 1800 def test_live_cache_timeout_default(self): """Live cache default is 0 (no caching).""" with patch("ffmpeg_session.get_settings", return_value={}): assert get_live_cache_timeout() == 0 def test_live_cache_timeout_custom(self): """Live cache from settings.""" with patch("ffmpeg_session.get_settings", return_value={"live_transcode_cache_secs": 30}): assert get_live_cache_timeout() == 30 # ============================================================================= # Session Start/Stop Tests # ============================================================================= class TestStopSession: """Tests for stop_session.""" def setup_method(self): _clear_session_state() def teardown_method(self): _clear_session_state() def test_stop_nonexistent_session(self): """Stopping nonexistent session is a no-op.""" stop_session("nonexistent") # Should not raise def test_stop_session_force(self): """Force stop removes session.""" with tempfile.TemporaryDirectory() as tmp: session_id = "test-123" with _transcode_lock: _transcode_sessions[session_id] = { "process": FakeProcess(alive=True), "dir": tmp, "url": "http://test", "last_access": time.time(), } _url_to_session["http://test"] = session_id stop_session(session_id, force=True) assert session_id not in _transcode_sessions assert "http://test" not in _url_to_session def test_stop_session_skip_recent_vod(self): """Skip stop for recently-accessed VOD session (race protection for seeking).""" session_id = "test-456" with _transcode_lock: _transcode_sessions[session_id] = { "process": FakeProcess(alive=True), "dir": "/tmp/test", "url": "http://test", "is_vod": True, # Grace period only applies to VOD "last_access": time.time(), # Just now } stop_session(session_id, force=False) # VOD session should still exist because it was recently accessed assert session_id in _transcode_sessions def test_stop_session_skips_recent_live(self): """Live sessions also get grace period for multi-user support.""" with tempfile.TemporaryDirectory() as tmp: session_id = "test-live" with _transcode_lock: _transcode_sessions[session_id] = { "process": FakeProcess(alive=True), "dir": tmp, "url": "http://live", "is_vod": False, "last_access": time.time(), # Just now } _url_to_session["http://live"] = session_id with patch("ffmpeg_session.get_live_cache_timeout", return_value=0): stop_session(session_id, force=False) # Live session should still exist because it was recently accessed assert session_id in _transcode_sessions def test_stop_session_multi_user_grace_period(self): """Stopping session while another user watching should preserve session.""" with tempfile.TemporaryDirectory() as tmp: session_id = "test-shared" with _transcode_lock: _transcode_sessions[session_id] = { "process": FakeProcess(alive=True), "dir": tmp, "url": "http://shared-stream", "is_vod": False, "last_access": time.time() - 10, # User A started 10 sec ago } _url_to_session["http://shared-stream"] = session_id # User B accesses stream (simulates progress poll or segment request) touch_session(session_id) # User A disconnects and triggers stop with patch("ffmpeg_session.get_live_cache_timeout", return_value=0): stop_session(session_id, force=False) # Session should survive because User B just accessed it assert session_id in _transcode_sessions assert _transcode_sessions[session_id]["process"].returncode is None def test_stop_session_caches_vod(self): """Stop caches VOD session instead of removing it.""" with tempfile.TemporaryDirectory() as tmp: session_id = "test-vod" with _transcode_lock: _transcode_sessions[session_id] = { "process": FakeProcess(alive=True), "dir": tmp, "url": "http://vod", "is_vod": True, "last_access": time.time() - 10, # Old enough to stop } _url_to_session["http://vod"] = session_id with patch("ffmpeg_session.get_vod_cache_timeout", return_value=3600): stop_session(session_id, force=False) # Session should still exist (cached) assert session_id in _transcode_sessions class TestCleanupExpiredSessions: """Tests for cleanup_expired_sessions.""" def setup_method(self): _clear_session_state() def teardown_method(self): _clear_session_state() def test_cleanup_removes_expired(self): """Cleanup removes expired sessions.""" with tempfile.TemporaryDirectory() as tmp: session_id = "expired-session" with _transcode_lock: _transcode_sessions[session_id] = { "process": FakeProcess(alive=False), "dir": tmp, "url": "http://expired", "is_vod": False, "started": time.time() - 400, "last_access": time.time() - 400, # Expired } with patch("ffmpeg_session.get_live_cache_timeout", return_value=0): cleanup_expired_sessions() assert session_id not in _transcode_sessions def test_cleanup_keeps_valid(self): """Cleanup keeps valid sessions.""" session_id = "valid-session" with _transcode_lock: _transcode_sessions[session_id] = { "process": FakeProcess(alive=True), "dir": "/tmp/test", "url": "http://valid", "is_vod": False, "started": time.time(), "last_access": time.time(), } cleanup_expired_sessions() assert session_id in _transcode_sessions class TestShutdown: """Tests for shutdown.""" def setup_method(self): _clear_session_state() def teardown_method(self): _clear_session_state() def test_shutdown_kills_all_processes(self): """Shutdown kills all processes and clears sessions.""" proc1 = FakeProcess(alive=True) proc2 = FakeProcess(alive=True) with _transcode_lock: _transcode_sessions["s1"] = {"process": proc1, "dir": "/tmp/1"} _transcode_sessions["s2"] = {"process": proc2, "dir": "/tmp/2"} shutdown() # SIGTERM (-15) is used first; if process exits, SIGKILL (-9) isn't needed assert proc1.returncode == -15 assert proc2.returncode == -15 assert len(_transcode_sessions) == 0 # ============================================================================= # Stream Limits Tests # ============================================================================= class TestGetUserSessions: """Tests for get_user_sessions.""" def setup_method(self): _clear_session_state() def teardown_method(self): _clear_session_state() def test_get_user_sessions_filters_by_username(self): """Returns only sessions for specified user.""" with _transcode_lock: _transcode_sessions["s1"] = {"username": "alice", "started": 1} _transcode_sessions["s2"] = {"username": "bob", "started": 2} _transcode_sessions["s3"] = {"username": "alice", "started": 3} sessions = get_user_sessions("alice") assert len(sessions) == 2 assert sessions[0][0] == "s1" # Sorted by start time assert sessions[1][0] == "s3" def test_get_user_sessions_empty(self): """Returns empty list for user with no sessions.""" sessions = get_user_sessions("nobody") assert sessions == [] class TestGetSourceSessions: """Tests for get_source_sessions.""" def setup_method(self): _clear_session_state() def teardown_method(self): _clear_session_state() def test_get_source_sessions_filters_by_source(self): """Returns only sessions for specified source.""" with _transcode_lock: _transcode_sessions["s1"] = {"source_id": "src1", "started": 1} _transcode_sessions["s2"] = {"source_id": "src2", "started": 2} _transcode_sessions["s3"] = {"source_id": "src1", "started": 3} sessions = get_source_sessions("src1") assert len(sessions) == 2 assert sessions[0][0] == "s1" assert sessions[1][0] == "s3" class TestEnforceStreamLimits: """Tests for enforce_stream_limits.""" def setup_method(self): _clear_session_state() def teardown_method(self): _clear_session_state() def test_no_limits_returns_none(self): """No limits set = no error.""" result = enforce_stream_limits("alice", None, 0, 0) assert result is None def test_user_limit_stops_oldest(self): """User at limit stops their oldest session.""" with tempfile.TemporaryDirectory() as tmp: with _transcode_lock: _transcode_sessions["s1"] = { "username": "alice", "started": 1, "process": FakeProcess(alive=True), "dir": tmp, "url": "http://1", "last_access": 0, # Old enough to stop } _transcode_sessions["s2"] = { "username": "alice", "started": 2, "process": FakeProcess(alive=True), "dir": "/tmp/2", "url": "http://2", "last_access": time.time(), } result = enforce_stream_limits("alice", None, 2, 0) assert result is None assert "s1" not in _transcode_sessions assert "s2" in _transcode_sessions def test_source_limit_stops_user_session(self): """Source at limit stops user's oldest session on that source.""" with tempfile.TemporaryDirectory() as tmp: with _transcode_lock: _transcode_sessions["s1"] = { "username": "alice", "source_id": "src1", "started": 1, "process": FakeProcess(alive=True), "dir": tmp, "url": "http://1", "last_access": 0, } result = enforce_stream_limits("alice", "src1", 0, 1) assert result is None assert "s1" not in _transcode_sessions def test_source_limit_returns_error_for_other_user(self): """Source at limit with other user's session returns error.""" with _transcode_lock: _transcode_sessions["s1"] = { "username": "bob", "source_id": "src1", "started": 1, "process": FakeProcess(alive=True), "dir": "/tmp/1", "url": "http://1", } result = enforce_stream_limits("alice", "src1", 0, 1) assert result == "Source at capacity (1 streams)" assert "s1" in _transcode_sessions # Not stopped # ============================================================================= # Session Query/Update Tests # ============================================================================= class TestGetSession: """Tests for get_session.""" def setup_method(self): _clear_session_state() def teardown_method(self): _clear_session_state() def test_get_existing_session(self): """Returns copy of session dict.""" with _transcode_lock: _transcode_sessions["test"] = {"dir": "/tmp", "url": "http://test"} session = get_session("test") assert session is not None assert session["dir"] == "/tmp" def test_get_nonexistent_session(self): """Returns None for nonexistent session.""" assert get_session("nonexistent") is None class TestTouchSession: """Tests for touch_session (heartbeat).""" def setup_method(self): _clear_session_state() def teardown_method(self): _clear_session_state() def test_touch_updates_last_access(self): """Touch updates last_access timestamp.""" old_time = time.time() - 100 with _transcode_lock: _transcode_sessions["test"] = {"last_access": old_time} result = touch_session("test") assert result is True assert _transcode_sessions["test"]["last_access"] > old_time def test_touch_nonexistent_returns_false(self): """Touch returns False for nonexistent session.""" assert touch_session("nonexistent") is False class TestGetSessionProgress: """Tests for get_session_progress.""" def setup_method(self): _clear_session_state() def teardown_method(self): _clear_session_state() def test_progress_with_playlist(self): """Returns progress from playlist.""" with tempfile.TemporaryDirectory() as tmp: playlist = pathlib.Path(tmp) / "stream.m3u8" playlist.write_text("#EXTM3U\n#EXTINF:3.0,\nseg0.ts\n#EXTINF:3.0,\nseg1.ts\n") with _transcode_lock: _transcode_sessions["test"] = {"dir": tmp, "last_access": 0} progress = get_session_progress("test") assert progress is not None assert progress["segment_count"] == 2 assert progress["duration"] == 6.0 def test_progress_no_playlist(self): """Returns zero progress without playlist.""" with tempfile.TemporaryDirectory() as tmp: with _transcode_lock: _transcode_sessions["test"] = {"dir": tmp, "last_access": 0} progress = get_session_progress("test") assert progress == {"segment_count": 0, "duration": 0.0} def test_progress_nonexistent_session(self): """Returns None for nonexistent session.""" assert get_session_progress("nonexistent") is None class TestClearUrlSession: """Tests for clear_url_session.""" def setup_method(self): _clear_session_state() def teardown_method(self): _clear_session_state() def test_clear_existing_url(self): """Clears existing URL mapping.""" with _transcode_lock: _url_to_session["http://test"] = "session-123" result = clear_url_session("http://test") assert result == "session-123" assert "http://test" not in _url_to_session def test_clear_nonexistent_url(self): """Returns None for nonexistent URL.""" result = clear_url_session("http://nonexistent") assert result is None # ============================================================================= # Playlist Helper Tests # ============================================================================= class TestCalcHlsDuration: """Tests for _calc_hls_duration.""" def test_duration_from_playlist(self): """Calculates duration from EXTINF entries.""" with tempfile.TemporaryDirectory() as tmp: playlist = pathlib.Path(tmp) / "stream.m3u8" playlist.write_text("#EXTM3U\n#EXTINF:3.5,\nseg0.ts\n#EXTINF:3.0,\nseg1.ts\n") duration = _calc_hls_duration(playlist, 2) assert duration == 6.5 def test_duration_estimate_from_segments(self): """Estimates duration when playlist missing.""" with patch("ffmpeg_session.get_hls_segment_duration", return_value=3.0): playlist = pathlib.Path("/nonexistent/stream.m3u8") duration = _calc_hls_duration(playlist, 5) assert duration == 15.0 class TestBuildSubtitleTracks: """Tests for _build_subtitle_tracks.""" def test_builds_track_list(self): """Builds subtitle track list.""" sub_info = [ {"index": 2, "lang": "eng", "name": "English"}, {"index": 3, "lang": "jpn", "name": "Japanese"}, ] tracks = _build_subtitle_tracks("session-123", sub_info) assert len(tracks) == 2 assert tracks[0]["url"] == "/subs/session-123/sub0.vtt" assert tracks[0]["lang"] == "eng" assert tracks[0]["label"] == "English" assert tracks[0]["default"] is True assert tracks[1]["default"] is False def test_empty_sub_info(self): """Returns empty list for no subtitles.""" assert _build_subtitle_tracks("s", []) == [] assert _build_subtitle_tracks("s", None) == [] # type: ignore[arg-type] def test_non_dict_sub_info(self): """Returns empty list for old format (indices only).""" assert _build_subtitle_tracks("s", [2, 3]) == [] # type: ignore[arg-type] class TestRegeneratePlaylist: """Tests for _regenerate_playlist.""" def test_regenerates_playlist_from_segments(self): """Regenerates playlist from segment files.""" with tempfile.TemporaryDirectory() as tmp: output_dir = pathlib.Path(tmp) # Create segment files (output_dir / "seg000.ts").write_bytes(b"x" * 2000) (output_dir / "seg001.ts").write_bytes(b"x" * 2000) (output_dir / "seg002.ts").write_bytes(b"x" * 2000) with patch("ffmpeg_session.get_hls_segment_duration", return_value=3.0): _regenerate_playlist(output_dir, start_segment=1) playlist = output_dir / "stream.m3u8" assert playlist.exists() content = playlist.read_text() assert "#EXT-X-MEDIA-SEQUENCE:1" in content assert "seg001.ts" in content assert "seg002.ts" in content assert "seg000.ts" not in content # Before start_segment def test_regenerate_skips_small_segments(self): """Skips segments smaller than threshold.""" with tempfile.TemporaryDirectory() as tmp: output_dir = pathlib.Path(tmp) (output_dir / "seg000.ts").write_bytes(b"x" * 500) # Too small (output_dir / "seg001.ts").write_bytes(b"x" * 2000) # OK with patch("ffmpeg_session.get_hls_segment_duration", return_value=3.0): _regenerate_playlist(output_dir, start_segment=0) content = (output_dir / "stream.m3u8").read_text() assert "seg001.ts" in content assert "seg000.ts" not in content # ============================================================================= # Session Recovery Tests # ============================================================================= class TestCleanupAndRecoverSessions: """Tests for cleanup_and_recover_sessions.""" def setup_method(self): _clear_session_state() def teardown_method(self): _clear_session_state() def test_removes_orphaned_dirs(self): """Removes dirs without session.json.""" with tempfile.TemporaryDirectory() as tmp: transcode_dir = pathlib.Path(tmp) orphan = transcode_dir / "netv_transcode_orphan" orphan.mkdir() (orphan / "seg000.ts").write_bytes(b"data") with ( patch("ffmpeg_session.get_transcode_dir", return_value=transcode_dir), patch("ffmpeg_session.get_vod_cache_timeout", return_value=3600), ): cleanup_and_recover_sessions() assert not orphan.exists() def test_recovers_valid_vod_session(self): """Recovers valid VOD session with segments.""" with tempfile.TemporaryDirectory() as tmp: transcode_dir = pathlib.Path(tmp) vod_dir = transcode_dir / "netv_transcode_vod123" vod_dir.mkdir() # Create session.json session_info = { "session_id": "vod123", "url": "http://movie.mp4", "is_vod": True, "started": time.time(), "duration": 3600, } (vod_dir / "session.json").write_text(json.dumps(session_info)) (vod_dir / "seg000.ts").write_bytes(b"x" * 2000) with ( patch("ffmpeg_session.get_transcode_dir", return_value=transcode_dir), patch("ffmpeg_session.get_vod_cache_timeout", return_value=3600), ): cleanup_and_recover_sessions() assert "vod123" in _transcode_sessions assert _url_to_session.get("http://movie.mp4") == "vod123" def test_removes_expired_vod_session(self): """Removes expired VOD session (older than cache timeout).""" with tempfile.TemporaryDirectory() as tmp: transcode_dir = pathlib.Path(tmp) vod_dir = transcode_dir / "netv_transcode_expired" vod_dir.mkdir() session_info = { "session_id": "expired", "url": "http://old.mp4", "is_vod": True, } (vod_dir / "session.json").write_text(json.dumps(session_info)) (vod_dir / "seg000.ts").write_bytes(b"x" * 2000) # Very short cache timeout with ( patch("ffmpeg_session.get_transcode_dir", return_value=transcode_dir), patch("ffmpeg_session.get_vod_cache_timeout", return_value=0), ): cleanup_and_recover_sessions() assert not vod_dir.exists() assert "expired" not in _transcode_sessions if __name__ == "__main__": from testing import run_tests run_tests(__file__) ================================================ FILE: m3u.py ================================================ """M3U parsing, live/VOD/series data loading.""" from __future__ import annotations from typing import Any import logging import re import threading import time from cache import ( LIVE_CACHE_TTL, SERIES_CACHE_TTL, VOD_CACHE_TTL, get_cache, get_cache_lock, get_sources, load_file_cache, save_file_cache, update_source_epg_url, ) from util import safe_urlopen from xtream import XtreamClient log = logging.getLogger(__name__) _refresh_in_progress: set[str] = set() _fetch_locks: dict[str, threading.Lock] = { "live": threading.Lock(), "vod": threading.Lock(), "series": threading.Lock(), "epg": threading.Lock(), } def parse_m3u(content: str, source_id: str) -> tuple[list[dict], list[dict], str]: """Parse M3U content, return (categories, streams, epg_url).""" categories: dict[str, dict] = {} streams: list[dict] = [] stream_id_counter = 0 epg_url = "" lines = content.strip().split("\n") i = 0 while i < len(lines): line = lines[i].strip() if line.startswith("#EXTM3U"): match = re.search(r'(?:url-tvg|x-tvg-url)="([^"]*)"', line) if match: epg_url = match.group(1) elif line.startswith("#EXTINF:"): attrs: dict[str, str] = {} match = re.search(r"#EXTINF:[^,]*,(.*)", line) name = match.group(1).strip() if match else "Unknown" for attr_match in re.finditer(r'(\w+[-\w]*)="([^"]*)"', line): attrs[attr_match.group(1)] = attr_match.group(2) i += 1 while i < len(lines) and (not lines[i].strip() or lines[i].startswith("#")): i += 1 url = lines[i].strip() if i < len(lines) else "" group = attrs.get("group-title", "Uncategorized") if group not in categories: cat_slug = re.sub(r"[^a-zA-Z0-9]+", "_", group).strip("_").lower() cat_id = f"{source_id}_{cat_slug}" categories[group] = { "category_id": cat_id, "category_name": group, "parent_id": 0, "source_id": source_id, } stream_id_counter += 1 streams.append( { "stream_id": f"{source_id}_{stream_id_counter}", "name": name, "stream_icon": attrs.get("tvg-logo", ""), "epg_channel_id": attrs.get("tvg-id", ""), "category_ids": [categories[group]["category_id"]], "direct_url": url, "source_id": source_id, } ) i += 1 streams_with_epg = sum(1 for s in streams if s.get("epg_channel_id")) log.debug( "M3U parsed: %d streams (%d with tvg-id, %d without), %d categories", len(streams), streams_with_epg, len(streams) - streams_with_epg, len(categories), ) return list(categories.values()), streams, epg_url def fetch_m3u(url: str, source_id: str, timeout: int = 30) -> tuple[list[dict], list[dict], str]: """Fetch and parse M3U from URL, return (categories, streams, epg_url).""" with safe_urlopen(url, timeout=timeout) as resp: content = resp.read().decode("utf-8") return parse_m3u(content, source_id) def _fetch_all_live_data() -> tuple[list[dict], list[dict], list[tuple[str, int, str]]]: """Fetch live categories/streams from all sources.""" all_categories: list[dict] = [] all_streams: list[dict] = [] epg_urls: list[tuple[str, int, str]] = [] for source in get_sources(): try: if source.type == "xtream": client = XtreamClient(source.url, source.username, source.password) cats = client.get_live_categories() streams = client.get_live_streams() for c in cats: c["source_id"] = source.id c["category_id"] = f"{source.id}_{c['category_id']}" for s in streams: s["source_id"] = source.id s["source_type"] = "xtream" s["source_url"] = source.url s["source_username"] = source.username s["source_password"] = source.password orig_cats = s.get("category_ids") or [s.get("category_id")] s["category_ids"] = [f"{source.id}_{c}" for c in orig_cats if c] all_categories.extend(cats) all_streams.extend(streams) if source.epg_enabled: epg_urls.append((client.epg_url, source.epg_timeout, source.id)) elif source.type == "m3u": cats, streams, epg_url = fetch_m3u(source.url, source.id) all_categories.extend(cats) all_streams.extend(streams) if epg_url and source.epg_enabled: epg_urls.append((epg_url, source.epg_timeout, source.id)) elif source.type == "epg": if source.epg_enabled: epg_urls.append((source.url, source.epg_timeout, source.id)) except Exception as e: log.error("Error loading source %s: %s", source.name, e) return all_categories, all_streams, epg_urls def fetch_source_live_data(source: Any) -> tuple[list[dict], list[dict], str | None, int]: """Fetch live data for a single source. Returns (cats, streams, epg_url, epg_timeout).""" cats: list[dict] = [] streams: list[dict] = [] epg_url: str | None = None if source.type == "xtream": client = XtreamClient(source.url, source.username, source.password) cats = client.get_live_categories() streams = client.get_live_streams() for c in cats: c["source_id"] = source.id c["category_id"] = f"{source.id}_{c['category_id']}" for s in streams: s["source_id"] = source.id s["source_type"] = "xtream" s["source_url"] = source.url s["source_username"] = source.username s["source_password"] = source.password orig_cats = s.get("category_ids") or [s.get("category_id")] s["category_ids"] = [f"{source.id}_{c}" for c in orig_cats if c] detected_epg = client.epg_url update_source_epg_url(source.id, detected_epg) epg_url = detected_epg if source.epg_enabled else None elif source.type == "m3u": cats, streams, detected_epg = fetch_m3u(source.url, source.id) update_source_epg_url(source.id, detected_epg) epg_url = detected_epg if source.epg_enabled else None elif source.type == "epg": epg_url = source.url return cats, streams, epg_url, source.epg_timeout def fetch_source_vod_data(source: Any) -> tuple[list[dict], list[dict]]: """Fetch VOD data for a single Xtream source.""" if source.type != "xtream": return [], [] client = XtreamClient(source.url, source.username, source.password) cats = client.get_vod_categories() streams = client.get_vod_streams() # Tag with source_id for playback for c in cats: c["source_id"] = source.id for s in streams: s["source_id"] = source.id return cats, streams def parse_epg_urls(raw: list) -> list[tuple[str, int, str]]: """Convert JSON list back to tuples (JSON stores tuples as lists).""" return [(u[0], u[1], u[2]) for u in raw if isinstance(u, (list, tuple)) and len(u) >= 3] def load_all_live_data() -> tuple[list[dict], list[dict], list[tuple[str, int, str]]]: """Load live data with file cache and stale-while-revalidate.""" _cache = get_cache() _cache_lock = get_cache_lock() cached = load_file_cache("live_data") now = time.time() if cached: data, ts = cached cats, streams = data["cats"], data["streams"] epg_urls = parse_epg_urls(data.get("epg_urls", [])) age = now - ts if age > LIVE_CACHE_TTL and "live" not in _refresh_in_progress: _refresh_in_progress.add("live") def refresh() -> None: try: log.info("Refreshing live data in background") new_cats, new_streams, new_epg_urls = _fetch_all_live_data() save_file_cache( "live_data", {"cats": new_cats, "streams": new_streams, "epg_urls": new_epg_urls}, ) with _cache_lock: _cache.pop("live_categories", None) _cache.pop("live_streams", None) _cache["epg_urls"] = new_epg_urls log.info("Live data refreshed") finally: _refresh_in_progress.discard("live") threading.Thread(target=refresh, daemon=True).start() return cats, streams, epg_urls with _fetch_locks["live"]: cached = load_file_cache("live_data") if cached: data, _ = cached return data["cats"], data["streams"], parse_epg_urls(data.get("epg_urls", [])) log.info("No live cache, fetching") cats, streams, epg_urls = _fetch_all_live_data() save_file_cache("live_data", {"cats": cats, "streams": streams, "epg_urls": epg_urls}) return cats, streams, epg_urls def _fetch_vod_data() -> tuple[list[dict], list[dict]]: """Fetch VOD categories and streams from all Xtream sources.""" all_cats: list[dict] = [] all_streams: list[dict] = [] for source in get_sources(): if source.type != "xtream": continue try: client = XtreamClient(source.url, source.username, source.password) cats = client.get_vod_categories() streams = client.get_vod_streams() # Tag with source_id for playback and access control for c in cats: c["source_id"] = source.id for s in streams: s["source_id"] = source.id all_cats.extend(cats) all_streams.extend(streams) except Exception as e: log.warning("Failed to fetch VOD from source %s: %s", source.id, e) return all_cats, all_streams def load_vod_data() -> tuple[list[dict], list[dict]]: """Load VOD data with file cache and stale-while-revalidate.""" _cache = get_cache() _cache_lock = get_cache_lock() cached = load_file_cache("vod_data") now = time.time() if cached: data, ts = cached cats, streams = data["cats"], data["streams"] age = now - ts if age > VOD_CACHE_TTL and "vod" not in _refresh_in_progress: _refresh_in_progress.add("vod") def refresh() -> None: try: log.info("Refreshing VOD data in background") new_cats, new_streams = _fetch_vod_data() save_file_cache("vod_data", {"cats": new_cats, "streams": new_streams}) with _cache_lock: _cache.pop("vod_categories", None) _cache.pop("vod_streams", None) log.info("VOD data refreshed") finally: _refresh_in_progress.discard("vod") threading.Thread(target=refresh, daemon=True).start() return cats, streams with _fetch_locks["vod"]: cached = load_file_cache("vod_data") if cached: data, _ = cached return data["cats"], data["streams"] log.info("No VOD cache, fetching") cats, streams = _fetch_vod_data() if cats or streams: save_file_cache("vod_data", {"cats": cats, "streams": streams}) return cats, streams def _fetch_series_data() -> tuple[list[dict], list[dict]]: """Fetch series categories and list from all Xtream sources.""" all_cats: list[dict] = [] all_series: list[dict] = [] for source in get_sources(): if source.type != "xtream": continue try: client = XtreamClient(source.url, source.username, source.password) cats = client.get_series_categories() series = client.get_series() # Tag with source_id for playback and access control for c in cats: c["source_id"] = source.id for s in series: s["source_id"] = source.id all_cats.extend(cats) all_series.extend(series) except Exception as e: log.warning("Failed to fetch series from source %s: %s", source.id, e) return all_cats, all_series def load_series_data() -> tuple[list[dict], list[dict]]: """Load series data with file cache and stale-while-revalidate.""" _cache = get_cache() _cache_lock = get_cache_lock() cached = load_file_cache("series_data") now = time.time() if cached: data, ts = cached cats, series = data["cats"], data["series"] age = now - ts if age > SERIES_CACHE_TTL and "series" not in _refresh_in_progress: _refresh_in_progress.add("series") def refresh() -> None: try: log.info("Refreshing series data in background") new_cats, new_series = _fetch_series_data() save_file_cache("series_data", {"cats": new_cats, "series": new_series}) with _cache_lock: _cache.pop("series_categories", None) _cache.pop("series", None) log.info("Series data refreshed") finally: _refresh_in_progress.discard("series") threading.Thread(target=refresh, daemon=True).start() return cats, series with _fetch_locks["series"]: cached = load_file_cache("series_data") if cached: data, _ = cached return data["cats"], data["series"] log.info("No series cache, fetching") cats, series = _fetch_series_data() if cats or series: save_file_cache("series_data", {"cats": cats, "series": series}) return cats, series def get_first_xtream_client() -> XtreamClient | None: """Get the first available Xtream client (for VOD/series).""" for source in get_sources(): if source.type == "xtream": return XtreamClient(source.url, source.username, source.password) return None def get_xtream_client_by_source(source_id: str) -> XtreamClient | None: """Get Xtream client for a specific source ID.""" for source in get_sources(): if source.id == source_id and source.type == "xtream": return XtreamClient(source.url, source.username, source.password) return None def get_first_xtream_source_and_client() -> tuple[str, XtreamClient] | tuple[None, None]: """Get the first available Xtream source ID and client.""" for source in get_sources(): if source.type == "xtream": return source.id, XtreamClient(source.url, source.username, source.password) return None, None def get_fetch_lock(name: str) -> threading.Lock: """Get fetch lock by name.""" return _fetch_locks[name] def get_refresh_in_progress() -> set[str]: """Get refresh in progress set.""" return _refresh_in_progress ================================================ FILE: m3u_test.py ================================================ """Tests for m3u.py.""" from __future__ import annotations from pathlib import Path import pytest @pytest.fixture def m3u_module(tmp_path: Path): """Import m3u module with mocked cache.""" import cache cache.SERVER_SETTINGS_FILE = tmp_path / "server_settings.json" cache.USERS_DIR = tmp_path / "users" cache.USERS_DIR.mkdir(exist_ok=True) cache.CACHE_DIR = tmp_path / "cache" cache.CACHE_DIR.mkdir(exist_ok=True) cache.get_cache().clear() import m3u yield m3u cache.get_cache().clear() class TestParseM3u: def test_parse_basic_m3u(self, m3u_module): content = """#EXTM3U #EXTINF:-1 tvg-id="ch1" tvg-logo="http://logo.png" group-title="News",Channel One http://stream.example.com/ch1.m3u8 #EXTINF:-1 tvg-id="ch2" group-title="Sports",Channel Two http://stream.example.com/ch2.m3u8 """ cats, streams, _ = m3u_module.parse_m3u(content, "src1") assert len(cats) == 2 assert any(c["category_name"] == "News" for c in cats) assert any(c["category_name"] == "Sports" for c in cats) assert len(streams) == 2 assert streams[0]["name"] == "Channel One" assert streams[0]["epg_channel_id"] == "ch1" assert streams[0]["stream_icon"] == "http://logo.png" assert streams[0]["direct_url"] == "http://stream.example.com/ch1.m3u8" assert streams[0]["source_id"] == "src1" def test_parse_m3u_with_epg_url(self, m3u_module): content = """#EXTM3U url-tvg="http://epg.example.com/guide.xml" #EXTINF:-1,Test Channel http://test.stream """ _, _, epg_url = m3u_module.parse_m3u(content, "src1") assert epg_url == "http://epg.example.com/guide.xml" def test_parse_m3u_x_tvg_url(self, m3u_module): content = """#EXTM3U x-tvg-url="http://alt.epg.com/guide.xml" #EXTINF:-1,Test Channel http://test.stream """ _, _, epg_url = m3u_module.parse_m3u(content, "src1") assert epg_url == "http://alt.epg.com/guide.xml" def test_parse_m3u_uncategorized(self, m3u_module): content = """#EXTM3U #EXTINF:-1,No Group Channel http://stream.example.com/nogroupch.m3u8 """ cats, streams, _ = m3u_module.parse_m3u(content, "src1") assert len(cats) == 1 assert cats[0]["category_name"] == "Uncategorized" assert streams[0]["category_ids"][0].endswith("_uncategorized") def test_parse_m3u_category_ids_prefixed(self, m3u_module): content = """#EXTM3U #EXTINF:-1 group-title="Movies",Test http://test """ cats, streams, _ = m3u_module.parse_m3u(content, "mysource") assert cats[0]["category_id"].startswith("mysource_") assert streams[0]["category_ids"][0].startswith("mysource_") def test_parse_m3u_empty(self, m3u_module): cats, streams, epg_url = m3u_module.parse_m3u("", "src1") assert cats == [] assert streams == [] assert epg_url == "" class TestParseEpgUrls: def test_parse_tuple_list(self, m3u_module): raw = [["http://epg1.com", 120, "src1"], ["http://epg2.com", 60, "src2"]] result = m3u_module.parse_epg_urls(raw) assert len(result) == 2 assert result[0] == ("http://epg1.com", 120, "src1") assert result[1] == ("http://epg2.com", 60, "src2") def test_parse_tuple_passthrough(self, m3u_module): raw = [("http://epg.com", 100, "s1")] result = m3u_module.parse_epg_urls(raw) assert result[0] == ("http://epg.com", 100, "s1") def test_parse_empty(self, m3u_module): assert m3u_module.parse_epg_urls([]) == [] def test_parse_skips_malformed(self, m3u_module): raw = [["http://epg.com", 90], "plain_string", ["http://valid.com", 60, "src"]] result = m3u_module.parse_epg_urls(raw) assert len(result) == 1 assert result[0] == ("http://valid.com", 60, "src") class TestFetchLocks: def test_get_fetch_lock(self, m3u_module): lock = m3u_module.get_fetch_lock("live") assert lock is not None def test_get_refresh_in_progress(self, m3u_module): rip = m3u_module.get_refresh_in_progress() assert isinstance(rip, set) if __name__ == "__main__": from testing import run_tests run_tests(__file__) ================================================ FILE: main.py ================================================ #!/usr/bin/env python3 # /// script # requires-python = ">=3.11" # dependencies = ["fastapi", "uvicorn[standard]", "jinja2", "python-multipart", "cryptography", "defusedxml"] # /// """IPTV Web App. Usage: ./main.py [--port PORT] [--https] [--cert FILE --key FILE] Options: --port PORT Port to listen on (default: 8000) --https Enable HTTPS using Let's Encrypt certs (auto-detect domain) --cert FILE SSL certificate file (overrides --https) --key FILE SSL private key file (overrides --https) Examples: ./main.py # HTTP on port 8000 ./main.py --https # HTTPS with auto-detected Let's Encrypt certs ./main.py --cert c.pem --key k.pem # HTTPS with custom certs """ from __future__ import annotations from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass from datetime import UTC, datetime, timedelta from typing import Annotated, Any from xml.sax.saxutils import escape as xml_escape import asyncio import concurrent.futures import contextlib import json import logging import os import pathlib import re import signal import subprocess import threading import time import urllib.error import urllib.parse from fastapi import Depends, FastAPI, Form, HTTPException, Query, Request from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse, Response from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from starlette.responses import StreamingResponse from auth import create_token, verify_password, verify_token from cache import ( AVAILABLE_ENCODERS, CACHE_DIR, LOGO_BROWSER_TTL, LOGO_MAX_SIZE, Source, clear_all_caches, clear_all_file_caches, get_cache, get_cache_lock, get_cached_info, get_cached_logo, get_sources, get_watch_position, load_file_cache, load_server_settings, load_user_settings, refresh_encoders, save_file_cache, save_logo, save_server_settings, save_user_settings, save_watch_position, update_source_epg_url, ) from epg import fetch_epg from m3u import ( fetch_m3u, fetch_source_live_data, fetch_source_vod_data, get_first_xtream_client, get_refresh_in_progress, get_xtream_client_by_source, load_all_live_data, load_series_data, load_vod_data, parse_epg_urls, ) from xtream import XtreamClient import auth import epg import ffmpeg_command import ffmpeg_session log = logging.getLogger() # SSE subscribers for EPG ready notifications (limit to prevent DoS) _epg_subscribers: set[asyncio.Queue[str]] = set() _shutdown_event: asyncio.Event | None = None # Set during shutdown to close SSE _MAX_SSE_SUBSCRIBERS = 100 # Login rate limiting: track failed attempts per IP _login_attempts: dict[str, list[float]] = {} _LOGIN_WINDOW = 300 # 5 minutes _LOGIN_MAX_ATTEMPTS = 10 # Category filter limits _MAX_FILTER_CATEGORIES = 10000 # ============================================================================= # App Setup # ============================================================================= APP_DIR = pathlib.Path(__file__).parent TEMPLATES = Jinja2Templates(directory=APP_DIR / "templates") TEMPLATES.env.auto_reload = True # Super-resolution engine directory (TensorRT engines for different resolutions) SR_ENGINE_DIR = pathlib.Path( os.environ.get("SR_ENGINE_DIR", pathlib.Path.home() / "ffmpeg_build/models") ) def get_sr_models() -> list[str]: """Get available AI Upscale models (unique model names from engine files).""" if not SR_ENGINE_DIR.exists(): return [] # Engine files are named: {model}_{height}p_fp16.engine # e.g., 4x-compact_1080p_fp16.engine, 2x-liveaction-span_720p_fp16.engine models = set() for engine in SR_ENGINE_DIR.glob("*_*p_fp16.engine"): # Extract model name by removing _{height}p_fp16.engine suffix name = engine.stem # e.g., "2x-liveaction-span_1080p_fp16" # Remove _fp16 and _{height}p parts = name.rsplit("_", 2) # ["2x-liveaction-span", "1080p", "fp16"] if len(parts) >= 3: models.add(parts[0]) # Sort with 4x-compact first (recommended), then alphabetically def sort_key(m: str) -> tuple[int, str]: if m == "4x-compact": return (0, m) return (1, m) return sorted(models, key=sort_key) def is_sr_available() -> bool: """Check if AI Upscale is available (at least one TensorRT engine exists).""" return len(get_sr_models()) > 0 def _logo_url_filter(url: str) -> str: """Wrap external logo URLs through /api/logo proxy.""" if not url or url.startswith("/") or url.startswith("data:"): return url # Already local or data URL # Use hostname as source for organization parsed = urllib.parse.urlparse(url) source = parsed.netloc.split(":")[0] if parsed.netloc else "external" return f"/api/logo?source={urllib.parse.quote(source)}&url={urllib.parse.quote(url)}" TEMPLATES.env.filters["logo_url"] = _logo_url_filter def _safe_float(value: float | str | None, default: float = 0.0) -> float: """Safely convert value to float, returning default on failure.""" if value is None: return default try: return float(value) except (ValueError, TypeError): return default # Thread locks for fetch operations _fetch_locks: dict[str, threading.Lock] = { "live": threading.Lock(), "vod": threading.Lock(), "series": threading.Lock(), "epg": threading.Lock(), } @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncIterator[None]: """Clean up orphaned transcodes and preload data on startup.""" # Initialize EPG database epg.init(CACHE_DIR) # Prune expired EPG data (keep 24h buffer for "what was just on") cutoff = datetime.now(UTC) - timedelta(hours=24) pruned = epg.prune_old_programs(cutoff) if pruned: log.info("Pruned %d expired EPG programs", pruned) # Initialize transcoding module with settings callback ffmpeg_command.init( load_server_settings, sr_engine_dir=str(SR_ENGINE_DIR) if is_sr_available() else "", ) # Kill orphaned ffmpeg processes try: result = subprocess.run( ["pgrep", "-f", "ffmpeg.*iptv_transcode"], check=False, capture_output=True, text=True, ) for pid in result.stdout.strip().split("\n"): if pid: try: os.kill(int(pid), signal.SIGKILL) log.info("Killed orphaned ffmpeg pid %s", pid) except (ProcessLookupError, ValueError): pass except Exception: pass # Clean up orphaned dirs and recover valid VOD sessions ffmpeg_session.cleanup_and_recover_sessions() # Preload all data in background threads (parallel) def load_live(): get_refresh_in_progress().add("guide_load") try: log.info("Preloading live data") cats, streams, epg_urls = load_all_live_data() with get_cache_lock(): get_cache()["live_categories"] = cats get_cache()["live_streams"] = streams get_cache()["epg_urls"] = epg_urls log.info("Live data loaded") finally: get_refresh_in_progress().discard("guide_load") def load_epg_data(): try: epg_urls = get_cache().get("epg_urls", []) if epg_urls: load_all_epg(epg_urls) log.info("EPG data loaded: %d programs", epg.get_program_count()) # Notify SSE subscribers for q in list(_epg_subscribers): with contextlib.suppress(Exception): q.put_nowait("epg_ready") except Exception as e: log.error("EPG load error: %s", e) def load_vod(): vod_cats, vod_streams = load_vod_data() with get_cache_lock(): get_cache()["vod_categories"] = vod_cats get_cache()["vod_streams"] = vod_streams log.info("VOD data loaded") def load_series(): series_cats, series_list = load_series_data() with get_cache_lock(): get_cache()["series_categories"] = series_cats get_cache()["series"] = series_list log.info("Series data loaded") # Start all preloads in parallel (EPG waits for live data internally) def load_all(): load_live() # EPG needs epg_urls from live data, so run after load_epg_data() threading.Thread(target=load_all, daemon=True).start() threading.Thread(target=load_vod, daemon=True).start() threading.Thread(target=load_series, daemon=True).start() log.info("Preload started: live+EPG, VOD, series loading in parallel") # Periodic cleanup of expired sessions (VOD and live) cleanup_stop = threading.Event() def cleanup_loop(): while not cleanup_stop.wait(60): # Check every minute ffmpeg_session.cleanup_expired_sessions() cleanup_thread = threading.Thread(target=cleanup_loop, daemon=True) cleanup_thread.start() # EPG scheduler scheduler_stop = threading.Event() _last_triggered: dict[str, str] = {} # source_id -> last triggered time def scheduler_loop(): while not scheduler_stop.wait(30): # Check every 30 seconds now = datetime.now() current_time = now.strftime("%H:%M") for source in get_sources(): if current_time in source.epg_schedule: key = f"{source.id}_epg" # Only trigger once per scheduled time if ( _last_triggered.get(source.id) != current_time and key not in get_refresh_in_progress() ): log.info("Scheduled EPG refresh for %s at %s", source.name, current_time) _last_triggered[source.id] = current_time get_refresh_in_progress().add(key) def do_refresh(src: Source = source, k: str = key): try: epg_url = None if src.type == "xtream": client = XtreamClient(src.url, src.username, src.password) epg_url = client.epg_url elif src.type == "m3u": _, _, epg_url = fetch_m3u(src.url, src.id) elif src.type == "epg": epg_url = src.url if epg_url: _fetch_all_epg([(epg_url, src.epg_timeout, src.id)]) log.info("Scheduled EPG refresh complete for %s", src.name) except Exception as e: log.error("Scheduled EPG refresh failed for %s: %s", src.name, e) finally: get_refresh_in_progress().discard(k) threading.Thread(target=do_refresh, daemon=True).start() scheduler_thread = threading.Thread(target=scheduler_loop, daemon=True) scheduler_thread.start() yield # Shutdown - signal SSE connections to close global _shutdown_event _shutdown_event = asyncio.Event() _shutdown_event.set() cleanup_stop.set() scheduler_stop.set() ffmpeg_session.shutdown() app = FastAPI(title="neTV", lifespan=lifespan) app.mount("/static", StaticFiles(directory=APP_DIR / "static"), name="static") class AuthRequired(Exception): """Raised when authentication is required.""" @app.exception_handler(AuthRequired) async def auth_required_handler(request: Request, _exc: AuthRequired): return RedirectResponse("/login", status_code=303) @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException): """Show nice HTML error pages for HTTP errors.""" # Only handle HTML requests, let API requests get JSON accept = request.headers.get("accept", "") if "text/html" not in accept: return JSONResponse({"detail": exc.detail}, status_code=exc.status_code) return TEMPLATES.TemplateResponse( request, "error.html", {"title": f"Error {exc.status_code}", "message": exc.detail}, status_code=exc.status_code, ) def get_current_user(request: Request) -> dict | None: token = request.cookies.get("token") if not token: return None return verify_token(token) def require_auth(request: Request) -> dict: user = get_current_user(request) if not user: raise AuthRequired return user def require_admin(request: Request) -> dict: user = require_auth(request) username = user.get("sub", "") if not auth.is_admin(username): raise HTTPException(403, "Admin access required") return user # ============================================================================= # Auth Routes # ============================================================================= @app.get("/setup", response_class=HTMLResponse) async def setup_page(request: Request): """Initial setup page - create first admin user.""" if not auth.is_setup_required(): return RedirectResponse("/login", status_code=303) return TEMPLATES.TemplateResponse(request, "setup.html", {"error": None}) @app.post("/setup") async def setup_create_user( request: Request, username: Annotated[str, Form()], password: Annotated[str, Form()], confirm: Annotated[str, Form()], ): """Create the initial admin user.""" if not auth.is_setup_required(): return RedirectResponse("/login", status_code=303) # Validate if len(username) < 3: return TEMPLATES.TemplateResponse( request, "setup.html", {"error": "Username must be at least 3 characters"} ) if len(password) < 8: return TEMPLATES.TemplateResponse( request, "setup.html", {"error": "Password must be at least 8 characters"} ) if password != confirm: return TEMPLATES.TemplateResponse( request, "setup.html", {"error": "Passwords do not match"} ) auth.create_user(username, password) return RedirectResponse("/login", status_code=303) @app.get("/login", response_class=HTMLResponse) async def login_page(request: Request, error: str | None = None): """Login page - redirects to setup if no users exist.""" if auth.is_setup_required(): return RedirectResponse("/setup", status_code=303) last_user = request.cookies.get("last_user", "") return TEMPLATES.TemplateResponse( request, "login.html", {"error": error, "last_user": last_user} ) def _check_rate_limit(ip: str) -> None: """Check login rate limit. Raises HTTPException if exceeded.""" now = time.time() attempts = _login_attempts.get(ip, []) # Clean old attempts for this IP attempts = [t for t in attempts if now - t < _LOGIN_WINDOW] if attempts: _login_attempts[ip] = attempts elif ip in _login_attempts: del _login_attempts[ip] # Periodically clean stale IPs (when dict is large) if len(_login_attempts) > 1000: stale = [k for k, v in _login_attempts.items() if not v or now - max(v) > _LOGIN_WINDOW] for k in stale[:100]: del _login_attempts[k] if len(attempts) >= _LOGIN_MAX_ATTEMPTS: raise HTTPException(429, "Too many login attempts, try again later") @app.post("/login") async def login( request: Request, username: Annotated[str, Form()], password: Annotated[str, Form()], ): """Authenticate user and create session.""" ip = request.client.host if request.client else "unknown" _check_rate_limit(ip) if not verify_password(username, password): _login_attempts.setdefault(ip, []).append(time.time()) return RedirectResponse("/login?error=invalid", status_code=303) token = create_token({"sub": username}) response = RedirectResponse("/", status_code=303) is_secure = request.url.scheme == "https" or "https" in request.headers.get("x-forwarded-proto", "").lower() or "https" in request.headers.get("x-forwarded-scheme", "").lower() response.set_cookie( "token", token, httponly=True, samesite="strict", max_age=86400 * 7, secure=is_secure ) response.set_cookie("last_user", username, max_age=86400 * 365, secure=is_secure) return response @app.get("/logout") async def logout(): response = RedirectResponse("/login", status_code=303) response.delete_cookie("token") return response # ============================================================================= # Main Pages # ============================================================================= @app.get("/favicon.ico") async def favicon(): return Response(status_code=204) @app.get("/", response_class=HTMLResponse) async def index(request: Request, _user: Annotated[dict, Depends(require_auth)]): return RedirectResponse("/guide", status_code=303) def _fetch_all_epg(epg_urls: list[tuple[str, int, str]]) -> int: """Fetch EPG from all URLs into sqlite (in parallel). Returns total program count.""" user_agent = ffmpeg_command.get_user_agent() def fetch_one(url_timeout_source: tuple[str, int, str]) -> tuple[str, int]: url, timeout, source_id = url_timeout_source try: log.info("Fetching EPG (timeout=%ds): %s", timeout, url[:80]) count = fetch_epg(url, CACHE_DIR, timeout=timeout, source_id=source_id, user_agent=user_agent) log.info("EPG done: %d programs from %s", count, url[:50]) return url, count except Exception as e: log.error("EPG failed: %s - %s", url[:50], e) return url, 0 total = 0 max_workers = min(len(epg_urls) or 1, 8) # Cap at 8 concurrent fetches with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as ex: futures = [ex.submit(fetch_one, u) for u in epg_urls] for future in concurrent.futures.as_completed(futures): _, count = future.result() total += count # Prune expired programs (keep 24h buffer) cutoff = datetime.now(UTC) - timedelta(hours=24) pruned = epg.prune_old_programs(cutoff) if pruned: log.info("Pruned %d expired EPG programs", pruned) log.info("EPG fetch complete: %d programs total", total) return total def load_all_epg(epg_urls: list[tuple[str, int, str]]) -> None: """Load EPG into sqlite database if empty. Args: epg_urls: List of (url, timeout, source_id) tuples """ if epg.has_programs(): log.info("EPG database has %d programs", epg.get_program_count()) return # No data - fetch synchronously with _fetch_locks["epg"]: if epg.has_programs(): return log.info("No EPG data, fetching") try: _fetch_all_epg(epg_urls) except Exception as e: log.error("EPG fetch failed: %s", e) get_cache()["epg_error"] = str(e) def _start_guide_background_load() -> None: """Start background loading of guide data if not already in progress.""" if "guide_load" in get_refresh_in_progress(): return get_refresh_in_progress().add("guide_load") def load(): try: log.info("Loading guide data in background") cats, streams, epg_urls = load_all_live_data() with get_cache_lock(): get_cache()["live_categories"] = cats get_cache()["live_streams"] = streams get_cache()["epg_urls"] = epg_urls try: _fetch_all_epg(epg_urls) except Exception as e: with get_cache_lock(): get_cache()["epg_error"] = str(e) log.info("Guide data loaded") finally: get_refresh_in_progress().discard("guide_load") threading.Thread(target=load, daemon=True).start() @app.get("/events/epg") async def epg_events(_user: Annotated[dict, Depends(require_auth)]): """SSE endpoint - notifies when EPG is ready.""" if len(_epg_subscribers) >= _MAX_SSE_SUBSCRIBERS: raise HTTPException(503, "Too many subscribers") queue: asyncio.Queue[str] = asyncio.Queue() _epg_subscribers.add(queue) async def event_stream(): try: # If EPG already loaded, send immediately if epg.has_programs(): yield "data: epg_ready\n\n" return # Wait for EPG ready event or shutdown while True: if _shutdown_event and _shutdown_event.is_set(): return try: event = await asyncio.wait_for(queue.get(), timeout=1) yield f"data: {event}\n\n" return except TimeoutError: continue except TimeoutError: yield "data: timeout\n\n" finally: _epg_subscribers.discard(queue) return StreamingResponse(event_stream(), media_type="text/event-stream") @app.get("/guide", response_class=HTMLResponse) async def guide_page( request: Request, user: Annotated[dict, Depends(require_auth)], offset: int = 0, # hours offset from now cats: str = "", # comma-separated category IDs ): username = user.get("sub", "") # Check if cats was explicitly in URL (even if empty) cats_in_url = "cats" in request.query_params # If no channel data in memory, try file cache first (async to avoid blocking) if "live_categories" not in get_cache() or "live_streams" not in get_cache(): cached = await asyncio.to_thread(load_file_cache, "live_data") if cached: data, _ = cached with get_cache_lock(): get_cache()["live_categories"] = data["cats"] get_cache()["live_streams"] = data["streams"] get_cache()["epg_urls"] = parse_epg_urls(data.get("epg_urls", [])) else: # No cache at all - start background load and show loading page _start_guide_background_load() return TEMPLATES.TemplateResponse( request, "guide.html", { "grid_data": [], "selected_cats": [], "cats_param": cats, "time_markers": [], "offset": offset, "window_start": "", "loading_message": "Loading channel data...", "channel_count": 0, "loading": True, }, ) categories = get_cache()["live_categories"] # EPG is optional - check sqlite db for data epg_loading = not epg.has_programs() # Get the full saved filter for dropdown (not just current URL filter) user_settings = load_user_settings(username) saved_filter_list = user_settings.get("guide_filter", []) saved_filter = set(saved_filter_list) # For fast lookup # Build ordered list of category objects matching user's saved order cat_by_id = {str(c["category_id"]): c for c in categories} ordered_filter_cats = [cat_by_id[cid] for cid in saved_filter_list if cid in cat_by_id] # Get saved VIEW selection (separate from Settings filter) saved_view_cats = user_settings.get("guide_selected_cats") # None = show all # Determine effective cats: URL param (if present) > saved view > all from filter if cats_in_url: # URL explicitly has cats param (could be empty for "none") effective_cats = cats elif saved_view_cats is not None: # Use saved view selection (could be [] for "none") effective_cats = ",".join(saved_view_cats) else: # Default: show all from settings filter effective_cats = ",".join(saved_filter_list) # Use helper to get filtered/sorted streams streams, ordered_cats, selected_cats = _get_guide_streams(effective_cats, username) total_count = len(streams) # Time window: 3 hours starting from offset now = datetime.now(UTC) window_start = now.replace(minute=0, second=0, microsecond=0) + timedelta(hours=offset) window_end = window_start + timedelta(hours=3) # Virtual scrolling: only render first batch, JS fetches rest on scroll # When disabled, render all rows server-side virtual_scroll_enabled = user_settings.get("virtual_scroll", True) initial_batch_size = 500 if virtual_scroll_enabled else total_count grid_data = _build_guide_rows(streams, 0, initial_batch_size, window_start, window_end) # Time markers (every 30 min) - convert to local time for display time_markers = [] for i in range(7): # 0, 30, 60, 90, 120, 150, 180 minutes t = window_start + timedelta(minutes=i * 30) t_local = t.astimezone() # Convert to local timezone time_markers.append( { "label": t_local.strftime("%H:%M"), "left_pct": (i * 30 / 180) * 100, } ) # Mobile time markers (2 hour window instead of 3) time_markers_mobile = [] for i in range(5): # 0, 30, 60, 90, 120 minutes t = window_start + timedelta(minutes=i * 30) t_local = t.astimezone() time_markers_mobile.append( { "label": t_local.strftime("%H:%M"), "left_pct": (i * 30 / 120) * 100, } ) return TEMPLATES.TemplateResponse( request, "guide.html", { "categories": categories, "selected_cats": selected_cats, "saved_filter": saved_filter, # Full saved filter for dropdown (set) "ordered_filter_cats": ordered_filter_cats, # Ordered list for dropdown "cats_param": cats, "effective_cats": effective_cats, # What's actually being used "grid_data": grid_data, "time_markers": time_markers, "time_markers_mobile": time_markers_mobile, "offset": offset, "window_start": window_start.strftime("%Y-%m-%d %H:%M"), "epg_error": get_cache().get("epg_error"), "epg_loading": epg_loading, "channel_count": len(grid_data), "total_count": total_count, # For virtual scrolling "virtual_scroll": virtual_scroll_enabled, "loading": False, "content_access": _get_content_access(username), }, ) def _get_guide_streams(cats: str, username: str) -> tuple[list[dict], list[str], set[str]]: """Get filtered and sorted streams for guide. Returns: Tuple of (filtered_streams, ordered_cat_ids, selected_cat_set) """ all_streams = get_cache().get("live_streams", []) # Parse selected category IDs (ordered list) ordered_cats: list[str] = [] if cats: ordered_cats = [c.strip() for c in cats.split(",") if c.strip()] selected_cats = set(ordered_cats) if not selected_cats: return [], ordered_cats, selected_cats # Get user's unavailable groups for filtering user_limits = auth.get_user_limits(username) unavailable_groups = set(user_limits.get("unavailable_groups", [])) cat_order = {c: i for i, c in enumerate(ordered_cats)} def stream_sort_key(s: dict) -> int: for c in s.get("category_ids") or []: if str(c) in cat_order: return cat_order[str(c)] return len(ordered_cats) def stream_allowed(s: dict) -> bool: cat_ids = s.get("category_ids") or [] return not any(f"cat:{c}" in unavailable_groups for c in cat_ids) streams = [ s for s in all_streams if any(str(c) in selected_cats for c in (s.get("category_ids") or [])) and stream_allowed(s) ] streams.sort(key=stream_sort_key) return streams, ordered_cats, selected_cats def _build_guide_rows( streams: list[dict], start_idx: int, count: int, window_start: datetime, window_end: datetime, ) -> list[dict]: """Build guide grid rows for a range of streams. Returns: List of row dicts with channel info and programs. """ end_idx = min(start_idx + count, len(streams)) slice_streams = streams[start_idx:end_idx] if not slice_streams: return [] # Collect EPG IDs for batch query epg_ids = [s.get("epg_channel_id") or "" for s in slice_streams] epg_ids_set = [e for e in epg_ids if e] # Batch fetch icons and programs icons_map = epg.get_icons_batch(epg_ids_set) if epg_ids_set else {} # Build preferred_sources for EPG matching preferred_sources = { epg_id: s.get("source_id", "") for s, epg_id in zip(slice_streams, epg_ids, strict=False) if epg_id and s.get("source_id") } programs_map = ( epg.get_programs_batch(epg_ids_set, window_start, window_end, preferred_sources) if epg_ids_set else {} ) # Build rows window_end_mobile = window_start + timedelta(hours=2) grid_data = [] for idx, (s, epg_id) in enumerate(zip(slice_streams, epg_ids, strict=False), start=start_idx): icon = s.get("stream_icon", "") or icons_map.get(epg_id, "") ch = { "stream_id": s["stream_id"], "name": s["name"], "icon": icon, "epg_id": epg_id, } row = {"channel": ch, "programs": [], "programs_mobile": [], "index": idx} for p in programs_map.get(epg_id, []): p_start = max(p.start, window_start) p_end = min(p.stop, window_end) start_mins = (p_start - window_start).total_seconds() / 60 duration_mins = (p_end - p_start).total_seconds() / 60 left_pct = (start_mins / 180) * 100 width_pct = (duration_mins / 180) * 100 row["programs"].append( { "title": p.title, "desc": p.desc, "start": p.start.strftime("%H:%M"), "end": p.stop.strftime("%H:%M"), "left_pct": left_pct, "width_pct": width_pct, } ) # Mobile: 2-hour window if p.start < window_end_mobile: p_end_m = min(p.stop, window_end_mobile) duration_mins_m = (p_end_m - p_start).total_seconds() / 60 left_pct_m = (start_mins / 120) * 100 width_pct_m = (duration_mins_m / 120) * 100 row["programs_mobile"].append( { "title": p.title, "desc": p.desc, "start": p.start.strftime("%H:%M"), "end": p.stop.strftime("%H:%M"), "left_pct": left_pct_m, "width_pct": width_pct_m, } ) grid_data.append(row) return grid_data @app.get("/api/guide/rows") async def guide_rows_api( user: Annotated[dict, Depends(require_auth)], start: int = Query(default=0, ge=0, description="Starting row index"), count: int = Query(default=130, ge=1, le=500, description="Number of rows to fetch"), offset: int = Query(default=0, ge=-168, le=168, description="Hours offset from now"), cats: str = "", ): """API endpoint for virtual scrolling - returns guide rows as JSON.""" username = user.get("sub", "") # Use saved filter if no cats provided if not cats: user_settings = load_user_settings(username) saved = user_settings.get("guide_filter", []) if saved: cats = ",".join(saved) # Ensure data is loaded if "live_streams" not in get_cache(): cached = await asyncio.to_thread(load_file_cache, "live_data") if cached: data, _ = cached with get_cache_lock(): get_cache()["live_categories"] = data["cats"] get_cache()["live_streams"] = data["streams"] get_cache()["epg_urls"] = parse_epg_urls(data.get("epg_urls", [])) streams, _, _ = _get_guide_streams(cats, username) total_count = len(streams) if total_count == 0: return JSONResponse({"rows": [], "total": 0, "start": start}) # Time window now = datetime.now(UTC) window_start = now.replace(minute=0, second=0, microsecond=0) + timedelta(hours=offset) window_end = window_start + timedelta(hours=3) rows = _build_guide_rows(streams, start, count, window_start, window_end) return JSONResponse( {"rows": rows, "total": total_count, "start": start}, headers={"Cache-Control": "no-store"}, ) def _start_vod_background_load() -> None: """Start background loading of VOD data if not already in progress.""" if "vod_load" in get_refresh_in_progress(): return get_refresh_in_progress().add("vod_load") def load(): try: log.info("Loading VOD data in background") vod_cats, vod_streams = load_vod_data() with get_cache_lock(): get_cache()["vod_categories"] = vod_cats get_cache()["vod_streams"] = vod_streams log.info("VOD data loaded") finally: get_refresh_in_progress().discard("vod_load") threading.Thread(target=load, daemon=True).start() @app.get("/vod", response_class=HTMLResponse) async def vod_page( request: Request, user: Annotated[dict, Depends(require_auth)], category: int | None = None, sort: str | None = None, ): # Load from file cache if not in memory (async to avoid blocking) if "vod_categories" not in get_cache() or "vod_streams" not in get_cache(): cached = await asyncio.to_thread(load_file_cache, "vod_data") if cached: data, _ = cached get_cache()["vod_categories"] = data["cats"] get_cache()["vod_streams"] = data["streams"] else: # No cache - start background load and show loading page _start_vod_background_load() username = user.get("sub", "") user_settings = load_user_settings(username) return TEMPLATES.TemplateResponse( request, "vod.html", { "categories": [], "streams": [], "current_category": category, "current_sort": sort, "loading": True, "favorites": user_settings.get("favorites", {"series": {}, "movies": {}}), }, ) username = user.get("sub", "") user_settings = load_user_settings(username) # Check if user has access to any movies content_access = _get_content_access(username) if not content_access["movies"]: raise HTTPException(403, "Access to movies is restricted") # Get user's unavailable groups for filtering user_limits = auth.get_user_limits(username) unavailable_groups = set(user_limits.get("unavailable_groups", [])) # Filter by group access (movies:{source_id}) def movie_allowed(s: dict) -> bool: source_id = s.get("source_id", "") return f"movies:{source_id}" not in unavailable_groups streams = [s for s in get_cache()["vod_streams"] if movie_allowed(s)] categories = [ c for c in get_cache()["vod_categories"] if f"movies:{c.get('source_id', '')}" not in unavailable_groups ] # Apply user's VOD category filter (if set) vod_filter = user_settings.get("vod_filter", []) if vod_filter: vod_filter_set = set(str(c) for c in vod_filter) categories = [c for c in categories if str(c.get("category_id")) in vod_filter_set] streams = [s for s in streams if str(s.get("category_id")) in vod_filter_set] # Filter by category if specified if category: streams = [s for s in streams if str(s.get("category_id")) == str(category)] # Sort if sort == "alpha": streams.sort(key=lambda s: (s.get("name") or "").lower()) elif sort == "rating": streams.sort(key=lambda s: _safe_float(s.get("rating")), reverse=True) elif sort == "newest": streams.sort(key=lambda s: int(s.get("added") or 0), reverse=True) return TEMPLATES.TemplateResponse( request, "vod.html", { "categories": categories, "streams": streams, "current_category": category, "current_sort": sort, "favorites": user_settings.get("favorites", {"series": {}, "movies": {}}), "content_access": _get_content_access(username), }, ) def _start_series_background_load() -> None: """Start background loading of series data if not already in progress.""" if "series_load" in get_refresh_in_progress(): return get_refresh_in_progress().add("series_load") def load(): try: log.info("Loading series data in background") series_cats, series_list = load_series_data() with get_cache_lock(): get_cache()["series_categories"] = series_cats get_cache()["series"] = series_list log.info("Series data loaded") finally: get_refresh_in_progress().discard("series_load") threading.Thread(target=load, daemon=True).start() @app.get("/series", response_class=HTMLResponse) async def series_page( request: Request, user: Annotated[dict, Depends(require_auth)], category: int | None = None, sort: str | None = None, ): # Load from file cache if not in memory (async to avoid blocking) if "series_categories" not in get_cache() or "series" not in get_cache(): cached = await asyncio.to_thread(load_file_cache, "series_data") if cached: data, _ = cached get_cache()["series_categories"] = data["cats"] get_cache()["series"] = data["series"] else: # No cache - start background load and show loading page _start_series_background_load() username = user.get("sub", "") user_settings = load_user_settings(username) return TEMPLATES.TemplateResponse( request, "series.html", { "categories": [], "series": [], "current_category": category, "current_sort": sort, "loading": True, "favorites": user_settings.get("favorites", {"series": {}, "movies": {}}), }, ) username = user.get("sub", "") user_settings = load_user_settings(username) # Check if user has access to any series content_access = _get_content_access(username) if not content_access["series"]: raise HTTPException(403, "Access to series is restricted") # Get user's unavailable groups for filtering user_limits = auth.get_user_limits(username) unavailable_groups = set(user_limits.get("unavailable_groups", [])) # Filter by group access (series:{source_id}) def series_allowed(s: dict) -> bool: source_id = s.get("source_id", "") return f"series:{source_id}" not in unavailable_groups series = [s for s in get_cache()["series"] if series_allowed(s)] categories = [ c for c in get_cache()["series_categories"] if f"series:{c.get('source_id', '')}" not in unavailable_groups ] # Apply user's series category filter (if set) series_filter = user_settings.get("series_filter", []) if series_filter: series_filter_set = set(str(c) for c in series_filter) categories = [c for c in categories if str(c.get("category_id")) in series_filter_set] series = [s for s in series if str(s.get("category_id")) in series_filter_set] # Filter by category if specified if category: series = [s for s in series if str(s.get("category_id")) == str(category)] # Sort if sort == "alpha": series.sort(key=lambda s: (s.get("name") or "").lower()) elif sort == "rating": series.sort(key=lambda s: _safe_float(s.get("rating")), reverse=True) elif sort == "newest": series.sort(key=lambda s: int(s.get("last_modified") or 0), reverse=True) return TEMPLATES.TemplateResponse( request, "series.html", { "categories": categories, "series": series, "current_category": category, "current_sort": sort, "favorites": user_settings.get("favorites", {"series": {}, "movies": {}}), "content_access": _get_content_access(username), }, ) @app.get("/series/{series_id}", response_class=HTMLResponse) async def series_detail_page( request: Request, series_id: int, user: Annotated[dict, Depends(require_auth)], refresh: bool = False, ): username = user.get("sub", "") # Check access for this specific series and get source_id source_id = "" if "series" in get_cache(): cached_series = next( (s for s in get_cache()["series"] if str(s.get("series_id")) == str(series_id)), None, ) if cached_series: source_id = cached_series.get("source_id", "") user_limits = auth.get_user_limits(username) unavailable_groups = set(user_limits.get("unavailable_groups", [])) if f"series:{source_id}" in unavailable_groups: raise HTTPException(403, "Access to this series is restricted") # Use the series' source, fall back to first Xtream source xtream = get_xtream_client_by_source(source_id) if source_id else None if not xtream: xtream = get_first_xtream_client() if not xtream: raise HTTPException(404, "No Xtream source configured") cache_key = f"series_info_{source_id}_{series_id}" if source_id else f"series_info_{series_id}" try: series_data = await asyncio.to_thread( get_cached_info, cache_key, lambda: xtream.get_series_info(series_id), refresh ) except (urllib.error.URLError, TimeoutError) as e: return TEMPLATES.TemplateResponse( request, "error.html", { "title": "Provider Error", "message": f"Failed to load series info: {e}", }, status_code=502, ) if refresh: log.info("Force refreshed series info %s", series_id) # Extract year from releaseDate if not present if series_data.get("info"): info = series_data["info"] if not info.get("year") and info.get("releaseDate"): info["year"] = info["releaseDate"][:4] # Strip redundant series title and episode numbers from episode titles if series_data.get("episodes"): for season_eps in series_data["episodes"].values(): for ep in season_eps: if ep.get("title"): # Remove patterns like "Series Name - S01E01 - Episode Title" # Keep only the actual episode title title = ep["title"] # Remove S##E## - pattern title = re.sub(r"^S\d+E\d+\s*-\s*", "", title) # Remove any leading "SeriesName - " pattern if " - " in title and len(title.split(" - ")) > 1: parts = title.split(" - ") # Take the last part which should be the actual episode title title = parts[-1] ep["title"] = title.strip() # Parse info field if it's JSON if ep.get("info"): if isinstance(ep["info"], str): try: info_obj = json.loads(ep["info"]) # Extract plot/description from parsed JSON if isinstance(info_obj, dict): ep["description"] = ( info_obj.get("plot") or info_obj.get("description") or "" ) except (json.JSONDecodeError, TypeError): pass elif isinstance(ep["info"], dict): # Already a dict ep["description"] = ( ep["info"].get("plot") or ep["info"].get("description") or "" ) username = user.get("sub", "") user_settings = load_user_settings(username) return TEMPLATES.TemplateResponse( request, "series_detail.html", { "series": series_data, "series_id": series_id, "favorites": user_settings.get("favorites", {"series": {}, "movies": {}}), }, ) @app.get("/movie/{stream_id}", response_class=HTMLResponse) async def movie_detail_page( request: Request, stream_id: int, user: Annotated[dict, Depends(require_auth)], ): username = user.get("sub", "") # Load from file cache if not in memory if "vod_streams" not in get_cache(): vod_cats, vod_streams = load_vod_data() get_cache()["vod_categories"] = vod_cats get_cache()["vod_streams"] = vod_streams vod_streams = get_cache().get("vod_streams", []) movie = next((m for m in vod_streams if m.get("stream_id") == stream_id), None) # Check access for this specific movie if movie: source_id = movie.get("source_id", "") user_limits = auth.get_user_limits(username) unavailable_groups = set(user_limits.get("unavailable_groups", [])) if f"movies:{source_id}" in unavailable_groups: raise HTTPException(403, "Access to this movie is restricted") # Fetch detailed movie info if movie: source_id = movie.get("source_id", "") xtream = get_xtream_client_by_source(source_id) if source_id else None if not xtream: xtream = get_first_xtream_client() if xtream: cache_key = ( f"vod_info_{source_id}_{stream_id}" if source_id else f"vod_info_{stream_id}" ) try: vod_info = await asyncio.to_thread( get_cached_info, cache_key, lambda: xtream.get_vod_info(stream_id) ) except (urllib.error.URLError, TimeoutError): vod_info = {} if vod_info and vod_info.get("info"): info = vod_info["info"] # Merge detailed info into movie object movie = {**movie} # Copy movie["plot"] = info.get("plot") or info.get("description", "") movie["director"] = info.get("director", "") movie["cast"] = info.get("cast") or info.get("actors", "") movie["genre"] = info.get("genre", "") movie["rating"] = info.get("rating", "") movie["year"] = info.get("releasedate", "")[:4] if info.get("releasedate") else "" movie["duration"] = info.get("duration", "") movie["cover_big"] = info.get("cover_big") or info.get("movie_image", "") movie["youtube_trailer"] = info.get("youtube_trailer", "") username = user.get("sub", "") user_settings = load_user_settings(username) return TEMPLATES.TemplateResponse( request, "movie_detail.html", { "movie": movie, "favorites": user_settings.get("favorites", {"series": {}, "movies": {}}), }, ) @dataclass(slots=True) class PlayerInfo: """Info needed to render the player page.""" url: str = "" is_m3u: bool = False channel_name: str = "" program_title: str = "" program_desc: str = "" deinterlace_fallback: bool = True # Used when probe is skipped source_id: str = "" # Source ID for stream limit tracking category_ids: list[str] | None = None # Category IDs for live streams (access check) def _get_episode_desc(ep: dict) -> str: """Extract description from episode info (handles str or dict).""" info = ep.get("info") if isinstance(info, str): try: info = json.loads(info) except (json.JSONDecodeError, TypeError): info = None if isinstance(info, dict): return info.get("plot") or info.get("description") or "" return ep.get("description") or ep.get("plot") or "" def _get_live_player_info(stream_id: str) -> PlayerInfo: """Get player info for live stream.""" _ensure_live_cache() stream = next( (s for s in get_cache()["live_streams"] if str(s.get("stream_id")) == stream_id), None, ) if not stream: return PlayerInfo() info = PlayerInfo(channel_name=stream.get("name", "")) if stream.get("direct_url"): info.url = stream["direct_url"] info.is_m3u = True elif stream.get("source_type") == "xtream": base, user, pwd = stream["source_url"], stream["source_username"], stream["source_password"] orig_id = stream_id.split("_")[-1] if "_" in stream_id else stream_id # URL-encode username/password to handle special chars like # in passwords user = urllib.parse.quote(user, safe="") pwd = urllib.parse.quote(pwd, safe="") info.url = f"{base}/live/{user}/{pwd}/{orig_id}.m3u8" # Look up source settings source_id = stream.get("source_id", "") info.source_id = source_id info.category_ids = stream.get("category_ids") if source_id: sources = load_server_settings().get("sources", []) source = next((s for s in sources if s.get("id") == source_id), None) if source: info.deinterlace_fallback = source.get("deinterlace_fallback", True) # Look up current program from EPG epg_id = stream.get("epg_channel_id") or "" if epg_id: now = datetime.now(UTC) programs = epg.get_programs_in_range(epg_id, now, now + timedelta(minutes=1)) if programs: info.program_title, info.program_desc = programs[0].title, programs[0].desc return info def _get_movie_player_info(stream_id: str, ext: str) -> PlayerInfo: """Get player info for movie.""" # Find movie in cache to get its source_id cached_movie = None if "vod_streams" in get_cache(): cached_movie = next( (m for m in get_cache()["vod_streams"] if str(m.get("stream_id")) == str(stream_id)), None, ) # Get client for the movie's source (fall back to first if not found) source_id = cached_movie.get("source_id", "") if cached_movie else "" xtream = get_xtream_client_by_source(source_id) if source_id else None if not xtream: xtream = get_first_xtream_client() if not xtream: return PlayerInfo() ext = ext or "mkv" info = PlayerInfo(url=xtream.build_stream_url("movie", int(stream_id), ext)) info.source_id = source_id cache_key = f"vod_info_{source_id}_{stream_id}" try: movie = get_cached_info(cache_key, lambda: xtream.get_vod_info(int(stream_id))) except (urllib.error.URLError, TimeoutError): return info if movie and movie.get("info"): m = movie["info"] name = m.get("name", "") year = str(m.get("year") or m.get("releasedate", ""))[:4] info.channel_name = f"{name} ({year})" if year else name info.program_desc = m.get("plot") or m.get("description") or "" return info def _get_series_player_info( stream_id: str, series_id: int | None, ext: str ) -> tuple[PlayerInfo, str | None]: """Get player info for series episode. Returns (info, next_episode_url).""" # Find series in cache to get its source_id cached_series = None source_id = "" if series_id and "series" in get_cache(): cached_series = next( (s for s in get_cache()["series"] if str(s.get("series_id")) == str(series_id)), None, ) if cached_series: source_id = cached_series.get("source_id", "") # Get client for the series' source (fall back to first if not found) xtream = get_xtream_client_by_source(source_id) if source_id else None if not xtream: xtream = get_first_xtream_client() if not xtream: return PlayerInfo(), None ext = ext or "mkv" info = PlayerInfo(url=xtream.build_stream_url("series", int(stream_id), ext)) info.source_id = source_id if not series_id: return info, None cache_key = f"series_info_{source_id}_{series_id}" try: series = get_cached_info(cache_key, lambda: xtream.get_series_info(series_id)) except (urllib.error.URLError, TimeoutError) as e: log.warning("Failed to fetch series info %s: %s", series_id, e) return info, None if not series: return info, None if series.get("info"): name = series["info"].get("name", "") year = series["info"].get("year", "") info.channel_name = f"{name} ({year})" if year else name # Build flat list of all episodes in order (season, episode) all_episodes: list[tuple[int, dict]] = [] for season_num, eps in sorted((series.get("episodes") or {}).items(), key=lambda x: int(x[0])): for ep in sorted(eps, key=lambda e: int(e.get("episode_num", 0))): all_episodes.append((int(season_num), ep)) # Find current episode and next next_episode_url = None for i, (season_num, ep) in enumerate(all_episodes): if str(ep.get("id")) == str(stream_id): title = re.sub(r"^S\d+E\d+\s*-\s*", "", ep.get("title", "")) if " - " in title: title = title.split(" - ")[-1] info.program_title = ( f"S{int(season_num):02d}E{int(ep.get('episode_num', 0)):02d} — {title.strip()}" ) info.program_desc = _get_episode_desc(ep) # Get next episode URL if i + 1 < len(all_episodes): _, next_ep = all_episodes[i + 1] next_ext = next_ep.get("container_extension") or ext next_episode_url = ( f"/play/series/{next_ep['id']}?series_id={series_id}&ext={next_ext}" ) break return info, next_episode_url def _ensure_live_cache() -> None: """Ensure live streams and EPG are loaded.""" if "live_streams" not in get_cache(): cats, streams, epg_urls = load_all_live_data() with get_cache_lock(): get_cache()["live_categories"] = cats get_cache()["live_streams"] = streams get_cache()["epg_urls"] = epg_urls if not epg.has_programs(): with contextlib.suppress(Exception): load_all_epg(get_cache().get("epg_urls", [])) @app.get("/play/{stream_type}/{stream_id:path}", response_class=HTMLResponse) async def player_page( request: Request, stream_type: str, stream_id: str, user: Annotated[dict, Depends(require_auth)], ext: str = "", series_id: int | None = None, ): """Render player page for live/movie/series stream.""" username = user.get("sub", "") next_episode_url = None if stream_type == "live": info = await asyncio.to_thread(_get_live_player_info, stream_id) elif stream_type == "movie": info = await asyncio.to_thread(_get_movie_player_info, stream_id, ext) elif stream_type == "series": info, next_episode_url = await asyncio.to_thread( _get_series_player_info, stream_id, series_id, ext ) else: raise HTTPException(404, "Invalid stream type") if not info.url: raise HTTPException(404, "Stream not found") # Check user's group access user_limits = auth.get_user_limits(username) unavailable_groups = set(user_limits.get("unavailable_groups", [])) log.info( "Access check: user=%s type=%s source_id=%s unavailable=%s", username, stream_type, info.source_id, unavailable_groups, ) if unavailable_groups: if stream_type == "live" and info.category_ids: # Live streams: blocked if any category is unavailable if any(f"cat:{cat_id}" in unavailable_groups for cat_id in info.category_ids): raise HTTPException(403, "Access to this channel is restricted") elif stream_type == "movie" and info.source_id: if f"movies:{info.source_id}" in unavailable_groups: raise HTTPException(403, "Access to movies is restricted") elif ( stream_type == "series" and info.source_id and f"series:{info.source_id}" in unavailable_groups ): raise HTTPException(403, "Access to series is restricted") log.info("Play %s/%s: %s", stream_type, stream_id, info.url) server_settings = load_server_settings() user_settings = load_user_settings(username) transcode_mode = server_settings.get("transcode_mode", "auto") is_https = request.url.scheme == "https" or "https" in request.headers.get("x-forwarded-proto", "").lower() or "https" in request.headers.get("x-forwarded-scheme", "").lower() if transcode_mode == "auto": needs_transcode = info.is_m3u or ext in ("mkv", "mp4", "avi", "wmv", "flv") mixed_content = is_https and info.url.startswith("http://") if needs_transcode or mixed_content: transcode_mode = "always" # Get saved watch position for VOD (per-user) resume_position = 0.0 if stream_type in ("movie", "series"): watch_entry = get_watch_position(username, info.url) if watch_entry: resume_position = watch_entry.get("position", 0.0) # For series, stream_id is episode_id episode_id = int(stream_id) if stream_type == "series" and stream_id.isdigit() else None # Extract series name from channel_name (format: "Series Name (Year)" or just "Series Name") series_name = "" if stream_type == "series" and info.channel_name: # Strip year suffix like " (2020)" series_name = re.sub(r"\s*\(\d{4}\)$", "", info.channel_name) return TEMPLATES.TemplateResponse( request, "player.html", { "raw_url": info.url, "transcode_mode": transcode_mode, "stream_type": stream_type, "channel_name": info.channel_name, "program_title": info.program_title, "program_desc": info.program_desc, "captions_enabled": user_settings.get("captions_enabled", False), "resume_position": resume_position, "series_id": series_id, "episode_id": episode_id, "series_name": series_name, "cc_lang": user_settings.get("cc_lang", ""), "cc_style": user_settings.get("cc_style", {}), "cast_host": user_settings.get("cast_host", ""), "next_episode_url": next_episode_url, "deinterlace_fallback": info.deinterlace_fallback, "source_id": info.source_id, "content_access": _get_content_access(username), }, ) @app.get("/search", response_class=HTMLResponse) async def search_page( request: Request, user: Annotated[dict, Depends(require_auth)], q: str = "", regex: bool = False, live: bool = False, vod: bool = False, series: bool = False, limit: int = 100, ): results: dict[str, list] = {"live": [], "vod": [], "series": []} # Default all on if none specified if not live and not vod and not series: live = vod = series = True if q: if regex: # Limit regex length to prevent ReDoS if len(q) > 100: raise HTTPException(400, "Regex pattern too long") try: pattern = re.compile(q, re.IGNORECASE) def match_fn(name: str) -> bool: try: # Timeout via match limit - search only first 1000 chars return pattern.search(name[:1000]) is not None except Exception: return False except re.error: def match_fn(name: str) -> bool: return False else: q_lower = q.lower() def match_fn(name: str) -> bool: return q_lower in name.lower() # Load live data (run in thread to avoid blocking) if live: if "live_streams" not in get_cache(): cats, streams, epg_urls = await asyncio.to_thread(load_all_live_data) with get_cache_lock(): get_cache()["live_categories"] = cats get_cache()["live_streams"] = streams get_cache()["epg_urls"] = epg_urls matched = sorted( [s for s in get_cache()["live_streams"] if match_fn(s.get("name") or "")], key=lambda x: x.get("name", "").lower(), ) results["live"] = matched[:limit] if limit else matched # Load VOD data (run in thread to avoid blocking) if vod: if "vod_streams" not in get_cache(): vod_cats, vod_streams = await asyncio.to_thread(load_vod_data) with get_cache_lock(): get_cache()["vod_categories"] = vod_cats get_cache()["vod_streams"] = vod_streams matched = sorted( [s for s in get_cache()["vod_streams"] if match_fn(s.get("name") or "")], key=lambda x: x.get("name", "").lower(), ) results["vod"] = matched[:limit] if limit else matched # Load series data (run in thread to avoid blocking) if series: if "series" not in get_cache(): series_cats, series_list = await asyncio.to_thread(load_series_data) with get_cache_lock(): get_cache()["series_categories"] = series_cats get_cache()["series"] = series_list matched = sorted( [s for s in get_cache()["series"] if match_fn(s.get("name") or "")], key=lambda x: x.get("name", "").lower(), ) results["series"] = matched[:limit] if limit else matched username = user.get("sub", "") user_settings = load_user_settings(username) # Filter results based on user access user_limits = auth.get_user_limits(username) unavailable_groups = set(user_limits.get("unavailable_groups", [])) # Filter live results by category access results["live"] = [ s for s in results["live"] if not any( f"cat:{cat_id}" in unavailable_groups for cat_id in (s.get("category_ids") or []) ) ] # Filter movie results by source access results["vod"] = [ s for s in results["vod"] if f"movies:{s.get('source_id', '')}" not in unavailable_groups ] # Filter series results by source access results["series"] = [ s for s in results["series"] if f"series:{s.get('source_id', '')}" not in unavailable_groups ] # Apply user's category filters from settings guide_filter = user_settings.get("guide_filter", []) if guide_filter: guide_filter_set = set(str(c) for c in guide_filter) results["live"] = [ s for s in results["live"] if any(str(cat_id) in guide_filter_set for cat_id in (s.get("category_ids") or [])) ] vod_filter = user_settings.get("vod_filter", []) if vod_filter: vod_filter_set = set(str(c) for c in vod_filter) results["vod"] = [s for s in results["vod"] if str(s.get("category_id")) in vod_filter_set] series_filter = user_settings.get("series_filter", []) if series_filter: series_filter_set = set(str(c) for c in series_filter) results["series"] = [ s for s in results["series"] if str(s.get("category_id")) in series_filter_set ] content_access = _get_content_access(username) return TEMPLATES.TemplateResponse( request, "search.html", { "query": q, "results": results, "regex": regex, "search_live": live, "search_vod": vod and content_access["movies"], "search_series": series and content_access["series"], "limit": limit, "favorites": user_settings.get("favorites", {"series": {}, "movies": {}}), "content_access": content_access, }, ) @app.get("/stream/{stream_type}/{stream_id}") async def stream_redirect( stream_type: str, stream_id: int, _user: Annotated[dict, Depends(require_auth)], ext: str = "", ): xtream = get_first_xtream_client() if not xtream: raise HTTPException(404, "No Xtream source configured") url = xtream.build_stream_url(stream_type, stream_id, ext) return RedirectResponse(url, status_code=302) @app.get("/playlist.xspf") async def playlist_xspf( _user: Annotated[dict, Depends(require_auth)], url: str, ): content = f""" {xml_escape(url)} """ return Response( content=content, media_type="application/xspf+xml", headers={"Content-Disposition": "attachment; filename=stream.xspf"}, ) # ============================================================================= # Transcoding routes (logic in ffmpeg_session.py) # ============================================================================= @app.get("/transcode/start") async def transcode_start( user: Annotated[dict, Depends(require_auth)], url: str, content_type: str = "live", # "movie", "series", or "live" series_id: int | None = None, episode_id: int | None = None, series_name: str = "", deinterlace_fallback: str = "1", # "1" or "0" source_id: str = "", ): """Start a transcode session, return session ID.""" deinterlace_fb = deinterlace_fallback == "1" username = user.get("sub", "") # Get user limits for this source user_limits = auth.get_user_limits(username) max_streams_per_source = user_limits.get("max_streams_per_source", {}) user_max_streams = max_streams_per_source.get(source_id, 0) if source_id else 0 # Get source max_streams (global limit for this source) source_max_streams = 0 if source_id: settings = load_server_settings() sources = settings.get("sources", []) source = next((s for s in sources if s.get("id") == source_id), None) if source: source_max_streams = source.get("max_streams", 0) return await ffmpeg_session.start_transcode( url, content_type, series_id, episode_id, series_name, deinterlace_fb, username, source_id, user_max_streams, source_max_streams, ) @app.get("/transcode/seek/{session_id}") async def transcode_seek( session_id: str, time: float, _user: Annotated[dict, Depends(require_auth)], ): """Seek VOD transcode to a new position.""" return await ffmpeg_session.seek_transcode(session_id, time) @app.get("/transcode/progress/{session_id}") async def transcode_progress( session_id: str, _user: Annotated[dict, Depends(require_auth)], ): """Get transcode progress (segment count, duration).""" progress = ffmpeg_session.get_session_progress(session_id) if not progress: raise HTTPException(404, "Session not found") return progress @app.get("/transcode/{session_id}/{filename}") async def transcode_file( request: Request, session_id: str, filename: str, ): """Serve HLS playlist or segments (no auth - session IDs are unguessable).""" # Prevent path traversal safe_filename = pathlib.Path(filename).name if safe_filename != filename or ".." in filename: raise HTTPException(400, "Invalid filename") session = ffmpeg_session.get_session(session_id) if not session: log.debug(f"[CAST] 404 session not found: {session_id}") raise HTTPException(404, "Transcode session not found") file_path = pathlib.Path(session["dir"]) / safe_filename if not file_path.exists(): log.debug(f"[CAST] 404 file not found: {file_path}") raise HTTPException(404, "File not found") # Log Chromecast requests ua = request.headers.get("user-agent", "") if "CrKey" in ua or "Chromecast" in ua.lower() or "cast" in ua.lower(): log.debug(f"[CAST] Chromecast request: {filename} UA={ua[:80]}") cors = {"Access-Control-Allow-Origin": "*"} if filename.endswith(".m3u8"): content = file_path.read_text() return Response( content=content, media_type="application/vnd.apple.mpegurl", headers={"Cache-Control": "no-cache, no-store, must-revalidate", **cors}, ) if filename.endswith(".vtt"): content = file_path.read_text() return Response(content=content, media_type="text/vtt", headers=cors) return FileResponse(file_path, media_type="video/mp2t", headers=cors) @app.get("/subs/{session_id}/{filename}") async def subtitle_file(session_id: str, filename: str): """Serve VTT subtitle files (no auth - session IDs are unguessable).""" # Prevent path traversal safe_filename = pathlib.Path(filename).name if safe_filename != filename or ".." in filename: raise HTTPException(400, "Invalid filename") if not safe_filename.endswith(".vtt"): raise HTTPException(400, "Only VTT files allowed") session = ffmpeg_session.get_session(session_id) if not session: raise HTTPException(404, "Session not found") file_path = pathlib.Path(session["dir"]) / safe_filename # Wait briefly for file, return empty VTT if not ready (client will poll again) for _ in range(15): # 3 seconds if file_path.exists() and file_path.stat().st_size > 20: break await asyncio.sleep(0.2) try: content = file_path.read_text() if file_path.exists() else "WEBVTT\n\n" except (UnicodeDecodeError, OSError): # File may be partially written or corrupted content = "WEBVTT\n\n" return Response( content=content, media_type="text/vtt", headers={"Access-Control-Allow-Origin": "*"}, ) @app.delete("/transcode/{session_id}") async def transcode_stop( session_id: str, _user: Annotated[dict, Depends(require_auth)], ): """Stop a transcode session (VOD sessions stay cached).""" ffmpeg_session.stop_session(session_id, force=False) return {"status": "stopped"} @app.post("/transcode/{session_id}/stop") async def transcode_stop_post( session_id: str, _user: Annotated[dict, Depends(require_auth)], ): """Stop a transcode session (POST for sendBeacon, VOD cached).""" ffmpeg_session.stop_session(session_id, force=False) return {"status": "stopped"} @app.delete("/transcode-clear") async def transcode_clear( url: str, _user: Annotated[dict, Depends(require_auth)], ): """Force-delete any cached transcode session for a URL.""" session_id = ffmpeg_session.clear_url_session(url) if session_id: ffmpeg_session.stop_session(session_id, force=True) log.info("Force-cleared transcode session %s for URL", session_id) return {"status": "cleared", "session_id": session_id} def _build_all_groups() -> list[dict[str, str]]: """Build list of all available groups for user restrictions. Groups are: - Live TV categories: cat:{category_id} displayed as "{source.name}: {category.name}" - Movies: movies:{source_id} displayed as "{source.name}: Movies" - Series: series:{source_id} displayed as "{source.name}: Series" """ groups = [] sources_by_id = {s.id: s for s in get_sources()} # Live TV categories for cat in get_cache().get("live_categories", []): source_id = cat.get("source_id", "") source = sources_by_id.get(source_id) if source: groups.append( { "id": f"cat:{cat['category_id']}", "name": f"{source.name}: {cat['category_name']}", "type": "live", } ) # Movies and Series for Xtream sources for source in get_sources(): if source.type == "xtream": groups.append( { "id": f"movies:{source.id}", "name": f"{source.name}: Movies", "type": "movies", } ) groups.append( { "id": f"series:{source.id}", "name": f"{source.name}: Series", "type": "series", } ) return groups def _get_content_access(username: str) -> dict[str, bool]: """Check if user has access to movies/series from any source. Returns dict with 'movies' and 'series' booleans. If no xtream sources exist, access is granted (nothing to restrict). """ user_limits = auth.get_user_limits(username) unavailable_groups = set(user_limits.get("unavailable_groups", [])) has_movies = False has_series = False has_xtream_source = False for source in get_sources(): if source.type == "xtream": has_xtream_source = True if f"movies:{source.id}" not in unavailable_groups: has_movies = True if f"series:{source.id}" not in unavailable_groups: has_series = True # If no xtream sources, allow access (nothing to restrict) if not has_xtream_source: return {"movies": True, "series": True} return {"movies": has_movies, "series": has_series} # Register template global for content access check TEMPLATES.env.globals["get_content_access"] = _get_content_access def _get_content_access_from_request(request: Request) -> dict[str, bool]: """Get content access from request (for use in base template).""" token = request.cookies.get("token") if not token: return {"movies": True, "series": True} # Not logged in, show all payload = verify_token(token) if not payload: return {"movies": True, "series": True} username = payload.get("sub", "") if not username: return {"movies": True, "series": True} return _get_content_access(username) TEMPLATES.env.globals["get_content_access_from_request"] = _get_content_access_from_request @app.get("/settings", response_class=HTMLResponse) async def settings_page(request: Request, user: Annotated[dict, Depends(require_auth)]): username = user.get("sub", "") is_admin = auth.is_admin(username) server_settings = load_server_settings() user_settings = load_user_settings(username) # Load categories (from file cache or trigger background load) if "live_categories" not in get_cache(): cached = await asyncio.to_thread(load_file_cache, "live_data") if cached: data, _ = cached with get_cache_lock(): get_cache()["live_categories"] = data["cats"] get_cache()["live_streams"] = data["streams"] get_cache()["epg_urls"] = parse_epg_urls(data.get("epg_urls", [])) else: # No cache - start background load _start_guide_background_load() # Load VOD categories if not cached if "vod_categories" not in get_cache(): cached = await asyncio.to_thread(load_file_cache, "vod_data") if cached: data, _ = cached with get_cache_lock(): get_cache()["vod_categories"] = data["cats"] get_cache()["vod_streams"] = data["streams"] # Load series categories if not cached if "series_categories" not in get_cache(): cached = await asyncio.to_thread(load_file_cache, "series_data") if cached: data, _ = cached with get_cache_lock(): get_cache()["series_categories"] = data["cats"] get_cache()["series"] = data["series"] # Build source_id -> source_name mapping source_names = {s["id"]: s["name"] for s in server_settings.get("sources", [])} # Filter categories based on user's unavailable groups user_limits = auth.get_user_limits(username) unavailable_groups = set(user_limits.get("unavailable_groups", [])) all_live_cats = get_cache().get("live_categories", []) live_categories = [ cat for cat in all_live_cats if f"cat:{cat['category_id']}" not in unavailable_groups ] vod_categories = get_cache().get("vod_categories", []) series_categories = get_cache().get("series_categories", []) return TEMPLATES.TemplateResponse( request, "settings.html", { # Server settings "sources": server_settings.get("sources", []), "transcode_mode": server_settings.get("transcode_mode", "auto"), "transcode_hw": server_settings.get("transcode_hw", "nvidia"), "max_resolution": server_settings.get("max_resolution", "1080p"), "quality": server_settings.get("quality", "high"), "vod_transcode_cache_mins": server_settings.get("vod_transcode_cache_mins", 60), "live_transcode_cache_secs": server_settings.get("live_transcode_cache_secs", 60), "live_dvr_mins": server_settings.get("live_dvr_mins", 0), "transcode_dir": server_settings.get("transcode_dir", ""), "probe_live": server_settings.get("probe_live", True), "probe_movies": server_settings.get("probe_movies", True), "probe_series": server_settings.get("probe_series", False), "user_agent_preset": server_settings.get("user_agent_preset", "default"), "user_agent_custom": server_settings.get("user_agent_custom", ""), "available_encoders": AVAILABLE_ENCODERS, "sr_available": is_sr_available(), "sr_models": get_sr_models(), "sr_model": server_settings.get("sr_model", ""), "all_users": auth.get_users_with_admin(), "all_groups": _build_all_groups(), "current_user": username, "is_admin": is_admin, # User settings "captions_enabled": user_settings.get("captions_enabled", False), "virtual_scroll": user_settings.get("virtual_scroll", True), "live_categories": live_categories, "vod_categories": vod_categories, "series_categories": series_categories, "source_names": source_names, "selected_cats": user_settings.get("guide_filter", []), "selected_vod_cats": user_settings.get("vod_filter", []), "selected_series_cats": user_settings.get("series_filter", []), "cc_lang": user_settings.get("cc_lang", ""), "cc_style": user_settings.get("cc_style", {}), "cast_host": user_settings.get("cast_host", ""), "content_access": _get_content_access(username), }, ) @app.post("/settings/guide-filter") async def settings_guide_filter( request: Request, user: Annotated[dict, Depends(require_auth)], ): username = user.get("sub", "") data = await request.json() cats = data.get("cats", []) if not isinstance(cats, list) or len(cats) > _MAX_FILTER_CATEGORIES: raise HTTPException(400, "Invalid filter list") user_settings = load_user_settings(username) user_settings["guide_filter"] = cats save_user_settings(username, user_settings) return {"status": "ok"} @app.post("/settings/vod-filter") async def settings_vod_filter( request: Request, user: Annotated[dict, Depends(require_auth)], ): username = user.get("sub", "") data = await request.json() cats = data.get("cats", []) if not isinstance(cats, list) or len(cats) > _MAX_FILTER_CATEGORIES: raise HTTPException(400, "Invalid filter list") user_settings = load_user_settings(username) user_settings["vod_filter"] = cats save_user_settings(username, user_settings) return {"status": "ok"} @app.post("/settings/series-filter") async def settings_series_filter( request: Request, user: Annotated[dict, Depends(require_auth)], ): username = user.get("sub", "") data = await request.json() cats = data.get("cats", []) if not isinstance(cats, list) or len(cats) > _MAX_FILTER_CATEGORIES: raise HTTPException(400, "Invalid filter list") user_settings = load_user_settings(username) user_settings["series_filter"] = cats save_user_settings(username, user_settings) return {"status": "ok"} @app.post("/settings/add") async def settings_add_source( _user: Annotated[dict, Depends(require_admin)], name: Annotated[str, Form()], source_type: Annotated[str, Form()], url: Annotated[str, Form()], username: Annotated[str, Form()] = "", password: Annotated[str, Form()] = "", epg_timeout: Annotated[int, Form()] = 120, epg_schedule: Annotated[str, Form()] = "", epg_enabled: Annotated[str, Form()] = "", # Checkbox: "on" if checked deinterlace_fallback: Annotated[str, Form()] = "", # Checkbox: "on" if checked max_streams: Annotated[int, Form()] = 0, ): # Validate inputs if not name or not name.strip(): raise HTTPException(400, "Name is required") if source_type not in ("xtream", "m3u", "epg"): raise HTTPException(400, "Invalid source type") parsed_url = urllib.parse.urlparse(url) if parsed_url.scheme not in ("http", "https"): raise HTTPException(400, "URL must use http or https") if len(name) > 200: raise HTTPException(400, "Name too long") # Parse schedule times schedule_list = [] for t in epg_schedule.split(","): t = t.strip() if t and re.match(r"^\d{1,2}:\d{2}$", t): schedule_list.append(t.zfill(5)) settings = load_server_settings() sources = settings.get("sources", []) source_id = f"src_{int(time.time())}_{len(sources)}" sources.append( { "id": source_id, "name": name, "type": source_type, "url": url.rstrip("/"), "username": username, "password": password, "epg_timeout": max(1, min(3600, epg_timeout)), "epg_schedule": schedule_list, "epg_enabled": epg_enabled == "on" or source_type == "epg", "deinterlace_fallback": deinterlace_fallback == "on", "max_streams": max(0, max_streams), } ) settings["sources"] = sources save_server_settings(settings) clear_all_caches() return RedirectResponse("/settings", status_code=303) @app.post("/settings/edit/{source_id}") async def settings_edit_source( source_id: str, _user: Annotated[dict, Depends(require_admin)], name: Annotated[str, Form()], source_type: Annotated[str, Form()], url: Annotated[str, Form()], username: Annotated[str, Form()] = "", password: Annotated[str, Form()] = "", epg_timeout: Annotated[int, Form()] = 120, epg_schedule: Annotated[str, Form()] = "", epg_enabled: Annotated[str, Form()] = "", # Checkbox: "on" if checked epg_url: Annotated[str, Form()] = "", deinterlace_fallback: Annotated[str, Form()] = "", # Checkbox: "on" if checked max_streams: Annotated[int, Form()] = 0, ): # Validate inputs if not name or not name.strip(): raise HTTPException(400, "Name is required") if source_type not in ("xtream", "m3u", "epg"): raise HTTPException(400, "Invalid source type") parsed_url = urllib.parse.urlparse(url) if parsed_url.scheme not in ("http", "https"): raise HTTPException(400, "URL must use http or https") if len(name) > 200: raise HTTPException(400, "Name too long") # Parse schedule times (comma-separated HH:MM) schedule_list = [] for t in epg_schedule.split(","): t = t.strip() if t and re.match(r"^\d{1,2}:\d{2}$", t): schedule_list.append(t.zfill(5)) # Normalize to HH:MM settings = load_server_settings() for s in settings.get("sources", []): if s["id"] == source_id: s["name"] = name s["type"] = source_type s["url"] = url.rstrip("/") s["username"] = username s["password"] = password s["epg_timeout"] = max(1, min(3600, epg_timeout)) s["epg_schedule"] = schedule_list s["epg_enabled"] = epg_enabled == "on" or source_type == "epg" s["epg_url"] = epg_url.strip() s["deinterlace_fallback"] = deinterlace_fallback == "on" s["max_streams"] = max(0, max_streams) break save_server_settings(settings) clear_all_caches() return {"ok": True} @app.post("/settings/delete/{source_id}") async def settings_delete_source( source_id: str, _user: Annotated[dict, Depends(require_admin)], ): settings = load_server_settings() settings["sources"] = [s for s in settings.get("sources", []) if s["id"] != source_id] save_server_settings(settings) # Clear all caches including EPG data for this source epg.clear_source(source_id) clear_all_file_caches() return RedirectResponse("/settings", status_code=303) @app.get("/guide/refresh") async def guide_refresh(_user: Annotated[dict, Depends(require_auth)]): """Refresh guide data in background (stale-while-revalidate).""" def refresh_live(): try: log.info("Live refresh: fetching channels") cats, streams, epg_urls = load_all_live_data() with get_cache_lock(): get_cache()["live_categories"] = cats get_cache()["live_streams"] = streams get_cache()["epg_urls"] = epg_urls save_file_cache("live_data", {"cats": cats, "streams": streams, "epg_urls": epg_urls}) log.info("Live refresh: complete (%d categories, %d streams)", len(cats), len(streams)) except Exception as e: log.error("Live refresh failed: %s", e) finally: get_refresh_in_progress().discard("live_refresh") def refresh_epg(): try: epg_urls = get_cache().get("epg_urls", []) if epg_urls: log.info("EPG refresh: fetching %d sources", len(epg_urls)) epg.clear() count = _fetch_all_epg(epg_urls) with get_cache_lock(): get_cache().pop("epg_error", None) log.info("EPG refresh: complete (%d programs)", count) else: log.warning("EPG refresh: no EPG URLs available") except Exception as e: log.error("EPG refresh failed: %s", e) with get_cache_lock(): get_cache()["epg_error"] = str(e) finally: get_refresh_in_progress().discard("epg_refresh") # Set flags before starting threads to avoid race with status polling if "live_refresh" not in get_refresh_in_progress(): get_refresh_in_progress().add("live_refresh") threading.Thread(target=refresh_live, daemon=True).start() if "epg_refresh" not in get_refresh_in_progress(): get_refresh_in_progress().add("epg_refresh") threading.Thread(target=refresh_epg, daemon=True).start() return RedirectResponse("/guide?refreshing=1", status_code=303) @app.get("/guide/refresh-status") async def guide_refresh_status(_user: Annotated[dict, Depends(require_auth)]): """Return refresh status for polling.""" return { "live": "live_refresh" in get_refresh_in_progress(), "epg": "epg_refresh" in get_refresh_in_progress(), } @app.post("/settings/refresh/{source_id}/{refresh_type}") async def settings_refresh_source( source_id: str, refresh_type: str, _user: Annotated[dict, Depends(require_admin)], ): """Refresh a specific data type for a single source.""" sources = get_sources() source = next((s for s in sources if s.id == source_id), None) if not source: return {"error": "Source not found"} key = f"{source_id}_{refresh_type}" if key in get_refresh_in_progress(): return {"status": "already_running"} get_refresh_in_progress().add(key) def do_refresh(): try: if refresh_type == "live": log.info("Refreshing live data for source: %s", source.name) cats, streams, epg_url, timeout = fetch_source_live_data(source) # Update cache by replacing this source's data with get_cache_lock(): existing_cats = [ c for c in get_cache().get("live_categories", []) if c.get("source_id") != source_id ] existing_streams = [ s for s in get_cache().get("live_streams", []) if s.get("source_id") != source_id ] existing_epg = [e for e in get_cache().get("epg_urls", []) if e[2] != source_id] new_cats = existing_cats + cats new_streams = existing_streams + streams new_epg = existing_epg + ([(epg_url, timeout, source_id)] if epg_url else []) get_cache()["live_categories"] = new_cats get_cache()["live_streams"] = new_streams get_cache()["epg_urls"] = new_epg # Save to file cache save_file_cache( "live_data", {"cats": new_cats, "streams": new_streams, "epg_urls": new_epg}, ) log.info( "Live refresh complete for %s: %d cats, %d streams", source.name, len(cats), len(streams), ) elif refresh_type == "epg": log.info( "Refreshing EPG for source: %s (timeout=%ds)", source.name, source.epg_timeout ) epg_url = source.epg_url or (source.url if source.type == "epg" else "") if epg_url: epg.clear_source(source_id) count = _fetch_all_epg([(epg_url, source.epg_timeout, source_id)]) log.info("EPG refresh complete for %s: %d programs", source.name, count) else: log.warning("No EPG URL for source: %s", source.name) elif refresh_type == "vod" and source.type == "xtream": log.info("Refreshing VOD for source: %s", source.name) new_cats, new_streams = fetch_source_vod_data(source) # Merge with existing data from other sources existing_cats, existing_streams = load_vod_data() # Remove old data from this source, keep others merged_cats = [c for c in existing_cats if c.get("source_id") != source_id] merged_streams = [s for s in existing_streams if s.get("source_id") != source_id] # Add new data from this source merged_cats.extend(new_cats) merged_streams.extend(new_streams) with get_cache_lock(): get_cache().pop("vod_categories", None) get_cache().pop("vod_streams", None) for f in CACHE_DIR.glob("vod_data*.json"): f.unlink(missing_ok=True) save_file_cache("vod_data", {"cats": merged_cats, "streams": merged_streams}) log.info( "VOD refresh complete for %s: %d cats, %d streams (total: %d)", source.name, len(new_cats), len(new_streams), len(merged_streams), ) elif refresh_type == "m3u" and source.type == "m3u": log.info("Refreshing M3U playlist for source: %s", source.name) cats, streams, detected_epg_url = fetch_m3u(source.url, source.id) update_source_epg_url(source_id, detected_epg_url) with get_cache_lock(): existing_cats = [ c for c in get_cache().get("live_categories", []) if c.get("source_id") != source_id ] existing_streams = [ s for s in get_cache().get("live_streams", []) if s.get("source_id") != source_id ] new_cats = existing_cats + cats new_streams = existing_streams + streams get_cache()["live_categories"] = new_cats get_cache()["live_streams"] = new_streams epg_urls = get_cache().get("epg_urls", []) save_file_cache( "live_data", { "cats": new_cats, "streams": new_streams, "epg_urls": epg_urls, }, ) log.info( "M3U refresh complete for %s: %d cats, %d streams", source.name, len(cats), len(streams), ) except Exception as e: log.error("Source refresh failed (%s/%s): %s", source.name, refresh_type, e) finally: get_refresh_in_progress().discard(key) threading.Thread(target=do_refresh, daemon=True).start() return {"status": "started", "key": key} @app.get("/settings/refresh-status") async def settings_refresh_status(_user: Annotated[dict, Depends(require_auth)]): """Return per-source refresh status.""" statuses: dict[str, Any] = {} for key in list(get_refresh_in_progress()): if "_" in key: # Format: source_id_type (e.g., "ota_epg" or "src_123_epg") parts = key.rsplit("_", 1) if len(parts) == 2: source_id, rtype = parts statuses.setdefault(source_id, {})[rtype] = True # Report global guide_load as affecting all sources if "guide_load" in get_refresh_in_progress(): statuses["_global"] = {"live": True, "epg": True} return statuses @app.post("/settings/captions") async def settings_captions( user: Annotated[dict, Depends(require_auth)], enabled: Annotated[str, Form()] = "", ): username = user.get("sub", "") user_settings = load_user_settings(username) user_settings["captions_enabled"] = enabled == "on" save_user_settings(username, user_settings) return {"ok": True} @app.post("/api/cast-log") async def cast_log_endpoint(request: Request): """Log cast events from client (debug mode only).""" if log.isEnabledFor(logging.DEBUG): body = await request.body() # Sanitize: limit length, single line, printable chars only msg = body.decode("utf-8", errors="replace")[:2048] msg = "".join(c if c.isprintable() and c != "\n" else "?" for c in msg) log.debug(f"[CAST] {msg}") return {"ok": True} @app.get("/api/user-prefs") async def get_user_prefs(user: Annotated[dict, Depends(require_auth)]): """Get user preferences (favorites, cc_lang, cc_style, cast_host, virtual_scroll).""" username = user.get("sub", "") settings = load_user_settings(username) return { "favorites": settings.get("favorites", {}), "cc_lang": settings.get("cc_lang", ""), "cc_style": settings.get("cc_style", {}), "cast_host": settings.get("cast_host", ""), "virtual_scroll": settings.get("virtual_scroll", True), } @app.post("/api/user-prefs") async def save_user_prefs( request: Request, user: Annotated[dict, Depends(require_auth)], ): """Save user preferences (partial update).""" username = user.get("sub", "") body = await request.body() if len(body) > 64 * 1024: # 64KB limit raise HTTPException(400, "Request too large") data = json.loads(body) settings = load_user_settings(username) for key in ( "favorites", "cc_lang", "cc_style", "cast_host", "virtual_scroll", "guide_selected_cats", ): if key in data: settings[key] = data[key] save_user_settings(username, settings) return {"ok": True} def _fetch_logo(url: str, timeout: int = 10) -> tuple[bytes, str]: """Fetch logo synchronously. Returns (data, content_type).""" from util import safe_urlopen with safe_urlopen(url, timeout=timeout) as resp: content_type = resp.headers.get("Content-Type", "") if not content_type.startswith("image/"): raise ValueError("URL is not an image") data = resp.read(LOGO_MAX_SIZE) if len(data) >= LOGO_MAX_SIZE: raise ValueError("Image too large") return data, content_type @app.get("/api/logo") async def get_logo( url: str, _user: Annotated[dict, Depends(require_auth)], source: str = "default", ): """Proxy and cache external logos to avoid mixed-content issues.""" if not url: raise HTTPException(400, "Missing url parameter") # Check cache first cached = get_cached_logo(source, url) if cached: return FileResponse(cached, headers={"Cache-Control": f"max-age={LOGO_BROWSER_TTL}"}) # Validate URL scheme parsed = urllib.parse.urlparse(url) if parsed.scheme not in ("http", "https"): raise HTTPException(400, "Invalid URL scheme") # Fetch the logo (in thread to avoid blocking) try: data, content_type = await asyncio.to_thread(_fetch_logo, url) path = save_logo(source, url, data, content_type) return FileResponse(path, headers={"Cache-Control": f"max-age={LOGO_BROWSER_TTL}"}) except ValueError as e: raise HTTPException(400, str(e)) from None except urllib.error.URLError as e: log.debug("Logo fetch failed for %s: %s", url, e) raise HTTPException(502, "Failed to fetch logo") from None except Exception as e: log.debug("Logo fetch error for %s: %s", url, e) raise HTTPException(500, "Logo fetch error") from None @app.post("/settings/transcode") async def settings_transcode( _user: Annotated[dict, Depends(require_admin)], transcode_mode: Annotated[str, Form()], transcode_hw: Annotated[str, Form()], max_resolution: Annotated[str, Form()] = "1080p", quality: Annotated[str, Form()] = "high", vod_transcode_cache_mins: Annotated[int, Form()] = 60, live_transcode_cache_secs: Annotated[int, Form()] = 0, live_dvr_mins: Annotated[int, Form()] = 0, transcode_dir: Annotated[str, Form()] = "", probe_live: Annotated[str | None, Form()] = None, probe_movies: Annotated[str | None, Form()] = None, probe_series: Annotated[str | None, Form()] = None, sr_model: Annotated[str, Form()] = "", ): settings = load_server_settings() settings["transcode_mode"] = transcode_mode # Validate sr_model against available models available_models = get_sr_models() settings["sr_model"] = sr_model if sr_model in available_models else "" settings["transcode_hw"] = transcode_hw settings["max_resolution"] = max_resolution settings["quality"] = quality if quality in ("high", "medium", "low") else "high" settings["vod_transcode_cache_mins"] = max(0, vod_transcode_cache_mins) settings["live_transcode_cache_secs"] = max(0, live_transcode_cache_secs) settings["live_dvr_mins"] = max(0, live_dvr_mins) if transcode_dir: settings["transcode_dir"] = transcode_dir elif "transcode_dir" in settings: del settings["transcode_dir"] # Use default settings["probe_live"] = probe_live == "on" settings["probe_movies"] = probe_movies == "on" settings["probe_series"] = probe_series == "on" save_server_settings(settings) return {"ok": True} @app.post("/settings/refresh-encoders") async def settings_refresh_encoders( _user: Annotated[dict, Depends(require_admin)], ): """Re-detect available hardware encoders.""" encoders = refresh_encoders() return {"ok": True, "encoders": encoders} @app.post("/settings/user-agent") async def settings_user_agent( _user: Annotated[dict, Depends(require_admin)], preset: Annotated[str, Form()], custom: Annotated[str, Form()] = "", ): valid_presets = {"default", "vlc", "chrome", "tivimate", "custom"} if preset not in valid_presets: preset = "default" settings = load_server_settings() settings["user_agent_preset"] = preset settings["user_agent_custom"] = custom save_server_settings(settings) return {"ok": True} def _enrich_probe_cache_stats(stats: list[dict], xtream: Any) -> list[dict]: """Enrich probe cache stats with series/episode names (blocking).""" for entry in stats: series: dict | None = None if not entry.get("name") or entry.get("episodes"): cache_key = f"series_info_{entry['series_id']}" with contextlib.suppress(Exception): series = get_cached_info( cache_key, lambda sid=entry["series_id"]: xtream.get_series_info(sid) ) if series: if not entry.get("name") and series.get("info"): entry["name"] = series["info"].get("name", "") ep_map: dict[int, str] = {} for season_num, eps in (series.get("episodes") or {}).items(): for ep in eps: eid = ep.get("id") if eid: ep_num = ep.get("episode_num", 0) title = ep.get("title", "") title = re.sub(r"^S\d+E\d+\s*-\s*", "", title) if " - " in title: title = title.split(" - ")[-1] ep_map[int(eid)] = ( f"S{int(season_num):02d}E{int(ep_num):02d} {title.strip()}" ) for ep in entry.get("episodes", []): ep_id = ep.get("episode_id") if ep_id in ep_map: ep["name"] = ep_map[ep_id] return stats @app.get("/settings/probe-cache") async def get_probe_cache( _user: Annotated[dict, Depends(require_auth)], response: Response, ): """Get probe cache stats for settings UI.""" response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate" stats = ffmpeg_command.get_series_probe_cache_stats() xtream = get_first_xtream_client() if not xtream: return {"series": stats} stats = await asyncio.to_thread(_enrich_probe_cache_stats, stats, xtream) return {"series": stats} @app.post("/settings/probe-cache/clear") async def clear_probe_cache(_user: Annotated[dict, Depends(require_admin)]): """Clear all probe caches.""" count = ffmpeg_command.clear_all_probe_cache() return {"ok": True, "cleared": count} @app.post("/settings/probe-cache/clear/{series_id}") async def clear_series_probe_cache( series_id: int, _user: Annotated[dict, Depends(require_admin)], episode_id: int | None = None, ): """Clear probe cache for a specific series or episode.""" ffmpeg_command.invalidate_series_probe_cache(series_id, episode_id) return {"ok": True} @app.post("/settings/probe-cache/clear-mru/{series_id}") async def clear_series_mru( series_id: int, _user: Annotated[dict, Depends(require_admin)], ): """Clear only the MRU for a series, keeping episode cache intact.""" ffmpeg_command.clear_series_mru(series_id) return {"ok": True} @app.post("/settings/data-cache/clear") async def clear_data_cache(_user: Annotated[dict, Depends(require_admin)]): """Clear all data file caches (live, VOD, series) and memory cache.""" count = clear_all_file_caches() return {"ok": True, "cleared": count} @app.get("/api/settings") async def get_settings_api(_user: Annotated[dict, Depends(require_auth)]): return load_server_settings() @app.post("/api/settings") async def update_settings_api( request: Request, _user: Annotated[dict, Depends(require_admin)], ): data = await request.json() # Whitelist allowed keys - never allow users/secret_key to be overwritten allowed_keys = { "transcode_mode", "transcode_hw", "vod_transcode_cache_mins", "probe_live", "probe_movies", "probe_series", "vod_order", "series_order", } settings = load_server_settings() for key in allowed_keys: if key in data: settings[key] = data[key] save_server_settings(settings) return {"status": "ok"} @app.post("/api/watch-position") async def save_watch_position_api( request: Request, user: Annotated[dict, Depends(require_auth)], ): """Save watch position for a stream (per-user).""" username = user.get("sub", "") data = await request.json() url = data.get("url", "") position = float(data.get("position", 0)) duration = float(data.get("duration", 0)) if url and position >= 0: save_watch_position(username, url, position, duration) return {"status": "ok"} @app.get("/api/watch-position") async def get_watch_position_api( user: Annotated[dict, Depends(require_auth)], url: str, ): """Get watch position for a stream (per-user).""" username = user.get("sub", "") entry = get_watch_position(username, url) if entry: return {"position": entry.get("position", 0), "duration": entry.get("duration", 0)} return {"position": 0, "duration": 0} # User management endpoints @app.post("/settings/users/delete/{username}") async def settings_delete_user( username: str, user: Annotated[dict, Depends(require_auth)], password: Annotated[str, Form()] = "", ): """Delete a user. Self-deletion requires password. Other users require admin.""" current_user = user.get("sub", "") if username == current_user: if not password or not auth.verify_password(username, password): raise HTTPException(400, "Password required to delete your own account") auth.delete_user(username) response = RedirectResponse("/login", status_code=303) response.delete_cookie("token") return response # Deleting other users requires admin if not auth.is_admin(current_user): raise HTTPException(403, "Admin access required") if not auth.delete_user(username): raise HTTPException(404, "User not found") return RedirectResponse("/settings", status_code=303) @app.post("/settings/users/add") async def settings_add_user( _user: Annotated[dict, Depends(require_admin)], username: Annotated[str, Form()], password: Annotated[str, Form()], admin: Annotated[str, Form()] = "", max_streams_per_source: Annotated[str | None, Form()] = None, unavailable_groups: Annotated[str | None, Form()] = None, ): """Add a new user.""" username = username.strip() if not username or len(username) < 2: raise HTTPException(400, "Username must be at least 2 characters") if len(password) < 8: raise HTTPException(400, "Password must be at least 8 characters") if username in auth.get_all_usernames(): raise HTTPException(400, "User already exists") auth.create_user(username, password, admin=admin == "on") # Apply limits if provided parsed_max_streams = None if max_streams_per_source: with contextlib.suppress(json.JSONDecodeError): parsed_max_streams = json.loads(max_streams_per_source) parsed_unavailable = None if unavailable_groups: with contextlib.suppress(json.JSONDecodeError): parsed_unavailable = json.loads(unavailable_groups) if parsed_max_streams or parsed_unavailable: auth.set_user_limits(username, parsed_max_streams, parsed_unavailable) return {"status": "ok"} @app.post("/settings/users/password") async def settings_change_own_password( user: Annotated[dict, Depends(require_auth)], current_password: Annotated[str, Form()], new_password: Annotated[str, Form()], ): """Change own password. Requires current password verification.""" username = user.get("sub", "") if not auth.verify_password(username, current_password): raise HTTPException(400, "Current password is incorrect") if len(new_password) < 8: raise HTTPException(400, "Password must be at least 8 characters") if not auth.change_password(username, new_password): raise HTTPException(404, "User not found") return {"status": "ok"} @app.post("/settings/users/password/{target_user}") async def settings_change_password( target_user: str, user: Annotated[dict, Depends(require_auth)], new_password: Annotated[str, Form()], ): """Change a user's password. Own password or admin required.""" current_user = user.get("sub", "") if target_user != current_user and not auth.is_admin(current_user): raise HTTPException(403, "Admin access required") if len(new_password) < 8: raise HTTPException(400, "Password must be at least 8 characters") if not auth.change_password(target_user, new_password): raise HTTPException(404, "User not found") return {"status": "ok"} @app.post("/settings/users/admin/{target_user}") async def settings_set_admin( target_user: str, _user: Annotated[dict, Depends(require_admin)], admin: Annotated[str, Form()] = "", ): """Set admin status for a user.""" if not auth.set_admin(target_user, admin == "on"): raise HTTPException(404, "User not found") return {"status": "ok"} @app.post("/settings/users/limits/{target_user}") async def settings_set_user_limits( target_user: str, _user: Annotated[dict, Depends(require_admin)], max_streams_per_source: Annotated[str | None, Form()] = None, # JSON object string unavailable_groups: Annotated[str | None, Form()] = None, # JSON array string ): """Set stream limits and group restrictions for a user.""" parsed_max_streams = None if max_streams_per_source is not None: try: parsed_max_streams = json.loads(max_streams_per_source) if not isinstance(parsed_max_streams, dict): raise HTTPException(400, "max_streams_per_source must be a JSON object") except json.JSONDecodeError as err: raise HTTPException(400, "Invalid JSON for max_streams_per_source") from err parsed_unavailable = None if unavailable_groups is not None: try: parsed_unavailable = json.loads(unavailable_groups) if not isinstance(parsed_unavailable, list): raise HTTPException(400, "unavailable_groups must be a JSON array") except json.JSONDecodeError as err: raise HTTPException(400, "Invalid JSON for unavailable_groups") from err if not auth.set_user_limits(target_user, parsed_max_streams, parsed_unavailable): raise HTTPException(404, "User not found") return {"status": "ok"} if __name__ == "__main__": import argparse import uvicorn # pyright: ignore[reportMissingImports] parser = argparse.ArgumentParser(description="IPTV Web App") parser.add_argument("--port", type=int, default=8000, help="Port to listen on") parser.add_argument("--debug", action="store_true", help="Enable debug logging") parser.add_argument( "--https", nargs="?", const="", metavar="DOMAIN", help="Enable HTTPS (auto-detect domain, or specify one)", ) parser.add_argument("--cert", help="SSL certificate file (e.g., fullchain.pem)") parser.add_argument("--key", help="SSL private key file (e.g., privkey.pem)") args = parser.parse_args() # LOG_LEVEL env var takes precedence, then --debug flag, then default INFO log_level_env = os.environ.get("LOG_LEVEL", "").upper() if log_level_env in ("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"): log_level = getattr(logging, log_level_env) else: log_level = logging.DEBUG if args.debug else logging.INFO logging.basicConfig( level=log_level, format="%(asctime)s | %(levelname)s | %(message)s", datefmt="%H:%M:%S", force=True, ) ssl_args = {} if args.cert and args.key: ssl_args = {"ssl_certfile": args.cert, "ssl_keyfile": args.key} elif args.https is not None: live_dir = pathlib.Path("/etc/letsencrypt/live") if args.https: domain = args.https else: # Auto-detect first domain domains = [ d.name for d in live_dir.iterdir() if d.is_dir() and (d / "fullchain.pem").exists() ] if not domains: raise SystemExit("No Let's Encrypt certs found in /etc/letsencrypt/live/") domain = domains[0] cert = live_dir / domain / "fullchain.pem" key = live_dir / domain / "privkey.pem" if not cert.exists(): raise SystemExit(f"Cert not found: {cert}") log.info("Using Let's Encrypt certs for %s", domain) ssl_args = {"ssl_certfile": str(cert), "ssl_keyfile": str(key)} uv_log = "debug" if args.debug else "info" uvicorn.run( app, host="0.0.0.0", port=args.port, access_log=args.debug, log_level=uv_log, log_config=None, # preserve our basicConfig timeout_graceful_shutdown=5, proxy_headers=True, forwarded_allow_ips="*", **ssl_args, # pyright: ignore[reportArgumentType] ) ================================================ FILE: main_test.py ================================================ """Tests for main.py - FastAPI routes.""" from __future__ import annotations from pathlib import Path from unittest.mock import MagicMock, patch import json import pytest import cache as cache_module import m3u as m3u_module @pytest.fixture def mock_deps(): """Mock all external dependencies before importing main.""" with ( patch.dict( "sys.modules", {"defusedxml": MagicMock(), "defusedxml.ElementTree": MagicMock()} ), patch("cache.CACHE_DIR", Path("/tmp/test_cache")), patch("cache.SERVER_SETTINGS_FILE", Path("/tmp/test_cache/server_settings.json")), patch("cache.USERS_DIR", Path("/tmp/test_cache/users")), ): yield @pytest.fixture def client(tmp_path: Path, mock_deps): """Create test client with mocked dependencies.""" from fastapi.testclient import TestClient # Patch paths before importing main with ( patch("cache.CACHE_DIR", tmp_path), patch("cache.SERVER_SETTINGS_FILE", tmp_path / "server_settings.json"), patch("cache.USERS_DIR", tmp_path / "users"), patch("auth.CACHE_DIR", tmp_path), patch("auth.SERVER_SETTINGS_FILE", tmp_path / "server_settings.json"), patch("auth.USERS_DIR", tmp_path / "users"), patch("epg.init"), patch("ffmpeg_command.init"), patch("ffmpeg_session.cleanup_and_recover_sessions"), ): (tmp_path / "users").mkdir(exist_ok=True) import main # Disable background loading cache_module.get_cache().clear() yield TestClient(main.app) @pytest.fixture def auth_client(tmp_path: Path, mock_deps): """Create test client with a logged-in user.""" from fastapi.testclient import TestClient with ( patch("cache.CACHE_DIR", tmp_path), patch("cache.SERVER_SETTINGS_FILE", tmp_path / "server_settings.json"), patch("cache.USERS_DIR", tmp_path / "users"), patch("auth.CACHE_DIR", tmp_path), patch("auth.SERVER_SETTINGS_FILE", tmp_path / "server_settings.json"), patch("auth.USERS_DIR", tmp_path / "users"), patch("epg.init"), patch("ffmpeg_command.init"), patch("ffmpeg_session.cleanup_and_recover_sessions"), ): (tmp_path / "users").mkdir(exist_ok=True) import auth import main cache_module.get_cache().clear() client = TestClient(main.app) # Create user and get token auth.create_user("testuser", "testpass123") token = auth.create_token({"sub": "testuser"}) client.cookies.set("token", token) yield client class TestSetup: """Tests for initial setup flow.""" def test_setup_page_shown_when_no_users(self, client): resp = client.get("/setup", follow_redirects=False) assert resp.status_code == 200 assert b"setup" in resp.content.lower() or b"Create" in resp.content def test_setup_redirects_when_users_exist(self, client, tmp_path): import auth auth.create_user("admin", "password123") resp = client.get("/setup", follow_redirects=False) assert resp.status_code == 303 assert resp.headers["location"] == "/login" def test_setup_creates_user(self, client): resp = client.post( "/setup", data={"username": "admin", "password": "password123", "confirm": "password123"}, follow_redirects=False, ) assert resp.status_code == 303 assert resp.headers["location"] == "/login" import auth assert auth.verify_password("admin", "password123") def test_setup_validates_username_length(self, client): resp = client.post( "/setup", data={"username": "ab", "password": "password123", "confirm": "password123"}, ) assert resp.status_code == 200 assert b"at least 3" in resp.content def test_setup_validates_password_length(self, client): resp = client.post( "/setup", data={"username": "admin", "password": "short", "confirm": "short"}, ) assert resp.status_code == 200 assert b"at least 8" in resp.content def test_setup_validates_password_match(self, client): resp = client.post( "/setup", data={"username": "admin", "password": "password123", "confirm": "different"}, ) assert resp.status_code == 200 assert b"do not match" in resp.content class TestLogin: """Tests for login flow.""" def test_login_page_redirects_to_setup_when_no_users(self, client): resp = client.get("/login", follow_redirects=False) assert resp.status_code == 303 assert resp.headers["location"] == "/setup" def test_login_page_shown_when_users_exist(self, client, tmp_path): import auth auth.create_user("admin", "password123") resp = client.get("/login") assert resp.status_code == 200 def test_login_success_sets_cookie(self, client, tmp_path): import auth auth.create_user("admin", "password123") resp = client.post( "/login", data={"username": "admin", "password": "password123"}, follow_redirects=False, ) assert resp.status_code == 303 assert "token" in resp.cookies def test_login_failure_returns_401(self, client, tmp_path): import auth auth.create_user("admin", "password123") resp = client.post( "/login", data={"username": "admin", "password": "wrongpassword"}, follow_redirects=False, ) assert resp.status_code == 303 assert "error=invalid" in resp.headers["location"] class TestLogout: """Tests for logout.""" def test_logout_clears_cookie(self, auth_client): resp = auth_client.get("/logout", follow_redirects=False) assert resp.status_code == 303 assert resp.headers["location"] == "/login" class TestAuthRequired: """Tests for auth-protected routes.""" def test_index_redirects_to_login(self, client, tmp_path): import auth auth.create_user("admin", "password123") resp = client.get("/", follow_redirects=False) assert resp.status_code == 303 assert resp.headers["location"] == "/login" def test_guide_redirects_to_login(self, client, tmp_path): import auth auth.create_user("admin", "password123") resp = client.get("/guide", follow_redirects=False) assert resp.status_code == 303 assert resp.headers["location"] == "/login" def test_vod_redirects_to_login(self, client, tmp_path): import auth auth.create_user("admin", "password123") resp = client.get("/vod", follow_redirects=False) assert resp.status_code == 303 assert resp.headers["location"] == "/login" def test_series_redirects_to_login(self, client, tmp_path): import auth auth.create_user("admin", "password123") resp = client.get("/series", follow_redirects=False) assert resp.status_code == 303 assert resp.headers["location"] == "/login" class TestIndex: """Tests for index route.""" def test_index_redirects_to_guide(self, auth_client): resp = auth_client.get("/", follow_redirects=False) assert resp.status_code == 303 assert resp.headers["location"] == "/guide" class TestFavicon: """Tests for favicon.""" def test_favicon_returns_204(self, client): resp = client.get("/favicon.ico") assert resp.status_code == 204 class TestGuide: """Tests for guide page.""" def test_guide_shows_loading_when_no_cache(self, auth_client): with patch("main.load_file_cache", return_value=None): resp = auth_client.get("/guide") assert resp.status_code == 200 # Should show loading state assert b"loading" in resp.content.lower() or b"Loading" in resp.content def test_guide_shows_channels_from_cache(self, auth_client): cache_module.get_cache()["live_categories"] = [ {"category_id": "1", "category_name": "News"} ] cache_module.get_cache()["live_streams"] = [ {"stream_id": 1, "name": "CNN", "category_ids": ["1"], "epg_channel_id": ""} ] with patch("main.epg.has_programs", return_value=True): resp = auth_client.get("/guide?cats=1") assert resp.status_code == 200 def test_guide_uses_saved_filter(self, auth_client, tmp_path): user_dir = tmp_path / "users" / "testuser" user_dir.mkdir(parents=True, exist_ok=True) (user_dir / "settings.json").write_text(json.dumps({"guide_filter": ["1", "2"]})) cache_module.get_cache()["live_categories"] = [] cache_module.get_cache()["live_streams"] = [] # Guide now renders directly using saved filter (no redirect) with patch("main.epg.has_programs", return_value=True): resp = auth_client.get("/guide") assert resp.status_code == 200 class TestVod: """Tests for VOD page.""" def test_vod_shows_loading_when_no_cache(self, auth_client): with patch("main.load_file_cache", return_value=None): resp = auth_client.get("/vod") assert resp.status_code == 200 def test_vod_shows_movies_from_cache(self, auth_client): cache_module.get_cache()["vod_categories"] = [ {"category_id": "10", "category_name": "Movies", "source_id": "src1"} ] cache_module.get_cache()["vod_streams"] = [ {"stream_id": 100, "name": "Movie 1", "category_id": "10", "source_id": "src1"} ] resp = auth_client.get("/vod") assert resp.status_code == 200 def test_vod_filters_by_category(self, auth_client): cache_module.get_cache()["vod_categories"] = [ {"category_id": "10", "category_name": "Action", "source_id": "src1"}, {"category_id": "20", "category_name": "Comedy", "source_id": "src1"}, ] cache_module.get_cache()["vod_streams"] = [ {"stream_id": 100, "name": "Action Movie", "category_id": "10", "source_id": "src1"}, {"stream_id": 101, "name": "Comedy Movie", "category_id": "20", "source_id": "src1"}, ] resp = auth_client.get("/vod?category=10") assert resp.status_code == 200 def test_vod_sorts_by_alpha(self, auth_client): cache_module.get_cache()["vod_categories"] = [] cache_module.get_cache()["vod_streams"] = [ {"stream_id": 1, "name": "Zebra", "source_id": "src1"}, {"stream_id": 2, "name": "Apple", "source_id": "src1"}, ] resp = auth_client.get("/vod?sort=alpha") assert resp.status_code == 200 class TestSeries: """Tests for series page.""" def test_series_shows_loading_when_no_cache(self, auth_client): with patch("main.load_file_cache", return_value=None): resp = auth_client.get("/series") assert resp.status_code == 200 def test_series_shows_list_from_cache(self, auth_client): cache_module.get_cache()["series_categories"] = [ {"category_id": "30", "category_name": "Drama", "source_id": "src1"} ] cache_module.get_cache()["series"] = [ {"series_id": 200, "name": "Show 1", "category_id": "30", "source_id": "src1"} ] resp = auth_client.get("/series") assert resp.status_code == 200 class TestSearch: """Tests for search page.""" def test_search_page_renders(self, auth_client): cache_module.get_cache()["live_streams"] = [] cache_module.get_cache()["vod_streams"] = [] cache_module.get_cache()["series"] = [] resp = auth_client.get("/search") assert resp.status_code == 200 def test_search_finds_live_streams(self, auth_client): cache_module.get_cache()["live_streams"] = [ {"stream_id": 1, "name": "CNN News"}, {"stream_id": 2, "name": "BBC World"}, ] cache_module.get_cache()["live_categories"] = [] cache_module.get_cache()["epg_urls"] = [] cache_module.get_cache()["vod_streams"] = [] cache_module.get_cache()["series"] = [] resp = auth_client.get("/search?q=CNN&live=true") assert resp.status_code == 200 def test_search_regex_mode(self, auth_client): cache_module.get_cache()["live_streams"] = [ {"stream_id": 1, "name": "CNN News"}, {"stream_id": 2, "name": "CNBC Finance"}, ] cache_module.get_cache()["live_categories"] = [] cache_module.get_cache()["epg_urls"] = [] cache_module.get_cache()["vod_streams"] = [] cache_module.get_cache()["series"] = [] resp = auth_client.get("/search?q=CN.*®ex=true&live=true") assert resp.status_code == 200 def test_search_rejects_long_regex(self, auth_client): cache_module.get_cache()["live_streams"] = [] resp = auth_client.get(f"/search?q={'a' * 101}®ex=true&live=true") assert resp.status_code == 400 class TestSettings: """Tests for settings page.""" def test_settings_page_renders(self, auth_client): cache_module.get_cache()["live_categories"] = [] with patch("main.load_file_cache", return_value=None): resp = auth_client.get("/settings") assert resp.status_code == 200 def test_settings_guide_filter(self, auth_client): resp = auth_client.post( "/settings/guide-filter", json={"cats": ["1", "2", "3"]}, ) assert resp.status_code == 200 assert resp.json()["status"] == "ok" def test_settings_captions(self, auth_client): resp = auth_client.post( "/settings/captions", data={"enabled": "on"}, ) assert resp.status_code == 200 assert resp.json()["ok"] is True def test_settings_transcode(self, auth_client): resp = auth_client.post( "/settings/transcode", data={ "transcode_mode": "auto", "transcode_hw": "nvidia", "vod_transcode_cache_mins": 60, }, ) assert resp.status_code == 200 assert resp.json()["ok"] is True class TestAddSource: """Tests for adding sources.""" def test_add_xtream_source(self, auth_client): with patch("main.clear_all_caches"): resp = auth_client.post( "/settings/add", data={ "name": "Test Provider", "source_type": "xtream", "url": "http://example.com", "username": "user", "password": "pass", "epg_timeout": 120, }, follow_redirects=False, ) assert resp.status_code == 303 def test_add_m3u_source(self, auth_client): with patch("main.clear_all_caches"): resp = auth_client.post( "/settings/add", data={ "name": "M3U Playlist", "source_type": "m3u", "url": "http://example.com/playlist.m3u", "epg_timeout": 120, }, follow_redirects=False, ) assert resp.status_code == 303 def test_add_source_validates_type(self, auth_client): resp = auth_client.post( "/settings/add", data={ "name": "Bad Source", "source_type": "invalid", "url": "http://example.com", }, ) assert resp.status_code == 400 def test_add_source_validates_url_scheme(self, auth_client): resp = auth_client.post( "/settings/add", data={ "name": "Bad Source", "source_type": "xtream", "url": "ftp://example.com", }, ) assert resp.status_code == 400 def test_add_source_validates_name_length(self, auth_client): resp = auth_client.post( "/settings/add", data={ "name": "x" * 201, "source_type": "xtream", "url": "http://example.com", }, ) assert resp.status_code == 400 class TestDeleteSource: """Tests for deleting sources.""" def test_delete_source(self, auth_client, tmp_path): settings_file = tmp_path / "server_settings.json" settings_file.write_text( json.dumps( { "sources": [ { "id": "src_123", "name": "Test", "type": "xtream", "url": "http://example.com", } ] } ) ) with patch("main.clear_all_caches"): resp = auth_client.post("/settings/delete/src_123", follow_redirects=False) assert resp.status_code == 303 class TestUserPrefs: """Tests for user preferences API.""" def test_get_user_prefs(self, auth_client): resp = auth_client.get("/api/user-prefs") assert resp.status_code == 200 data = resp.json() assert "favorites" in data assert "cc_lang" in data def test_save_user_prefs(self, auth_client): resp = auth_client.post( "/api/user-prefs", json={"cc_lang": "eng", "cast_host": "192.168.1.100"}, ) assert resp.status_code == 200 assert resp.json()["ok"] is True class TestWatchPosition: """Tests for watch position API.""" def test_save_watch_position(self, auth_client): resp = auth_client.post( "/api/watch-position", json={"url": "http://example.com/movie.mkv", "position": 1234.5, "duration": 7200}, ) assert resp.status_code == 200 def test_get_watch_position(self, auth_client): # Save first auth_client.post( "/api/watch-position", json={"url": "http://example.com/movie.mkv", "position": 1234.5, "duration": 7200}, ) resp = auth_client.get("/api/watch-position?url=http://example.com/movie.mkv") assert resp.status_code == 200 data = resp.json() assert data["position"] == 1234.5 assert data["duration"] == 7200 def test_get_watch_position_not_found(self, auth_client): resp = auth_client.get("/api/watch-position?url=http://example.com/unknown.mkv") assert resp.status_code == 200 data = resp.json() assert data["position"] == 0 class TestUserManagement: """Tests for user management endpoints.""" def test_delete_user(self, auth_client, tmp_path): import auth auth.create_user("otheruser", "password123") resp = auth_client.post("/settings/users/delete/otheruser", follow_redirects=False) assert resp.status_code == 303 def test_cannot_delete_self(self, auth_client): resp = auth_client.post("/settings/users/delete/testuser") assert resp.status_code == 400 def test_change_password(self, auth_client): resp = auth_client.post( "/settings/users/password", data={"current_password": "testpass123", "new_password": "newpass456"}, ) assert resp.status_code == 200 def test_change_password_wrong_current(self, auth_client): resp = auth_client.post( "/settings/users/password", data={"current_password": "wrongpass", "new_password": "newpass456"}, ) assert resp.status_code == 400 class TestPlaylistXspf: """Tests for XSPF playlist generation.""" def test_playlist_xspf(self, auth_client): resp = auth_client.get("/playlist.xspf?url=http://example.com/stream.m3u8") assert resp.status_code == 200 assert b"=3.10" authors = [ { name = "Joshua V. Dillon" }, ] keywords = ["iptv", "streaming", "epg", "hls", "transcoding", "chromecast"] classifiers = [ "Development Status :: 4 - Beta", "Environment :: Web Environment", "Framework :: FastAPI", "Intended Audience :: End Users/Desktop", "License :: OSI Approved :: Apache Software License", "Operating System :: POSIX :: Linux", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Topic :: Multimedia :: Video :: Display", ] dependencies = [ "fastapi>=0.115", "uvicorn[standard]>=0.32", "jinja2>=3.1", "python-multipart>=0.0.9", "cryptography>=43.0", "defusedxml>=0.7", ] [dependency-groups] dev = [ "basedpyright>=1.10", "httpx>=0.27", "pytest>=7.0", "pytest-asyncio>=0.23", "ruff>=0.8", ] ai_upscale = [ "torch>=2.0", "onnx>=1.14", "tensorrt>=10.0", ] [project.urls] Repository = "https://github.com/jvdillon/netv" [build-system] requires = ["hatchling"] build-backend = "hatchling.build" [tool.hatch.build.targets.wheel] include = ["*.py"] [tool.uv] package = false [tool.ruff] fix = true show-fixes = true line-length = 100 output-format = "concise" target-version = "py311" [tool.ruff.lint] select = ["E", "F", "I", "UP", "B", "SIM"] ignore = ["E501", "SIM108"] [tool.ruff.lint.isort] case-sensitive = false combine-as-imports = true force-wrap-aliases = true split-on-trailing-comma = true default-section = "third-party" detect-same-package = true force-single-line = false force-sort-within-sections = false from-first = true lines-between-types = 1 lines-after-imports = 2 # extra-standard-library = ["typing","typing_extensions"] known-first-party = ["core", "ae", "gm", "data", "tools"] no-lines-before = ["future"] #, "standard-library"] order-by-type = true relative-imports-order = "closest-to-furthest" section-order = [ "future", "standard-library", "third-party", "first-party", "local-folder", ] [tool.basedpyright] typeCheckingMode = "standard" pythonVersion = "3.11" exclude = ["tools/", ".venv/"] # Optional scripts + virtual env [tool.pytest.ini_options] testpaths = ["."] python_files = ["*_test.py"] pythonpath = ["."] filterwarnings = ["ignore::pytest.PytestUnraisableExceptionWarning:coroutine.*was never awaited"] ================================================ FILE: static/js/app.js ================================================ // Keyboard navigation for 10-foot UI (function() { 'use strict'; // Pages with custom arrow key handling const customNavPages = ['/play/', '/guide']; const hasCustomNav = customNavPages.some(p => location.pathname.startsWith(p)); // ============================================================ // Focus Management // ============================================================ function getFocusables(container = document) { return Array.from(container.querySelectorAll( 'a[href]:not([disabled]), button:not([disabled]), input:not([disabled]), select:not([disabled]), [tabindex="0"], .focusable' )).filter(el => el.offsetParent !== null); // visible only } function getGridInfo(element) { const grid = element.closest('.grid'); if (!grid) return null; const items = Array.from(grid.querySelectorAll('[data-nav="grid"]')); const index = items.indexOf(element); if (index === -1) return null; // Detect columns by comparing Y positions let cols = 1; if (items.length > 1) { const firstTop = items[0].getBoundingClientRect().top; for (let i = 1; i < items.length; i++) { if (items[i].getBoundingClientRect().top > firstTop + 5) { cols = i; break; } } if (cols === 1) cols = items.length; } return { items, index, cols }; } function moveFocus(direction) { const current = document.activeElement; const focusables = getFocusables(); const currentIndex = focusables.indexOf(current); // Try grid navigation first const gridInfo = getGridInfo(current); if (gridInfo && gridInfo.cols > 1) { const { items, index, cols } = gridInfo; let nextIndex = -1; switch (direction) { case 'up': nextIndex = index - cols; break; case 'down': nextIndex = index + cols; break; case 'left': nextIndex = index - 1; break; case 'right': nextIndex = index + 1; break; } if (nextIndex >= 0 && nextIndex < items.length) { items[nextIndex].focus(); items[nextIndex].scrollIntoView({ block: 'nearest', behavior: 'smooth' }); return true; } // At grid edge - don't wrap for up/down if (direction === 'up' || direction === 'down') return false; } // Linear navigation fallback let nextElement = null; if (direction === 'up' || direction === 'left') { nextElement = focusables[currentIndex - 1]; } else { nextElement = focusables[currentIndex + 1]; } if (nextElement) { nextElement.focus(); nextElement.scrollIntoView({ block: 'nearest', behavior: 'smooth' }); return true; } return false; } // ============================================================ // Initial Focus // ============================================================ function setInitialFocus() { // Skip if something is already focused (other than body) if (document.activeElement && document.activeElement !== document.body) return; // Priority: [autofocus], first grid item, first focusable in main const autofocus = document.querySelector('[autofocus]'); if (autofocus) { autofocus.focus(); return; } const mainContent = document.querySelector('main'); if (!mainContent) return; const gridItem = mainContent.querySelector('[data-nav="grid"]'); if (gridItem) { gridItem.focus(); return; } const firstFocusable = getFocusables(mainContent)[0]; if (firstFocusable) firstFocusable.focus(); } // ============================================================ // Favorites Toggle // ============================================================ function toggleFocusedFavorite() { const el = document.activeElement; if (!el) return false; // Check for movie card const movieCard = el.closest('.movie-card'); if (movieCard) { const btn = movieCard.querySelector('.fav-btn, .fav-btn-movie'); if (btn) { btn.click(); return true; } } // Check for series card const seriesCard = el.closest('.series-card'); if (seriesCard) { const btn = seriesCard.querySelector('.fav-btn, .fav-btn-series'); if (btn) { btn.click(); return true; } } // Check for favorites tile (in favorites view) const tile = el.closest('.vod-tile, .series-tile'); if (tile) { const btn = tile.querySelector('button'); if (btn) { btn.click(); return true; } } // Check for detail page favorite button const favBtn = document.getElementById('fav-btn'); if (favBtn) { favBtn.click(); return true; } return false; } // ============================================================ // Keyboard Handler // ============================================================ document.addEventListener('keydown', (e) => { const isInput = e.target.tagName === 'INPUT' || e.target.tagName === 'TEXTAREA'; const isSelect = e.target.tagName === 'SELECT'; // Input field handling if (isInput) { if (e.key === 'Escape') { e.target.blur(); return; } // Allow down arrow to escape search input if (e.key === 'ArrowDown' && e.target.type === 'text') { const mainContent = document.querySelector('main'); const firstResult = mainContent?.querySelector('[data-nav="grid"]'); if (firstResult) { e.preventDefault(); firstResult.focus(); return; } } // Let other keys work normally in inputs return; } // Select handling - let arrows work for options if (isSelect && (e.key === 'ArrowUp' || e.key === 'ArrowDown')) { return; } switch (e.key) { case 'ArrowUp': case 'ArrowDown': case 'ArrowLeft': case 'ArrowRight': // Skip if page has custom navigation or Alt pressed (browser nav) if (hasCustomNav || e.altKey) return; e.preventDefault(); const dir = e.key.replace('Arrow', '').toLowerCase(); moveFocus(dir); break; case 'Enter': { const el = document.activeElement; if (el?.href) { e.preventDefault(); if (e.ctrlKey || e.metaKey) { window.open(el.href, '_blank'); } else { window.location.href = el.href; } } else if (el?.click && el.tagName !== 'A' && el.tagName !== 'BUTTON') { e.preventDefault(); el.click(); } break; } case 'f': case 'F': if (toggleFocusedFavorite()) { e.preventDefault(); } break; case 'Escape': // Only handle if focus is on a known focusable element (not during browser find dialog, etc.) if (!document.activeElement || document.activeElement === document.body) return; e.preventDefault(); if (document.activeElement?.closest('nav')) { // In nav - go to main content const mainFocusable = document.querySelector('main [data-nav="grid"], main .focusable, main a[href], main button'); if (mainFocusable) mainFocusable.focus(); } else { // In content - go to nav const navLink = document.querySelector('nav .nav-link'); if (navLink) navLink.focus(); } break; case 'Backspace': // Go back unless on root pages or in input const rootPages = ['/', '/guide', '/vod', '/series', '/search', '/settings']; if (!rootPages.includes(location.pathname)) { e.preventDefault(); history.back(); } break; } }); // Set initial focus after page load if (document.readyState === 'loading') { document.addEventListener('DOMContentLoaded', setInitialFocus); } else { setTimeout(setInitialFocus, 0); } })(); ================================================ FILE: static/js/favorites-grid.js ================================================ // Shared Favorites Grid Module for VOD/Series pages // Requires: window.FAVORITES_CONFIG = { type: 'movies'|'series', favorites, cardClass, tileClass, detailUrl, orderKey } (function() { 'use strict'; const cfg = window.FAVORITES_CONFIG; if (!cfg) return; function escapeHtml(s) { if (!s) return ''; return String(s).replace(/&/g, '&').replace(//g, '>').replace(/"/g, '"').replace(/'/g, '''); } function escapeAttr(s) { if (!s) return ''; return String(s).replace(/&/g, '&').replace(/"/g, '"').replace(/'/g, '''); } window.favorites = cfg.favorites; function getFavorites() { return window.favorites[cfg.type] || {}; } function saveFavorites() { fetch('/api/user-prefs', { method: 'POST', headers: {'Content-Type': 'application/json'}, body: JSON.stringify({favorites: window.favorites}) }); } window.toggleFavorite = function(id, name, cover, ext) { const favs = window.favorites[cfg.type]; if (favs[id]) { delete favs[id]; } else { favs[id] = cfg.type === 'movies' ? { name, cover, ext } : { name, cover }; } saveFavorites(); updateFavoriteButtons(); if (typeof window.renderFavorites === 'function') window.renderFavorites(); }; window.updateFavoriteButtons = function() { const favs = getFavorites(); document.querySelectorAll('.fav-btn').forEach(btn => { const card = btn.closest('.' + cfg.cardClass); const id = card?.dataset[cfg.type === 'movies' ? 'movieId' : 'seriesId']; btn.textContent = favs[id] ? '★' : '☆'; btn.classList.toggle('text-yellow-400', !!favs[id]); }); }; // Browse view handlers if (cfg.isBrowseView) { window.updateBrowseUrl = function() { const cat = document.getElementById('category-select').value; const sort = document.getElementById('sort-select').value; const params = new URLSearchParams(); if (cat) params.set('category', cat); params.set('sort', sort); window.location.href = cfg.baseUrl + '?' + params; }; const catSel = document.getElementById('category-select'); const sortSel = document.getElementById('sort-select'); if (catSel) catSel.addEventListener('change', updateBrowseUrl); if (sortSel) sortSel.addEventListener('change', updateBrowseUrl); updateFavoriteButtons(); } // Favorites view handlers if (!cfg.isBrowseView) { async function getOrder() { try { const resp = await fetch('/api/settings'); const settings = await resp.json(); return settings[cfg.orderKey] || []; } catch (e) { console.error('Failed to get order:', e); return []; } } async function saveOrder(order) { try { const resp = await fetch('/api/settings'); const settings = await resp.json(); settings[cfg.orderKey] = order; await fetch('/api/settings', { method: 'POST', headers: {'Content-Type': 'application/json'}, body: JSON.stringify(settings) }); } catch (e) { console.error('Failed to save order:', e); } } window.renderFavorites = async function() { const favs = getFavorites(); const grid = document.getElementById('favorites-grid'); const noFavs = document.getElementById('no-favorites'); let ids = Object.keys(favs); if (ids.length === 0) { grid.innerHTML = ''; noFavs.classList.remove('hidden'); return; } const order = await getOrder(); const orderedIds = order.filter(id => favs[id]); const unorderedIds = ids.filter(id => !order.includes(id)); ids = [...orderedIds, ...unorderedIds]; noFavs.classList.add('hidden'); grid.innerHTML = ids.map(id => { const f = favs[id]; const safeId = escapeAttr(id); const safeCover = escapeAttr(f.cover); const safeName = escapeHtml(f.name); return ` `; }).join(''); initDragDrop(); }; function initDragDrop() { const grid = document.getElementById('favorites-grid'); let draggedEl = null; let touchStartY = 0; let touchStartX = 0; let longPressTimer = null; let isDragging = false; grid.addEventListener('contextmenu', (e) => { if (e.target.closest('.' + cfg.tileClass)) e.preventDefault(); }); grid.querySelectorAll('.' + cfg.tileClass).forEach(tile => { tile.draggable = true; tile.addEventListener('dragstart', () => { draggedEl = tile; tile.classList.add('opacity-50'); }); tile.addEventListener('dragend', () => { tile.classList.remove('opacity-50'); draggedEl = null; saveCurrentOrder(); }); tile.addEventListener('dragover', (e) => { e.preventDefault(); if (draggedEl && draggedEl !== tile) { const rect = tile.getBoundingClientRect(); const midpoint = rect.left + rect.width / 2; grid.insertBefore(draggedEl, e.clientX < midpoint ? tile : tile.nextSibling); } }); tile.addEventListener('touchstart', (e) => { touchStartX = e.touches[0].clientX; touchStartY = e.touches[0].clientY; longPressTimer = setTimeout(() => { isDragging = true; draggedEl = tile; tile.classList.add('opacity-50', 'ring-2', 'ring-blue-500'); navigator.vibrate?.(50); }, 400); }, {passive: true}); tile.addEventListener('touchmove', (e) => { if (longPressTimer && !isDragging) { const dx = Math.abs(e.touches[0].clientX - touchStartX); const dy = Math.abs(e.touches[0].clientY - touchStartY); if (dx > 10 || dy > 10) { clearTimeout(longPressTimer); longPressTimer = null; } } if (!isDragging || !draggedEl) return; e.preventDefault(); const touch = e.touches[0]; const target = document.elementFromPoint(touch.clientX, touch.clientY)?.closest('.' + cfg.tileClass); if (target && target !== draggedEl) { const rect = target.getBoundingClientRect(); const midpoint = rect.left + rect.width / 2; grid.insertBefore(draggedEl, touch.clientX < midpoint ? target : target.nextSibling); } }, {passive: false}); tile.addEventListener('touchend', () => { clearTimeout(longPressTimer); longPressTimer = null; if (isDragging && draggedEl) { draggedEl.classList.remove('opacity-50', 'ring-2', 'ring-blue-500'); draggedEl = null; isDragging = false; saveCurrentOrder(); } }); }); } async function saveCurrentOrder() { const grid = document.getElementById('favorites-grid'); const order = Array.from(grid.querySelectorAll('.' + cfg.tileClass)).map(tile => tile.dataset.id); await saveOrder(order); } renderFavorites(); } })(); ================================================ FILE: static/js/player.js ================================================ // IPTV Player Module // Requires: Hls.js, window.PLAYER_CONFIG (function() { 'use strict'; const cfg = window.PLAYER_CONFIG; const video = document.getElementById('video'); const loading = document.getElementById('loading'); const error = document.getElementById('error'); const ccBtn = document.getElementById('toggle-cc'); const settingsMenu = document.getElementById('settings-menu'); // State let transcodeSessionId = null; let currentHls = null; let isTranscoding = false; let ccEnabled = cfg.captionsEnabled; let subtitlePollTimerId = null; let transcodedDuration = 0; let totalDuration = 0; let seekInProgress = false; let seekOffset = 0; let currentSubtitles = null; let activeTrackStates = null; let progressPollTimerId = null; let lastSavedPosition = 0; let savePositionTimeout = null; let autoMutedByPolicy = false; // ============================================================ // Utilities // ============================================================ function formatTime(seconds) { const h = Math.floor(seconds / 3600); const m = Math.floor((seconds % 3600) / 60); const s = Math.floor(seconds % 60); if (h > 0) return `${h}:${m.toString().padStart(2, '0')}:${s.toString().padStart(2, '0')}`; return `${m}:${s.toString().padStart(2, '0')}`; } function parseTime(str) { str = str.trim(); if (!str) return 0; const parts = str.split(':').map(Number); if (parts.some(isNaN)) return 0; if (parts.length === 3) return parts[0] * 3600 + parts[1] * 60 + parts[2]; if (parts.length === 2) return parts[0] * 60 + parts[1]; const n = parts[0]; if (totalDuration >= 3600 && n * 3600 <= totalDuration) return n * 3600; if (n * 60 <= totalDuration) return n * 60; return n; } function parseVttTime(str) { const parts = str.split(':'); if (parts.length === 3) return parseFloat(parts[0]) * 3600 + parseFloat(parts[1]) * 60 + parseFloat(parts[2]); if (parts.length === 2) return parseFloat(parts[0]) * 60 + parseFloat(parts[1]); return 0; } function parseVttCues(vttText) { const cues = []; const lines = vttText.split('\n'); let i = 0; while (i < lines.length && !lines[i].includes('-->')) i++; while (i < lines.length) { const line = lines[i]; if (line.includes('-->')) { const [startStr, endStr] = line.split('-->').map(s => s.trim().split(' ')[0]); const start = parseVttTime(startStr); const end = parseVttTime(endStr); i++; const textLines = []; while (i < lines.length && lines[i].trim() !== '') { textLines.push(lines[i]); i++; } if (textLines.length > 0) cues.push({start, end, text: textLines.join('\n')}); } i++; } return cues; } // ============================================================ // UI Helpers // ============================================================ function hideLoading() { loading.classList.add('hidden'); } function showLoading() { loading.classList.remove('hidden'); } function showError() { hideLoading(); error.classList.remove('hidden'); } function updateTranscodeCheck() { const check = document.getElementById('tc-check'); if (check) check.textContent = isTranscoding ? '✓' : ''; } function updateCcButton() { ccBtn.classList.toggle('active', ccEnabled); } function updatePlayIcon() { const playIcon = document.getElementById('play-icon'); const pauseIcon = document.getElementById('pause-icon'); if (playIcon && pauseIcon) { playIcon.classList.toggle('hidden', !video.paused); pauseIcon.classList.toggle('hidden', video.paused); } } function updateMuteIcon() { const volIcon = document.getElementById('vol-icon'); const mutedIcon = document.getElementById('muted-icon'); if (volIcon && mutedIcon) { volIcon.classList.toggle('hidden', video.muted); mutedIcon.classList.toggle('hidden', !video.muted); } } function updateFullscreenIcon() { const fsEnter = document.getElementById('fs-enter'); const fsExit = document.getElementById('fs-exit'); const isFs = !!document.fullscreenElement; if (fsEnter && fsExit) { fsEnter.classList.toggle('hidden', isFs); fsExit.classList.toggle('hidden', !isFs); } } function disableCcButton() { ccBtn.disabled = true; ccBtn.classList.remove('active'); } function enableCcButton() { ccBtn.disabled = false; updateCcButton(); } // ============================================================ // HLS Configuration // ============================================================ // Custom cueHandler for CEA-608 caption positioning // See: https://github.com/video-dev/hls.js/issues/654 const customCueHandler = { newCue(track, startTime, endTime, captionScreen) { const lines = []; for (let r = 0; r < 15; r++) { const row = captionScreen.rows[r]; let text = ''; for (let c = 0; c < 32; c++) { text += row.chars[c]?.uchar || ' '; } text = text.trim(); if (text) lines.push(text); } if (lines.length === 0) return []; const cue = new VTTCue(startTime, endTime, lines.join('\n')); cue.line = -2; cue.align = 'center'; track.addCue(cue); return [cue]; } }; function createHlsConfig(options = {}) { const base = { enableWorker: true, lowLatencyMode: false, enableCEA708Captions: true, subtitleDisplay: ccEnabled, cueHandler: customCueHandler, manifestLoadingRetryDelay: 1000, levelLoadingRetryDelay: 1000, fragLoadingRetryDelay: 1000, }; if (options.forSeek) { return { ...base, liveSyncDurationCount: 0, startPosition: 0, manifestLoadingMaxRetry: 30, levelLoadingMaxRetry: 30, fragLoadingMaxRetry: 30, }; } return { ...base, liveSyncDurationCount: options.isVod ? 0 : 3, startPosition: options.isVod ? 0 : -1, }; } // ============================================================ // Captions // ============================================================ function applyCaptionStyles() { const s = cfg.ccStyle || {}; const hexToRgba = (hex, opacity) => { if (hex === 'transparent') return 'transparent'; const r = parseInt(hex.slice(1,3), 16); const g = parseInt(hex.slice(3,5), 16); const b = parseInt(hex.slice(5,7), 16); return `rgba(${r},${g},${b},${opacity})`; }; const color = hexToRgba(s.cc_color || '#ffffff', 1); const shadow = s.cc_shadow || '0 0 4px black, 0 0 4px black'; const bg = hexToRgba(s.cc_bg || '#000000', s.cc_bg_opacity || 0.75); const size = s.cc_size || '1em'; const font = s.cc_font || 'inherit'; let style = document.getElementById('cc-style'); if (!style) { style = document.createElement('style'); style.id = 'cc-style'; document.head.appendChild(style); } const sizeMultiplier = parseFloat(size) || 1; const infoSize = (2.5 * sizeMultiplier) + 'vh'; style.textContent = ` video::cue { color: ${color} !important; text-shadow: ${shadow} !important; background-color: ${bg} !important; font-size: ${size} !important; font-family: ${font} !important; } #info-overlay { font-size: ${infoSize}; max-width: 50em; } `; } function getPreferredSubtitleTrack(tracks) { const prefLang = cfg.ccLang || ''; if (!prefLang || tracks.length === 0) return 0; // Handle CC1-CC4 (CEA-608 channels) if (/^cc[1-4]$/i.test(prefLang)) { const ccNum = prefLang.toUpperCase(); const idx = tracks.findIndex(t => (t.name || t.label || '').toUpperCase().includes(ccNum)); if (idx >= 0) return idx; // Fallback: CC1 is usually index 0, CC2 is index 1, etc. const num = parseInt(prefLang.slice(2)) - 1; return num < tracks.length ? num : 0; } const langNames = {en: 'english', es: 'spanish', fr: 'french', de: 'german', it: 'italian', pt: 'portuguese', ja: 'japanese', ko: 'korean', zh: 'chinese'}; let idx = tracks.findIndex(t => t.lang && t.lang.toLowerCase().startsWith(prefLang)); if (idx >= 0) return idx; const prefName = langNames[prefLang]; if (prefName) { idx = tracks.findIndex(t => (t.name || t.label) && (t.name || t.label).toLowerCase().includes(prefName)); if (idx >= 0) return idx; } return 0; } function applyCaptionsSetting() { const tracks = Array.from(video.textTracks).filter(t => (t.kind === 'subtitles' || t.kind === 'captions') && t.mode !== 'disabled'); if (!ccEnabled) { tracks.forEach(t => t.mode = 'hidden'); return; } if (tracks.length === 0) return; const prefLang = cfg.ccLang || ''; const langNames = {en: 'english', es: 'spanish', fr: 'french', de: 'german', it: 'italian', pt: 'portuguese', ja: 'japanese', ko: 'korean', zh: 'chinese'}; let preferredIdx = 0; if (prefLang) { const idx = tracks.findIndex(t => t.language && t.language.toLowerCase().startsWith(prefLang)); if (idx >= 0) preferredIdx = idx; else { const prefName = langNames[prefLang]; if (prefName) { const nameIdx = tracks.findIndex(t => t.label && t.label.toLowerCase().includes(prefName)); if (nameIdx >= 0) preferredIdx = nameIdx; } } } tracks.forEach((t, i) => t.mode = i === preferredIdx ? 'showing' : 'hidden'); } function startSubtitlePolling(subtitles, prefIdx) { if (subtitlePollTimerId) { clearInterval(subtitlePollTimerId); clearTimeout(subtitlePollTimerId); subtitlePollTimerId = null; } if (activeTrackStates && activeTrackStates.length === subtitles.length) { for (const ts of activeTrackStates) { if (ts.track.mode === 'disabled') ts.track.mode = 'hidden'; const cues = ts.track.cues; if (cues) while (cues.length > 0) ts.track.removeCue(cues[0]); ts.addedCues.clear(); ts.retryCount = 0; } } else { for (let i = 0; i < video.textTracks.length; i++) { video.textTracks[i].mode = 'disabled'; } activeTrackStates = subtitles.map((sub) => ({ url: sub.url, track: video.addTextTrack('subtitles', sub.label, sub.lang), addedCues: new Set(), retryCount: 0, })); } activeTrackStates.forEach((ts, i) => { ts.track.mode = (ccEnabled && i === prefIdx) ? 'showing' : 'hidden'; }); const poll = async () => { for (let i = 0; i < activeTrackStates.length; i++) { const ts = activeTrackStates[i]; if (ts.retryCount > 120) continue; try { const resp = await fetch(ts.url + '?t=' + Date.now()); if (!resp.ok) { ts.retryCount++; continue; } const vtt = await resp.text(); const cues = parseVttCues(vtt); for (const cue of cues) { const key = `${cue.start}-${cue.end}`; if (!ts.addedCues.has(key)) { try { ts.track.addCue(new VTTCue(cue.start, cue.end, cue.text)); ts.addedCues.add(key); } catch (e) {} } } ts.retryCount = 0; } catch (e) { ts.retryCount++; } } }; let pollCount = 0; const doPoll = async () => { await poll(); pollCount++; subtitlePollTimerId = pollCount < 20 ? setTimeout(doPoll, 500) : setInterval(poll, 5000); }; doPoll(); } // ============================================================ // Position Tracking // ============================================================ function savePosition() { const actualTime = video.currentTime + seekOffset; if (!cfg.isVod || actualTime < 5) return; if (video.currentTime < 1 && seekOffset > 0) return; if (Math.abs(actualTime - lastSavedPosition) < 5) return; lastSavedPosition = actualTime; const data = JSON.stringify({ url: cfg.rawUrl, position: actualTime, duration: totalDuration }); if (document.visibilityState === 'hidden') { navigator.sendBeacon('/api/watch-position', new Blob([data], {type: 'application/json'})); } else { fetch('/api/watch-position', { method: 'POST', headers: {'Content-Type': 'application/json'}, body: data }); } } function restorePosition() { if (!cfg.isVod) return; const savedTime = cfg.serverResumePosition; if (!savedTime || savedTime <= 5) return; if (transcodedDuration < 30) return; const rangeStart = seekOffset; const rangeEnd = seekOffset + Math.max(0, transcodedDuration - 10); if (savedTime < rangeStart) return; const targetTime = Math.min(savedTime, rangeEnd); video.currentTime = targetTime - seekOffset; } function setupPositionTracking() { if (!cfg.isVod) return; video.addEventListener('timeupdate', () => { if (savePositionTimeout) return; savePositionTimeout = setTimeout(() => { savePosition(); savePositionTimeout = null; }, 10000); }); video.addEventListener('pause', savePosition); video.addEventListener('ended', savePosition); video.addEventListener('ended', () => { if (autoNextEnabled && cfg.nextEpisodeUrl) window.location.href = cfg.nextEpisodeUrl; }); document.addEventListener('visibilitychange', () => { if (document.visibilityState === 'hidden') savePosition(); }); } // ============================================================ // Progress Polling // ============================================================ function startProgressPolling() { if (progressPollTimerId) clearInterval(progressPollTimerId); const poll = async () => { if (!transcodeSessionId) return; try { const resp = await fetch('/transcode/progress/' + transcodeSessionId); if (resp.ok) { const data = await resp.json(); transcodedDuration = data.duration || 0; } } catch (e) {} }; poll(); progressPollTimerId = setInterval(poll, 2000); } function stopProgressPolling() { if (progressPollTimerId) { clearInterval(progressPollTimerId); progressPollTimerId = null; } } // ============================================================ // Transcode Management // ============================================================ async function cleanupTranscode() { if (document.pictureInPictureElement === video) return; stopProgressPolling(); if (subtitlePollTimerId) { clearInterval(subtitlePollTimerId); subtitlePollTimerId = null; } if (currentHls) { currentHls.destroy(); currentHls = null; } transcodedDuration = 0; totalDuration = 0; seekInProgress = false; seekOffset = 0; currentSubtitles = null; activeTrackStates = null; document.getElementById('menu-jump')?.classList.add('hidden'); document.getElementById('seek-container').classList.add('hidden'); if (transcodeSessionId) { const sessionToStop = transcodeSessionId; transcodeSessionId = null; try { await fetch('/transcode/' + sessionToStop, {method: 'DELETE'}); } catch (e) { console.error('Cleanup error:', e); } } } function cleanupTranscodeSync() { if (document.pictureInPictureElement === video) return; stopProgressPolling(); if (subtitlePollTimerId) { clearInterval(subtitlePollTimerId); subtitlePollTimerId = null; } if (currentHls) { currentHls.destroy(); currentHls = null; } if (transcodeSessionId) { const blob = new Blob([], {type: 'application/json'}); navigator.sendBeacon('/transcode/' + transcodeSessionId + '/stop', blob); transcodeSessionId = null; } } async function handleSeekToPosition(targetTime) { if (!transcodeSessionId || seekInProgress) return false; seekInProgress = true; showLoading(); error.classList.add('hidden'); video.pause(); video.src = ''; if (subtitlePollTimerId) { clearInterval(subtitlePollTimerId); clearTimeout(subtitlePollTimerId); subtitlePollTimerId = null; } if (currentHls) { currentHls.destroy(); currentHls = null; } try { const resp = await fetch('/transcode/seek/' + transcodeSessionId + '?time=' + targetTime); if (!resp.ok) throw new Error('Seek failed: ' + resp.status); transcodedDuration = 0; seekOffset = targetTime; const hls = new Hls(createHlsConfig({forSeek: true})); currentHls = hls; hls.on(Hls.Events.MANIFEST_PARSED, () => { hideLoading(); error.classList.add('hidden'); seekInProgress = false; savePosition(); if (currentSubtitles && currentSubtitles.length > 0) { startSubtitlePolling(currentSubtitles, getPreferredSubtitleTrack(currentSubtitles)); } video.play().catch(() => {}); }); let recoveryAttempts = 0; hls.on(Hls.Events.ERROR, (event, data) => { if (data.fatal) { recoveryAttempts++; if (recoveryAttempts <= 3) { if (data.type === Hls.ErrorTypes.NETWORK_ERROR) hls.startLoad(); else if (data.type === Hls.ErrorTypes.MEDIA_ERROR) hls.recoverMediaError(); else { console.error('[SEEK] HLS error:', data); showError(); } } else { console.error('[SEEK] HLS error after retries:', data); showError(); } } }); hls.loadSource('/transcode/' + transcodeSessionId + '/stream.m3u8'); hls.attachMedia(video); return true; } catch (e) { console.error('[SEEK] Error:', e); seekInProgress = false; showError(); return false; } } async function startTranscode(onError) { showLoading(); await cleanupTranscode(); try { let url = '/transcode/start?url=' + encodeURIComponent(cfg.rawUrl) + '&content_type=' + cfg.streamType; if (cfg.seriesId) url += '&series_id=' + cfg.seriesId; if (cfg.episodeId) url += '&episode_id=' + cfg.episodeId; if (cfg.seriesName) url += '&series_name=' + encodeURIComponent(cfg.seriesName); if (cfg.deinterlaceFallback !== undefined) url += '&deinterlace_fallback=' + (cfg.deinterlaceFallback ? '1' : '0'); if (cfg.sourceId) url += '&source_id=' + encodeURIComponent(cfg.sourceId); const resp = await fetch(url); if (!resp.ok) throw new Error('Transcode start failed: ' + resp.status); const data = await resp.json(); transcodeSessionId = data.session_id; isTranscoding = true; updateTranscodeCheck(); totalDuration = data.duration || 0; seekOffset = data.seek_offset || 0; transcodedDuration = data.transcoded_duration || 0; currentSubtitles = data.subtitles || null; if (cfg.isVod) { document.getElementById('menu-jump')?.classList.remove('hidden'); document.getElementById('progress-container')?.classList.remove('hidden'); if (totalDuration > 0) { document.getElementById('seek-duration').textContent = '/ ' + formatTime(totalDuration); document.getElementById('time-duration').textContent = formatTime(totalDuration); } enableCcButton(); } playWithUrl(data.playlist, onError, data.subtitles); } catch (e) { console.error('[TC] Error:', e); if (onError) onError(); else showError(); } } // ============================================================ // Playback // ============================================================ function playWithUrl(url, onError, subtitles) { showLoading(); const useHls = Hls.isSupported() && ( url.includes('.m3u8') || url.includes('/live/') || url.includes('/transcode') ); if (!useHls) { video.src = url; video.addEventListener('loadedmetadata', function() { hideLoading(); error.classList.add('hidden'); if (video.textTracks.length === 0) disableCcButton(); applyCaptionsSetting(); restorePosition(); video.play().catch(() => { if (!video.muted) { autoMutedByPolicy = true; video.muted = true; } video.play(); }); }, { once: true }); video.addEventListener('error', function() { if (onError) onError(); else showError(); }, { once: true }); return; } const isVodUrl = url.includes('/transcode'); const hls = new Hls(createHlsConfig({isVod: isVodUrl})); currentHls = hls; let recoveryAttempts = 0; let hasLoaded = false; hls.loadSource(url); hls.attachMedia(video); // Timeout for initial load (Auto mode only) let loadTimeout = null; if (onError) { loadTimeout = setTimeout(() => { if (!hasLoaded) { console.log('[AUTO] Load timeout, triggering transcode'); hls.destroy(); currentHls = null; onError(); } }, 10000); } // Check for missing audio (Auto mode only) if (onError) { let audioChecked = false; video.addEventListener('timeupdate', function checkAudio() { if (audioChecked || video.currentTime < 1) return; audioChecked = true; video.removeEventListener('timeupdate', checkAudio); let hasAudio = false; if (typeof video.webkitAudioDecodedByteCount !== 'undefined') { hasAudio = video.webkitAudioDecodedByteCount > 0; } else if (typeof video.mozHasAudio !== 'undefined') { hasAudio = video.mozHasAudio; } else if (video.audioTracks && video.audioTracks.length > 0) { hasAudio = true; } console.log('[AUTO] Audio check: hasAudio=' + hasAudio + ', webkitAudioDecodedByteCount=' + video.webkitAudioDecodedByteCount); if (!hasAudio) { console.log('[AUTO] No audio detected, triggering transcode'); hls.destroy(); currentHls = null; onError(); } }); } hls.on(Hls.Events.MANIFEST_PARSED, () => { if (loadTimeout) clearTimeout(loadTimeout); hideLoading(); error.classList.add('hidden'); hasLoaded = true; recoveryAttempts = 0; if (subtitles && subtitles.length > 0) { startSubtitlePolling(subtitles, getPreferredSubtitleTrack(subtitles)); } if (cfg.captionsEnabled && hls.subtitleTracks.length > 0) { hls.subtitleTrack = getPreferredSubtitleTrack(hls.subtitleTracks); } if (transcodeSessionId && isVodUrl) startProgressPolling(); restorePosition(); video.play().catch(() => { if (!video.muted) { autoMutedByPolicy = true; video.muted = true; } video.play(); }); setTimeout(() => { if (hls.subtitleTracks.length === 0 && video.textTracks.length === 0) disableCcButton(); }, 1000); }); hls.on(Hls.Events.SUBTITLE_TRACKS_UPDATED, () => { if (hls.subtitleTracks.length > 0 && ccBtn.disabled) { enableCcButton(); if (cfg.captionsEnabled) hls.subtitleTrack = getPreferredSubtitleTrack(hls.subtitleTracks); } else if (hls.subtitleTracks.length === 0 && !ccBtn.disabled) { disableCcButton(); } }); video.textTracks.addEventListener('addtrack', (e) => { if (e.track.kind === 'captions' || e.track.kind === 'subtitles') { applyCaptionsSetting(); if (ccBtn.disabled) enableCcButton(); } }); hls.on(Hls.Events.ERROR, (event, data) => { if (data.fatal) { recoveryAttempts++; if (recoveryAttempts <= 3) { if (data.type === Hls.ErrorTypes.NETWORK_ERROR) hls.startLoad(); else if (data.type === Hls.ErrorTypes.MEDIA_ERROR) hls.recoverMediaError(); else { hls.destroy(); currentHls = null; if (!hasLoaded && onError) onError(); else showError(); } } else { hls.destroy(); currentHls = null; if (!hasLoaded && onError) onError(); else showError(); } } }); } // ============================================================ // Controls // ============================================================ function setupKeyboardControls() { document.addEventListener('keydown', (e) => { // In seek input: allow player hotkeys, block other non-time chars if (e.target.id === 'seek-input') { const passthrough = ['j', 'm', 'f', ' ', 'k', 'c', 'i', 'Escape']; if (!passthrough.includes(e.key) && !/^[0-9:]$/.test(e.key) && !['Backspace', 'Delete', 'ArrowLeft', 'ArrowRight', 'Home', 'End', 'Tab', 'Enter'].includes(e.key)) { e.preventDefault(); return; } // Let passthrough keys fall through to main handler below if (!passthrough.includes(e.key)) return; } // Skip other input fields entirely else if (e.target.tagName === 'INPUT' || e.target.tagName === 'TEXTAREA') { if (e.key === 'Escape') e.target.blur(); return; } switch(e.key) { case ' ': case 'k': if (e.target.tagName === 'BUTTON' || e.target.tagName === 'A') return; e.preventDefault(); video.paused ? video.play() : video.pause(); break; case 'ArrowLeft': e.preventDefault(); video.currentTime -= 10; break; case 'ArrowRight': e.preventDefault(); video.currentTime += 10; break; case 'ArrowUp': e.preventDefault(); video.volume = Math.min(1, video.volume + 0.1); break; case 'ArrowDown': e.preventDefault(); video.volume = Math.max(0, video.volume - 0.1); break; case 'f': e.preventDefault(); suppressShowControls = true; setTimeout(() => suppressShowControls = false, 150); if (document.fullscreenElement) document.exitFullscreen(); else document.getElementById('player-container').requestFullscreen(); break; case 'm': video.muted = !video.muted; updateMuteIcon(); break; case 't': document.getElementById('menu-transcode')?.click(); break; case 'c': ccBtn.click(); break; case 'i': document.getElementById('info-btn')?.click(); break; case 'a': document.getElementById('cast-btn')?.click(); break; case 'x': document.getElementById('menu-restart')?.click(); break; case 'j': e.preventDefault(); if (cfg.isVod) document.getElementById('jump-btn')?.click(); break; case 'n': document.getElementById('autonext-btn')?.click(); break; case 'p': document.getElementById('pip-btn')?.click(); break; case 'h': const container = document.getElementById('player-container'); if (container.classList.contains('controls-visible')) { container.classList.remove('controls-visible'); clearTimeout(activityTimeoutId); activityTimeoutId = null; } else { container.classList.add('controls-visible'); clearTimeout(activityTimeoutId); activityTimeoutId = setTimeout(() => { if (!settingsMenu.classList.contains('open')) { container.classList.remove('controls-visible'); activityTimeoutId = null; } }, 3000); } break; case 'Escape': { const seekContainer = document.getElementById('seek-container'); const infoOverlay = document.getElementById('info-overlay'); if (seekContainer && !seekContainer.classList.contains('hidden')) { seekContainer.classList.add('hidden'); } else if (infoOverlay?.classList.contains('pinned')) { infoOverlay.classList.remove('pinned'); document.getElementById('info-btn')?.classList.remove('active'); saveSettings({ infoPinned: false }); } else if (settingsMenu.classList.contains('open')) { settingsMenu.classList.remove('open'); } else if (!document.fullscreenElement) { if (history.length > 1) history.back(); else window.location.href = '/guide'; } break; } } }); } let activityTimeoutId = null; let autoNextEnabled = true; let suppressShowControls = false; // Persistent settings const STORAGE_KEY = 'playerSettings'; function loadSettings() { try { return JSON.parse(localStorage.getItem(STORAGE_KEY)) || {}; } catch { return {}; } } function saveSettings(updates) { const settings = { ...loadSettings(), ...updates }; localStorage.setItem(STORAGE_KEY, JSON.stringify(settings)); } function setupActivityTracking() { const container = document.getElementById('player-container'); const HIDE_DELAY = 3000; function showControls() { if (suppressShowControls) return; container.classList.add('controls-visible'); clearTimeout(activityTimeoutId); activityTimeoutId = setTimeout(hideControls, HIDE_DELAY); } function hideControls() { if (settingsMenu.classList.contains('open')) return; if (document.getElementById('seek-container')?.classList.contains('hidden') === false) return; container.classList.remove('controls-visible'); activityTimeoutId = null; } container.addEventListener('mousemove', showControls); container.addEventListener('mouseenter', showControls); container.addEventListener('click', showControls); container.addEventListener('mouseleave', () => { clearTimeout(activityTimeoutId); activityTimeoutId = null; hideControls(); }); showControls(); } function setupButtonHandlers() { const seekContainer = document.getElementById('seek-container'); const seekInput = document.getElementById('seek-input'); // Play/Pause button document.getElementById('play-btn')?.addEventListener('click', () => { video.paused ? video.play() : video.pause(); }); // Click video to toggle play/pause and show controls video.addEventListener('click', (e) => { if (e.target !== video) return; video.paused ? video.play() : video.pause(); const container = document.getElementById('player-container'); container.classList.add('controls-visible'); clearTimeout(activityTimeoutId); activityTimeoutId = setTimeout(() => { if (!settingsMenu.classList.contains('open') && seekContainer.classList.contains('hidden')) { container.classList.remove('controls-visible'); activityTimeoutId = null; } }, 3000); }); // Mute button document.getElementById('mute-btn')?.addEventListener('click', () => { video.muted = !video.muted; updateMuteIcon(); saveSettings({ muted: video.muted }); }); // Volume slider const volSlider = document.getElementById('volume-slider'); volSlider?.addEventListener('input', () => { video.volume = parseFloat(volSlider.value); video.muted = false; updateMuteIcon(); saveSettings({ volume: video.volume, muted: false }); }); video.addEventListener('volumechange', () => { if (volSlider) volSlider.value = video.muted ? 0 : video.volume; }); // Jump button document.getElementById('jump-btn')?.addEventListener('click', (e) => { e.stopPropagation(); seekContainer.classList.toggle('hidden'); if (!seekContainer.classList.contains('hidden')) { // Show controls and prevent auto-hide while jump input is active document.getElementById('player-container').classList.add('controls-visible'); clearTimeout(activityTimeoutId); activityTimeoutId = null; seekInput.value = ''; seekInput.focus(); } }); // Auto-next button const autoNextBtn = document.getElementById('autonext-btn'); autoNextBtn?.addEventListener('click', (e) => { e.stopPropagation(); autoNextEnabled = !autoNextEnabled; autoNextBtn.classList.toggle('active', autoNextEnabled); }); // Info button const infoOverlay = document.getElementById('info-overlay'); const infoBtn = document.getElementById('info-btn'); infoBtn?.addEventListener('click', (e) => { e.stopPropagation(); const wasPinned = infoOverlay.classList.contains('pinned'); infoOverlay.classList.toggle('pinned'); infoBtn.classList.toggle('active', !wasPinned); saveSettings({ infoPinned: !wasPinned }); if (wasPinned && !activityTimeoutId) { document.getElementById('player-container').classList.remove('controls-visible'); } }); // CC button ccBtn.addEventListener('click', function(e) { e.stopPropagation(); ccEnabled = !ccEnabled; if (activeTrackStates && activeTrackStates.length > 0) { const prefIdx = getPreferredSubtitleTrack(activeTrackStates.map(ts => ({lang: ts.track.language, label: ts.track.label}))); activeTrackStates.forEach((ts, i) => ts.track.mode = (ccEnabled && i === prefIdx) ? 'showing' : 'hidden'); } else { const tracks = Array.from(video.textTracks).filter(t => (t.kind === 'subtitles' || t.kind === 'captions') && t.mode !== 'disabled'); const prefIdx = getPreferredSubtitleTrack(tracks.map(t => ({lang: t.language, label: t.label}))); tracks.forEach((t, i) => t.mode = (ccEnabled && i === prefIdx) ? 'showing' : 'hidden'); } if (currentHls) { currentHls.subtitleDisplay = ccEnabled; if (ccEnabled && currentHls.subtitleTracks?.length > 0) { currentHls.subtitleTrack = getPreferredSubtitleTrack(currentHls.subtitleTracks); } } updateCcButton(); saveSettings({ ccEnabled }); }); // PiP button const pipBtn = document.getElementById('pip-btn'); if (pipBtn && document.pictureInPictureEnabled) { pipBtn.addEventListener('click', async (e) => { e.stopPropagation(); try { if (document.pictureInPictureElement) { await document.exitPictureInPicture(); } else { await video.requestPictureInPicture(); } } catch (err) { console.error('[PiP] Error:', err); } }); } else if (pipBtn) { pipBtn.style.display = 'none'; } // Settings menu toggle document.getElementById('settings-btn')?.addEventListener('click', () => { settingsMenu.classList.toggle('open'); }); // Close settings menu when clicking outside document.addEventListener('click', (e) => { if (!e.target.closest('#settings-btn') && !e.target.closest('#settings-menu')) { settingsMenu.classList.remove('open'); } }); // CC Track selection let selectedCcTrackIdx = 0; const ccTracksMenuItem = document.getElementById('menu-cc-tracks'); function getCcTracks() { const tracks = []; if (activeTrackStates?.length > 0) { activeTrackStates.forEach((ts, i) => tracks.push({ idx: i, label: ts.track.label || `Track ${i + 1}`, lang: ts.track.language })); } else if (currentHls?.subtitleTracks?.length > 0) { currentHls.subtitleTracks.forEach((t, i) => tracks.push({ idx: i, label: t.name || `Track ${i + 1}`, lang: t.lang })); } else { Array.from(video.textTracks).filter(t => t.kind === 'subtitles' || t.kind === 'captions') .forEach((t, i) => tracks.push({ idx: i, label: t.label || `Track ${i + 1}`, lang: t.language })); } return tracks; } function selectCcTrack(idx) { ccEnabled = true; updateCcButton(); if (activeTrackStates?.length > 0) { activeTrackStates.forEach((ts, i) => ts.track.mode = i === idx ? 'showing' : 'hidden'); } else if (currentHls?.subtitleTracks?.length > 0) { currentHls.subtitleTrack = idx; currentHls.subtitleDisplay = true; // iOS: also set TextTrack.mode directly (HLS.js property change doesn't always trigger render) const tracks = Array.from(video.textTracks).filter( t => t.kind === 'subtitles' || t.kind === 'captions' ); tracks.forEach((t, i) => t.mode = i === idx ? 'showing' : 'hidden'); } else { const tracks = Array.from(video.textTracks).filter(t => t.kind === 'subtitles' || t.kind === 'captions'); tracks.forEach((t, i) => t.mode = i === idx ? 'showing' : 'hidden'); } } function updateCcTracksLabel() { if (!ccTracksMenuItem) return; const tracks = getCcTracks(); const label = tracks.length > 0 && tracks[selectedCcTrackIdx] ? `CC: ${tracks[selectedCcTrackIdx].label}` : 'CC Track'; ccTracksMenuItem.innerHTML = `${label}`; } ccTracksMenuItem?.addEventListener('click', () => { const tracks = getCcTracks(); if (tracks.length === 0) return; selectedCcTrackIdx = (selectedCcTrackIdx + 1) % tracks.length; selectCcTrack(selectedCcTrackIdx); updateCcTracksLabel(); }); // Settings menu items document.getElementById('menu-transcode')?.addEventListener('click', async () => { settingsMenu.classList.remove('open'); video.pause(); video.src = ''; error.classList.add('hidden'); if (isTranscoding) { await cleanupTranscode(); isTranscoding = false; updateTranscodeCheck(); playWithUrl(cfg.rawUrl); } else { await startTranscode(); } }); document.getElementById('menu-restart')?.addEventListener('click', async () => { settingsMenu.classList.remove('open'); video.pause(); video.src = ''; error.classList.add('hidden'); try { await fetch('/transcode-clear?url=' + encodeURIComponent(cfg.rawUrl), {method: 'DELETE'}); await cleanupTranscode(); isTranscoding = false; await startTranscode(); } catch (e) { console.error('[X] Error:', e); showError(); } }); document.getElementById('menu-jump')?.addEventListener('click', () => { settingsMenu.classList.remove('open'); seekContainer.classList.toggle('hidden'); if (!seekContainer.classList.contains('hidden')) { // Show controls and prevent auto-hide while jump input is active document.getElementById('player-container').classList.add('controls-visible'); clearTimeout(activityTimeoutId); activityTimeoutId = null; seekInput.value = ''; seekInput.focus(); } }); document.getElementById('menu-url')?.addEventListener('click', () => { settingsMenu.classList.remove('open'); if (navigator.clipboard) { navigator.clipboard.writeText(cfg.rawUrl).catch(fallback); } else { fallback(); } function fallback() { const ta = document.createElement('textarea'); ta.value = cfg.rawUrl; ta.style.position = 'fixed'; ta.style.opacity = '0'; document.body.appendChild(ta); ta.select(); document.execCommand('copy'); document.body.removeChild(ta); } }); document.getElementById('menu-external')?.addEventListener('click', () => { settingsMenu.classList.remove('open'); video.pause(); video.src = ''; window.location.href = '/playlist.xspf?url=' + encodeURIComponent(cfg.rawUrl); }); // Fullscreen button document.getElementById('fullscreen-btn')?.addEventListener('click', () => { if (document.fullscreenElement) document.exitFullscreen(); else document.getElementById('player-container').requestFullscreen(); }); // Fullscreen change listener document.addEventListener('fullscreenchange', updateFullscreenIcon); // Video state listeners video.addEventListener('play', updatePlayIcon); video.addEventListener('pause', updatePlayIcon); video.addEventListener('volumechange', updateMuteIcon); // Prevent video element's native spacebar handling video.addEventListener('keydown', (e) => { if (e.key === ' ') e.preventDefault(); }); // Mousewheel volume control document.getElementById('player-container')?.addEventListener('wheel', (e) => { e.preventDefault(); video.volume = Math.max(0, Math.min(1, video.volume + (e.deltaY < 0 ? 0.05 : -0.05))); saveSettings({ volume: video.volume }); }, { passive: false }); // Progress bar const progressBar = document.getElementById('progress-bar'); const progressPlayed = document.getElementById('progress-played'); const progressHandle = document.getElementById('progress-handle'); const progressBuffered = document.getElementById('progress-buffered'); const timeCurrent = document.getElementById('time-current'); const timeDuration = document.getElementById('time-duration'); function updateProgress() { const duration = totalDuration || video.duration || 0; if (!duration) return; const currentTime = video.currentTime + seekOffset; const pct = (currentTime / duration) * 100; if (progressPlayed) progressPlayed.style.width = pct + '%'; if (progressHandle) progressHandle.style.left = pct + '%'; if (timeCurrent) timeCurrent.textContent = formatTime(currentTime); if (timeDuration) timeDuration.textContent = formatTime(duration); } video.addEventListener('timeupdate', updateProgress); video.addEventListener('loadedmetadata', () => { if (cfg.isVod && !transcodeSessionId) { document.getElementById('progress-container')?.classList.remove('hidden'); document.getElementById('menu-jump')?.classList.remove('hidden'); } updateProgress(); }); progressBar?.addEventListener('click', async (e) => { const rect = progressBar.getBoundingClientRect(); const pct = (e.clientX - rect.left) / rect.width; const duration = totalDuration || video.duration || 0; if (!duration) return; const targetTime = pct * duration; if (transcodeSessionId) { const actualTranscodedEnd = seekOffset + transcodedDuration; if (targetTime >= seekOffset && targetTime <= actualTranscodedEnd + 10) { video.currentTime = targetTime - seekOffset; } else { await handleSeekToPosition(targetTime); } } else { video.currentTime = targetTime; } }); // Seek input handler - filter chars here (must be at input level to block typing) seekInput?.addEventListener('keydown', async function(e) { // Only allow: digits, colon, navigation keys, Enter // Hotkeys and other chars: preventDefault (hotkeys will still bubble to global handler) const typeable = /^[0-9:]$/.test(e.key) || ['Backspace', 'Delete', 'ArrowLeft', 'ArrowRight', 'Home', 'End', 'Tab'].includes(e.key); if (!typeable) e.preventDefault(); if (e.key !== 'Enter') return; e.preventDefault(); const targetTime = parseTime(seekInput.value); if (targetTime < 0 || targetTime > totalDuration) { seekInput.classList.add('ring-2', 'ring-red-500'); setTimeout(() => seekInput.classList.remove('ring-2', 'ring-red-500'), 500); return; } seekContainer.classList.add('hidden'); const actualTranscodedEnd = seekOffset + transcodedDuration; if (targetTime >= seekOffset && targetTime <= actualTranscodedEnd + 10) { video.currentTime = targetTime - seekOffset; return; } await handleSeekToPosition(targetTime); }); } // ============================================================ // Cast (Chromecast) // ============================================================ function setupCast() { if (!cfg.isHttps) return; const castBtn = document.getElementById('cast-btn'); if (!castBtn) return; function getCastUrl() { const host = cfg.castHost || window.location.host; const proto = window.location.protocol; if (transcodeSessionId) { return proto + '//' + host + '/transcode/' + transcodeSessionId + '/stream.m3u8'; } if (cfg.rawUrl.includes('localhost') || cfg.rawUrl.includes('127.0.0.1')) { return cfg.rawUrl.replace(/localhost|127\.0\.0\.1/, host.split(':')[0]); } return cfg.rawUrl; } function castLog(msg) { fetch('/api/cast-log', {method: 'POST', body: msg}).catch(() => {}); } function initCast() { cast.framework.CastContext.getInstance().setOptions({ receiverApplicationId: chrome.cast.media.DEFAULT_MEDIA_RECEIVER_APP_ID, autoJoinPolicy: chrome.cast.AutoJoinPolicy.ORIGIN_SCOPED, }); castBtn.disabled = false; cast.framework.CastContext.getInstance().addEventListener( cast.framework.CastContextEventType.SESSION_STATE_CHANGED, (e) => { const connected = e.sessionState === cast.framework.SessionState.SESSION_STARTED || e.sessionState === cast.framework.SessionState.SESSION_RESUMED; castBtn.classList.toggle('active', connected); } ); } function loadMediaToCast() { const session = cast.framework.CastContext.getInstance().getCurrentSession(); if (!session) { castLog('No active session'); return; } const url = getCastUrl(); castLog('URL: ' + url + ' isVod=' + cfg.isVod + ' seek=' + (video.currentTime + seekOffset).toFixed(1)); const mediaInfo = new chrome.cast.media.MediaInfo(url, 'application/x-mpegurl'); mediaInfo.streamType = chrome.cast.media.StreamType.LIVE; mediaInfo.metadata = new chrome.cast.media.GenericMediaMetadata(); mediaInfo.metadata.title = cfg.mediaTitle; if (chrome.cast.media.HlsSegmentFormat) mediaInfo.hlsSegmentFormat = chrome.cast.media.HlsSegmentFormat.TS; if (chrome.cast.media.HlsVideoSegmentFormat) mediaInfo.hlsVideoSegmentFormat = chrome.cast.media.HlsVideoSegmentFormat.MPEG2_TS; const request = new chrome.cast.media.LoadRequest(mediaInfo); request.autoplay = true; request.currentTime = video.currentTime + seekOffset; castLog('streamType=' + mediaInfo.streamType); session.loadMedia(request).then( () => { castLog('Media loaded OK'); video.pause(); const media = session.getMediaSession(); if (media) { media.addUpdateListener((isAlive) => { if (!isAlive) { castLog('Session ended'); return; } const ps = media.playerState; const idle = media.idleReason; castLog('State: ' + ps + (idle ? ' (' + idle + ')' : '')); }); } }, (e) => { const code = e?.code || 'unknown'; const desc = e?.description || e?.message || String(e); castLog('LOAD FAILED: code=' + code + ' desc=' + desc); } ); } let castDialogClosedAt = 0; castBtn.addEventListener('click', function(e) { e.stopPropagation(); e.preventDefault(); this.blur(); settingsMenu.classList.remove('open'); if (!window.cast || !cast.framework) { alert('Cast not available.\n\nTry accessing via your LAN IP instead of 0.0.0.0'); return; } if (Date.now() - castDialogClosedAt < 1000) return; const ctx = cast.framework.CastContext.getInstance(); ctx.requestSession().then( () => { castDialogClosedAt = Date.now(); loadMediaToCast(); }, () => { castDialogClosedAt = Date.now(); } ); }); let pollCount = 0; const castPoll = setInterval(() => { if (window.cast && cast.framework) { clearInterval(castPoll); console.log('[CAST] SDK ready'); initCast(); } else if (++pollCount > 30) { clearInterval(castPoll); console.log('[CAST] SDK timeout'); castBtn.disabled = false; } }, 100); } // ============================================================ // Initialization // ============================================================ function init() { // Restore persistent settings const settings = loadSettings(); if (settings.volume !== undefined) video.volume = settings.volume; if (settings.muted !== undefined) video.muted = settings.muted; if (settings.ccEnabled !== undefined) ccEnabled = settings.ccEnabled; applyCaptionStyles(); setupPositionTracking(); setupKeyboardControls(); setupButtonHandlers(); setupActivityTracking(); setupCast(); updateTranscodeCheck(); updateCcButton(); updateMuteIcon(); document.getElementById('volume-slider').value = video.muted ? 0 : video.volume; // Restore volume on first user interaction if auto-muted by browser policy function restoreVolumeOnInteraction() { if (autoMutedByPolicy) { video.muted = false; updateMuteIcon(); autoMutedByPolicy = false; } } document.addEventListener('click', restoreVolumeOnInteraction, { once: true }); document.addEventListener('keydown', restoreVolumeOnInteraction, { once: true }); // Restore info pinned state if (settings.infoPinned) { document.getElementById('info-overlay')?.classList.add('pinned'); document.getElementById('info-btn')?.classList.add('active'); } window.addEventListener('beforeunload', cleanupTranscodeSync); window.addEventListener('pagehide', cleanupTranscodeSync); // Start playback based on transcode mode if (cfg.transcodeMode === 'always') { startTranscode(); } else if (cfg.transcodeMode === 'never') { playWithUrl(cfg.rawUrl); } else { playWithUrl(cfg.rawUrl, () => { error.classList.add('hidden'); startTranscode(); }); } } init(); })(); ================================================ FILE: static/js/settings.js ================================================ // Settings Page Module (function() { 'use strict'; const cfg = window.SETTINGS_CONFIG || {}; // ============================================================ // Shared Helpers // ============================================================ function escapeHtml(s) { if (!s) return ''; return s.replace(/&/g, '&').replace(//g, '>').replace(/"/g, '"'); } function showFeedback(el, success) { if (!el) return; const cls = success ? 'ring-green-500' : 'ring-red-500'; el.classList.add('ring-2', cls); setTimeout(() => el.classList.remove('ring-2', cls), success ? 500 : 1000); } async function saveWithFeedback(url, options, feedbackEl) { try { const resp = await fetch(url, options); showFeedback(feedbackEl, resp.ok); return resp; } catch (e) { console.error('Save failed:', e); showFeedback(feedbackEl, false); return null; } } function getFeedbackEl(el) { if (!el) return null; if (el.type === 'radio' || el.type === 'checkbox') return el.closest('label') || el; return el; } // ============================================================ // Drag-Drop Helper // ============================================================ function setupDragDrop(containerSelector, chipSelector, onDrop) { let draggedChip = null; document.querySelectorAll(chipSelector).forEach(chip => { chip.addEventListener('dragstart', e => { draggedChip = chip; e.dataTransfer.effectAllowed = 'move'; chip.classList.add('opacity-50'); }); chip.addEventListener('dragend', () => { chip.classList.remove('opacity-50'); draggedChip = null; }); chip.addEventListener('dragover', e => { e.preventDefault(); if (draggedChip && draggedChip !== chip) { chip.classList.add('border-t-2', 'border-blue-500'); } }); chip.addEventListener('dragleave', () => { chip.classList.remove('border-t-2', 'border-blue-500'); }); chip.addEventListener('drop', e => { e.preventDefault(); e.stopPropagation(); chip.classList.remove('border-t-2', 'border-blue-500'); if (draggedChip && draggedChip !== chip) { chip.parentElement.insertBefore(draggedChip, chip); onDrop?.(chip.parentElement, draggedChip); } }); }); document.querySelectorAll(containerSelector).forEach(container => { container.addEventListener('dragover', e => { e.preventDefault(); e.dataTransfer.dropEffect = 'move'; container.classList.add('border-blue-500'); }); container.addEventListener('dragleave', e => { if (!container.contains(e.relatedTarget)) { container.classList.remove('border-blue-500'); } }); container.addEventListener('drop', e => { e.preventDefault(); container.classList.remove('border-blue-500'); if (draggedChip && draggedChip.parentElement !== container) { container.appendChild(draggedChip); onDrop?.(container, draggedChip); } }); }); } function setupSearch(inputId, clearBtnId, chipSelector) { const input = document.getElementById(inputId); const clearBtn = document.getElementById(clearBtnId); if (!input) return; function apply() { const q = input.value.toLowerCase(); document.querySelectorAll(chipSelector).forEach(el => { el.style.display = el.textContent.toLowerCase().includes(q) ? '' : 'none'; }); clearBtn?.classList.toggle('hidden', !input.value); } input.addEventListener('input', apply); clearBtn?.addEventListener('click', () => { input.value = ''; apply(); }); } // ============================================================ // Global Functions (used by inline handlers) // ============================================================ window.togglePwdVis = function(btn) { const input = btn.parentElement.querySelector('input[type="password"], input[type="text"]'); if (!input) return; const isPassword = input.type === 'password'; input.type = isPassword ? 'text' : 'password'; btn.querySelector('.eye-off')?.classList.toggle('hidden', isPassword); btn.querySelector('.eye-on')?.classList.toggle('hidden', !isPassword); }; window.toggleSourceFields = function(select) { const form = select.closest('form'); const isXtream = select.value === 'xtream'; const isEpg = select.value === 'epg'; form.querySelector('.xtream-fields')?.style.setProperty('display', isXtream ? 'grid' : 'none'); form.querySelector('.non-epg-only')?.style.setProperty('display', isEpg ? 'none' : 'block'); form.querySelector('.epg-url-field')?.style.setProperty('display', isEpg ? 'none' : 'block'); }; window.showDeleteSelfModal = function() { document.getElementById('delete-self-modal')?.classList.remove('hidden'); const pwInput = document.getElementById('delete-self-password'); if (pwInput) { pwInput.value = ''; pwInput.focus(); } const msg = document.getElementById('delete-self-msg'); if (msg) msg.textContent = ''; }; window.hideDeleteSelfModal = function() { document.getElementById('delete-self-modal')?.classList.add('hidden'); }; window.submitDeleteSelf = async function(e) { e.preventDefault(); const pw = document.getElementById('delete-self-password')?.value; const msgEl = document.getElementById('delete-self-msg'); if (!pw) return; const form = new FormData(); form.append('password', pw); try { const resp = await fetch('/settings/users/delete/' + cfg.currentUser, { method: 'POST', body: form }); if (resp.ok || resp.redirected) { window.location.href = '/login'; } else { const data = await resp.json(); if (msgEl) msgEl.textContent = data.detail || 'Failed'; } } catch (e) { console.error('Delete self failed:', e); if (msgEl) msgEl.textContent = 'Request failed'; } }; // ============================================================ // Add Source Type Select // ============================================================ function setupSourceTypeSelect() { const typeSelect = document.getElementById('source-type'); if (!typeSelect) return; typeSelect.addEventListener('change', function() { const isXtream = this.value === 'xtream'; const isM3u = this.value === 'm3u'; const isEpg = this.value === 'epg'; document.getElementById('xtream-fields')?.style.setProperty('display', isXtream ? 'grid' : 'none'); document.getElementById('epg-enabled-field')?.style.setProperty('display', isEpg ? 'none' : 'block'); const deinterlaceField = document.getElementById('deinterlace-field'); if (deinterlaceField) { deinterlaceField.style.display = isEpg ? 'none' : 'block'; const cb = deinterlaceField.querySelector('input[name="deinterlace_fallback"]'); if (cb) cb.checked = isM3u; } document.getElementById('max-streams-field')?.style.setProperty('display', isEpg ? 'none' : 'block'); const urlInput = document.querySelector('#add-source-form input[name="url"]'); if (urlInput) { const placeholders = { xtream: 'https://server.com', m3u: 'http://server.com/playlist.m3u', epg: 'http://server.com/epg.xml' }; urlInput.placeholder = placeholders[this.value] || placeholders.xtream; } }); } // ============================================================ // Source Edit Auto-Save // ============================================================ function setupSourceEditForms() { document.querySelectorAll('.source-edit-form').forEach(form => { const sourceId = form.dataset.sourceId; if (!sourceId) return; form.querySelectorAll('input, select').forEach(el => { if (el.type === 'button' || el.type === 'submit') return; el.addEventListener('change', async function() { await saveWithFeedback( `/settings/edit/${sourceId}`, { method: 'POST', body: new FormData(form) }, getFeedbackEl(this) ); }); }); form.addEventListener('submit', e => e.preventDefault()); }); // Delete source buttons document.querySelectorAll('.delete-source-btn').forEach(btn => { btn.addEventListener('click', async () => { const sourceId = btn.dataset.sourceId; if (!confirm('Delete this source?')) return; btn.disabled = true; btn.textContent = 'Deleting...'; try { const resp = await fetch(`/settings/delete/${sourceId}`, { method: 'POST' }); if (resp.ok) location.reload(); else throw new Error('Delete failed'); } catch { btn.disabled = false; btn.textContent = 'Delete'; } }); }); } // ============================================================ // Live TV Category Filter // ============================================================ function setupCategoryFilter() { const availableContainer = document.getElementById('available-cats'); const unavailableContainer = document.getElementById('unavailable-cats'); if (!availableContainer || !unavailableContainer) return; async function save(container) { const cats = Array.from(availableContainer.querySelectorAll('.cat-chip')).map(el => el.dataset.id); await saveWithFeedback( '/settings/guide-filter', { method: 'POST', headers: {'Content-Type': 'application/json'}, body: JSON.stringify({cats}) }, container ); } // Initialize order from config const chipById = {}; unavailableContainer.querySelectorAll('.cat-chip').forEach(el => chipById[el.dataset.id] = el); (cfg.selectedCats || []).forEach(catId => { if (chipById[catId]) availableContainer.appendChild(chipById[catId]); }); setupDragDrop('#available-cats, #unavailable-cats', '#filters .cat-chip', save); setupSearch('cat-search', 'cat-search-clear', '#filters .cat-chip'); document.getElementById('cat-move-all-right')?.addEventListener('click', async () => { availableContainer.querySelectorAll('.cat-chip:not([style*="display: none"])').forEach(c => unavailableContainer.appendChild(c)); await save(unavailableContainer); }); document.getElementById('cat-move-all-left')?.addEventListener('click', async () => { unavailableContainer.querySelectorAll('.cat-chip:not([style*="display: none"])').forEach(c => availableContainer.appendChild(c)); await save(availableContainer); }); } // ============================================================ // VOD Category Filter // ============================================================ function setupVodCategoryFilter() { const availableContainer = document.getElementById('available-vod-cats'); const unavailableContainer = document.getElementById('unavailable-vod-cats'); if (!availableContainer || !unavailableContainer) return; async function save(container) { const cats = Array.from(availableContainer.querySelectorAll('.vod-cat-chip')).map(el => el.dataset.id); await saveWithFeedback( '/settings/vod-filter', { method: 'POST', headers: {'Content-Type': 'application/json'}, body: JSON.stringify({cats}) }, container ); } // Initialize order from config const chipById = {}; unavailableContainer.querySelectorAll('.vod-cat-chip').forEach(el => chipById[el.dataset.id] = el); (cfg.selectedVodCats || []).forEach(catId => { if (chipById[catId]) availableContainer.appendChild(chipById[catId]); }); setupDragDrop('#available-vod-cats, #unavailable-vod-cats', '#vod-filters .vod-cat-chip', save); setupSearch('vod-cat-search', 'vod-cat-search-clear', '#vod-filters .vod-cat-chip'); document.getElementById('vod-cat-move-all-right')?.addEventListener('click', async () => { availableContainer.querySelectorAll('.vod-cat-chip:not([style*="display: none"])').forEach(c => unavailableContainer.appendChild(c)); await save(unavailableContainer); }); document.getElementById('vod-cat-move-all-left')?.addEventListener('click', async () => { unavailableContainer.querySelectorAll('.vod-cat-chip:not([style*="display: none"])').forEach(c => availableContainer.appendChild(c)); await save(availableContainer); }); } // ============================================================ // Series Category Filter // ============================================================ function setupSeriesCategoryFilter() { const availableContainer = document.getElementById('available-series-cats'); const unavailableContainer = document.getElementById('unavailable-series-cats'); if (!availableContainer || !unavailableContainer) return; async function save(container) { const cats = Array.from(availableContainer.querySelectorAll('.series-cat-chip')).map(el => el.dataset.id); await saveWithFeedback( '/settings/series-filter', { method: 'POST', headers: {'Content-Type': 'application/json'}, body: JSON.stringify({cats}) }, container ); } // Initialize order from config const chipById = {}; unavailableContainer.querySelectorAll('.series-cat-chip').forEach(el => chipById[el.dataset.id] = el); (cfg.selectedSeriesCats || []).forEach(catId => { if (chipById[catId]) availableContainer.appendChild(chipById[catId]); }); setupDragDrop('#available-series-cats, #unavailable-series-cats', '#series-filters .series-cat-chip', save); setupSearch('series-cat-search', 'series-cat-search-clear', '#series-filters .series-cat-chip'); document.getElementById('series-cat-move-all-right')?.addEventListener('click', async () => { availableContainer.querySelectorAll('.series-cat-chip:not([style*="display: none"])').forEach(c => unavailableContainer.appendChild(c)); await save(unavailableContainer); }); document.getElementById('series-cat-move-all-left')?.addEventListener('click', async () => { unavailableContainer.querySelectorAll('.series-cat-chip:not([style*="display: none"])').forEach(c => availableContainer.appendChild(c)); await save(availableContainer); }); } // ============================================================ // Chrome CC Link Copy // ============================================================ function setupChromeCcLink() { const el = document.getElementById('chrome-cc-link'); if (!el) return; el.addEventListener('click', async () => { const text = 'chrome://settings/captions'; const orig = el.textContent; try { await navigator.clipboard.writeText(text); el.textContent = 'Copied!'; } catch (e) { console.error('Copy failed:', e); el.textContent = 'Failed'; } setTimeout(() => el.textContent = orig, 1500); }); } // ============================================================ // Caption Settings // ============================================================ function setupCaptionSettings() { const preview = document.getElementById('cc-preview'); const selects = document.querySelectorAll('.cc-setting'); const langSelect = document.getElementById('cc-lang-pref'); const enabledCb = document.getElementById('captions-enabled'); let ccStyle = cfg.ccStyle || {}; function hexToRgba(hex, opacity) { if (hex === 'transparent') return 'transparent'; const r = parseInt(hex.slice(1,3), 16); const g = parseInt(hex.slice(3,5), 16); const b = parseInt(hex.slice(5,7), 16); return `rgba(${r},${g},${b},${opacity})`; } function updatePreview() { if (!preview) return; preview.style.color = hexToRgba(ccStyle.cc_color || '#ffffff', 1); preview.style.textShadow = ccStyle.cc_shadow || '0 0 4px black, 0 0 4px black'; preview.style.backgroundColor = hexToRgba(ccStyle.cc_bg || '#000000', ccStyle.cc_bg_opacity ?? 0.5); preview.style.fontSize = ccStyle.cc_size || '1em'; preview.style.fontFamily = ccStyle.cc_font || 'inherit'; } selects.forEach(sel => { if (ccStyle[sel.dataset.setting]) sel.value = ccStyle[sel.dataset.setting]; sel.addEventListener('change', async function() { ccStyle[this.dataset.setting] = this.value; updatePreview(); await saveWithFeedback( '/api/user-prefs', { method: 'POST', headers: {'Content-Type': 'application/json'}, body: JSON.stringify({cc_style: ccStyle}) }, this ); }); }); updatePreview(); if (langSelect) { if (cfg.ccLang) langSelect.value = cfg.ccLang; langSelect.addEventListener('change', async function() { await saveWithFeedback( '/api/user-prefs', { method: 'POST', headers: {'Content-Type': 'application/json'}, body: JSON.stringify({cc_lang: this.value}) }, this ); }); } if (enabledCb) { enabledCb.addEventListener('change', async function() { const form = new FormData(); if (this.checked) form.append('enabled', 'on'); await saveWithFeedback('/settings/captions', { method: 'POST', body: form }, getFeedbackEl(this)); }); } } // ============================================================ // Guide Settings // ============================================================ function setupGuideSettings() { const virtualScrollCb = document.getElementById('virtual-scroll'); if (virtualScrollCb) { virtualScrollCb.addEventListener('change', async function() { await saveWithFeedback( '/api/user-prefs', { method: 'POST', headers: {'Content-Type': 'application/json'}, body: JSON.stringify({virtual_scroll: this.checked}) }, getFeedbackEl(this) ); }); } } // ============================================================ // Transcode & User-Agent Settings // ============================================================ function setupTranscodeSettings() { const container = document.getElementById('transcode-settings'); if (!container) return; // Collect all transcode-related inputs (in container + probe checkboxes + transcode_dir) const transcodeInputs = [ ...container.querySelectorAll('.setting-input'), ...document.querySelectorAll('input[name="probe_live"], input[name="probe_movies"], input[name="probe_series"]'), document.querySelector('input[name="transcode_dir"]') ].filter(Boolean); async function save(triggerEl) { const form = new FormData(); // Auto-collect all transcode inputs by type transcodeInputs.forEach(el => { if (!el.name) return; if (el.type === 'checkbox') { if (el.checked) form.append(el.name, 'on'); } else if (el.type === 'radio') { if (el.checked) form.append(el.name, el.value); } else { form.append(el.name, el.value); } }); await saveWithFeedback('/settings/transcode', { method: 'POST', body: form }, getFeedbackEl(triggerEl)); } // Auto-enable "Always" transcode when AI Upscale is enabled const srInputs = container.querySelectorAll('input[name="sr_model"]'); const alwaysRadio = container.querySelector('input[name="transcode_mode"][value="always"]'); srInputs.forEach(sr => { sr.addEventListener('change', function() { if (this.value !== '' && alwaysRadio && !alwaysRadio.checked) { alwaysRadio.checked = true; showFeedback(getFeedbackEl(alwaysRadio), true); } }); }); transcodeInputs.forEach(el => { el.addEventListener('change', function() { save(this); }); }); // Re-detect hardware button const refreshBtn = document.getElementById('refresh-encoders-btn'); if (refreshBtn) { refreshBtn.addEventListener('click', async () => { refreshBtn.disabled = true; refreshBtn.textContent = 'Detecting...'; try { const resp = await fetch('/settings/refresh-encoders', { method: 'POST' }); if (resp.ok) { const { encoders = {} } = await resp.json(); // Map encoder detection to radio button enable/disable states const radioStates = { 'nvenc+vaapi': encoders.nvenc && encoders.vaapi, 'nvenc+software': encoders.nvenc, 'amf+vaapi': encoders.amf && encoders.vaapi, 'amf+software': encoders.amf, 'qsv': encoders.qsv, 'vaapi': encoders.vaapi, 'software': true, // Always available }; Object.entries(radioStates).forEach(([value, enabled]) => { const radio = container.querySelector(`input[name="transcode_hw"][value="${value}"]`); const label = radio?.closest('label'); if (radio && label) { radio.disabled = !enabled; label.classList.toggle('opacity-40', !enabled); } }); refreshBtn.textContent = 'Done!'; } else { refreshBtn.textContent = 'Failed'; } } catch (e) { console.error('Refresh encoders failed:', e); refreshBtn.textContent = 'Failed'; } setTimeout(() => { refreshBtn.textContent = 'Re-detect Hardware'; refreshBtn.disabled = false; }, 1500); }); } } function setupUserAgentSettings() { const container = document.getElementById('user-agent-settings'); if (!container) return; const customContainer = document.getElementById('custom-user-agent-container'); const customInput = container.querySelector('input[name="user_agent_custom"]'); async function save(triggerEl) { const form = new FormData(); form.append('preset', container.querySelector('input[name="user_agent_preset"]:checked')?.value || 'default'); form.append('custom', customInput?.value || ''); await saveWithFeedback('/settings/user-agent', { method: 'POST', body: form }, getFeedbackEl(triggerEl)); } container.querySelectorAll('input[name="user_agent_preset"]').forEach(radio => { radio.addEventListener('change', function() { customContainer?.classList.toggle('hidden', this.value !== 'custom'); save(this); }); }); customInput?.addEventListener('change', function() { save(this); }); } // ============================================================ // Data & Probe Cache // ============================================================ function setupDataCache() { const clearBtn = document.getElementById('clear-data-cache'); if (!clearBtn) return; clearBtn.addEventListener('click', async () => { clearBtn.disabled = true; clearBtn.textContent = 'Deleting...'; const resp = await saveWithFeedback('/settings/data-cache/clear', { method: 'POST' }, clearBtn); clearBtn.textContent = resp?.ok ? 'Deleted!' : 'Failed'; setTimeout(() => { clearBtn.textContent = 'Delete'; clearBtn.disabled = false; }, 2000); }); } function setupProbeCache() { const listEl = document.getElementById('probe-cache-list'); const clearAllBtn = document.getElementById('clear-all-probe-cache'); if (!listEl) return; function formatDuration(secs) { if (!secs || secs <= 0) return ''; const h = Math.floor(secs / 3600); const m = Math.floor((secs % 3600) / 60); return h > 0 ? `${h}h${m}m` : `${m}m`; } function loadCache() { fetch('/settings/probe-cache') .then(r => r.json()) .then(data => { const series = data.series || []; if (series.length === 0) { listEl.innerHTML = '
No cached probes
'; return; } listEl.innerHTML = series.map(s => { const name = escapeHtml(s.name) || `Series ${s.series_id}`; const episodes = s.episodes || []; const mruEp = s.mru != null ? episodes.find(ep => ep.episode_id === s.mru) : null; const mruName = mruEp ? escapeHtml(mruEp.name) || `Episode ${s.mru}` : (s.mru != null ? `Episode ${s.mru}` : null); return `
${name} ${s.episode_count} ep${s.episode_count > 1 ? 's' : ''} ${escapeHtml(s.video_codec || '')}/${escapeHtml(s.audio_codec || '')} ${s.subtitle_count > 0 ? `+${s.subtitle_count} subs` : ''}
${mruName ? `
MRU: ${mruName}
` : ''} ${episodes.map(ep => `
${escapeHtml(ep.name) || 'Episode ' + ep.episode_id}${ep.duration ? ` (${formatDuration(ep.duration)})` : ''}${ep.subtitle_count ? ` +${ep.subtitle_count} subs` : ''}
`).join('')}
`; }).join(''); }) .catch(() => { listEl.innerHTML = '
Failed to load
'; }); } clearAllBtn?.addEventListener('click', () => { fetch('/settings/probe-cache/clear', { method: 'POST' }).then(() => loadCache()); }); // Event delegation for dynamically created buttons listEl.addEventListener('click', (e) => { const btn = e.target.closest('button'); if (!btn) return; e.stopPropagation(); if (btn.classList.contains('clear-series')) { fetch(`/settings/probe-cache/clear/${btn.dataset.series}`, { method: 'POST' }).then(() => loadCache()); } else if (btn.classList.contains('clear-mru')) { fetch(`/settings/probe-cache/clear-mru/${btn.dataset.series}`, { method: 'POST' }).then(() => loadCache()); } else if (btn.classList.contains('clear-episode')) { fetch(`/settings/probe-cache/clear/${btn.dataset.series}?episode_id=${btn.dataset.episode}`, { method: 'POST' }).then(() => loadCache()); } }); loadCache(); } // ============================================================ // Source Refresh Buttons // ============================================================ function setupRefreshButtons() { const activeRefreshes = new Set(); let pollInterval = null; function updateButtonStates(statuses) { const globalStatus = statuses._global || {}; document.querySelectorAll('[data-source-id]').forEach(container => { const sourceId = container.dataset.sourceId; const sourceStatuses = statuses[sourceId] || {}; container.querySelectorAll('.refresh-btn').forEach(btn => { const refreshType = btn.dataset.refresh; const isActive = !!sourceStatuses[refreshType] || !!globalStatus[refreshType]; btn.classList.toggle('active', isActive); if (isActive) activeRefreshes.add(`${sourceId}_${refreshType}`); else activeRefreshes.delete(`${sourceId}_${refreshType}`); }); }); if (activeRefreshes.size === 0 && pollInterval) { clearInterval(pollInterval); pollInterval = null; } } function pollStatus() { fetch('/settings/refresh-status').then(r => r.json()).then(updateButtonStates).catch(() => {}); } function startPolling() { if (!pollInterval) { pollInterval = setInterval(pollStatus, 1000); pollStatus(); } } document.querySelectorAll('[data-source-id] .refresh-btn').forEach(btn => { btn.addEventListener('click', () => { const container = btn.closest('[data-source-id]'); const sourceId = container.dataset.sourceId; btn.classList.add('active'); activeRefreshes.add(`${sourceId}_${btn.dataset.refresh}`); fetch(`/settings/refresh/${sourceId}/${btn.dataset.refresh}`, { method: 'POST' }) .then(() => startPolling()) .catch(() => btn.classList.remove('active')); }); }); fetch('/settings/refresh-status').then(r => r.json()).then(statuses => { if (Object.keys(statuses).length > 0) { updateButtonStates(statuses); startPolling(); } }).catch(() => {}); } // ============================================================ // User Management // ============================================================ function setupUserForms() { // Add User form const addUserForm = document.getElementById('add-user-form'); if (addUserForm) { setupDragDrop( '#add-user-available-groups, #add-user-unavailable-groups', '.add-user-group-chip', null ); setupSearch('add-user-group-search', 'add-user-group-search-clear', '.add-user-group-chip'); document.getElementById('add-user-block-all')?.addEventListener('click', () => { const avail = document.getElementById('add-user-available-groups'); const unavail = document.getElementById('add-user-unavailable-groups'); avail?.querySelectorAll('.add-user-group-chip:not([style*="display: none"])').forEach(c => unavail?.appendChild(c)); }); document.getElementById('add-user-allow-all')?.addEventListener('click', () => { const avail = document.getElementById('add-user-available-groups'); const unavail = document.getElementById('add-user-unavailable-groups'); unavail?.querySelectorAll('.add-user-group-chip:not([style*="display: none"])').forEach(c => avail?.appendChild(c)); }); addUserForm.addEventListener('submit', async function(e) { e.preventDefault(); const form = new FormData(this); const maxStreamsPerSource = {}; document.querySelectorAll('.add-user-source-max-streams').forEach(inp => { const val = parseInt(inp.value) || 0; if (val > 0) maxStreamsPerSource[inp.dataset.sourceId] = val; }); form.append('max_streams_per_source', JSON.stringify(maxStreamsPerSource)); const unavailableGroups = Array.from( document.querySelectorAll('#add-user-unavailable-groups .add-user-group-chip') ).map(c => c.dataset.groupId); form.append('unavailable_groups', JSON.stringify(unavailableGroups)); const msgEl = document.getElementById('add-user-msg'); try { const resp = await fetch('/settings/users/add', { method: 'POST', body: form }); if (resp.ok) { if (msgEl) { msgEl.textContent = 'Added'; msgEl.className = 'text-sm text-green-400'; } this.reset(); setTimeout(() => location.reload(), 500); } else { const data = await resp.json(); if (msgEl) { msgEl.textContent = data.detail || 'Failed'; msgEl.className = 'text-sm text-red-400'; } } } catch (e) { console.error('Add user failed:', e); if (msgEl) { msgEl.textContent = 'Request failed'; msgEl.className = 'text-sm text-red-400'; } } msgEl?.classList.remove('hidden'); setTimeout(() => { if (msgEl) msgEl.className = 'text-sm hidden'; }, 3000); }); } // Password inputs document.querySelectorAll('.password-input').forEach(input => { input.addEventListener('change', async function() { const username = this.closest('[data-username]')?.dataset.username; if (!username || this.value.length < 8) { showFeedback(this, false); return; } const form = new FormData(); form.append('new_password', this.value); const resp = await saveWithFeedback(`/settings/users/password/${username}`, { method: 'POST', body: form }, this); if (resp?.ok) this.value = ''; }); }); // Admin toggles document.querySelectorAll('.admin-toggle').forEach(checkbox => { checkbox.addEventListener('change', async function() { const username = this.closest('[data-username]')?.dataset.username; if (!username) return; const form = new FormData(); if (this.checked) form.append('admin', 'on'); try { const resp = await fetch(`/settings/users/admin/${username}`, { method: 'POST', body: form }); if (resp.ok) location.reload(); else this.checked = !this.checked; } catch (e) { console.error('Admin toggle failed:', e); this.checked = !this.checked; } }); }); // Max streams per source document.querySelectorAll('.user-source-max-streams').forEach(input => { input.addEventListener('change', async function() { const container = this.closest('.user-max-streams-container'); const username = container?.dataset.username; if (!username) return; const maxStreamsPerSource = {}; container.querySelectorAll('.user-source-max-streams').forEach(inp => { const val = parseInt(inp.value) || 0; if (val > 0) maxStreamsPerSource[inp.dataset.sourceId] = val; }); const form = new FormData(); form.append('max_streams_per_source', JSON.stringify(maxStreamsPerSource)); await saveWithFeedback(`/settings/users/limits/${username}`, { method: 'POST', body: form }, this); }); }); // Group restrictions setupUserGroupDragDrop(); } function setupUserGroupDragDrop() { async function saveGroups(username, feedbackContainer) { const unavailableContainer = document.querySelector(`.user-unavailable-groups[data-username="${username}"]`); const unavailableGroups = Array.from(unavailableContainer?.querySelectorAll('.group-chip') || []) .map(c => c.dataset.groupId); const form = new FormData(); form.append('unavailable_groups', JSON.stringify(unavailableGroups)); await saveWithFeedback(`/settings/users/limits/${username}`, { method: 'POST', body: form }, feedbackContainer); } setupDragDrop('.user-available-groups, .user-unavailable-groups', '.group-chip', (container) => { const username = container.dataset.username; if (username) saveGroups(username, container); }); // Search per user document.querySelectorAll('.user-group-search').forEach(input => { const username = input.dataset.username; const clearBtn = document.querySelector(`.user-group-search-clear[data-username="${username}"]`); function apply() { const q = input.value.toLowerCase(); [`.user-available-groups[data-username="${username}"]`, `.user-unavailable-groups[data-username="${username}"]`].forEach(sel => { document.querySelectorAll(`${sel} .group-chip`).forEach(chip => { chip.style.display = chip.textContent.toLowerCase().includes(q) ? '' : 'none'; }); }); clearBtn?.classList.toggle('hidden', !input.value); } input.addEventListener('input', apply); clearBtn?.addEventListener('click', () => { input.value = ''; apply(); }); }); // Move all buttons document.querySelectorAll('.group-move-all-unavailable').forEach(btn => { btn.addEventListener('click', async () => { const username = btn.dataset.username; if (!username) return; const avail = document.querySelector(`.user-available-groups[data-username="${username}"]`); const unavail = document.querySelector(`.user-unavailable-groups[data-username="${username}"]`); avail?.querySelectorAll('.group-chip:not([style*="display: none"])').forEach(c => unavail?.appendChild(c)); await saveGroups(username, unavail); }); }); document.querySelectorAll('.group-move-all-available').forEach(btn => { btn.addEventListener('click', async () => { const username = btn.dataset.username; if (!username) return; const avail = document.querySelector(`.user-available-groups[data-username="${username}"]`); const unavail = document.querySelector(`.user-unavailable-groups[data-username="${username}"]`); unavail?.querySelectorAll('.group-chip:not([style*="display: none"])').forEach(c => avail?.appendChild(c)); await saveGroups(username, avail); }); }); } // ============================================================ // Init // ============================================================ function init() { setupSourceTypeSelect(); setupSourceEditForms(); setupCategoryFilter(); setupVodCategoryFilter(); setupSeriesCategoryFilter(); setupChromeCcLink(); setupCaptionSettings(); setupGuideSettings(); setupTranscodeSettings(); setupUserAgentSettings(); setupDataCache(); setupProbeCache(); setupRefreshButtons(); setupUserForms(); } init(); })(); ================================================ FILE: static/js/virtual-guide.js ================================================ /** * Virtual scrolling for the TV guide. * Only renders rows that are visible (plus buffer), fetches more as needed. */ // Configuration constants const VIRTUAL_GUIDE_DEFAULTS = { ROW_HEIGHT_DESKTOP: 64, // 4rem in pixels ROW_HEIGHT_MOBILE: 40, // 2.5rem in pixels BUFFER_SIZE: 50, // Rows to load above/below viewport MAX_CACHE_SIZE: 500, // Evict cache beyond this MAX_RETRIES: 3, // Retry failed fetches this many times MOBILE_BREAKPOINT: 512, // Width below which is considered mobile SCROLL_DIRECTION_THRESHOLD: 5, // Min scroll delta to register direction RENDER_DEBOUNCE_MS: 16, // ~60fps for smooth visual update FETCH_DEBOUNCE_MS: 150, // Wait for scroll to settle before fetching RESIZE_DEBOUNCE_MS: 100, // Debounce window resize handler }; class VirtualGuide { constructor(options) { const D = VIRTUAL_GUIDE_DEFAULTS; this.container = options.container; this.rowHeight = options.rowHeight || D.ROW_HEIGHT_DESKTOP; this.rowHeightMobile = options.rowHeightMobile || D.ROW_HEIGHT_MOBILE; this.totalRows = options.totalRows; this.bufferSize = options.bufferSize || D.BUFFER_SIZE; this.maxCacheSize = options.maxCacheSize || D.MAX_CACHE_SIZE; this.initialRows = options.initialRows || []; this.offset = options.offset || 0; this.cats = options.cats || ''; this.logoUrlFilter = options.logoUrlFilter || (url => url); // State this.cache = new Map(); // row index -> row data this.failedRanges = new Map(); // range key -> retry count this.maxRetries = D.MAX_RETRIES; this.needsRecheck = false; this.renderedRange = { start: 0, end: 0 }; this.pendingFetch = null; this.pendingFetchRange = null; this.scrollDebounce = null; this.fetchDebounce = null; this.renderDebounce = null; this.recheckTimeout = null; this.isMobile = window.innerWidth < D.MOBILE_BREAKPOINT; this.lastScrollTop = 0; this.scrollDirection = 'down'; // DOM elements this.viewport = null; this.content = null; this.spacer = null; this.init(); } get currentRowHeight() { return this.isMobile ? this.rowHeightMobile : this.rowHeight; } get visibleCount() { if (!this.viewport) return 30; return Math.ceil(this.viewport.clientHeight / this.currentRowHeight) + 1; } init() { // Cache initial SSR rows for (const row of this.initialRows) { this.cache.set(row.index, row); } // Set up virtual scroll container this.setupDOM(); this.bindEvents(); // Handle scroll position restoration // Check if there's a saved scroll position that's beyond initial rows const scrollKey = 'guide_scroll'; const savedScroll = sessionStorage.getItem(scrollKey); if (savedScroll && this.viewport) { const scrollTop = parseInt(savedScroll); const firstVisible = Math.floor(scrollTop / this.currentRowHeight); // If saved position is beyond initial batch, fetch first then scroll if (firstVisible >= this.initialRows.length) { // Fetch data for the saved position, then restore scroll const start = Math.max(0, firstVisible - this.bufferSize); const end = Math.min(this.totalRows, firstVisible + this.visibleCount + this.bufferSize); this.fetchMissingRanges([{ start, end }]).then(() => { this.viewport.scrollTop = scrollTop; this.renderedRange = { start, end }; this.render(); }); return; // Don't do normal init flow } } // If we have more rows than initial batch, enable virtual scrolling if (this.totalRows > this.initialRows.length) { this.updateVisibleRange(); } } setupDOM() { // Find the scroll container (the overflow-y-auto div) this.viewport = this.container.querySelector('.overflow-y-auto'); if (!this.viewport) return; // Create spacer for full height scrollbar this.spacer = document.createElement('div'); this.spacer.className = 'virtual-spacer'; this.spacer.style.height = `${this.totalRows * this.currentRowHeight}px`; this.spacer.style.position = 'absolute'; this.spacer.style.top = '0'; this.spacer.style.left = '0'; this.spacer.style.right = '0'; this.spacer.style.pointerEvents = 'none'; // Create content container this.content = document.createElement('div'); this.content.className = 'virtual-content'; this.content.style.position = 'relative'; this.content.style.zIndex = '1'; // Move existing rows into content container const existingRows = this.viewport.querySelectorAll('.guide-row'); existingRows.forEach(row => this.content.appendChild(row)); // Set viewport to relative positioning this.viewport.style.position = 'relative'; // Add spacer and content to viewport this.viewport.appendChild(this.spacer); this.viewport.insertBefore(this.content, this.spacer); // Set initial rendered range based on SSR content this.renderedRange = { start: 0, end: this.initialRows.length }; } bindEvents() { if (!this.viewport) return; // Scroll handler with RAF for smooth updates let ticking = false; this.viewport.addEventListener('scroll', () => { if (!ticking) { requestAnimationFrame(() => { this.onScroll(); ticking = false; }); ticking = true; } }, { passive: true }); // Handle resize const D = VIRTUAL_GUIDE_DEFAULTS; let resizeTimer; window.addEventListener('resize', () => { clearTimeout(resizeTimer); resizeTimer = setTimeout(() => { const wasMobile = this.isMobile; this.isMobile = window.innerWidth < D.MOBILE_BREAKPOINT; if (wasMobile !== this.isMobile) { // Row height changed, update spacer this.spacer.style.height = `${this.totalRows * this.currentRowHeight}px`; this.updateVisibleRange(); } }, D.RESIZE_DEBOUNCE_MS); }); } onScroll() { const D = VIRTUAL_GUIDE_DEFAULTS; // Clear any pending debounce clearTimeout(this.fetchDebounce); clearTimeout(this.renderDebounce); const scrollTop = this.viewport.scrollTop; const firstVisible = Math.floor(scrollTop / this.currentRowHeight); const lastVisible = firstVisible + this.visibleCount; // Track scroll direction const scrollDelta = scrollTop - (this.lastScrollTop || 0); this.lastScrollTop = scrollTop; if (Math.abs(scrollDelta) > D.SCROLL_DIRECTION_THRESHOLD) { this.scrollDirection = scrollDelta > 0 ? 'down' : 'up'; } // Calculate desired range with buffer const desiredStart = Math.max(0, firstVisible - this.bufferSize); const desiredEnd = Math.min(this.totalRows, lastVisible + this.bufferSize); // Check if we need to update rendered range const needsRender = desiredStart < this.renderedRange.start || desiredEnd > this.renderedRange.end; if (needsRender) { // Render immediately with whatever we have (placeholders for missing) this.renderDebounce = setTimeout(() => { this.renderedRange = { start: desiredStart, end: desiredEnd }; this.render(); }, D.RENDER_DEBOUNCE_MS); // Debounce fetching - wait for scroll to settle before fetching this.fetchDebounce = setTimeout(() => { this.updateVisibleRange(); }, D.FETCH_DEBOUNCE_MS); } } async updateVisibleRange() { const scrollTop = this.viewport.scrollTop; const firstVisible = Math.floor(scrollTop / this.currentRowHeight); const lastVisible = firstVisible + this.visibleCount; // Calculate ranges: visible, forward buffer, backward buffer const visibleStart = Math.max(0, firstVisible); const visibleEnd = Math.min(this.totalRows, lastVisible + 1); const bufferStart = Math.max(0, firstVisible - this.bufferSize); const bufferEnd = Math.min(this.totalRows, lastVisible + this.bufferSize); // Priority fetch order based on scroll direction const fetchOrder = []; // 1. Always fetch visible rows first const visibleMissing = this.findMissingRanges(visibleStart, visibleEnd); if (visibleMissing.length > 0) { fetchOrder.push({ ranges: visibleMissing, priority: 'visible' }); } // 2. Fetch buffer in scroll direction // 3. Fetch buffer in opposite direction if (this.scrollDirection === 'down') { const forwardMissing = this.findMissingRanges(visibleEnd, bufferEnd); const backwardMissing = this.findMissingRanges(bufferStart, visibleStart); if (forwardMissing.length > 0) fetchOrder.push({ ranges: forwardMissing, priority: 'forward' }); if (backwardMissing.length > 0) fetchOrder.push({ ranges: backwardMissing, priority: 'backward' }); } else { const backwardMissing = this.findMissingRanges(bufferStart, visibleStart); const forwardMissing = this.findMissingRanges(visibleEnd, bufferEnd); if (backwardMissing.length > 0) fetchOrder.push({ ranges: backwardMissing, priority: 'backward' }); if (forwardMissing.length > 0) fetchOrder.push({ ranges: forwardMissing, priority: 'forward' }); } // Fetch in priority order, re-rendering after each batch for (const batch of fetchOrder) { const success = await this.fetchMissingRanges(batch.ranges); // Re-render after each batch so visible content appears first this.renderedRange = { start: bufferStart, end: bufferEnd }; this.render(); if (!success) { // A pending fetch exists or we just aborted one - don't fetch lower-priority buffers // The recheck timeout will re-call updateVisibleRange with correct priorities break; } } // Always do a final render to ensure current position is shown // This handles the case where fetchOrder is empty (all rows cached) this.renderedRange = { start: bufferStart, end: bufferEnd }; this.render(); } findMissingRanges(start, end) { const ranges = []; let rangeStart = null; for (let i = start; i < end; i++) { if (!this.cache.has(i)) { if (rangeStart === null) rangeStart = i; } else if (rangeStart !== null) { ranges.push({ start: rangeStart, end: i }); rangeStart = null; } } if (rangeStart !== null) { ranges.push({ start: rangeStart, end }); } return ranges; } /** * Fetch missing rows for the given ranges. * @returns {Promise} true if fetch completed (or nothing to fetch), * false if a pending fetch blocked us or we aborted one */ async fetchMissingRanges(ranges) { // Early return if no ranges to fetch if (!ranges || ranges.length === 0) { return true; } // Merge into a single request for simplicity const overallStart = Math.min(...ranges.map(r => r.start)); const overallEnd = Math.max(...ranges.map(r => r.end)); // If there's a pending fetch, check if it's for a relevant range if (this.pendingFetch) { if (this.pendingFetchRange) { const p = this.pendingFetchRange; const overlaps = !(overallEnd < p.start || overallStart > p.end); if (overlaps) { // Pending fetch will give us some useful data, let it finish // The recheck timeout will catch any remaining gaps this.needsRecheck = true; return false; } } // Non-overlapping or orphaned pending fetch - abort it // Return false so caller doesn't continue to lower-priority fetches this.pendingFetch.abort(); this.pendingFetch = null; this.pendingFetchRange = null; // CRITICAL: Schedule recheck ourselves since the aborted fetch's finally // block won't do it (we already set pendingFetch = null) clearTimeout(this.recheckTimeout); this.recheckTimeout = setTimeout(() => { this.updateVisibleRange(); }, 50); return false; } const controller = new AbortController(); this.pendingFetch = controller; this.pendingFetchRange = { start: overallStart, end: overallEnd }; // Clear needsRecheck since we're now fetching what we need this.needsRecheck = false; try { const params = new URLSearchParams({ start: overallStart, count: overallEnd - overallStart, offset: this.offset }); // Pass cats if set (for temporary dropdown filters) if (this.cats) { params.set('cats', this.cats); } const resp = await fetch(`/api/guide/rows?${params}`, { signal: controller.signal }); if (!resp.ok) throw new Error(`HTTP ${resp.status}`); const data = await resp.json(); // Cache the fetched rows for (const row of data.rows) { this.cache.set(row.index, row); } // Clear any failure tracking for this range const rangeKey = `${overallStart}-${overallEnd}`; this.failedRanges.delete(rangeKey); // Evict old cache entries to prevent memory growth this.pruneCache(); } catch (e) { if (e.name !== 'AbortError') { console.error('Failed to fetch guide rows:', e); // Track failed range for retry const rangeKey = `${overallStart}-${overallEnd}`; const retryCount = (this.failedRanges.get(rangeKey) || 0) + 1; if (retryCount < this.maxRetries) { this.failedRanges.set(rangeKey, retryCount); // Schedule retry after delay setTimeout(() => { this.failedRanges.delete(rangeKey); this.updateVisibleRange(); }, 1000 * retryCount); // Exponential backoff: 1s, 2s, 3s } else { // Max retries reached - clear tracking this.failedRanges.delete(rangeKey); console.error(`Failed to fetch rows ${overallStart}-${overallEnd} after ${this.maxRetries} retries`); } } } finally { if (this.pendingFetch === controller) { this.pendingFetch = null; this.pendingFetchRange = null; // Always recheck after fetch completes to catch any gaps // Use a small delay to batch multiple rapid rechecks clearTimeout(this.recheckTimeout); this.recheckTimeout = setTimeout(() => { this.updateVisibleRange(); }, 50); } } return true; } render() { if (!this.content) return; const html = []; for (let i = this.renderedRange.start; i < this.renderedRange.end; i++) { const row = this.cache.get(i); if (row) { html.push(this.renderRow(row, i)); } else { html.push(this.renderPlaceholder(i)); } } // Position content at the right scroll offset this.content.style.transform = `translateY(${this.renderedRange.start * this.currentRowHeight}px)`; this.content.innerHTML = html.join(''); } /** * Evict cached rows far from current view to prevent unbounded memory growth. * Keeps rows within 2x buffer distance from current rendered range. */ pruneCache() { if (this.cache.size <= this.maxCacheSize) { return; } const center = Math.floor((this.renderedRange.start + this.renderedRange.end) / 2); const keepDistance = this.bufferSize * 2; // Collect indices to remove (those far from current view) const toRemove = []; for (const index of this.cache.keys()) { const distance = Math.abs(index - center); if (distance > keepDistance) { toRemove.push(index); } } // Remove furthest first until under max size toRemove.sort((a, b) => Math.abs(b - center) - Math.abs(a - center)); const removeCount = Math.min(toRemove.length, this.cache.size - this.maxCacheSize); for (let i = 0; i < removeCount; i++) { this.cache.delete(toRemove[i]); } } renderPlaceholder(index) { const height = this.currentRowHeight; const isMobile = this.isMobile; if (isMobile) { return `
`; } return `
`; } renderRow(row, index) { const ch = row.channel; const iconUrl = ch.icon ? this.logoUrlFilter(ch.icon) : ''; const height = this.currentRowHeight; // Escape HTML in text content const escapeHtml = (str) => { if (!str) return ''; return str.replace(/&/g, '&') .replace(//g, '>') .replace(/"/g, '"') .replace(/'/g, '''); }; // Desktop programs let programsDesktop = ''; if (row.programs && row.programs.length > 0) { programsDesktop = row.programs.map((prog, pIdx) => `
${escapeHtml(prog.title)}
${escapeHtml(prog.desc)}
`).join(''); } else { programsDesktop = `
No program info
`; } // Mobile programs let programsMobile = ''; if (row.programs_mobile && row.programs_mobile.length > 0) { programsMobile = row.programs_mobile.map((prog, pIdx) => `
${escapeHtml(prog.title)}
`).join(''); } else { programsMobile = `
No info
`; } return `
${iconUrl ? `` : ''} ${escapeHtml(ch.name)}
${programsMobile}
${programsDesktop}
`; } /** * Clean up resources when the virtual guide is no longer needed. * Call this before removing/reinitializing to prevent memory leaks. */ destroy() { // Clear all timers clearTimeout(this.fetchDebounce); clearTimeout(this.renderDebounce); clearTimeout(this.scrollDebounce); clearTimeout(this.recheckTimeout); // Abort any pending fetch if (this.pendingFetch) { this.pendingFetch.abort(); this.pendingFetch = null; this.pendingFetchRange = null; } // Clear caches this.cache.clear(); this.failedRanges.clear(); // Clear DOM references if (this.content) { this.content.innerHTML = ''; } this.viewport = null; this.content = null; this.spacer = null; this.container = null; // Note: Event listeners on window (resize) are not removed // as they use anonymous functions. For full cleanup, would need // to store references to bound handlers in constructor. } } // Export for use window.VirtualGuide = VirtualGuide; ================================================ FILE: templates/base.html ================================================ {% block title %}neTV{% endblock %} {% block head_extra %}{% endblock %}
{% block content %}{% endblock %}
{% block scripts %}{% endblock %} ================================================ FILE: templates/error.html ================================================ {% extends "base.html" %} {% block title %}Error{% endblock %} {% block content %}
⚠️

{{ title }}

{{ message }}

{% endblock %} ================================================ FILE: templates/guide.html ================================================ {% extends "base.html" %} {% block title %}Live TV - neTV{% endblock %} {% block head_extra %} {% if loading and not request.query_params.get('refreshing') %} {% endif %} {% endblock %} {% block content %}

Live TV

{% if saved_filter or selected_cats %}
({{ total_count | default(channel_count) }} ch) [edit] {% endif %}
{% if loading_message %}

{{ loading_message }}

{% elif epg_error %}

EPG Error: {{ epg_error }}

Showing channels without program info

{% endif %}
Live TV {% if saved_filter or selected_cats %}
{% endif %}
{% if grid_data %}
{% for marker in time_markers_mobile %}
{{ marker.label }}
{% endfor %}
{% for marker in time_markers %}
{{ marker.label }}
{% endfor %}
{% for row in grid_data %}
{% if row.channel.icon %} {% endif %} {{ row.channel.name }}
{% set row_idx = loop.index0 %}
{% if row.programs_mobile %} {% for prog in row.programs_mobile %}
{{ prog.title }}
{% endfor %} {% else %}
No info
{% endif %}
{% if row.programs %} {% for prog in row.programs %}
{{ prog.title }}
{{ prog.desc }}
{% endfor %} {% else %}
No program info
{% endif %}
{% endfor %}
{% else %}

No channels selected

Go to Settings to configure Live TV category filter

{% endif %}
{% endblock %} {% block scripts %} {% endblock %} ================================================ FILE: templates/login.html ================================================ Login - neTV

neTV

{% if error %}
Invalid username or password
{% endif %}
================================================ FILE: templates/movie_detail.html ================================================ {% extends "base.html" %} {% block title %}{{ movie.name if movie else 'Movie' }} - neTV{% endblock %} {% block content %}
{% if movie %}
{% if movie.cover_big or movie.stream_icon %} {% endif %}

{{ movie.name }}

{% if movie.year %}{{ movie.year }}{% endif %} {% if movie.rating %}★ {{ movie.rating }}/10{% endif %} {% if movie.duration %}{{ movie.duration }}{% endif %}
{% if movie.genre %}
{{ movie.genre }}
{% endif %} {% if movie.plot %}

{{ movie.plot }}

{% endif %} {% if movie.director %}

Director: {{ movie.director }}

{% endif %} {% if movie.cast %}

Cast: {{ movie.cast }}

{% endif %}
▶ Play {% if movie.youtube_trailer %} Trailer {% endif %}
{% else %}

Movie not found

{% endif %}
{% endblock %} {% block scripts %} {% endblock %} ================================================ FILE: templates/player.html ================================================ {% extends "base.html" %} {% block title %}{% if channel_name %}{{ channel_name }}{% if program_title %} — {{ program_title }}{% endif %}{% else %}Now Playing{% endif %}{% endblock %} {% block head_extra %} {% endblock %} {% block content %}
{% if channel_name %}{{ channel_name }}{% if program_title %} — {{ program_title }}{% endif %}{% else %}Now Playing{% endif %}
{% if program_desc %}
{{ program_desc }}
{% endif %}
0:00 0:00
{% if stream_type in ['movie', 'series'] %} {% endif %} {% if stream_type == 'series' and next_episode_url %} {% endif %}
{% if request.url.scheme == 'https' %} {% endif %}
{% endblock %} {% block scripts %} {% endblock %} ================================================ FILE: templates/search.html ================================================ {% extends "base.html" %} {% block title %}Search - neTV{% endblock %} {% block content %}

Search

{% if query %} {% if results.live or results.vod or results.series %} {% endif %} {% if results.live %}

Live Channels ({{ results.live|length }}{% if limit and results.live|length == limit %}+{% endif %})

{% for stream in results.live %} {% if stream.stream_icon %} {% endif %} {{ stream.name }} {% endfor %}
{% endif %} {% if results.vod %}

Movies ({{ results.vod|length }}{% if limit and results.vod|length == limit %}+{% endif %})

{% for movie in results.vod %} {% endfor %}
{% endif %} {% if results.series %}

Series ({{ results.series|length }}{% if limit and results.series|length == limit %}+{% endif %})

{% for s in results.series %} {% endfor %}
{% endif %} {% if not results.live and not results.vod and not results.series %}

No results found for "{{ query }}"

{% endif %} {% endif %}
{% endblock %} {% block scripts %} {% endblock %} ================================================ FILE: templates/series.html ================================================ {% extends "base.html" %} {% block title %}Series - neTV{% endblock %} {% block head_extra %} {% if loading %} {% endif %} {% endblock %} {% block content %} {% if loading %}

Loading series...

{% else %}

Series

Search {% if current_sort %} Favorites {% else %} Browse All {% endif %}
{% if current_sort %}
{% else %}

★ My Favorites

{% endif %}
{% endif %} {% endblock %} {% block scripts %} {% endblock %} ================================================ FILE: templates/series_detail.html ================================================ {% extends "base.html" %} {% block title %}{{ series.info.name if series.info else 'Series' }} - neTV{% endblock %} {% block content %}
{% if series.info %}
{% if series.info.cover %} {% endif %}

{{ series.info.name }}

{% if series.info.year %}{{ series.info.year }}{% endif %} {% if series.info.rating %}★ {{ series.info.rating }}/10{% endif %} {% if series.info.episode_run_time %}{{ series.info.episode_run_time }} min/ep{% endif %}
{% if series.info.genre %}
{{ series.info.genre }}
{% endif %} {% if series.info.plot %}

{{ series.info.plot }}

{% endif %} {% if series.info.director %}

Director: {{ series.info.director }}

{% endif %} {% if series.info.cast %}

Cast: {{ series.info.cast }}

{% endif %}
{% endif %}
{% if series.episodes %} {% set ns = namespace(first_ep=true) %} {% for season_num, episodes in series.episodes.items() %}

Season {{ season_num }}

{% for ep in episodes %}
E{{ ep.episode_num }}
{{ ep.title }}
{% endfor %}
{% endfor %} {% else %}

No episodes available

{% endif %}
{% endblock %} {% block scripts %} {% endblock %} ================================================ FILE: templates/settings.html ================================================ {% extends "base.html" %} {% block title %}Settings - neTV{% endblock %} {% block head_extra %} {% endblock %} {% block content %}

Settings

Client Settings

Live TV Filter

Select and order which channel categories appear in the Live TV guide

Available
Unavailable
{% for cat in live_categories %} {% set src_name = source_names.get(cat.source_id, '') %}
{% if src_name %}[{{ src_name }}] {% endif %}{{ cat.category_name }}
{% endfor %}

Movies Filter

Select which movie categories appear on the Movies page

Available
Unavailable
{% for cat in vod_categories %} {% set src_name = source_names.get(cat.source_id, '') %}
{% if src_name %}[{{ src_name }}] {% endif %}{{ cat.category_name }}
{% endfor %}

Series Filter

Select which series categories appear on the Series page

Available
Unavailable
{% for cat in series_categories %} {% set src_name = source_names.get(cat.source_id, '') %}
{% if src_name %}[{{ src_name }}] {% endif %}{{ cat.category_name }}
{% endfor %}

Guide

When enabled, only visible rows are loaded as you scroll. Disable for smoother scrolling with fewer channels.

Closed Captions

Sample Caption Text

Also: chrome://settings/captions

{% if is_admin %}

Server Settings

{% endif %}

{% if is_admin %}Users{% else %}Account{% endif %}

{% for u in all_users %} {% if is_admin or u.username == current_user %}
{{ u.username }} {% if u.username == current_user %}(you){% endif %} {% if u.admin %}admin{% endif %}
Expand
{% if is_admin and all_users|length > 1 %} {% endif %}
Min 8 characters
{% if is_admin %}
{% for src in sources %} {% if src.type != 'epg' %}
{{ src.name }}
{% endif %} {% endfor %}
Available
{% for grp in all_groups %} {% if grp.id not in (u.unavailable_groups or []) %}
{{ grp.name }}
{% endif %} {% endfor %}
Unavailable
{% for grp in all_groups %} {% if grp.id in (u.unavailable_groups or []) %}
{{ grp.name }}
{% endif %} {% endfor %}
{% endif %}
{% if u.username == current_user %} {% elif is_admin %}
{% endif %}
{% endif %} {% endfor %} {% if is_admin %}
Add User
Expand
{% for src in sources %} {% if src.type != 'epg' %}
{{ src.name }}
{% endif %} {% endfor %}
Available
{% for grp in all_groups %}
{{ grp.name }}
{% endfor %}
Unavailable
{% endif %}
{% if is_admin %}

Transcoding

Auto transcodes when browser can't play the format natively

NVENC: discrete Nvidia GPU encoder; AMF: discrete AMD GPU encoder; VAAPI: CPU-integrated-GPU decoder (and/or encoder); Software: CPU decoder (and/or encoder)

Limit output resolution (lower = faster, more compatible)

Maximum quality ceiling (won't use more bits than source provides)

{% if sr_available %}AI upscaling via TensorRT (requires NVIDIA GPU). Upscales to max resolution then scales down.{% else %}Not available - run tools/install-ai_upscale.sh to install{% endif %}

Keep transcoded VOD for reuse (0 = disabled)

Keep dead session for reconnect (0 = immediate cleanup)

Allow seeking back in live streams (0 = disabled, ~30s buffer)

User-Agent

HTTP User-Agent header sent when fetching streams. Only affects transcoding; passthrough uses the client's native user-agent.

Data Cache

Directory for HLS segments. Empty = system temp. Use a fast SSD for best performance.

Clear cached channel lists, movies, and series data. Use this after changing source settings or if content access isn't working correctly.

Probe Cache

Probing adds a few seconds delay on first play but enables hardware decode pipeline. Results are cached.

Cached probe results avoid re-probing episodes in the same series.

Loading...

Sources

{% for source in sources %}
{{ source.name }} {{ source.type }}
{{ source.url }}
Expand
1-3600s (default: 120)
Times to auto-refresh (comma-separated HH:MM)
Auto-detected on first refresh, or set manually
Enable for OTA/HDHomeRun, disable for IPTV
0 = unlimited
{% if source.type == 'xtream' %} {% elif source.type == 'm3u' %} {% elif source.type == 'epg' %} {% endif %}
{% endfor %}
Add Source
Expand
Enable for OTA/HDHomeRun, disable for IPTV
0 = unlimited
1-3600s
HH:MM times
{% endif %}
{% endblock %} {% block scripts %} {% endblock %} ================================================ FILE: templates/setup.html ================================================ Setup - neTV

neTV

Create your admin account

{% if error %}
{{ error }}
{% endif %}
================================================ FILE: templates/vod.html ================================================ {% extends "base.html" %} {% block title %}Movies - neTV{% endblock %} {% block head_extra %} {% if loading %} {% endif %} {% endblock %} {% block content %} {% if loading %}

Loading movies...

{% else %}

Movies

Search {% if current_sort %} Favorites {% else %} Browse All {% endif %}
{% if current_sort %}
{% else %}

★ My Favorites

{% endif %}
{% endif %} {% endblock %} {% block scripts %} {% endblock %} ================================================ FILE: testing.py ================================================ """Test utilities.""" import sys import warnings # Suppress unawaited coroutine warnings from AsyncMock in tests. warnings.filterwarnings("ignore", message="coroutine.*was never awaited") def run_tests(test_file: str) -> None: """Run pytest on a test file with standard flags. Usage: if __name__ == "__main__": from testing import run_tests run_tests(__file__) """ import pytest sys.exit( pytest.main( [ test_file, "-v", "-s", "-W", "ignore::pytest.PytestAssertRewriteWarning", *sys.argv[1:], ] ) ) ================================================ FILE: tools/alignm3u.py ================================================ #!/usr/bin/env python3 # pyright: reportUnknownVariableType=false, reportUnknownMemberType=false, reportUnknownArgumentType=false """alignm3u.py -- Align HDHomeRun M3U with XMLTV guide data. Takes an M3U playlist from HDHomeRun and aligns channel IDs with an XMLTV guide file (e.g., from zap2xml.py). Outputs a new M3U with tvg-id attributes set for EPG matching. Usage: # First, fetch your HDHomeRun lineup and generate XMLTV: wget http://YOUR_HDHR_IP/lineup.m3u -O lineup.m3u ./zap2xml.py --zip YOUR_ZIP # Then align them: ./alignm3u.py --input lineup.m3u --xmltv xmltv.xml --output ota.m3u # Optionally specify the URL where xmltv.xml will be served: ./alignm3u.py --input lineup.m3u --xmltv xmltv.xml --output ota.m3u \\ --xmltv-url http://your-server/xmltv.xml """ from __future__ import annotations import argparse import collections import pathlib import re import xml.etree.ElementTree as ET # https://en.wikipedia.org/wiki/Call_signs_in_the_United_States#Suffixes _CALLSIGN_REGEX = re.compile(r"^([A-Z]+?)(LD|DT|CD|CA|LP|TV|FM|D)(\d*)$") def parse_callsign(coded_callsign: str) -> tuple[str, str, int]: """Parse FCC callsign into (call, suffix, number).""" result = _CALLSIGN_REGEX.search(coded_callsign.upper()) if not result: return coded_callsign, "", 1 call, suffix, num = result.groups() if call == "KQS" and suffix == "LD": call, suffix = "KQSL", "LD" # Known bug in some data return call, suffix, int(num) if num else 1 def parse_m3u(path: pathlib.Path) -> list[list]: """Parse M3U file into list of [title, attrs, url].""" with open(path) as f: first_line = f.readline().strip() if not first_line.startswith("#EXTM3U"): raise ValueError(f"Invalid M3U file: {path}") entries = [] for line in f: line = line.strip() if not line: continue if not line.startswith("#EXTINF:"): if entries: entries[-1].append(line) continue attrs_str, title = line.split("#EXTINF:")[1].split(",", 1) attrs_str = attrs_str.split("-1 ", 1)[1] if "-1 " in attrs_str else attrs_str attrs_list = re.findall(r'(?:[^\s,"]|"(?:\\.|[^"])*")+', attrs_str) attrs = dict(s.replace('"', "").split("=", 1) for s in attrs_list if "=" in s) entries.append([title.strip(), attrs]) return entries def parse_xmltv_channels(path: pathlib.Path) -> dict[str, tuple[str, ...]]: """Parse XMLTV file and return {channel_id: (display_names...)}.""" channels = {} for elem in ET.parse(path).getroot(): if elem.tag == "channel": channel_id = elem.get("id") names = tuple(v.text for v in elem if v.tag == "display-name" and v.text) if channel_id: channels[channel_id] = names return channels def build_lookup(xmltv_channels: dict[str, tuple[str, ...]]) -> dict[str, set[str]]: """Build lookup from channel number/name to channel IDs.""" lookup: dict[str, set[str]] = collections.defaultdict(set) for channel_id, names in xmltv_channels.items(): for name in names: lookup[name].add(channel_id) return lookup def align_channels( m3u: list[list], lookup: dict[str, set[str]], ) -> tuple[list[list], list[list]]: """Align M3U entries with XMLTV channel IDs. Returns (aligned, missing).""" missing = [] for entry in m3u: if len(entry) < 3: continue title, attrs, _ = entry chan_num = attrs.get("channel-number", "") chan_name = attrs.get("tvg-name", title) # Normalize channel number (some have leading digit for ATSC3) if chan_num and float(chan_num) > 100: chan_num = str(float(chan_num[1:])) candidates_num = tuple(lookup.get(chan_num, ())) candidates_name = tuple(lookup.get(chan_name, ())) # Priority: exact match on number > exact match on name > any match if len(candidates_num) == 1: attrs["tvg-id"] = candidates_num[0] elif len(candidates_name) == 1: attrs["tvg-id"] = candidates_name[0] elif candidates_num: attrs["tvg-id"] = candidates_num[0] elif candidates_name: attrs["tvg-id"] = candidates_name[0] else: missing.append(entry) return m3u, missing def write_m3u( m3u: list[list], output: pathlib.Path, xmltv_url: str, group_prefix: str = "OTA", ) -> None: """Write aligned M3U file.""" with open(output, "w") as f: if xmltv_url: print(f'#EXTM3U url-tvg="{xmltv_url}" x-tvg-url="{xmltv_url}"', file=f) else: print("#EXTM3U", file=f) for entry in m3u: if len(entry) < 3: continue title, attrs, url = entry # Use tvg-name as title if available title = attrs.get("tvg-name", title) # Build group-title groups = [group_prefix] if "group-title" in attrs: groups.append(attrs["group-title"]) # Mark ATSC3 channels if "channel-id" in attrs and float(attrs["channel-id"]) >= 100: groups.append("ATSC3") attrs["group-title"] = " | ".join(groups) # Format attributes attrs_str = " ".join(f'{k}="{v}"' for k, v in attrs.items()) print(f"#EXTINF:-1 {attrs_str},{title}", file=f) print(url, file=f) def main() -> None: parser = argparse.ArgumentParser( description="Align HDHomeRun M3U with XMLTV guide data.", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Example: wget http://192.168.1.100/lineup.m3u -O lineup.m3u ./zap2xml.py --zip 90210 ./alignm3u.py --input lineup.m3u --xmltv xmltv.xml --output ota.m3u """, ) parser.add_argument( "--input", "-i", type=pathlib.Path, required=True, help="Input M3U file from HDHomeRun", ) parser.add_argument( "--xmltv", "-x", type=pathlib.Path, required=True, help="XMLTV guide file (e.g., from zap2xml.py)", ) parser.add_argument( "--output", "-o", type=pathlib.Path, required=True, help="Output M3U file with aligned tvg-id", ) parser.add_argument( "--xmltv-url", type=str, default="", help="URL where XMLTV file will be served (for M3U header)", ) parser.add_argument( "--group", type=str, default="OTA", help="Group prefix for channels (default: OTA)", ) args = parser.parse_args() # Parse inputs print(f"Reading M3U: {args.input}") m3u = parse_m3u(args.input) print(f" Found {len(m3u)} channels") print(f"Reading XMLTV: {args.xmltv}") xmltv_channels = parse_xmltv_channels(args.xmltv) print(f" Found {len(xmltv_channels)} channels") # Build lookup and align lookup = build_lookup(xmltv_channels) m3u, missing = align_channels(m3u, lookup) if missing: print(f"\nUnable to align {len(missing)} channels:") for entry in missing: title, attrs = entry[0], entry[1] num = attrs.get("channel-number", "?") print(f" {num}: {title}") # Write output print(f"\nWriting: {args.output}") write_m3u(m3u, args.output, args.xmltv_url, args.group) aligned = len(m3u) - len(missing) print(f" Aligned {aligned}/{len(m3u)} channels") if __name__ == "__main__": main() ================================================ FILE: tools/export-tensorrt.py ================================================ #!/usr/bin/env python3 """Export upscaling models to TensorRT engines for FFmpeg dnn_processing filter. This script converts AI upscaling models to TensorRT engines (.engine files) that can be loaded by FFmpeg's TensorRT DNN backend. Available models (use --list to see all): 2x models (1080p → 4K): - 2x-liveaction-span Best for live action TV/film 4x models (720p → 4K, 480p → 1080p): - 4x-compact Fast, good quality (SRVGGNetCompact) Usage: # List available models python export-tensorrt.py --list # Export 2x model for live action python export-tensorrt.py --model 2x-liveaction-span -o model.engine # Export with custom height range python export-tensorrt.py --model 2x-liveaction-span --min-height 720 --max-height 1080 Requirements: pip install torch onnx tensorrt Example FFmpeg usage after export: ffmpeg -i input.mp4 -vf "dnn_processing=dnn_backend=tensorrt:model=model.engine" output.mp4 """ from __future__ import annotations from pathlib import Path from typing import TYPE_CHECKING, NotRequired, TypedDict if TYPE_CHECKING: import tensorrt as trt import argparse import sys import tempfile import urllib.request import tensorrt as trt import torch import torch.nn as nn import torch.nn.functional as F class SRVGGNetCompact(nn.Module): """Compact SR network - fast inference, good quality.""" upscale: int body: nn.ModuleList upsampler: nn.PixelShuffle def __init__( self, num_in_ch: int = 3, num_out_ch: int = 3, num_feat: int = 64, num_conv: int = 32, upscale: int = 4, ): super().__init__() self.upscale = upscale self.body = nn.ModuleList() self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)) self.body.append(nn.PReLU(num_parameters=num_feat)) for _ in range(num_conv - 2): self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1)) self.body.append(nn.PReLU(num_parameters=num_feat)) self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1)) self.upsampler = nn.PixelShuffle(upscale) def forward(self, x: torch.Tensor) -> torch.Tensor: out = x for layer in self.body[:-1]: out = layer(out) out = self.body[-1](out) out = self.upsampler(out) return out + F.interpolate(x, scale_factor=self.upscale, mode="nearest") class ResidualDenseBlock(nn.Module): """Residual Dense Block for RRDBNet.""" conv1: nn.Conv2d conv2: nn.Conv2d conv3: nn.Conv2d conv4: nn.Conv2d conv5: nn.Conv2d lrelu: nn.LeakyReLU def __init__(self, nf: int = 64, gc: int = 32): super().__init__() self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1) self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1) self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1) self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1) self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x: torch.Tensor) -> torch.Tensor: x1 = self.lrelu(self.conv1(x)) x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) return x5 * 0.2 + x class RRDB(nn.Module): """Residual in Residual Dense Block.""" rdb1: ResidualDenseBlock rdb2: ResidualDenseBlock rdb3: ResidualDenseBlock def __init__(self, nf: int, gc: int = 32): super().__init__() self.rdb1 = ResidualDenseBlock(nf, gc) self.rdb2 = ResidualDenseBlock(nf, gc) self.rdb3 = ResidualDenseBlock(nf, gc) def forward(self, x: torch.Tensor) -> torch.Tensor: out = self.rdb1(x) out = self.rdb2(out) out = self.rdb3(out) return out * 0.2 + x class RRDBNet(nn.Module): """RRDBNet architecture for Real-ESRGAN - highest quality, slower.""" scale: int conv_first: nn.Conv2d body: nn.Sequential conv_body: nn.Conv2d conv_up1: nn.Conv2d conv_up2: nn.Conv2d conv_hr: nn.Conv2d conv_last: nn.Conv2d lrelu: nn.LeakyReLU def __init__( self, num_in_ch: int = 3, num_out_ch: int = 3, num_feat: int = 64, num_block: int = 23, num_grow_ch: int = 32, scale: int = 4, ): super().__init__() self.scale = scale self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) self.body = nn.Sequential(*[RRDB(num_feat, num_grow_ch) for _ in range(num_block)]) self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x: torch.Tensor) -> torch.Tensor: feat = self.conv_first(x) body_feat = self.conv_body(self.body(feat)) feat = feat + body_feat feat = self.lrelu( self.conv_up1( F.interpolate( feat, scale_factor=2, mode="nearest", ) ) ) feat = self.lrelu( self.conv_up2( F.interpolate( feat, scale_factor=2, mode="nearest", ) ) ) out = self.conv_last(self.lrelu(self.conv_hr(feat))) return out class ModelInfo(TypedDict): """Type definition for model registry entries.""" description: str filename: str scale: int arch: str url: NotRequired[str] onnx_url: NotRequired[str] MODELS: dict[str, ModelInfo] = { # 2x models - high quality, 1080p → 4K "2x-liveaction-span": { "description": "Live action TV/film - handles compression, preserves grain", "onnx_url": "https://github.com/jcj83429/upscaling/raw/f73a3a02874360ec6ced18f8bdd8e43b5d7bba57/2xLiveActionV1_SPAN/2xLiveActionV1_SPAN_490000.onnx", "filename": "2xLiveActionV1_SPAN.onnx", "scale": 2, "arch": "span", }, # 4x models - 720p → 4K or 480p → 1080p "4x-compact": { "description": "Fast 4x upscale - SRVGGNetCompact", "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", "filename": "realesr-general-x4v3.pth", "scale": 4, "arch": "compact", }, # 4x-realesrgan - not recommended (overly smooths faces) "4x-realesrgan": { "description": "RealESRGAN 4x - smooths faces (not recommended)", "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", "filename": "RealESRGAN_x4plus.pth", "scale": 4, "arch": "rrdbnet", }, # NOTE: 4x-rrdbnet was removed because: # - 1080p engine build fails with OOM even on 32GB VRAM (RTX 5090) # - 720p engine causes "Invalid frame dimensions 0x0" errors during playback # - Same weights as 4x-realesrgan but different name } def resolve_model(model_name: str) -> tuple[str, ModelInfo]: """Resolve model name.""" info = MODELS.get(model_name) if info is None: raise ValueError(f"Unknown model: {model_name}") return model_name, info def download_model(model_name: str, cache_dir: Path) -> Path: """Download model weights (ONNX or PTH).""" model_name, info = resolve_model(model_name) # Use .name to prevent path traversal path = cache_dir / Path(info["filename"]).name if path.exists(): print(f"Using cached model: {path}") return path url = info.get("onnx_url") or info.get("url") if url is None: raise ValueError(f"No URL for model: {model_name}") if not url.startswith("https://"): raise ValueError(f"URL must use HTTPS: {url}") print(f"Downloading {info['filename']}...") # Download to a temp file first, then rename to avoid partial downloads temp_path = path.with_suffix(path.suffix + ".tmp") try: with ( urllib.request.urlopen(url, timeout=300) as response, open(temp_path, "wb") as f, ): f.write(response.read()) # Verify the download succeeded and file is not empty file_size = temp_path.stat().st_size if file_size == 0: raise RuntimeError(f"Downloaded file is empty: {temp_path}") temp_path.rename(path) print(f"Downloaded to {path} ({file_size / 1024 / 1024:.1f} MB)") except Exception as e: # Clean up partial download if temp_path.exists(): temp_path.unlink() raise RuntimeError(f"Failed to download model from {url}: {e}") from e return path def list_models() -> None: """Print available models.""" print("\nAvailable models:\n") print(" 2x models (1080p → 4K):") for name, info in MODELS.items(): if name.startswith("2x-"): rec = " (recommended)" if name == "2x-liveaction-span" else "" print(f" {name:24s} {info['description']}{rec}") print("\n 4x models (720p → 4K):") for name, info in MODELS.items(): if name.startswith("4x-"): print(f" {name:24s} {info['description']}") print() def get_model_and_onnx( model_name: str, cache_dir: Path | None = None, ) -> tuple[ nn.Module | None, Path | None, int, ]: """Load model and return (model_or_none, onnx_path_or_none, scale). For ONNX-based models, returns (None, onnx_path, scale). For PTH-based models, returns (model, None, scale). """ if cache_dir is None: cache_dir = Path.home() / ".cache" / "ai_upscale" cache_dir.mkdir(parents=True, exist_ok=True) model_name, info = resolve_model(model_name) scale = info["scale"] arch = info["arch"] model_path = download_model(model_name, cache_dir) # ONNX-based models - no PyTorch loading needed if "onnx_url" in info: print(f"Using ONNX model directly: {model_path}") print(f" Architecture: {arch}, Scale: {scale}x") return None, model_path, scale # PTH-based models - load PyTorch print(f"Loading PyTorch model from {model_path}") state_dict = torch.load(model_path, map_location="cpu", weights_only=True) if "params_ema" in state_dict: state_dict = state_dict["params_ema"] elif "params" in state_dict: state_dict = state_dict["params"] # Instantiate model based on explicit architecture if arch == "rrdbnet": model: nn.Module = RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale, ) arch_name = "RRDBNet" elif arch == "compact": # Count conv layers to determine num_conv for SRVGGNetCompact num_conv_layers = sum( 1 for k, v in state_dict.items() if "weight" in k and len(v.shape) == 4 ) model = SRVGGNetCompact( num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv_layers, upscale=scale, ) arch_name = "SRVGGNetCompact" else: raise ValueError(f"Unknown architecture: {arch}") model.load_state_dict(state_dict) model.eval() params = sum(p.numel() for p in model.parameters()) / 1e6 print(f" Loaded {arch_name} ({params:.2f}M params), Scale: {scale}x") return model, None, scale def export_onnx(model: nn.Module, opt_shape: tuple[int, int], onnx_path: Path | str) -> None: """Export model to ONNX format with dynamic axes.""" opt_w, opt_h = opt_shape print(f"Exporting to ONNX: {onnx_path}") print(f" Optimal shape: 1x3x{opt_h}x{opt_w}") dummy_input = torch.randn(1, 3, opt_h, opt_w, device="cpu") dynamic_axes = { "input": { 2: "height", 3: "width", }, "output": { 2: "out_height", 3: "out_width", }, } torch.onnx.export( model, (dummy_input,), onnx_path, input_names=["input"], output_names=["output"], opset_version=17, do_constant_folding=True, dynamic_axes=dynamic_axes, dynamo=False, ) print(" ONNX export complete (dynamic H/W)") def _get_trt_dtype_map() -> dict[str, trt.DataType]: """Get mapping from precision string to TensorRT DataType.""" dtype_map: dict[str, trt.DataType] = { "fp32": trt.float32, "fp16": trt.float16, } if hasattr(trt, "bfloat16"): dtype_map["bf16"] = trt.bfloat16 return dtype_map def _trt_dtype_str(dtype: trt.DataType) -> str: """Convert TensorRT DataType to human-readable string.""" for name, dt in _get_trt_dtype_map().items(): if dtype == dt: return name.upper() return str(dtype) def build_engine( onnx_path: Path | str, engine_path: Path | str, min_shape: tuple[int, int], opt_shape: tuple[int, int], max_shape: tuple[int, int], precision: str = "fp16", workspace_gb: int = 4, opt_level: int = 3, ) -> None: """Build TensorRT engine from ONNX model with dynamic shapes.""" min_w, min_h = min_shape opt_w, opt_h = opt_shape max_w, max_h = max_shape print(f"Building TensorRT engine: {engine_path}") print(" Dynamic shapes:") print(f" min: {min_w}x{min_h}") print(f" opt: {opt_w}x{opt_h}") print(f" max: {max_w}x{max_h}") print(f" Precision: {precision}") print(f" Workspace: {workspace_gb} GB") logger = trt.Logger(trt.Logger.INFO) builder = trt.Builder(logger) network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) parser = trt.OnnxParser(network, logger) with open(onnx_path, "rb") as f: if not parser.parse(f.read()): for i in range(parser.num_errors): print(f" ONNX parse error: {parser.get_error(i)}") raise RuntimeError("Failed to parse ONNX model") config = builder.create_builder_config() config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_gb * (1 << 30)) # Optimization level (0-5, default is 3) # Higher levels enable more aggressive kernel selection/fusion but use more memory config.builder_optimization_level = opt_level print(f" Optimization level: {opt_level}") profile = builder.create_optimization_profile() input_name = network.get_input(0).name profile.set_shape( input_name, min=(1, 3, min_h, min_w), opt=(1, 3, opt_h, opt_w), max=(1, 3, max_h, max_w), ) config.add_optimization_profile(profile) # Set compute precision if precision in ("fp16", "bf16"): if builder.platform_has_fast_fp16: config.set_flag(trt.BuilderFlag.FP16) else: print(" Warning: FP16/BF16 not supported on this platform, using FP32") precision = "fp32" if precision == "bf16": if hasattr(trt.BuilderFlag, "BF16"): config.set_flag(trt.BuilderFlag.BF16) else: print(" Warning: BF16 not supported by TensorRT, using FP16") precision = "fp16" # Set I/O tensor precision (matches compute precision) dtype_map = _get_trt_dtype_map() if precision not in dtype_map: raise ValueError(f"Unknown precision: {precision}") io_dtype = dtype_map[precision] if io_dtype != trt.float32: for i in range(network.num_inputs): network.get_input(i).dtype = io_dtype for i in range(network.num_outputs): network.get_output(i).dtype = io_dtype print(" Building engine (this may take several minutes)...") serialized_engine = builder.build_serialized_network(network, config) if serialized_engine is None: raise RuntimeError("Failed to build TensorRT engine") with open(engine_path, "wb") as f: f.write(serialized_engine) print( f" Engine saved: {engine_path} ({Path(engine_path).stat().st_size / 1024 / 1024:.1f} MB)" ) # Verify the built engine has correct I/O types runtime = trt.Runtime(logger) engine = runtime.deserialize_cuda_engine(serialized_engine) print(" Verifying engine I/O:") for i in range(engine.num_io_tensors): name = engine.get_tensor_name(i) dtype = engine.get_tensor_dtype(name) mode = engine.get_tensor_mode(name) dtype_str = _trt_dtype_str(dtype) print(f" {name}: {dtype_str} ({mode})") if dtype != io_dtype: print(f" WARNING: {name} is {dtype_str} but {_trt_dtype_str(io_dtype)} was requested!") def height_to_shape(h: int, aspect: float = 16 / 9) -> tuple[int, int]: """Convert height to (width, height) assuming aspect ratio. Both width and height are aligned to 8 pixels, as required by many neural network architectures with pooling/striding layers. """ # Align height to 8 first h = (h + 7) // 8 * 8 # Calculate width from aligned height w = int(h * aspect) # Align width to 8 w = (w + 7) // 8 * 8 return (w, h) def main() -> None: parser = argparse.ArgumentParser( description="Export AI upscaling models to TensorRT engines", formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument( "--model", "-m", type=str, default="4x-compact", help="Model name (use --list to see available models)", ) parser.add_argument("--list", "-l", action="store_true", help="List available models") parser.add_argument( "--min-height", type=int, default=None, help="Minimum input height (default: auto)", ) parser.add_argument( "--opt-height", type=int, default=None, help="Optimal input height (default: auto)", ) parser.add_argument( "--max-height", type=int, default=None, help="Maximum input height (default: auto)", ) parser.add_argument( "--output", "-o", type=str, default=None, help="Output engine path", ) parser.add_argument( "--precision", "-p", type=str, default="fp16", choices=["fp16", "bf16", "fp32"], help="Model precision for compute and I/O tensors (default: fp16)", ) parser.add_argument( "--workspace", type=int, default=8, help="TensorRT workspace size in GB (default: 8)", ) parser.add_argument( "--opt-level", type=int, default=3, choices=[0, 1, 2, 3, 4, 5], help="TensorRT builder optimization level 0-5 (default: 3). Higher = more memory, potentially faster.", ) parser.add_argument( "--onnx-only", action="store_true", help="Only export ONNX, skip TensorRT engine build", ) args = parser.parse_args() if args.list: list_models() return # Get model info for defaults try: model_name, info = resolve_model(args.model) except ValueError as e: print(f"Error: {e}", file=sys.stderr) list_models() sys.exit(1) scale = info["scale"] # Set height defaults based on scale factor if scale == 2: # 2x: input 1080p -> output 4K default_min, default_opt, default_max = 720, 1080, 1080 elif scale == 4: # 4x: input 720p -> output 4K, or 480p -> 1080p default_min, default_opt, default_max = 480, 720, 1080 else: raise ValueError(f"Unsupported scale factor: {scale}") min_h = args.min_height or default_min opt_h = args.opt_height or default_opt max_h = args.max_height or default_max # Validate height constraints if min_h > max_h: raise ValueError(f"--min-height ({min_h}) cannot be greater than --max-height ({max_h})") if opt_h < min_h or opt_h > max_h: raise ValueError( f"--opt-height ({opt_h}) must be between --min-height ({min_h}) and --max-height ({max_h})" ) min_shape = height_to_shape(min_h) opt_shape = height_to_shape(opt_h) max_shape = height_to_shape(max_h) if args.output is None: args.output = f"{model_name}_{opt_h}p_{args.precision}.engine" print("=" * 60) print("AI Upscale: TensorRT Engine Export") print("=" * 60) print(f"Model: {model_name}") print(f" {info['description']}") print() model, existing_onnx, _ = get_model_and_onnx(args.model) # Determine ONNX path if existing_onnx: # Model already has ONNX - use it directly onnx_path = existing_onnx cleanup_onnx = False elif args.onnx_only: # Save ONNX to current directory with sensible name onnx_path = Path(f"{model_name}_{opt_h}p.onnx") cleanup_onnx = False else: # Temp file for intermediate ONNX with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmp: onnx_path = Path(tmp.name) cleanup_onnx = True try: # Export to ONNX if needed (PTH-based models only) if model is not None: export_onnx(model, opt_shape, onnx_path) if args.onnx_only: print(f"\nONNX saved to: {onnx_path}") print("Skipping TensorRT build (--onnx-only). Build later with:") print(f" trtexec --onnx={onnx_path} --saveEngine={args.output} --fp16") return build_engine( onnx_path, args.output, min_shape=min_shape, opt_shape=opt_shape, max_shape=max_shape, precision=args.precision, workspace_gb=args.workspace, opt_level=args.opt_level, ) finally: if cleanup_onnx and (onnx_file := Path(onnx_path)).exists(): onnx_file.unlink() print() print("=" * 60) print("Export complete!") print("=" * 60) print() print(f"Model: {model_name} ({scale}x upscale)") print(f"Engine accepts input heights from {min_h} to {max_h} (16:9)") print() print("Usage with FFmpeg:") print( f' ffmpeg -i input.mp4 -vf "dnn_processing=dnn_backend=tensorrt:model={args.output}" output.mp4' ) if __name__ == "__main__": main() ================================================ FILE: tools/install-ai_upscale.sh ================================================ #!/bin/bash # Build TensorRT engines for AI Upscale # # Prerequisites: uv sync --group ai_upscale # Or: pip install torch onnx tensorrt # # Models sourced from https://openmodeldb.info/ # set -e # Capture script directory (with error handling) SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" || { echo "ERROR: Failed to determine script directory" >&2 exit 1 } PROJECT_DIR="$(dirname "$SCRIPT_DIR")" MODEL_DIR="${MODEL_DIR:-$HOME/ffmpeg_build/models}" MODEL="${MODEL:-recommended}" PRECISION="${PRECISION:-fp16}" # Recursion guard to prevent fork bombs when calling ourselves MAX_RECURSION_DEPTH=10 RECURSION_DEPTH=${RECURSION_DEPTH:-0} if [ "$RECURSION_DEPTH" -ge "$MAX_RECURSION_DEPTH" ]; then echo "ERROR: Maximum recursion depth ($MAX_RECURSION_DEPTH) exceeded" >&2 exit 1 fi # Use uv run if in a uv project, otherwise plain python3 # Note: PYTHON_CMD is an array to handle paths with spaces correctly if [ -f "$PROJECT_DIR/pyproject.toml" ] && command -v uv >/dev/null 2>&1; then PYTHON_CMD=("uv" "run" "--project" "$PROJECT_DIR" "python3") else PYTHON_CMD=("python3") fi # Helper: run python command run_python() { "${PYTHON_CMD[@]}" "$@" } # Validate export-tensorrt.py exists EXPORT_SCRIPT="$SCRIPT_DIR/export-tensorrt.py" if [ ! -f "$EXPORT_SCRIPT" ]; then echo "ERROR: export-tensorrt.py not found at $EXPORT_SCRIPT" >&2 exit 1 fi # Show help if [ "$1" = "-h" ] || [ "$1" = "--help" ]; then echo "Usage: $0 [MODEL]" echo "" echo "Build TensorRT engines for AI Upscale." echo "" echo "Arguments:" echo " MODEL Model to build (default: $MODEL)" echo " 'recommended' - 4x-compact, 2x-liveaction-span" echo " 'all' - all models including 4x-realesrgan" echo "" echo "Environment:" echo " MODEL_DIR Output directory (default: \$HOME/ffmpeg_build/models)" echo " MODEL Model name (can also be passed as argument)" echo " PRECISION Model precision: fp16, bf16, fp32 (default: fp16)" echo "" echo "Available models:" run_python "$EXPORT_SCRIPT" --list exit 0 fi # Allow model to be passed as argument if [ -n "$1" ]; then MODEL="$1" fi # Handle "recommended" option - build recommended models if [ "$MODEL" = "recommended" ]; then echo "========================================" echo "AI Upscale: Building recommended models" echo "========================================" echo "" for m in 4x-compact 2x-liveaction-span; do echo ">>> Building $m..." # Increment recursion depth when calling ourselves RECURSION_DEPTH=$((RECURSION_DEPTH + 1)) MODEL="$m" "$0" echo "" done echo "Done! Recommended models built." exit 0 fi # Handle "all" option - build all available models if [ "$MODEL" = "all" ]; then echo "========================================" echo "AI Upscale: Building ALL models" echo "========================================" echo "" for m in 4x-compact 2x-liveaction-span 4x-realesrgan; do echo ">>> Building $m..." # Increment recursion depth when calling ourselves RECURSION_DEPTH=$((RECURSION_DEPTH + 1)) MODEL="$m" "$0" echo "" done echo "Done! All models built." exit 0 fi echo "========================================" echo "AI Upscale: TensorRT Engine Builder" echo "========================================" echo "Model: $MODEL" echo "Output: $MODEL_DIR/" echo "" # Check dependencies if ! run_python -c "import torch, onnx, tensorrt" 2>/dev/null; then echo "ERROR: Missing dependencies. Install with:" echo " uv sync --group ai_upscale" echo "Or:" echo " pip install torch onnx tensorrt" exit 1 fi # Create output directory with validation mkdir -p "$MODEL_DIR" || { echo "ERROR: Cannot create directory: $MODEL_DIR" >&2 exit 1 } if [ ! -w "$MODEL_DIR" ]; then echo "ERROR: No write permission for: $MODEL_DIR" >&2 exit 1 fi # Check disk space (engines are ~100-500MB each, need at least 2GB free) REQUIRED_SPACE_KB=$((2 * 1024 * 1024)) # 2GB in KB AVAILABLE_KB=$(df "$MODEL_DIR" 2>/dev/null | tail -1 | awk '{print $4}') if [ -n "$AVAILABLE_KB" ] && [ "$AVAILABLE_KB" -lt "$REQUIRED_SPACE_KB" ] 2>/dev/null; then echo "WARNING: Low disk space in $MODEL_DIR ($(( AVAILABLE_KB / 1024 ))MB available, recommend 2GB+)" >&2 fi # Input resolutions to build engines for (output can be downscaled as needed) RESOLUTIONS="480 720 1080" # Sanitize model name for safe filename (remove any path separators) # Done once before the loop since MODEL doesn't change during iteration SAFE_MODEL="${MODEL//\//_}" SAFE_MODEL="${SAFE_MODEL//\\/_}" # Build engines for common resolutions (FFmpeg TensorRT backend needs fixed shapes) echo "Building TensorRT engines for resolutions: $RESOLUTIONS" echo "" # Use word splitting intentionally here (RESOLUTIONS is space-separated) # shellcheck disable=SC2086 for res in $RESOLUTIONS; do engine="$MODEL_DIR/${SAFE_MODEL}_${res}p_${PRECISION}.engine" if [ -f "$engine" ]; then echo " ${res}p: already exists, skipping" else echo " ${res}p: building..." # Capture output to show errors if build fails if ! OUTPUT=$(run_python "$EXPORT_SCRIPT" \ --model "$MODEL" \ --precision "$PRECISION" \ --min-height "$res" --opt-height "$res" --max-height "$res" \ -o "$engine" 2>&1); then echo "ERROR building ${res}p engine:" >&2 echo "$OUTPUT" >&2 exit 1 fi # Show filtered progress on success echo "$OUTPUT" | grep -E "^(Downloading|Using cached|Loading|Using ONNX|Engine saved| )" || true # Verify engine was created if [ ! -f "$engine" ]; then echo "ERROR: Engine file not created: $engine" >&2 echo "Build output:" >&2 echo "$OUTPUT" >&2 exit 1 fi fi done echo "" echo "========================================" echo "Installation complete!" echo "========================================" echo "" echo "Engines built:" # Safe listing of engine files (handles filenames with special chars) find "$MODEL_DIR" -maxdepth 1 -name "${SAFE_MODEL}_*.engine" -type f -exec ls -lh {} \; 2>/dev/null | \ while IFS= read -r line; do size=$(echo "$line" | awk '{print $5}') file=$(echo "$line" | awk '{print $NF}') echo " $(basename "$file") ($size)" done echo "" echo "To use a different model, run:" echo " MODEL=2x-liveaction-span $0" echo " MODEL=4x-compact $0" echo "" echo "Test with:" echo " ffmpeg -init_hw_device cuda=cu -filter_hw_device cu \\" echo " -f lavfi -i testsrc=duration=3:size=1920x1080:rate=30 \\" echo " -vf \"format=rgb24,hwupload,dnn_processing=dnn_backend=8:model=$MODEL_DIR/${SAFE_MODEL}_1080p_${PRECISION}.engine\" \\" echo " -c:v hevc_nvenc test.mp4" ================================================ FILE: tools/install-ffmpeg.sh ================================================ #!/bin/bash # Build ffmpeg from source with hardware acceleration support # Supports: NVIDIA NVENC, AMD AMF, Intel QSV/VAAPI, LibTorch DNN, TensorRT DNN # https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu # # Usage Examples: # # # Default build (TensorRT enabled, requires compute_70+ GPU) # ./install-ffmpeg.sh # # # Maxwell GPU (compute_52, e.g. GTX 900/TITAN X) - use LibTorch 2.5 instead of TensorRT # ENABLE_LIBTORCH=1 ENABLE_TENSORRT=0 LIBTORCH_VERSION=2.5.0 LIBTORCH_VARIANT=cu121 ./install-ffmpeg.sh # # # CPU-only DNN inference (no GPU required) # ENABLE_LIBTORCH=1 ENABLE_TENSORRT=0 LIBTORCH_VARIANT=cpu ./install-ffmpeg.sh # # # Specific CUDA version # CUDA_VERSION=12.4 ./install-ffmpeg.sh # # # Skip dependency installation (already installed) # SKIP_DEPS=1 ./install-ffmpeg.sh # # # Both LibTorch and TensorRT (for testing/comparison) # ENABLE_LIBTORCH=1 ENABLE_TENSORRT=1 ./install-ffmpeg.sh # set -e # ============================================================================= # Potentially Viable Pre-built FFmpeg alternatives # # DOCKER IMAGES # LinuxServer docker-ffmpeg https://github.com/linuxserver/docker-ffmpeg # - Full hardware accel (NVENC, VAAPI, QSV) # - Comprehensive codec support, builds libva 2.23+ from source # - Used by Dispatcharr, basis for comparison with this script # # STATIC BINARIES (Linux) # BtbN/FFmpeg-Builds https://github.com/BtbN/FFmpeg-Builds # - Daily automated builds from git master and release branches # - GPL/LGPL/nonfree variants, static and shared options # - Targets glibc 2.28+ (RHEL 8 / Ubuntu 20.04+) # - CUDA support: sm_52+ (Maxwell and newer) # # John Van Sickle https://johnvansickle.com/ffmpeg/ # - Static builds for amd64, i686, armhf, arm64 # - GPL v3 licensed, targets kernel 3.2.0+ # - Note: static glibc = no DNS resolution (install nscd to fix) # # STATIC BINARIES (Windows) # gyan.dev https://www.gyan.dev/ffmpeg/builds/ # - Essentials build: common codecs (Win 7+) # - Full build: all codecs including bluray, opencl (Win 10+) # - Official FFmpeg download page recommendation # # SPECIALIZED BUILDS # Jellyfin-ffmpeg https://github.com/jellyfin/jellyfin-ffmpeg # - Modified FFmpeg with Jellyfin-specific patches # - Optimized for media server transcoding # - Ships with Jellyfin packages and Docker images # - Recommended only for Jellyfin; other apps should use standard builds # # ============================================================================= # ============================================================================= # FFmpeg library reference (checked 2026-01) # # Priority: high = essential for most workflows # med = useful for specific workflows # low = niche use cases # subsumed= functionality covered by another library we use # legacy = outdated, superseded by newer codecs # # Enable: src = built from source, apt = use apt package, - = not enabled # # Library | Build | Pri | Apt Ver | Latest | Description # -----------------|--------|----------|---------|---------|--------------------------- # VIDEO CODECS # libx264 | src | high | 0.164 | 0.165 | H.264/AVC encoder (8/10-bit) # libx265 | src | high | 3.5 | 4.1 | H.265/HEVC encoder (8/10/12-bit) # libsvtav1 | src | high | 1.7.0 | 3.0.2 | AV1 encoder (fast, scalable) # libaom | src | high | 3.8.2 | 3.13.1 | AV1 reference encoder/decoder # libdav1d | src | high | 1.4.1 | 1.5.3 | AV1 decoder (fastest) # libvpx | apt | high | 1.14.0 | 1.14.1 | VP8/VP9 encoder/decoder # libvvenc | - | low | - | 1.13.1 | H.266/VVC encoder (too early) # librav1e | - | subsumed | 0.7.1 | 0.8.1 | AV1 encoder - svtav1 faster # libkvazaar | - | subsumed | 2.3.1 | 2.3.2 | HEVC encoder - x265 better # libopenh264 | - | subsumed | 2.6.0 | 2.6.0 | H.264 (Cisco) - x264 better # libxvid | - | legacy | 1.3.7 | 1.3.7 | MPEG-4 Part 2 (obsolete) # libtheora | - | legacy | 1.2.0a1 | 1.2.0 | Theora codec (obsolete) # # IMAGE CODECS # libwebp | src | high | 1.3.2 | 1.6.0 | WebP image codec # libjxl | src | high | 0.7.0 | 0.11.1 | JPEG XL (next-gen, HDR) # libopenjpeg | - | low | 2.5.0 | 2.5.4 | JPEG 2000 (cinema/medical) # librsvg | - | low | 2.58.0 | 2.61.3 | SVG rasterization # libsnappy | - | low | 1.1.10 | 1.2.2 | Snappy compression (HAP codec) # # AUDIO CODECS # libfdk-aac | apt | high | 2.0.2 | 2.0.3 | AAC encoder (best quality) # libmp3lame | apt | high | 3.100 | 3.100 | MP3 encoder # libopus | apt | high | 1.5.2 | 1.6 | Opus encoder/decoder # libvorbis | apt | high | 1.3.7 | 1.3.7 | Vorbis encoder/decoder # librubberband | apt | med | 3.3.0 | 4.0.0 | Audio time-stretch/pitch-shift # liblc3 | - | low | 1.1.3 | 1.1.3 | LC3 Bluetooth audio codec # libopencore-amr | - | legacy | 0.1.6 | 0.1.6 | AMR-NB/WB (old mobile audio) # # SUBTITLE/TEXT # libass | apt | high | 0.17.3 | 0.17.4 | ASS/SSA subtitle renderer # libfreetype | apt | high | 2.13.3 | 2.14.1 | Font rendering # libfontconfig | apt | high | 2.15.0 | 2.17.0 | Font configuration # libfribidi | apt | med | 1.0.16 | 1.0.16 | BiDi text (RTL languages) # libharfbuzz | apt | med | 10.2.0 | 12.3.0 | Complex text shaping # # FILTERS/PROCESSING # libzimg | apt | high | 3.0.5 | 3.0.6 | High-quality image scaling # libsoxr | apt | high | 0.1.3 | 0.1.3 | High-quality audio resampling # libvmaf | src | med | 2.3.1 | 3.0.0 | Video quality metrics # libplacebo | src | med | 7.349.0 | 7.351.0 | GPU HDR tone mapping # libshaderc | src* | med | - | - | GLSL->SPIRV compiler (*via Vulkan SDK) # libvidstab | apt | med | 1.1.0 | 1.1.1 | Video stabilization # libmysofa | - | low | 1.3.3 | 1.3.3 | HRTF spatial audio (sofalizer) # libtesseract | - | low | 5.5.0 | 5.5.1 | OCR text extraction # opencl | apt | low | 2.3.3 | - | GPU compute filters # # HARDWARE ACCEL # libva | src | high | 2.20.0 | 2.23.0 | VA-API (Intel/AMD) - Xe support # libvpl | src | high | 2023.3 | 2.16.0 | Intel QuickSync Video # cuda-nvcc | src | high | - | - | NVIDIA CUDA compiler # nvenc | src | high | - | - | NVIDIA hardware encoder # cuvid | src | high | - | - | NVIDIA hardware decoder # vaapi | src | high | - | - | VA-API hwaccel # nvdec | src | med | - | - | NVIDIA hwaccel decode API # vulkan | src | med | - | - | Vulkan GPU compute # cuda-llvm | - | subsumed | - | - | CUDA via clang - we use nvcc # vdpau | - | legacy | 1.5 | 1.5 | NVIDIA VDPAU (use nvdec) # # PROTOCOLS/NETWORK # openssl | apt | high | 3.0.13 | 3.0.15 | TLS/HTTPS support # libsrt | apt | high | 1.5.3 | 1.5.4 | SRT streaming protocol # libssh | - | low | 0.10.6 | 0.11.1 | SFTP protocol # librist | - | low | 0.2.11 | 0.2.11 | RIST broadcast protocol # libzmq | - | low | 4.3.5 | 4.3.5 | ZeroMQ IPC messaging # libxml2 | - | low | 2.9.14 | 2.13.5 | XML/DASH manifest parsing # # INPUT/OUTPUT # libbluray | apt | med | 1.3.4 | 1.4.0 | Blu-ray disc reading # libv4l2 | - | low | 1.28.1 | 1.28.1 | V4L2 webcam/capture # alsa | - | low | 1.2.14 | 1.2.14 | Linux ALSA audio input # # META FLAGS # gpl | yes | high | - | - | Enable GPL-licensed code # version3 | yes | high | - | - | Enable (L)GPL v3 code # nonfree | yes | high | - | - | Enable non-free code (fdk-aac) # # ============================================================================= # Hardware acceleration (set to 1 to enable) ENABLE_NVIDIA_CUDA=${ENABLE_NVIDIA_CUDA:-1} # NVENC/NVDEC hardware encoding/decoding ENABLE_AMD_AMF=${ENABLE_AMD_AMF:-1} # AMD AMF hardware encoding (requires AMD GPU) ENABLE_LIBTORCH=${ENABLE_LIBTORCH:-0} # LibTorch DNN backend for AI filters (default off, prefer TensorRT) ENABLE_TENSORRT=${ENABLE_TENSORRT:-1} # TensorRT DNN backend for AI filters (fastest) # LibTorch CUDA variant (only used if ENABLE_LIBTORCH=1) # LIBTORCH_VARIANT options: # "cu126" - (default) CUDA 12.6 - compatible with LibTorch 2.7+ # Note: cu126 binaries work on CUDA 12.6+ runtimes (forward compatible) # "auto" - auto-detect from CUDA_VERSION, rounding minor to nearest even # (PyTorch only releases cu126, cu128, cu130 - even minor versions) # Examples: CUDA 12.9 -> cu128, CUDA 12.7 -> cu126, CUDA 13.x -> cu130 # "cpu" - CPU-only (no GPU acceleration for DNN filters) # "cu124" - CUDA 12.4 (for older LibTorch 2.5.x) # "cu126" - force CUDA 12.6 # "cu128" - force CUDA 12.8 # "cu130" - force CUDA 13.0 # "rocm6.4" - AMD ROCm 6.4 (requires ROCm installed on host) LIBTORCH_VARIANT=${LIBTORCH_VARIANT:-cu126} # Optional build components (set to 0 to use apt package instead) BUILD_LIBPLACEBO=${BUILD_LIBPLACEBO:-1} # GPU HDR tone mapping (requires Vulkan SDK) BUILD_LIBX265=${BUILD_LIBX265:-1} # H.265/HEVC encoder (apt: 3.5, latest: 4.1) BUILD_LIBAOM=${BUILD_LIBAOM:-1} # AV1 reference codec (apt: 3.8, latest: 3.13) BUILD_LIBWEBP=${BUILD_LIBWEBP:-1} # WebP image codec (apt: 1.3, latest: 1.6) BUILD_LIBVPL=${BUILD_LIBVPL:-1} # Intel QuickSync (apt: 2023.3, latest: 2.16) BUILD_LIBDAV1D=${BUILD_LIBDAV1D:-1} # AV1 decoder (apt: 1.4.1, latest: 1.5.0) BUILD_LIBSVTAV1=${BUILD_LIBSVTAV1:-1} # AV1 encoder (apt: 1.7.0, latest: 3.0.0) BUILD_LIBVMAF=${BUILD_LIBVMAF:-1} # Video quality metrics (apt: 2.3.1, latest: 3.0.0) BUILD_LIBVA=${BUILD_LIBVA:-1} # VA-API (apt: 2.20.0, latest: 2.23.0 - Xe support) BUILD_LIBJXL=${BUILD_LIBJXL:-1} # JPEG XL (apt: 0.7.0, latest: 0.11.1) BUILD_LIBX264=${BUILD_LIBX264:-1} # H.264 encoder (apt: 8-bit only, src: 8/10-bit) SVTAV1_GIT_REF=${SVTAV1_GIT_REF:-} # FFmpeg version: "snapshot" for latest git, or specific version like "7.1" FFMPEG_VERSION=${FFMPEG_VERSION:-snapshot} # Skip apt dependency installation (use if deps already installed, avoids sudo) SKIP_DEPS=${SKIP_DEPS:-0} PHASE=${PHASE:-all} # Noninteractive apt installs (prevents prompts) export DEBIAN_FRONTEND="${DEBIAN_FRONTEND:-noninteractive}" # Capture script directory before any cd commands (with error handling) SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" || { echo "ERROR: Failed to determine script directory" >&2 exit 1 } # Validate SCRIPT_DIR is not empty and exists if [ -z "$SCRIPT_DIR" ] || [ ! -d "$SCRIPT_DIR" ]; then echo "ERROR: Invalid script directory: $SCRIPT_DIR" >&2 exit 1 fi # ============================================================================= # Helper Functions # ============================================================================= # Log error message to stderr log_error() { echo "ERROR: $*" >&2 } # Log warning message to stderr log_warn() { echo "WARNING: $*" >&2 } # Log info message log_info() { echo "INFO: $*" } # Verify a patch was applied by checking for expected content verify_patch() { local file="$1" local pattern="$2" local description="$3" if ! grep -q "$pattern" "$file"; then log_error "Patch verification failed: $description" log_error "Expected pattern '$pattern' not found in $file" return 1 fi return 0 } # Clone a git repo with validation git_clone_validated() { local url="$1" local dir="$2" local depth="${3:-1}" local ref="${4:-}" if [ -n "$ref" ]; then git clone --depth "$depth" --branch "$ref" "$url" "$dir" else git clone --depth "$depth" "$url" "$dir" fi # Validate clone succeeded if [ ! -d "$dir/.git" ]; then log_error "Git clone failed: $url -> $dir" return 1 fi return 0 } # Update or clone a git repo git_update_or_clone() { local url="$1" local dir="$2" local depth="${3:-1}" local ref="${4:-}" if [ -d "$dir/.git" ]; then # Check for conflicts or errors during pull local pull_output if ! pull_output=$(git -C "$dir" pull 2>&1); then log_warn "Git pull failed for $dir: $pull_output" log_warn "Re-cloning..." safe_rm_rf "$dir" 2 || rm -rf "$dir" git_clone_validated "$url" "$dir" "$depth" "$ref" elif echo "$pull_output" | grep -qE "CONFLICT|fatal"; then log_warn "Git pull had conflicts for $dir, re-cloning..." safe_rm_rf "$dir" 2 || rm -rf "$dir" git_clone_validated "$url" "$dir" "$depth" "$ref" fi else # Remove if exists but not a git repo if [ -e "$dir" ]; then safe_rm_rf "$dir" 2 || rm -rf "$dir" fi git_clone_validated "$url" "$dir" "$depth" "$ref" fi } # Detect Ubuntu version with validation get_ubuntu_version() { local version_id if [ ! -f /etc/os-release ]; then log_error "/etc/os-release not found - cannot detect Ubuntu version" return 1 fi version_id=$(grep "^VERSION_ID=" /etc/os-release | cut -d'"' -f2) if [ -z "$version_id" ]; then log_error "Could not parse VERSION_ID from /etc/os-release" return 1 fi # Remove dots: "24.04" -> "2404" echo "$version_id" | tr -d '.' } # Safe rm -rf: validates path before deletion to prevent catastrophic mistakes safe_rm_rf() { local path="$1" local min_depth="${2:-2}" # Minimum path depth (default: 2 components) # Never delete empty paths if [ -z "$path" ]; then log_error "safe_rm_rf: empty path" return 1 fi # Never delete root or single-level paths local depth depth=$(echo "$path" | tr -cd '/' | wc -c) if [ "$depth" -lt "$min_depth" ]; then log_error "safe_rm_rf: path '$path' too shallow (depth $depth < $min_depth)" return 1 fi # Never delete common dangerous paths case "$path" in /|/bin|/boot|/dev|/etc|/home|/lib|/lib64|/media|/mnt|/opt|/proc|/root|/run|/sbin|/srv|/sys|/tmp|/usr|/var) log_error "safe_rm_rf: refusing to delete system path: $path" return 1 ;; esac # Path must exist to delete if [ ! -e "$path" ]; then return 0 # Nothing to delete fi rm -rf "$path" } # Validate CUDA version format (e.g., "12.4", "13.0") validate_cuda_version() { local version="$1" if ! [[ "$version" =~ ^[0-9]+\.[0-9]+$ ]]; then log_error "Invalid CUDA version format: '$version' (expected X.Y)" return 1 fi return 0 } # Validate torch variant format (e.g., "cu124", "cu130", "cpu") validate_torch_variant() { local variant="$1" if ! [[ "$variant" =~ ^(cu[0-9]+|rocm[0-9]+\.[0-9]+|cpu)$ ]]; then log_error "Invalid TORCH_VARIANT: '$variant' (expected cuXXX, rocmX.Y, or cpu)" return 1 fi return 0 } # Verify SHA256 checksum of a file verify_sha256() { local file="$1" local expected="$2" if [ ! -f "$file" ]; then log_error "File not found for checksum verification: $file" return 1 fi local actual actual=$(sha256sum "$file" | awk '{print $1}') if [ "$actual" != "$expected" ]; then log_error "SHA256 checksum mismatch for $file" log_error " Expected: $expected" log_error " Actual: $actual" return 1 fi log_info "SHA256 verified: $file" return 0 } # Download file with optional checksum verification download_file() { local url="$1" local output="$2" local sha256="${3:-}" # Optional checksum log_info "Downloading: $url" if ! wget -q -O "$output" "$url"; then log_error "Failed to download: $url" return 1 fi if [ -n "$sha256" ]; then if ! verify_sha256 "$output" "$sha256"; then rm -f "$output" return 1 fi fi return 0 } version_ge() { [ "$(printf '%s\n' "$2" "$1" | sort -V | head -n1)" = "$2" ] } ensure_meson_min_version() { local min_version="$1" local current_version current_version=$(meson --version 2>/dev/null || echo 0) if ! version_ge "$current_version" "$min_version"; then echo "Meson $current_version < $min_version, upgrading via pip..." sudo apt-get update sudo apt-get install -y python3-pip if pip3 install --help 2>/dev/null | grep -q -- '--break-system-packages'; then sudo -E pip3 install --upgrade --break-system-packages meson else sudo -E pip3 install --upgrade meson fi export PATH="/usr/local/bin:$PATH" fi } # NVIDIA CUDA setup (only used if ENABLE_NVIDIA_CUDA=1) # CUDA_VERSION options: # "auto" - (default) use installed CUDA if available, else install latest # "12.8" - explicit version (e.g., 12.4, 12.6, 13.0) - dots converted to dashes internally CUDA_VERSION=${CUDA_VERSION:-auto} # Validate CUDA_VERSION format before normalization if [ "$CUDA_VERSION" != "auto" ]; then if ! validate_cuda_version "$CUDA_VERSION"; then log_error "Set CUDA_VERSION to 'auto' or a valid version like '12.8' or '13.0'" exit 1 fi fi # Normalize: convert "12.4" to "12-4" for apt package names CUDA_VERSION="${CUDA_VERSION//./-}" # NVCC_GENCODE options: # "native" - (default) compile for build machine's GPU via nvidia-smi # "minimum" - lowest arch for CUDA version (sm_52 for <13, sm_75 for 13+) # "75" - explicit single arch (e.g., 75, 86, 89) NVCC_GENCODE=${NVCC_GENCODE:-native} # Build paths SRC_DIR="${SRC_DIR:-$HOME/ffmpeg_sources}" # Source code cache (can be deleted after build) BUILD_DIR="${BUILD_DIR:-$HOME/ffmpeg_build}" # Build artifacts cache (can be deleted after build) BIN_DIR="${BIN_DIR:-$HOME/.local/bin}" # Final binary install location LIB_DIR="${LIB_DIR:-$HOME/.local/lib}" # Final shared library install location (for libva) # Get number of processors with fallback and validation NPROC=$(nproc 2>/dev/null || echo 4) if ! [[ "$NPROC" =~ ^[0-9]+$ ]] || [ "$NPROC" -lt 1 ] || [ "$NPROC" -gt 256 ]; then log_warn "Invalid NPROC value '$NPROC', using 4" NPROC=4 fi # Ensure HOME is set (for cron/systemd contexts) if [ -z "$HOME" ]; then HOME=$(getent passwd "$(id -un)" 2>/dev/null | cut -d: -f6) || true if [ -z "$HOME" ] || [ ! -d "$HOME" ]; then log_error "HOME environment variable not set and could not be detected" exit 1 fi export HOME fi # Create build directories (with validation) mkdir -p "$SRC_DIR" "$BUILD_DIR" "$BIN_DIR" "$LIB_DIR" || { log_error "Failed to create build directories" exit 1 } # Note: libplacebo pin was previously needed for jammy because older versions # lacked dependencies. Now using latest which is compatible with FFmpeg snapshot API. # Note: SVT-AV1 pin was previously needed for FFmpeg 7.0 API compatibility on jammy. # With FFmpeg snapshot, we use latest SVT-AV1 (no pin needed). # Base packages (installed first, includes wget needed for CUDA repo setup) APT_PACKAGES=( autoconf automake build-essential cmake doxygen git meson nasm ninja-build pkg-config texinfo unzip wget xxd yasm libass-dev libbluray-dev libfdk-aac-dev libfontconfig1-dev libfreetype6-dev libfribidi-dev libharfbuzz-dev libsoxr-dev libsrt-openssl-dev libssl-dev libzstd-dev libzimg-dev liblzma-dev liblzo2-dev libmp3lame-dev libnuma-dev ocl-icd-opencl-dev libopus-dev librubberband-dev libsdl2-dev libtool python3-jinja2 libunistring-dev libvdpau-dev libvidstab-dev libdrm-dev libx11-dev libvorbis-dev libvpx-dev libxcb-shm0-dev libxcb-xfixes0-dev libxcb1-dev zlib1g-dev # Intel oneVPL/QSV runtime (needed for Intel GPU hardware encoding) libmfx-gen1.2 ) # Add apt packages for libraries we're not building from source [ "$BUILD_LIBX265" != "1" ] && APT_PACKAGES+=(libx265-dev) [ "$BUILD_LIBAOM" != "1" ] && APT_PACKAGES+=(libaom-dev) [ "$BUILD_LIBWEBP" != "1" ] && APT_PACKAGES+=(libwebp-dev) [ "$BUILD_LIBVPL" != "1" ] && APT_PACKAGES+=(libvpl-dev) [ "$BUILD_LIBDAV1D" != "1" ] && APT_PACKAGES+=(libdav1d-dev) [ "$BUILD_LIBSVTAV1" != "1" ] && APT_PACKAGES+=(libsvtav1enc-dev) [ "$BUILD_LIBVMAF" != "1" ] && APT_PACKAGES+=(libvmaf-dev) [ "$BUILD_LIBVA" != "1" ] && APT_PACKAGES+=(libva-dev) [ "$BUILD_LIBJXL" != "1" ] && APT_PACKAGES+=(libjxl-dev) [ "$BUILD_LIBX264" != "1" ] && APT_PACKAGES+=(libx264-dev) # Note: TensorRT headers (libnvinfer-headers-dev) installed later after CUDA repo is set up if [ "$SKIP_DEPS" != "1" ]; then sudo apt-get update sudo apt-get install -y "${APT_PACKAGES[@]}" ensure_meson_min_version 0.63 fi CUDA_FLAGS=() NVCC_ARCH="" if [ "$ENABLE_NVIDIA_CUDA" = "1" ]; then # Check if CUDA is already installed if [ "$CUDA_VERSION" = "auto" ]; then if command -v nvcc &> /dev/null; then # Extract version from nvcc (e.g., "12.9" -> "12-9") NVCC_VERSION=$(nvcc --version | grep -oP 'release \K[0-9]+\.[0-9]+') CUDA_VERSION=$(echo "$NVCC_VERSION" | tr '.' '-') echo "Detected installed CUDA $NVCC_VERSION (using version $CUDA_VERSION)" else echo "No CUDA installed, will install latest from NVIDIA repo" fi fi # Add CUDA repo if not present or if we need to install if [ "$SKIP_DEPS" != "1" ]; then if [ "$CUDA_VERSION" = "auto" ] || ! command -v nvcc &> /dev/null; then if ! dpkg -l cuda-keyring 2>/dev/null | grep -q ^ii; then # Detect Ubuntu version for correct CUDA repo (24.04 -> ubuntu2404, 25.04 -> ubuntu2504) UBUNTU_VERSION=$(get_ubuntu_version) || exit 1 CUDA_REPO_URL="https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VERSION}/x86_64/cuda-keyring_1.1-1_all.deb" if ! wget --progress=dot:giga "$CUDA_REPO_URL" -O cuda-keyring.deb; then log_error "Failed to download CUDA keyring from $CUDA_REPO_URL" exit 1 fi sudo dpkg -i cuda-keyring.deb rm cuda-keyring.deb sudo apt-get update fi # Get latest CUDA version if still auto if [ "$CUDA_VERSION" = "auto" ]; then CUDA_VERSION=$(apt-cache search '^cuda-nvcc-[0-9]' | sed 's/cuda-nvcc-//' | cut -d' ' -f1 | sort -V | tail -1) if [ -z "$CUDA_VERSION" ]; then echo "Error: No CUDA packages found. Install CUDA repo first or set CUDA_VERSION manually." >&2 exit 1 fi echo "Will install latest CUDA version: $CUDA_VERSION" fi fi # Install CUDA packages sudo apt-get install -y libffmpeg-nvenc-dev cuda-nvcc-$CUDA_VERSION cuda-cudart-dev-$CUDA_VERSION # Install TensorRT headers only (requires NVIDIA repo set up above) # We only need headers for compilation - libnvinfer is loaded via dlopen at runtime if [ "$ENABLE_TENSORRT" = "1" ]; then sudo apt-get install -y libnvinfer-headers-dev fi fi echo "Using CUDA version: $CUDA_VERSION" # Detect CUDA installation path CUDA_VERSION_DOT=$(echo "$CUDA_VERSION" | tr '-' '.') if [ -d "/usr/local/cuda" ]; then CUDA_PATH="/usr/local/cuda" elif [ -d "/usr/local/cuda-${CUDA_VERSION_DOT}" ]; then CUDA_PATH="/usr/local/cuda-${CUDA_VERSION_DOT}" else echo "Warning: CUDA path not found, using /usr/local/cuda (headers may be missing)" >&2 CUDA_PATH="/usr/local/cuda" fi echo "Using CUDA path: $CUDA_PATH" # Patch CUDA headers for glibc 2.42+ compatibility (Ubuntu 25.04+) # glibc 2.42 added rsqrt/rsqrtf to mathcalls.h which conflicts with CUDA's definitions # This causes "exception specification is incompatible" errors during nvcc compilation if [ "$SKIP_DEPS" != "1" ]; then CUDA_MATH_HEADER="$CUDA_PATH/targets/x86_64-linux/include/crt/math_functions.h" if [ -f "$CUDA_MATH_HEADER" ]; then GLIBC_VERSION=$(ldd --version | head -1 | grep -oP '\d+\.\d+$') GLIBC_MAJOR=$(echo "$GLIBC_VERSION" | cut -d. -f1) GLIBC_MINOR=$(echo "$GLIBC_VERSION" | cut -d. -f2) # Only patch if glibc >= 2.42 and patch not already applied if [ "$GLIBC_MAJOR" -gt 2 ] || ([ "$GLIBC_MAJOR" -eq 2 ] && [ "$GLIBC_MINOR" -ge 42 ]); then # Check for our patch OR NVIDIA's fix (they use __NV_GLIBC_PROVIDES_IEC_60559_FUNCS for similar issues) if grep -q "rsqrt" "$CUDA_MATH_HEADER" && \ ! grep -B2 "double[[:space:]]*rsqrt(double" "$CUDA_MATH_HEADER" | grep -q "GLIBC"; then echo "Patching CUDA headers for glibc $GLIBC_VERSION compatibility..." # Backup original if no backup exists [ ! -f "${CUDA_MATH_HEADER}.bak" ] && sudo cp "$CUDA_MATH_HEADER" "${CUDA_MATH_HEADER}.bak" # Add guards around rsqrt declaration (prevent conflict with glibc's rsqrt) sudo sed -i '/extern __DEVICE_FUNCTIONS_DECL__ __device_builtin__ double[[:space:]]*rsqrt(double/c\ #if !(defined(__GLIBC__) \&\& __GLIBC_USE_IEC_60559_FUNCS_EXT_C23)\ extern __DEVICE_FUNCTIONS_DECL__ __device_builtin__ double rsqrt(double x);\ #endif' "$CUDA_MATH_HEADER" # Add guards around rsqrtf declaration sudo sed -i '/extern __DEVICE_FUNCTIONS_DECL__ __device_builtin__ float[[:space:]]*rsqrtf(float/c\ #if !(defined(__GLIBC__) \&\& __GLIBC_USE_IEC_60559_FUNCS_EXT_C23)\ extern __DEVICE_FUNCTIONS_DECL__ __device_builtin__ float rsqrtf(float x);\ #endif' "$CUDA_MATH_HEADER" # Verify patch was applied if grep -B2 "double[[:space:]]*rsqrt(double" "$CUDA_MATH_HEADER" | grep -q "GLIBC"; then echo "CUDA header patched successfully" else echo "Warning: CUDA header patch may have failed - rsqrt declaration not found" >&2 echo "CUDA version may have different header format. Check $CUDA_MATH_HEADER" >&2 fi else echo "CUDA headers already patched for glibc compatibility" fi fi fi fi CUDA_FLAGS=(--enable-cuda-nvcc --enable-nvenc --enable-cuvid --enable-nvdec) CUDA_MAJOR="${CUDA_VERSION%%-*}" if [ "$NVCC_GENCODE" = "native" ]; then # Detect GPU compute capability if ! command -v nvidia-smi &> /dev/null; then log_warn "nvidia-smi not found, falling back to minimum arch" NVCC_GENCODE="minimum" elif ! COMPUTE_CAP=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null | head -1); then log_warn "nvidia-smi failed to query GPU, falling back to minimum arch" NVCC_GENCODE="minimum" elif [ -z "$COMPUTE_CAP" ]; then log_warn "nvidia-smi found but no GPU detected, falling back to minimum" NVCC_GENCODE="minimum" else COMPUTE_CAP_NUM=$(echo "$COMPUTE_CAP" | tr -d '.') NVCC_ARCH="-arch=sm_${COMPUTE_CAP_NUM}" log_info "CUDA $CUDA_VERSION NVCC_GENCODE=native -> $NVCC_ARCH (detected via nvidia-smi)" fi fi if [ "$NVCC_GENCODE" = "minimum" ]; then if [ "$CUDA_MAJOR" -ge 13 ]; then NVCC_ARCH="-arch=sm_75" else NVCC_ARCH="-arch=sm_52" fi echo "CUDA $CUDA_VERSION NVCC_GENCODE=minimum -> $NVCC_ARCH" elif [ "$NVCC_GENCODE" != "native" ]; then # Explicit arch number NVCC_ARCH="-arch=sm_$NVCC_GENCODE" echo "CUDA $CUDA_VERSION NVCC_GENCODE=$NVCC_GENCODE -> $NVCC_ARCH" fi # Pin nv-codec-headers for specific builds: # - netv-ffmpeg:cuda12.4 OR CUDA_VERSION=12.4 -> use NVENC API 12.2 headers (sdk/12.2) # - otherwise -> use upstream master NV_CODEC_REF="master" if [ "${FFMPEG_IMAGE:-}" = "netv-ffmpeg:cuda12.4" ] || [ "${CUDA_VERSION:-}" = "12-4" ]; then NV_CODEC_REF="sdk/12.2" fi echo "nv-codec-headers: FFMPEG_IMAGE=${FFMPEG_IMAGE:-unset}, CUDA_VERSION=${CUDA_VERSION:-unset}, ref=$NV_CODEC_REF" cd "$SRC_DIR" if [ -d nv-codec-headers/.git ]; then git -C nv-codec-headers fetch --depth 1 origin "$NV_CODEC_REF" git -C nv-codec-headers checkout -f "$NV_CODEC_REF" else git clone --depth 1 --branch "$NV_CODEC_REF" https://git.videolan.org/git/ffmpeg/nv-codec-headers.git fi cd nv-codec-headers make make PREFIX="$BUILD_DIR" install fi # AMD AMF setup (hardware encoding for AMD GPUs) # AMF is header-only at build time; runtime driver comes from host's AMD GPU driver AMF_FLAGS=() if [ "$ENABLE_AMD_AMF" = "1" ]; then cd "$SRC_DIR" git_update_or_clone "https://github.com/GPUOpen-LibrariesAndSDKs/AMF.git" "AMF" 1 if [ ! -d "AMF/amf/public/include" ]; then log_error "AMF clone succeeded but include directory not found" exit 1 fi mkdir -p "$BUILD_DIR/include/AMF" # Use /. to copy directory contents without glob expansion issues cp -r "AMF/amf/public/include/." "$BUILD_DIR/include/AMF/" AMF_FLAGS=(--enable-amf) log_info "AMF headers installed for AMD GPU encoding" fi if [ "$PHASE" = "deps" ]; then echo "PHASE=deps set; skipping source builds." exit 0 fi # libx264 (H.264/AVC encoder) # Build with --bit-depth=all for 8-bit and 10-bit support if [ "$BUILD_LIBX264" = "1" ]; then cd "$SRC_DIR" git_update_or_clone "https://code.videolan.org/videolan/x264.git" "x264" 1 cd x264 PATH="$BIN_DIR:$PATH" ./configure --prefix="$BUILD_DIR" --enable-static --enable-pic --disable-cli --bit-depth=all PATH="$BIN_DIR:$PATH" make -j "$NPROC" make install fi # libx265 (H.265/HEVC encoder) # Multilib build: 8-bit + 10-bit + 12-bit support (required for HDR) # Build order: 12-bit → 10-bit → 8-bit (main links the others) if [ "$BUILD_LIBX265" = "1" ]; then cd "$SRC_DIR" git_update_or_clone "https://bitbucket.org/multicoreware/x265_git.git" "x265_git" 1 cd x265_git/build/linux # Clean previous builds rm -rf 8bit 10bit 12bit mkdir -p 8bit 10bit 12bit # Build 12-bit cd 12bit PATH="$BIN_DIR:$PATH" cmake -G "Unix Makefiles" \ -DCMAKE_INSTALL_PREFIX="$BUILD_DIR" \ -DHIGH_BIT_DEPTH=ON \ -DEXPORT_C_API=OFF \ -DENABLE_SHARED=OFF \ -DENABLE_CLI=OFF \ -DMAIN12=ON \ ../../../source PATH="$BIN_DIR:$PATH" make -j "$NPROC" # Build 10-bit cd ../10bit PATH="$BIN_DIR:$PATH" cmake -G "Unix Makefiles" \ -DCMAKE_INSTALL_PREFIX="$BUILD_DIR" \ -DHIGH_BIT_DEPTH=ON \ -DEXPORT_C_API=OFF \ -DENABLE_SHARED=OFF \ -DENABLE_CLI=OFF \ ../../../source PATH="$BIN_DIR:$PATH" make -j "$NPROC" # Build 8-bit (main) and link in 10-bit and 12-bit cd ../8bit ln -sf ../10bit/libx265.a libx265_main10.a ln -sf ../12bit/libx265.a libx265_main12.a PATH="$BIN_DIR:$PATH" cmake -G "Unix Makefiles" \ -DCMAKE_INSTALL_PREFIX="$BUILD_DIR" \ -DLIB_INSTALL_DIR="$BUILD_DIR/lib" \ -DENABLE_SHARED=OFF \ -DENABLE_CLI=OFF \ -DEXTRA_LIB="x265_main10.a;x265_main12.a" \ -DEXTRA_LINK_FLAGS="-L." \ -DLINKED_10BIT=ON \ -DLINKED_12BIT=ON \ ../../../source PATH="$BIN_DIR:$PATH" make -j "$NPROC" # Merge 8-bit, 10-bit, and 12-bit libraries into one (cmake doesn't do this automatically) mv libx265.a libx265_main.a mkdir -p merged/8bit merged/10bit merged/12bit (cd merged/8bit && ar x ../../libx265_main.a) (cd merged/10bit && ar x ../../libx265_main10.a) (cd merged/12bit && ar x ../../libx265_main12.a) ar crs libx265.a merged/*/*.o rm -rf merged libx265_main.a make install # x265's cmake doesn't reliably install x265.pc, so we create it manually # Extract version from x265.h (format: #define X265_BUILD 215) X265_VERSION=$(grep '#define X265_BUILD' "$BUILD_DIR/include/x265.h" | awk '{print $3}') mkdir -p "$BUILD_DIR/lib/pkgconfig" cat > "$BUILD_DIR/lib/pkgconfig/x265.pc" << PCEOF prefix=$BUILD_DIR exec_prefix=\${prefix} libdir=\${exec_prefix}/lib includedir=\${prefix}/include Name: x265 Description: H.265/HEVC video encoder (8-bit + 10-bit + 12-bit) Version: $X265_VERSION Libs: -L\${libdir} -lx265 Libs.private: -lstdc++ -lm -lrt -ldl -lnuma -lpthread Cflags: -I\${includedir} PCEOF fi # libaom (AV1 reference codec) if [ "$BUILD_LIBAOM" = "1" ]; then cd "$SRC_DIR" git_update_or_clone "https://aomedia.googlesource.com/aom" "aom" 1 mkdir -p aom_build cd aom_build PATH="$BIN_DIR:$PATH" cmake -G "Unix Makefiles" -DCMAKE_INSTALL_PREFIX="$BUILD_DIR" -DENABLE_TESTS=OFF -DENABLE_NASM=on -DBUILD_SHARED_LIBS=OFF -DCONFIG_AV1_HIGHBITDEPTH=1 ../aom PATH="$BIN_DIR:$PATH" make -j "$NPROC" make install fi # libwebp (WebP image codec) if [ "$BUILD_LIBWEBP" = "1" ]; then cd "$SRC_DIR" git_update_or_clone "https://chromium.googlesource.com/webm/libwebp" "libwebp" 1 cd libwebp ./autogen.sh ./configure --prefix="$BUILD_DIR" --disable-shared --enable-static make -j "$NPROC" make install fi # libjxl (JPEG XL image codec) # Ubuntu 24.04 ships 0.7.0 which is quite old; latest is 0.11.1 with HDR improvements if [ "$BUILD_LIBJXL" = "1" ]; then cd "$SRC_DIR" # libjxl needs --recursive for submodules if [ -d libjxl/.git ]; then git -C libjxl pull git -C libjxl submodule update --init --recursive else rm -rf libjxl git clone --depth 1 --recursive https://github.com/libjxl/libjxl.git if [ ! -d libjxl/.git ]; then log_error "libjxl clone failed" exit 1 fi fi cd libjxl mkdir -p build cd build cmake -G "Unix Makefiles" -DCMAKE_INSTALL_PREFIX="$BUILD_DIR" -DCMAKE_BUILD_TYPE=Release \ -DBUILD_SHARED_LIBS=OFF -DJPEGXL_ENABLE_BENCHMARK=OFF -DJPEGXL_ENABLE_EXAMPLES=OFF \ -DJPEGXL_ENABLE_MANPAGES=OFF -DJPEGXL_ENABLE_PLUGINS=OFF -DJPEGXL_ENABLE_VIEWERS=OFF \ -DJPEGXL_ENABLE_TOOLS=OFF -DJPEGXL_ENABLE_DOXYGEN=OFF -DJPEGXL_ENABLE_JPEGLI=OFF .. make -j "$NPROC" make install fi # libvpl (Intel Video Processing Library / QuickSync) if [ "$BUILD_LIBVPL" = "1" ]; then cd "$SRC_DIR" git_update_or_clone "https://github.com/intel/libvpl.git" "libvpl" 1 mkdir -p libvpl/build cd libvpl/build cmake -G "Unix Makefiles" -DCMAKE_INSTALL_PREFIX="$BUILD_DIR" -DBUILD_SHARED_LIBS=OFF .. make -j "$NPROC" make install fi # libdav1d (AV1 decoder) if [ "$BUILD_LIBDAV1D" = "1" ]; then cd "$SRC_DIR" git_update_or_clone "https://code.videolan.org/videolan/dav1d.git" "dav1d" 1 cd dav1d if [ -f build/build.ninja ]; then meson setup --reconfigure build --buildtype=release --default-library=static --prefix="$BUILD_DIR" --libdir="$BUILD_DIR/lib" || \ meson setup --wipe build --buildtype=release --default-library=static --prefix="$BUILD_DIR" --libdir="$BUILD_DIR/lib" else meson setup build --buildtype=release --default-library=static --prefix="$BUILD_DIR" --libdir="$BUILD_DIR/lib" fi ninja -C build ninja -C build install fi # libsvtav1 (AV1 encoder) if [ "$BUILD_LIBSVTAV1" = "1" ]; then cd "$SRC_DIR" git_update_or_clone "https://gitlab.com/AOMediaCodec/SVT-AV1.git" "SVT-AV1" 1 mkdir -p SVT-AV1/build cd SVT-AV1/build if [ -n "$SVTAV1_GIT_REF" ]; then git -C .. fetch --depth 1 origin "$SVTAV1_GIT_REF" git -C .. checkout -q FETCH_HEAD fi PATH="$BIN_DIR:$PATH" cmake -G "Unix Makefiles" -DCMAKE_INSTALL_PREFIX="$BUILD_DIR" -DCMAKE_BUILD_TYPE=Release -DBUILD_DEC=OFF -DBUILD_SHARED_LIBS=OFF .. PATH="$BIN_DIR:$PATH" make -j "$NPROC" make install fi # libvmaf (video quality metrics) if [ "$BUILD_LIBVMAF" = "1" ]; then cd "$SRC_DIR" git_update_or_clone "https://github.com/Netflix/vmaf" "vmaf" 1 mkdir -p vmaf/libvmaf/build cd vmaf/libvmaf/build if [ -f build.ninja ]; then meson setup --reconfigure -Denable_tests=false -Denable_docs=false --buildtype=release --default-library=static '../' --prefix "$BUILD_DIR" --bindir="$BIN_DIR" --libdir="$BUILD_DIR/lib" || \ meson setup --wipe -Denable_tests=false -Denable_docs=false --buildtype=release --default-library=static '../' --prefix "$BUILD_DIR" --bindir="$BIN_DIR" --libdir="$BUILD_DIR/lib" else meson setup -Denable_tests=false -Denable_docs=false --buildtype=release --default-library=static '../' --prefix "$BUILD_DIR" --bindir="$BIN_DIR" --libdir="$BUILD_DIR/lib" fi ninja ninja install fi # libva (VA-API) # Ubuntu 24.04 ships 2.20.0 which lacks Intel Xe kernel driver support (added in 2.21) # Build from source to get Xe support for newer Intel GPUs if [ "$BUILD_LIBVA" = "1" ]; then cd "$SRC_DIR" git_update_or_clone "https://github.com/intel/libva.git" "libva" 1 cd libva if [ -f build/build.ninja ]; then meson setup --reconfigure build --buildtype=release --default-library=shared --prefix="$BUILD_DIR" --libdir="$BUILD_DIR/lib" || \ meson setup --wipe build --buildtype=release --default-library=shared --prefix="$BUILD_DIR" --libdir="$BUILD_DIR/lib" else meson setup build --buildtype=release --default-library=shared --prefix="$BUILD_DIR" --libdir="$BUILD_DIR/lib" fi ninja -C build ninja -C build install # Copy shared libs to permanent location (LIB_DIR) for runtime cp -a "$BUILD_DIR/lib"/libva*.so* "$LIB_DIR/" fi meson_supports_prefer_static() { meson setup --help 2>/dev/null | grep -q -- '--prefer-static' } build_libplacebo() { local meson_args=( --buildtype=release --default-library=static -Dvulkan=enabled -Dvulkan-registry="$VULKAN_SDK/share/vulkan/registry/vk.xml" -Dopengl=disabled -Dd3d11=disabled -Ddemos=false --prefix "$BUILD_DIR" --libdir "$BUILD_DIR/lib" ) if meson_supports_prefer_static; then meson_args+=(--prefer-static) fi if [ -f build/build.ninja ]; then meson setup --reconfigure "${meson_args[@]}" build || \ meson setup --wipe "${meson_args[@]}" build else meson setup "${meson_args[@]}" build fi } # libplacebo (for GPU tone mapping) if [ "$BUILD_LIBPLACEBO" = "1" ]; then # Download Vulkan SDK tarball (apt packages deprecated May 2025) VULKAN_SDK_VERSION=${VULKAN_SDK_VERSION:-1.4.335.0} VULKAN_SDK_DIR="${SRC_DIR}/vulkan-sdk-${VULKAN_SDK_VERSION}" if [ ! -d "$VULKAN_SDK_DIR" ]; then echo "Downloading Vulkan SDK $VULKAN_SDK_VERSION..." cd "$SRC_DIR" rm -f vulkansdk.tar.xz # Clean up any partial download wget --progress=dot:giga -O vulkansdk.tar.xz "https://sdk.lunarg.com/sdk/download/${VULKAN_SDK_VERSION}/linux/vulkansdk-linux-x86_64-${VULKAN_SDK_VERSION}.tar.xz" tar xf vulkansdk.tar.xz mv "${VULKAN_SDK_VERSION}" "vulkan-sdk-${VULKAN_SDK_VERSION}" rm -f vulkansdk.tar.xz fi export VULKAN_SDK="$VULKAN_SDK_DIR/x86_64" export PATH="$VULKAN_SDK/bin:$PATH" export PKG_CONFIG_PATH="$VULKAN_SDK/lib/pkgconfig:$PKG_CONFIG_PATH" echo "Using Vulkan SDK: $VULKAN_SDK" # Use static shaderc (avoid runtime .so dependency) if [ ! -f "$VULKAN_SDK/lib/pkgconfig/shaderc.pc.bak" ]; then cp "$VULKAN_SDK/lib/pkgconfig/shaderc.pc" "$VULKAN_SDK/lib/pkgconfig/shaderc.pc.bak" fi cp "$VULKAN_SDK/lib/pkgconfig/shaderc_combined.pc" "$VULKAN_SDK/lib/pkgconfig/shaderc.pc" cd "$SRC_DIR" git_update_or_clone "https://code.videolan.org/videolan/libplacebo.git" "libplacebo" 1 cd libplacebo if [ -n "$LIBPLACEBO_GIT_REF" ]; then git fetch --depth 1 origin "$LIBPLACEBO_GIT_REF" git checkout -q FETCH_HEAD fi build_libplacebo ninja -C build ninja -C build install fi # LibTorch (PyTorch C++ library for DNN backend) # Enables AI-based video filters like dnn_processing for upscaling, denoising, etc. # NOTE: LibTorch 2.6.0+ renamed initXPU() to init(). We patch ffmpeg to handle both. # Use 2.7.0+ for RTX 50-series (Blackwell/SM 12.0) support. LIBTORCH_FLAGS=() if [ "$ENABLE_LIBTORCH" = "1" ]; then LIBTORCH_VERSION=${LIBTORCH_VERSION:-2.7.0} LIBTORCH_DIR="$SRC_DIR/libtorch" # Determine LibTorch CUDA variant # LIBTORCH_VARIANT: cu124 (default), auto, cpu, cu126, cu128, cu130, rocm6.4 # PyTorch only releases for even-numbered CUDA versions if [ "$LIBTORCH_VARIANT" != "auto" ]; then # Validate user-provided variant if ! validate_torch_variant "$LIBTORCH_VARIANT"; then log_error "Valid variants: cu124, cu126, cu128, cu130, rocm6.4, cpu" exit 1 fi TORCH_VARIANT="$LIBTORCH_VARIANT" echo "LibTorch: using $TORCH_VARIANT (explicit)" elif [ "$ENABLE_NVIDIA_CUDA" = "1" ]; then CUDA_MAJOR="${CUDA_VERSION%%-*}" CUDA_MINOR="${CUDA_VERSION#*-}" if [ "$CUDA_MAJOR" -ge 13 ]; then TORCH_VARIANT="cu130" else # Round down to nearest even, clamp to [6, 8] EVEN_MINOR=$(( (CUDA_MINOR / 2) * 2 )) [ "$EVEN_MINOR" -gt 8 ] && EVEN_MINOR=8 [ "$EVEN_MINOR" -lt 6 ] && EVEN_MINOR=6 TORCH_VARIANT="cu12${EVEN_MINOR}" fi echo "LibTorch: using $TORCH_VARIANT (from CUDA $CUDA_VERSION)" else TORCH_VARIANT="cpu" echo "LibTorch: using CPU-only variant" fi # Download LibTorch if not present or wrong variant LIBTORCH_MARKER="$LIBTORCH_DIR/.variant-${TORCH_VARIANT}" if [ ! -f "$LIBTORCH_MARKER" ]; then echo "Downloading LibTorch $LIBTORCH_VERSION ($TORCH_VARIANT)..." cd "$SRC_DIR" # Clean up any existing libtorch directory (safe: inside $SRC_DIR) safe_rm_rf "$SRC_DIR/libtorch" 3 rm -f libtorch.zip # Download from pytorch.org # cu124 and earlier use cxx11-abi prefix, cu130+ dropped it if [[ "$TORCH_VARIANT" == cu13* ]] || [[ "$TORCH_VARIANT" == cu14* ]]; then LIBTORCH_URL="https://download.pytorch.org/libtorch/${TORCH_VARIANT}/libtorch-shared-with-deps-${LIBTORCH_VERSION}%2B${TORCH_VARIANT}.zip" else LIBTORCH_URL="https://download.pytorch.org/libtorch/${TORCH_VARIANT}/libtorch-cxx11-abi-shared-with-deps-${LIBTORCH_VERSION}%2B${TORCH_VARIANT}.zip" fi if ! wget --progress=dot:giga -O libtorch.zip "$LIBTORCH_URL"; then log_error "Failed to download LibTorch from $LIBTORCH_URL" exit 1 fi unzip -q libtorch.zip rm -f libtorch.zip # Verify extraction succeeded before marking complete if [ ! -f "$LIBTORCH_DIR/lib/libtorch.so" ]; then log_error "LibTorch extraction failed - libtorch.so not found" exit 1 fi touch "$LIBTORCH_MARKER" fi export LIBTORCH_PATH="$LIBTORCH_DIR" LIBTORCH_FLAGS=(--enable-libtorch) echo "Using LibTorch: $LIBTORCH_PATH" # Copy libtorch shared libs to permanent location (LIB_DIR) echo "Installing libtorch libs to $LIB_DIR..." if ! cp -a "$LIBTORCH_DIR/lib"/*.so* "$LIB_DIR/" 2>/dev/null; then log_warn "Some libtorch libs may not have been copied - check $LIB_DIR" fi # Verify at least libtorch.so was copied if [ ! -f "$LIB_DIR/libtorch.so" ]; then log_error "libtorch.so not found in $LIB_DIR after copy" exit 1 fi # Create pkg-config file for libtorch (FFmpeg configure uses pkg-config for detection) mkdir -p "$BUILD_DIR/lib/pkgconfig" # Include CUDA libs if using CUDA variant if [[ "$TORCH_VARIANT" == cu* ]]; then TORCH_LIBS="-ltorch -lc10 -ltorch_cpu -ltorch_cuda -lc10_cuda" # Needed for ffmpeg extra-libs to ensure libtorch_cuda is linked (not just dlopen'd) TORCH_EXTRA_LIBS="-lc10_cuda -ltorch_cuda" else TORCH_LIBS="-ltorch -lc10 -ltorch_cpu" TORCH_EXTRA_LIBS="" fi cat > "$BUILD_DIR/lib/pkgconfig/libtorch.pc" << PCEOF prefix=$LIBTORCH_DIR exec_prefix=\${prefix} libdir=$LIB_DIR includedir=\${prefix}/include Name: libtorch Description: PyTorch C++ library Version: $LIBTORCH_VERSION Libs: -L\${libdir} $TORCH_LIBS Cflags: -I\${includedir} -I\${includedir}/torch/csrc/api/include -std=c++17 PCEOF echo "Created libtorch.pc for pkg-config detection (variant: $TORCH_VARIANT)" fi # ffmpeg FFMPEG_DIR="ffmpeg-${FFMPEG_VERSION}" cd "$SRC_DIR" if [ ! -d "$FFMPEG_DIR" ]; then if [ "$FFMPEG_VERSION" = "snapshot" ]; then rm -f ffmpeg-snapshot.tar.bz2 # Clean up any partial download wget --progress=dot:giga -O ffmpeg-snapshot.tar.bz2 https://ffmpeg.org/releases/ffmpeg-snapshot.tar.bz2 tar xjf ffmpeg-snapshot.tar.bz2 mv ffmpeg "$FFMPEG_DIR" rm -f ffmpeg-snapshot.tar.bz2 else rm -f "ffmpeg-${FFMPEG_VERSION}.tar.xz" # Clean up any partial download wget --progress=dot:giga -O "ffmpeg-${FFMPEG_VERSION}.tar.xz" "https://ffmpeg.org/releases/ffmpeg-${FFMPEG_VERSION}.tar.xz" tar xJf "ffmpeg-${FFMPEG_VERSION}.tar.xz" rm -f "ffmpeg-${FFMPEG_VERSION}.tar.xz" fi fi # Patch ffmpeg for SVT-AV1 v4.0.0 API compatibility # v4.0.0 renamed enable_adaptive_quantization -> aq_mode SVTAV1_CODEC="$FFMPEG_DIR/libavcodec/libsvtav1.c" SVTAV1_HEADER="$BUILD_DIR/include/svt-av1/EbSvtAv1Enc.h" if [ -f "$SVTAV1_CODEC" ] && grep -q "enable_adaptive_quantization" "$SVTAV1_CODEC" && \ [ -f "$SVTAV1_HEADER" ] && grep -q "uint8_t aq_mode;" "$SVTAV1_HEADER"; then echo "Patching ffmpeg for SVT-AV1 v4.0.0 (adding compat alias)..." sed -i '/#include /i\ /* SVT-AV1 v4.0.0 compat: renamed enable_adaptive_quantization -> aq_mode */\ #define enable_adaptive_quantization aq_mode' "$SVTAV1_CODEC" fi # Patch ffmpeg's torch backend if [ "$ENABLE_LIBTORCH" = "1" ]; then TORCH_BACKEND="$FFMPEG_DIR/libavfilter/dnn/dnn_backend_torch.cpp" # Patch 1: Fix initXPU() -> init() for libtorch 2.6+ compatibility if [ -f "$TORCH_BACKEND" ] && grep -q "initXPU()" "$TORCH_BACKEND"; then TORCH_MAJOR=$(echo "$LIBTORCH_VERSION" | cut -d. -f1) TORCH_MINOR=$(echo "$LIBTORCH_VERSION" | cut -d. -f2) if [ "$TORCH_MAJOR" -gt 2 ] || { [ "$TORCH_MAJOR" -eq 2 ] && [ "$TORCH_MINOR" -ge 6 ]; }; then echo "Patching ffmpeg for libtorch 2.6+ (initXPU -> init)..." sed -i 's/initXPU()/init()/g' "$TORCH_BACKEND" fi fi # Patch 2: Add CUDA device support (upstream only supports CPU/XPU) if [ -f "$TORCH_BACKEND" ] && ! grep -q "device.is_cuda()" "$TORCH_BACKEND"; then echo "Patching ffmpeg torch backend for CUDA support..." # Add CUDA device support between XPU and the catch-all error # Also adds dlopen for libtorch_cuda.so to load CUDA kernels at runtime sed -i '/at::detail::getXPUHooks().init/a\ } else if (device.is_cuda()) {\ if (!at::cuda::is_available()) {\ av_log(ctx, AV_LOG_ERROR, "No CUDA device found\\n");\ goto fail;\ }\ // Load CUDA kernels - required for libtorch CUDA ops\ static bool cuda_lib_loaded = false;\ if (!cuda_lib_loaded) {\ cuda_lib_loaded = true;\ void *cuda_handle = dlopen("libtorch_cuda.so", RTLD_NOW | RTLD_GLOBAL);\ if (cuda_handle) {\ av_log(ctx, AV_LOG_DEBUG, "libtorch_cuda.so loaded\\n");\ } else {\ av_log(ctx, AV_LOG_WARNING, "Failed to load libtorch_cuda.so: %s\\n", dlerror());\ }\ }' "$TORCH_BACKEND" # Add required CUDA header if ! grep -q "#include " "$TORCH_BACKEND"; then sed -i '/#include /a #include ' "$TORCH_BACKEND" fi echo "Torch CUDA patch applied" fi # Patch 3: Add TensorRT support (load runtime + handle tuple outputs) if [ -f "$TORCH_BACKEND" ] && ! grep -q "isTuple" "$TORCH_BACKEND"; then echo "Patching ffmpeg torch backend for TensorRT support..." # Add dlfcn.h header for dlopen if ! grep -q "#include " "$TORCH_BACKEND"; then sed -i '/#include /a #include ' "$TORCH_BACKEND" fi # Add TensorRT runtime loading in model init (before torch::jit::load) if ! grep -q "libtorchtrt_runtime" "$TORCH_BACKEND"; then sed -i '/torch::jit::load(ctx->model_filename)/i\ // Load TensorRT runtime if available (enables TRT-compiled models)\ static bool trt_init_attempted = false;\ if (!trt_init_attempted) {\ trt_init_attempted = true;\ void *trt_handle = dlopen("libtorchtrt_runtime.so", RTLD_NOW | RTLD_GLOBAL);\ if (trt_handle) {\ av_log(ctx, AV_LOG_INFO, "TensorRT runtime loaded\\n");\ }\ }' "$TORCH_BACKEND" fi # Change forward().toTensor() to handle TRT tuple outputs sed -i 's/\*infer_request->output = th_model->jit_model->forward(inputs)\.toTensor();/auto _fwd_out = th_model->jit_model->forward(inputs);\ if (_fwd_out.isTuple()) {\ *infer_request->output = _fwd_out.toTuple()->elements()[0].toTensor();\ } else {\ *infer_request->output = _fwd_out.toTensor();\ }/' "$TORCH_BACKEND" # Fix device detection for TRT models (they may not have parameters) sed -i 's/c10::Device device = (\*th_model->jit_model->parameters()\.begin())\.device();/c10::Device device = torch::kCUDA;\ auto params = th_model->jit_model->parameters();\ if (params.begin() != params.end()) {\ device = (*params.begin()).device();\ }/' "$TORCH_BACKEND" echo "Torch TensorRT patch applied" fi fi # Patch ffmpeg's libplacebo filter to include libavformat version header # (suppresses LIBAVFORMAT_VERSION_INT -Wundef warnings). LIBPLACEBO_FILTER="$FFMPEG_DIR/libavfilter/vf_libplacebo.c" if [ -f "$LIBPLACEBO_FILTER" ] && ! grep -q "libavformat/version.h" "$LIBPLACEBO_FILTER"; then sed -i '/#include "libavfilter\/avfilter.h"/a #include "libavformat/version.h"' "$LIBPLACEBO_FILTER" fi # TensorRT native backend (no libtorch dependency) # Loads pre-compiled .engine files directly for maximum performance TENSORRT_FLAGS=() if [ "$ENABLE_TENSORRT" = "1" ]; then echo "Patching FFmpeg for native TensorRT DNN backend..." PATCH_DIR="$SCRIPT_DIR/patches" # Copy TensorRT backend source file if [ -f "$PATCH_DIR/dnn_backend_tensorrt.cpp" ]; then cp "$PATCH_DIR/dnn_backend_tensorrt.cpp" "$FFMPEG_DIR/libavfilter/dnn/" echo "Copied dnn_backend_tensorrt.cpp" else echo "WARNING: dnn_backend_tensorrt.cpp not found in $PATCH_DIR" fi # Copy CUDA kernels for GPU-resident format conversion (zero-copy) # Rename to .cuda to prevent FFmpeg from auto-compiling it as .cu -> .o # Our custom PTX rules in the Makefile will compile it properly if [ -f "$PATCH_DIR/dnn_cuda_kernels.cu" ]; then cp "$PATCH_DIR/dnn_cuda_kernels.cu" "$FFMPEG_DIR/libavfilter/dnn/dnn_cuda_kernels.cuda" cp "$PATCH_DIR/dnn_cuda_kernels.h" "$FFMPEG_DIR/libavfilter/dnn/" echo "Copied CUDA format conversion kernels (renamed to .cuda to avoid auto-build)" else echo "WARNING: dnn_cuda_kernels.cu not found in $PATCH_DIR" fi # Copy patched vf_dnn_processing.c for CUDA frame support if [ -f "$PATCH_DIR/vf_dnn_processing.c" ]; then cp "$PATCH_DIR/vf_dnn_processing.c" "$FFMPEG_DIR/libavfilter/" echo "Copied vf_dnn_processing.c with CUDA frame support" fi # Patch dnn_interface.h to add DNN_TRT enum and TRTOptions DNN_INTERFACE_H="$FFMPEG_DIR/libavfilter/dnn_interface.h" if [ -f "$DNN_INTERFACE_H" ] && ! grep -q "DNN_TRT" "$DNN_INTERFACE_H"; then # FFmpeg 7.0 has single-line enum: {DNN_TF = 1, DNN_OV, DNN_TH} # FFmpeg 7.1+/snapshot has multi-line: DNN_TH = 1 << 2 if grep -q "DNN_TH = 1 << 2" "$DNN_INTERFACE_H"; then # Multi-line format (7.1+/snapshot) sed -i '/DNN_TH = 1 << 2/s/$/,/' "$DNN_INTERFACE_H" sed -i '/DNN_TH = 1 << 2,$/a\ DNN_TRT = 1 << 3' "$DNN_INTERFACE_H" elif grep -q "DNN_TF = 1, DNN_OV, DNN_TH}" "$DNN_INTERFACE_H"; then # Single-line format (7.0) sed -i 's/DNN_TF = 1, DNN_OV, DNN_TH}/DNN_TF = 1, DNN_OV, DNN_TH, DNN_TRT}/' "$DNN_INTERFACE_H" else echo "ERROR: Unknown DNNBackendType enum format in dnn_interface.h" >&2 grep -A5 "DNNBackendType" "$DNN_INTERFACE_H" >&2 exit 1 fi # Add TRTOptions struct after THOptions sed -i '/^} THOptions;$/a\ \ typedef struct TRTOptions {\ const AVClass *clazz;\ int device_id;\ } TRTOptions;' "$DNN_INTERFACE_H" # Add trt_option to DnnContext (after torch_option) sed -i '/#if CONFIG_LIBTORCH/,/#endif/{ /#endif/a\ #if CONFIG_LIBTENSORRT\ TRTOptions trt_option;\ #endif }' "$DNN_INTERFACE_H" # Verify patch was applied if grep -q "DNN_TRT" "$DNN_INTERFACE_H"; then echo "Patched dnn_interface.h" else echo "ERROR: Failed to patch dnn_interface.h - DNN_TRT not found after patching" >&2 exit 1 fi fi # Patch dnn_interface.c to register TensorRT backend DNN_INTERFACE_C="$FFMPEG_DIR/libavfilter/dnn/dnn_interface.c" if [ -f "$DNN_INTERFACE_C" ] && ! grep -q "ff_dnn_backend_tensorrt" "$DNN_INTERFACE_C"; then # Add extern declaration sed -i '/extern const DNNModule ff_dnn_backend_torch;/a extern const DNNModule ff_dnn_backend_tensorrt;' "$DNN_INTERFACE_C" # Add to backend list sed -i '/#if CONFIG_LIBTORCH/,/#endif/{ /#endif/a\ #if CONFIG_LIBTENSORRT\ {offsetof(DnnContext, trt_option), .module = \&ff_dnn_backend_tensorrt},\ #endif }' "$DNN_INTERFACE_C" # Verify patch was applied if ! verify_patch "$DNN_INTERFACE_C" "ff_dnn_backend_tensorrt" "dnn_interface.c TensorRT registration"; then exit 1 fi log_info "Patched dnn_interface.c" fi # Patch dnn/Makefile to add TensorRT objects and CUDA kernel PTX compilation DNN_MAKEFILE="$FFMPEG_DIR/libavfilter/dnn/Makefile" if [ -f "$DNN_MAKEFILE" ] && ! grep -q "CONFIG_LIBTENSORRT" "$DNN_MAKEFILE"; then # Add TensorRT backend object sed -i '/CONFIG_LIBTORCH.*dnn_backend_torch/a DNN-OBJS-$(CONFIG_LIBTENSORRT) += dnn/dnn_backend_tensorrt.o' "$DNN_MAKEFILE" # Add embedded PTX object (compiled PTX as C byte array) sed -i '/dnn_backend_tensorrt.o/a DNN-OBJS-$(CONFIG_LIBTENSORRT) += dnn/dnn_cuda_kernels_ptx.o' "$DNN_MAKEFILE" # Add PTX compilation and embedding rules # 1. Compile .cu to .ptx with nvcc # 2. Embed .ptx as C byte array using xxd (bin2c alternative) # 3. Compile embedded C to object cat >> "$DNN_MAKEFILE" << 'PTXRULES' # CUDA kernel PTX compilation and embedding (no cudart dependency) # Step 1: Compile CUDA kernels to PTX (intermediate representation) # Source is .cuda (not .cu) to prevent FFmpeg from auto-compiling to .o libavfilter/dnn/dnn_cuda_kernels.ptx: libavfilter/dnn/dnn_cuda_kernels.cuda $(NVCC) --ptx -o $@ -x cu $< -m64 # Step 2: Embed PTX as C byte array (using xxd, like bin2c) libavfilter/dnn/dnn_cuda_kernels_ptx.c: libavfilter/dnn/dnn_cuda_kernels.ptx @echo "Embedding PTX as C byte array..." @echo "/* Auto-generated - do not edit */" > $@ @echo "#include " >> $@ @echo "" >> $@ @printf "const unsigned char ff_dnn_cuda_kernels_ptx[] = {\n" >> $@ @xxd -i < $< >> $@ @echo "};" >> $@ @echo "" >> $@ @printf "const unsigned int ff_dnn_cuda_kernels_ptx_len = sizeof(ff_dnn_cuda_kernels_ptx);\n" >> $@ # Step 3: Compile embedded PTX C file to object libavfilter/dnn/dnn_cuda_kernels_ptx.o: libavfilter/dnn/dnn_cuda_kernels_ptx.c $(CC) $(CPPFLAGS) $(CFLAGS) -c -o $@ $< PTXRULES # Verify Makefile patch if ! verify_patch "$DNN_MAKEFILE" "dnn_cuda_kernels_ptx.o" "dnn/Makefile PTX rules"; then exit 1 fi log_info "Patched dnn/Makefile with TensorRT and CUDA PTX kernel support" fi # Patch configure to add --enable-libtensorrt option CONFIGURE="$FFMPEG_DIR/configure" if [ -f "$CONFIGURE" ] && ! grep -q "enable-libtensorrt" "$CONFIGURE"; then # Add help text sed -i '/--enable-libtorch.*enable Torch/a\ --enable-libtensorrt enable TensorRT as one DNN backend [no]' "$CONFIGURE" # Add to library list sed -i '/^ libtorch$/a\ libtensorrt' "$CONFIGURE" # Add to dnn_deps_any sed -i 's/dnn_deps_any="libtensorflow libopenvino libtorch"/dnn_deps_any="libtensorflow libopenvino libtorch libtensorrt"/' "$CONFIGURE" # Add header check (after libtorch check) # TensorRT (libnvinfer) and CUDA (libcuda) are loaded via dlopen at runtime. # CUDA kernels are compiled to PTX and loaded via Driver API - no cudart dependency. sed -i '/enabled libtorch.*require_cxx libtorch/a\ enabled libtensorrt && check_cxxflags -std=c++17 && check_headers NvInfer.h' "$CONFIGURE" echo "Patched configure (TensorRT via dlopen, CUDA linked for kernels)" fi TENSORRT_FLAGS=(--enable-libtensorrt) echo "TensorRT DNN backend patches applied" fi cd "$FFMPEG_DIR" # Build configure flags # MARCH=native for CPU-specific optimizations (opt-in, not portable) EXTRA_CFLAGS="-I$BUILD_DIR/include -O3${MARCH:+ -march=$MARCH -mtune=$MARCH}" EXTRA_CXXFLAGS="" # -rpath embeds library search path in binary so it finds our built libs at runtime EXTRA_LDFLAGS="-L$BUILD_DIR/lib -s -Wl,-rpath,$LIB_DIR" if [ "$ENABLE_NVIDIA_CUDA" = "1" ]; then EXTRA_CFLAGS="$EXTRA_CFLAGS -I$CUDA_PATH/include" EXTRA_LDFLAGS="$EXTRA_LDFLAGS -L$CUDA_PATH/lib64" fi if [ "$BUILD_LIBPLACEBO" = "1" ]; then EXTRA_CFLAGS="$EXTRA_CFLAGS -I$VULKAN_SDK/include" EXTRA_LDFLAGS="$EXTRA_LDFLAGS -L$VULKAN_SDK/lib" fi if [ "$ENABLE_LIBTORCH" = "1" ]; then # LibTorch needs C++ flags (FFmpeg uses require_cxx for libtorch detection) # Include CUDA path for CUDA torch support EXTRA_CXXFLAGS="-I$LIBTORCH_PATH/include -I$LIBTORCH_PATH/include/torch/csrc/api/include" if [ "$ENABLE_NVIDIA_CUDA" = "1" ]; then EXTRA_CXXFLAGS="$EXTRA_CXXFLAGS -I$CUDA_PATH/include" fi EXTRA_LDFLAGS="$EXTRA_LDFLAGS -L$LIB_DIR -Wl,-rpath,$LIB_DIR" fi if [ "$ENABLE_TENSORRT" = "1" ]; then # TensorRT needs C++ flags with CUDA headers (uses require_cxx for detection) # Note: We use CUDA Driver API (libcuda.so) loaded via dlopen at runtime, # NOT CUDA Runtime API (libcudart.so). This avoids load-time dependency. if [ "$ENABLE_NVIDIA_CUDA" = "1" ]; then EXTRA_CXXFLAGS="$EXTRA_CXXFLAGS -I$CUDA_PATH/include" fi fi CONFIGURE_CMD=( ./configure --prefix="$BUILD_DIR" --pkg-config-flags="--static" --extra-cflags="$EXTRA_CFLAGS" --extra-cxxflags="$EXTRA_CXXFLAGS" --extra-ldflags="$EXTRA_LDFLAGS" --extra-libs="-lpthread -lm -ldl${TORCH_EXTRA_LIBS:+ $TORCH_EXTRA_LIBS}" --ld="g++" --bindir="$BIN_DIR" --disable-debug --enable-gpl --enable-version3 --enable-openssl --enable-libaom --enable-libass --enable-libbluray --enable-libfdk-aac --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libharfbuzz --enable-libjxl --enable-libmp3lame --enable-libopus --enable-libsvtav1 --enable-libdav1d --enable-libvmaf --enable-libvorbis --enable-libvpx --enable-libwebp --enable-libx264 --enable-libx265 --enable-librubberband --enable-libsoxr --enable-libsrt --enable-libvidstab --enable-libvpl --enable-libzimg --enable-opencl --enable-vaapi --enable-nonfree "${CUDA_FLAGS[@]}" "${AMF_FLAGS[@]}" "${LIBTORCH_FLAGS[@]}" "${TENSORRT_FLAGS[@]}" ) if [ "$BUILD_LIBPLACEBO" = "1" ]; then CONFIGURE_CMD+=(--enable-vulkan --enable-libplacebo) fi if [ -n "$NVCC_ARCH" ]; then CONFIGURE_CMD+=(--nvccflags="$NVCC_ARCH") fi # Build PATH: include CUDA bin if NVIDIA enabled BUILD_PATH="$BIN_DIR:$PATH" [ "$ENABLE_NVIDIA_CUDA" = "1" ] && BUILD_PATH="$CUDA_PATH/bin:$BUILD_PATH" PATH="$BUILD_PATH" PKG_CONFIG_PATH="$BUILD_DIR/lib/pkgconfig:$PKG_CONFIG_PATH" "${CONFIGURE_CMD[@]}" PATH="$BUILD_PATH" make -j "$NPROC" make install hash -r grep -q "$BUILD_DIR/share/man" "$HOME/.manpath" 2>/dev/null || echo "MANPATH_MAP $BIN_DIR $BUILD_DIR/share/man" >> "$HOME/.manpath" log_info "FFmpeg build completed successfully" ================================================ FILE: tools/install-letsencrypt.sh ================================================ #!/bin/bash # Install and configure Let's Encrypt certificates set -e DOMAIN="${1:-}" if [ -z "$DOMAIN" ]; then echo "Usage: $0 " echo "Example: $0 yourdomain.com" exit 1 fi echo "=== Installing certbot ===" sudo apt update sudo apt install -y certbot # Detect web server and choose authenticator if systemctl is-active --quiet apache2; then echo "=== Apache detected, using apache authenticator ===" sudo apt install -y python3-certbot-apache CERTBOT_MODE="--apache" elif systemctl is-active --quiet nginx; then echo "=== Nginx detected, using nginx authenticator ===" sudo apt install -y python3-certbot-nginx CERTBOT_MODE="--nginx" else echo "=== No web server detected, using standalone mode ===" echo "Note: Port 80 must be free for domain verification" CERTBOT_MODE="--standalone" fi echo "=== Obtaining certificate for $DOMAIN ===" sudo certbot $CERTBOT_MODE -d "$DOMAIN" echo "=== Setting up ssl-cert group permissions ===" sudo chgrp -R ssl-cert /etc/letsencrypt/archive/ sudo chmod -R g+r /etc/letsencrypt/archive/ echo "=== Installing deploy hook ===" cat <<'EOF' | sudo tee /etc/letsencrypt/renewal-hooks/deploy/ssl-cert-perms #!/bin/bash # Fix cert permissions after renewal for ssl-cert group chgrp -R ssl-cert /etc/letsencrypt/archive/ chmod -R g+r /etc/letsencrypt/archive/ EOF sudo chmod +x /etc/letsencrypt/renewal-hooks/deploy/ssl-cert-perms echo "" echo "=== Done ===" echo "" echo "Certificate installed for $DOMAIN" echo "Certbot timer will auto-renew (check: systemctl list-timers | grep certbot)" echo "" echo "To give a user access to certs:" echo " sudo usermod -aG ssl-cert " ================================================ FILE: tools/install-netv.sh ================================================ #!/bin/bash # Install netv systemd service # Prerequisites: uv (install time only), install-letsencrypt.sh # # Usage: sudo ./install-netv.sh [--port PORT] # --port PORT Port to listen on (default: 8000) set -e IPTV_DIR="$(cd "$(dirname "$0")/.." && pwd)" USER="${SUDO_USER:-$USER}" PORT=8000 # Parse arguments while [[ $# -gt 0 ]]; do case $1 in --port) PORT="$2" shift 2 ;; *) echo "Unknown option: $1" echo "Usage: sudo $0 [--port PORT]" exit 1 ;; esac done # Validate if [ "$USER" = "root" ]; then echo "Error: Run with sudo, not as root directly" echo "Usage: sudo $0 [--port PORT]" exit 1 fi # Find uv in user's environment (only needed at install time) UV_PATH=$(su - "$USER" -c "which uv" 2>/dev/null) if [ -z "$UV_PATH" ]; then echo "Error: uv not found for user $USER. Install with:" echo " curl -LsSf https://astral.sh/uv/install.sh | sh" echo "See: https://docs.astral.sh/uv/" exit 1 fi echo "=== Syncing dependencies ===" su - "$USER" -c "cd '$IPTV_DIR' && '$UV_PATH' sync" if [ ! -d /etc/letsencrypt/live ]; then echo "Warning: Let's Encrypt not configured. Run install-letsencrypt.sh first for HTTPS." echo "Continuing with HTTP-only setup..." HTTPS_FLAG="" else HTTPS_FLAG="--https" fi echo "=== Installing netv for user: $USER (port $PORT) ===" echo "=== Adding $USER to ssl-cert group ===" sudo usermod -aG ssl-cert "$USER" echo "=== Installing netv systemd service ===" # Build PATH - prefer custom ffmpeg in ~/.local/bin if it exists USER_LOCAL_BIN="/home/$USER/.local/bin" if [ -x "$USER_LOCAL_BIN/ffmpeg" ]; then echo " Found custom ffmpeg in $USER_LOCAL_BIN" ENV_PATH="$USER_LOCAL_BIN:/usr/local/bin:/usr/bin:/bin" else ENV_PATH="/usr/local/bin:/usr/bin:/bin" fi # Build LIBVA env vars if custom libva exists (for VAAPI on hybrid GPU systems) USER_LOCAL_LIB="/home/$USER/.local/lib" LIBVA_ENVS="" if [ -f "$USER_LOCAL_LIB/libva.so" ]; then echo " Found custom libva in $USER_LOCAL_LIB" # Auto-detect LIBVA driver based on GPU vendor LIBVA_DRIVER="" if lspci -nn 2>/dev/null | grep -qE "VGA.*\[8086:"; then LIBVA_DRIVER="i965" # Intel elif lspci -nn 2>/dev/null | grep -qE "VGA.*\[1002:"; then LIBVA_DRIVER="radeonsi" # AMD fi # Auto-detect DRI path DRI_PATH="" for p in /usr/lib/x86_64-linux-gnu/dri /usr/lib64/dri /usr/lib/dri; do if [ -d "$p" ]; then DRI_PATH="$p" break fi done if [ -n "$LIBVA_DRIVER" ] && [ -n "$DRI_PATH" ]; then echo " Detected VAAPI driver: $LIBVA_DRIVER, path: $DRI_PATH" LIBVA_ENVS="Environment=\"LIBVA_DRIVER_NAME=$LIBVA_DRIVER\" Environment=\"LIBVA_DRIVERS_PATH=$DRI_PATH\"" fi fi cat < /dev/null; then echo "Error: $cmd not found. Install it with your package manager." exit 1 fi done echo "=== Installing uv ===" if command -v uv &> /dev/null; then echo "uv already installed: $(uv --version)" else curl -LsSf https://astral.sh/uv/install.sh | sh export PATH="$HOME/.local/bin:$PATH" fi echo "=== Installing Python 3.11 via uv ===" uv python install 3.11 echo "" echo "=== Done ===" echo "" echo "Next steps:" echo " 1. Run: ./tools/install-letsencrypt.sh " echo " 2. Run: ./tools/install-ffmpeg.sh (optional, for transcoding)" echo " 3. Run: sudo ./tools/install-netv.sh" ================================================ FILE: tools/patches/dnn_backend_tensorrt.cpp ================================================ /* * Copyright 2026 Joshua V. Dillon * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /** * @file * DNN TensorRT backend implementation. * * This backend loads pre-compiled TensorRT engine files (.engine) for * high-performance GPU inference. Use tools/export-tensorrt.py to convert * PyTorch models to TensorRT engines. * * All libraries are loaded at runtime via dlopen - no CUDA or TensorRT * dependency at ffmpeg load time. Errors only occur when the TRT backend * is actually used. * * Usage: * ffmpeg -i input.mp4 -vf "dnn_processing=dnn_backend=tensorrt:model=model.engine" output.mp4 */ #include #include #include #include #include #include #include #include #include #include // ============================================================================ // Engine cache - avoid reloading same engine file multiple times // ============================================================================ struct CachedEngine { nvinfer1::ICudaEngine *engine; nvinfer1::IRuntime *runtime; std::atomic refcount; CachedEngine(nvinfer1::ICudaEngine *e, nvinfer1::IRuntime *r) : engine(e), runtime(r), refcount(1) {} }; static std::mutex g_engine_cache_mutex; static std::unordered_map g_engine_cache; // ============================================================================ // CUDA Driver API types (from cuda.h - we dlopen libcuda.so instead of linking) // ============================================================================ typedef int CUresult; typedef int CUdevice; typedef void* CUcontext; typedef void* CUmodule; typedef void* CUfunction; typedef void* CUstream; typedef unsigned long long CUdeviceptr; #define CUDA_SUCCESS 0 // CUDA Driver API function pointer types typedef CUresult (*fn_cuInit)(unsigned int); typedef CUresult (*fn_cuDeviceGet)(CUdevice*, int); typedef CUresult (*fn_cuDevicePrimaryCtxRetain)(CUcontext*, CUdevice); typedef CUresult (*fn_cuCtxGetCurrent)(CUcontext*); typedef CUresult (*fn_cuCtxSetCurrent)(CUcontext); typedef CUresult (*fn_cuCtxPushCurrent)(CUcontext); typedef CUresult (*fn_cuCtxPopCurrent)(CUcontext*); typedef CUresult (*fn_cuMemAlloc)(CUdeviceptr*, size_t); typedef CUresult (*fn_cuMemFree)(CUdeviceptr); // CUDA Runtime API function pointers (for compatibility with TensorRT which uses Runtime API) // Note: cudaError_t is already defined via NvInfer.h -> cuda_runtime_api.h typedef cudaError_t (*fn_cudaMalloc)(void**, size_t); typedef cudaError_t (*fn_cudaFree)(void*); typedef cudaError_t (*fn_cudaSetDevice)(int); typedef cudaError_t (*fn_cudaMemcpy)(void*, const void*, size_t, int); typedef cudaError_t (*fn_cudaMemcpyAsync)(void*, const void*, size_t, int, cudaStream_t); typedef cudaError_t (*fn_cudaStreamSynchronize)(cudaStream_t); typedef cudaError_t (*fn_cudaStreamCreate)(cudaStream_t*, unsigned int); typedef cudaError_t (*fn_cudaStreamDestroy)(cudaStream_t); typedef cudaError_t (*fn_cudaMemGetInfo)(size_t*, size_t*); #define cudaMemcpyHostToDevice 1 #define cudaMemcpyDeviceToHost 2 typedef CUresult (*fn_cuMemcpyHtoD)(CUdeviceptr, const void*, size_t); typedef CUresult (*fn_cuMemcpyDtoH)(void*, CUdeviceptr, size_t); typedef CUresult (*fn_cuMemcpyHtoDAsync)(CUdeviceptr, const void*, size_t, CUstream); typedef CUresult (*fn_cuMemcpyDtoHAsync)(void*, CUdeviceptr, size_t, CUstream); typedef CUresult (*fn_cuStreamCreate)(CUstream*, unsigned int); typedef CUresult (*fn_cuStreamDestroy)(CUstream); typedef CUresult (*fn_cuStreamSynchronize)(CUstream); typedef CUresult (*fn_cuModuleLoadData)(CUmodule*, const void*); typedef CUresult (*fn_cuModuleUnload)(CUmodule); typedef CUresult (*fn_cuModuleGetFunction)(CUfunction*, CUmodule, const char*); typedef CUresult (*fn_cuLaunchKernel)(CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, CUstream, void**, void**); typedef CUresult (*fn_cuGetErrorString)(CUresult, const char**); // CUDA Graph API function pointer types typedef void* CUgraph; typedef void* CUgraphExec; typedef CUresult (*fn_cuStreamBeginCapture)(CUstream, int); typedef CUresult (*fn_cuStreamEndCapture)(CUstream, CUgraph*); typedef CUresult (*fn_cuGraphInstantiate)(CUgraphExec*, CUgraph, unsigned long long); typedef CUresult (*fn_cuGraphLaunch)(CUgraphExec, CUstream); typedef CUresult (*fn_cuGraphDestroy)(CUgraph); typedef CUresult (*fn_cuGraphExecDestroy)(CUgraphExec); #define CU_STREAM_CAPTURE_MODE_GLOBAL 0 // ============================================================================ // Dynamic library loading for CUDA and TensorRT // NOTE: These handles are intentionally never dlclose'd. CUDA/TensorRT libraries // have complex cleanup requirements and calling dlclose can cause crashes. // The OS reclaims resources on process exit. // ============================================================================ static void *libcuda_handle = NULL; static void *libnvinfer_handle = NULL; static int cuda_loaded = 0; static int tensorrt_loaded = 0; static std::atomic libs_load_attempted(0); static std::mutex libs_load_mutex; // CUDA Driver API function pointers static fn_cuInit p_cuInit = NULL; static fn_cuDeviceGet p_cuDeviceGet = NULL; static fn_cuDevicePrimaryCtxRetain p_cuDevicePrimaryCtxRetain = NULL; static fn_cuCtxGetCurrent p_cuCtxGetCurrent = NULL; static fn_cuCtxSetCurrent p_cuCtxSetCurrent = NULL; static fn_cuCtxPushCurrent p_cuCtxPushCurrent = NULL; static fn_cuCtxPopCurrent p_cuCtxPopCurrent = NULL; static fn_cuMemAlloc p_cuMemAlloc = NULL; static fn_cuMemFree p_cuMemFree = NULL; static fn_cudaMalloc p_cudaMalloc = NULL; static fn_cudaFree p_cudaFree = NULL; static fn_cudaSetDevice p_cudaSetDevice = NULL; static fn_cudaMemcpy p_cudaMemcpy = NULL; static fn_cudaMemcpyAsync p_cudaMemcpyAsync = NULL; static fn_cudaStreamSynchronize p_cudaStreamSynchronize_rt = NULL; // Runtime API stream sync static fn_cudaStreamCreate p_cudaStreamCreate_rt = NULL; // Runtime API stream create static fn_cudaMemGetInfo p_cudaMemGetInfo = NULL; // For memory diagnostics static void *cuda_rt_handle = NULL; // libcudart.so handle static fn_cuMemcpyHtoD p_cuMemcpyHtoD = NULL; static fn_cuMemcpyDtoH p_cuMemcpyDtoH = NULL; static fn_cuMemcpyHtoDAsync p_cuMemcpyHtoDAsync = NULL; static fn_cuMemcpyDtoHAsync p_cuMemcpyDtoHAsync = NULL; static fn_cuStreamCreate p_cuStreamCreate = NULL; static fn_cuStreamDestroy p_cuStreamDestroy = NULL; static fn_cuStreamSynchronize p_cuStreamSynchronize = NULL; static fn_cuModuleLoadData p_cuModuleLoadData = NULL; static fn_cuModuleUnload p_cuModuleUnload = NULL; static fn_cuModuleGetFunction p_cuModuleGetFunction = NULL; static fn_cuLaunchKernel p_cuLaunchKernel = NULL; static fn_cuGetErrorString p_cuGetErrorString = NULL; // CUDA Graph API function pointers (optional - graceful fallback if unavailable) static fn_cuStreamBeginCapture p_cuStreamBeginCapture = NULL; static fn_cuStreamEndCapture p_cuStreamEndCapture = NULL; static fn_cuGraphInstantiate p_cuGraphInstantiate = NULL; static fn_cuGraphLaunch p_cuGraphLaunch = NULL; static fn_cuGraphDestroy p_cuGraphDestroy = NULL; static fn_cuGraphExecDestroy p_cuGraphExecDestroy = NULL; static int cuda_graphs_available = 0; // TensorRT factory function pointer typedef nvinfer1::IRuntime* (*fn_createInferRuntime)(nvinfer1::ILogger&); static fn_createInferRuntime p_createInferRuntime = NULL; // Forward declaration static int load_libs(void *log_ctx); extern "C" { #include "dnn_io_proc.h" #include "dnn_backend_common.h" #include "libavutil/opt.h" #include "libavutil/mem.h" #include "libavutil/avassert.h" #include "libavutil/internal.h" #include "libavutil/hwcontext.h" #include "libavutil/pixfmt.h" #include "libavutil/pixdesc.h" #include "queue.h" #include "safe_queue.h" #include "dnn_cuda_kernels.h" } // Get CUDA error string static const char* cuda_error_string(CUresult err) { const char *str = NULL; if (p_cuGetErrorString && p_cuGetErrorString(err, &str) == CUDA_SUCCESS && str) return str; return "unknown CUDA error"; } // Load CUDA Driver API and TensorRT via dlopen static int load_libs(void *log_ctx) { // Double-checked locking for thread safety if (libs_load_attempted.load(std::memory_order_acquire)) return (cuda_loaded && tensorrt_loaded) ? 0 : AVERROR(ENOSYS); std::lock_guard lock(libs_load_mutex); // Check again after acquiring lock if (libs_load_attempted.load(std::memory_order_relaxed)) return (cuda_loaded && tensorrt_loaded) ? 0 : AVERROR(ENOSYS); // Set at end of function, not here, to ensure proper initialization // before other threads see libs_load_attempted == 1 // Load CUDA Driver API (libcuda.so - NOT libcudart.so!) const char *cuda_names[] = {"libcuda.so.1", "libcuda.so", NULL}; for (int i = 0; cuda_names[i] && !libcuda_handle; i++) { libcuda_handle = dlopen(cuda_names[i], RTLD_NOW); } if (!libcuda_handle) { av_log(log_ctx, AV_LOG_ERROR, "CUDA driver not available: %s\n" "Install NVIDIA driver or run with --gpus all to use nvidia-container-toolkit\n", dlerror()); libs_load_attempted.store(1, std::memory_order_release); return AVERROR(ENOSYS); } // Load all required CUDA functions #define LOAD_CUDA_FUNC(name) \ p_##name = (fn_##name)dlsym(libcuda_handle, #name); \ if (!p_##name) { \ av_log(log_ctx, AV_LOG_ERROR, "Failed to load CUDA function: %s\n", #name); \ dlclose(libcuda_handle); libcuda_handle = NULL; \ libs_load_attempted.store(1, std::memory_order_release); \ return AVERROR(ENOSYS); \ } LOAD_CUDA_FUNC(cuInit); LOAD_CUDA_FUNC(cuDeviceGet); LOAD_CUDA_FUNC(cuDevicePrimaryCtxRetain); LOAD_CUDA_FUNC(cuCtxGetCurrent); LOAD_CUDA_FUNC(cuCtxSetCurrent); LOAD_CUDA_FUNC(cuCtxPushCurrent); LOAD_CUDA_FUNC(cuCtxPopCurrent); LOAD_CUDA_FUNC(cuMemAlloc); LOAD_CUDA_FUNC(cuMemFree); LOAD_CUDA_FUNC(cuMemcpyHtoD); LOAD_CUDA_FUNC(cuMemcpyDtoH); LOAD_CUDA_FUNC(cuMemcpyHtoDAsync); LOAD_CUDA_FUNC(cuMemcpyDtoHAsync); LOAD_CUDA_FUNC(cuStreamCreate); LOAD_CUDA_FUNC(cuStreamDestroy); LOAD_CUDA_FUNC(cuStreamSynchronize); LOAD_CUDA_FUNC(cuModuleLoadData); LOAD_CUDA_FUNC(cuModuleUnload); LOAD_CUDA_FUNC(cuModuleGetFunction); LOAD_CUDA_FUNC(cuLaunchKernel); // cuGetErrorString is optional (for better error messages) p_cuGetErrorString = (fn_cuGetErrorString)dlsym(libcuda_handle, "cuGetErrorString"); // CUDA Graph API (optional - CUDA 10.0+, graceful fallback if unavailable) p_cuStreamBeginCapture = (fn_cuStreamBeginCapture)dlsym(libcuda_handle, "cuStreamBeginCapture"); p_cuStreamEndCapture = (fn_cuStreamEndCapture)dlsym(libcuda_handle, "cuStreamEndCapture"); p_cuGraphInstantiate = (fn_cuGraphInstantiate)dlsym(libcuda_handle, "cuGraphInstantiateWithFlags"); p_cuGraphLaunch = (fn_cuGraphLaunch)dlsym(libcuda_handle, "cuGraphLaunch"); p_cuGraphDestroy = (fn_cuGraphDestroy)dlsym(libcuda_handle, "cuGraphDestroy"); p_cuGraphExecDestroy = (fn_cuGraphExecDestroy)dlsym(libcuda_handle, "cuGraphExecDestroy"); // CUDA Graphs disabled - adds ~1GB memory overhead with no fps improvement // (TensorRT is compute-bound in convolutions, not kernel-launch-bound) // Keep the function pointers loaded in case we want to re-enable later (void)p_cuStreamBeginCapture; (void)p_cuStreamEndCapture; (void)p_cuGraphInstantiate; (void)p_cuGraphLaunch; (void)p_cuGraphDestroy; (void)p_cuGraphExecDestroy; cuda_graphs_available = 0; #undef LOAD_CUDA_FUNC // Initialize CUDA (needed before any other CUDA calls) CUresult err = p_cuInit(0); if (err != CUDA_SUCCESS) { av_log(log_ctx, AV_LOG_ERROR, "cuInit failed: %s\n", cuda_error_string(err)); dlclose(libcuda_handle); libcuda_handle = NULL; libs_load_attempted.store(1, std::memory_order_release); return AVERROR(ENOSYS); } cuda_loaded = 1; av_log(log_ctx, AV_LOG_INFO, "CUDA driver API loaded via dlopen\n"); // Load CUDA Runtime API (required for TensorRT compatibility) const char *cudart_names[] = { "libcudart.so.12", "libcudart.so.11", "libcudart.so", NULL }; for (int i = 0; cudart_names[i] && !cuda_rt_handle; i++) { cuda_rt_handle = dlopen(cudart_names[i], RTLD_NOW); } if (!cuda_rt_handle) { av_log(log_ctx, AV_LOG_ERROR, "Failed to load CUDA runtime library (libcudart.so)\n"); dlclose(libcuda_handle); libcuda_handle = NULL; cuda_loaded = 0; libs_load_attempted.store(1, std::memory_order_release); return AVERROR(ENOSYS); } p_cudaMalloc = (fn_cudaMalloc)dlsym(cuda_rt_handle, "cudaMalloc"); p_cudaFree = (fn_cudaFree)dlsym(cuda_rt_handle, "cudaFree"); p_cudaSetDevice = (fn_cudaSetDevice)dlsym(cuda_rt_handle, "cudaSetDevice"); p_cudaMemcpy = (fn_cudaMemcpy)dlsym(cuda_rt_handle, "cudaMemcpy"); p_cudaMemcpyAsync = (fn_cudaMemcpyAsync)dlsym(cuda_rt_handle, "cudaMemcpyAsync"); p_cudaStreamSynchronize_rt = (fn_cudaStreamSynchronize)dlsym(cuda_rt_handle, "cudaStreamSynchronize"); p_cudaStreamCreate_rt = (fn_cudaStreamCreate)dlsym(cuda_rt_handle, "cudaStreamCreate"); p_cudaMemGetInfo = (fn_cudaMemGetInfo)dlsym(cuda_rt_handle, "cudaMemGetInfo"); // Optional, for diagnostics if (!p_cudaMalloc || !p_cudaFree || !p_cudaSetDevice || !p_cudaMemcpy) { av_log(log_ctx, AV_LOG_ERROR, "Failed to load CUDA runtime API functions\n"); dlclose(cuda_rt_handle); cuda_rt_handle = NULL; dlclose(libcuda_handle); libcuda_handle = NULL; cuda_loaded = 0; libs_load_attempted.store(1, std::memory_order_release); return AVERROR(ENOSYS); } // Load TensorRT const char *nvinfer_names[] = { "libnvinfer.so.10", "libnvinfer.so.9", "libnvinfer.so.8", "libnvinfer.so", NULL }; for (int i = 0; nvinfer_names[i] && !libnvinfer_handle; i++) { libnvinfer_handle = dlopen(nvinfer_names[i], RTLD_NOW); } if (!libnvinfer_handle) { av_log(log_ctx, AV_LOG_ERROR, "TensorRT not available: %s\n" "Install TensorRT or run with --gpus all to use nvidia-container-toolkit\n", dlerror()); dlclose(cuda_rt_handle); cuda_rt_handle = NULL; dlclose(libcuda_handle); libcuda_handle = NULL; cuda_loaded = 0; libs_load_attempted.store(1, std::memory_order_release); return AVERROR(ENOSYS); } // Get TensorRT factory function const char *create_runtime_names[] = { "createInferRuntime_INTERNAL", // TensorRT 10+ "_ZN8nvinfer118createInferRuntimeERNS_7ILoggerE", // GCC mangling (TRT 8-9) "createInferRuntime", // Some builds export unmangled NULL }; for (int i = 0; create_runtime_names[i] && !p_createInferRuntime; i++) { p_createInferRuntime = (fn_createInferRuntime)dlsym(libnvinfer_handle, create_runtime_names[i]); } if (!p_createInferRuntime) { av_log(log_ctx, AV_LOG_ERROR, "Failed to find createInferRuntime in TensorRT library\n"); dlclose(libnvinfer_handle); libnvinfer_handle = NULL; dlclose(cuda_rt_handle); cuda_rt_handle = NULL; dlclose(libcuda_handle); libcuda_handle = NULL; cuda_loaded = 0; libs_load_attempted.store(1, std::memory_order_release); return AVERROR(ENOSYS); } tensorrt_loaded = 1; av_log(log_ctx, AV_LOG_INFO, "TensorRT library loaded via dlopen\n"); // Mark as attempted AFTER successful initialization (release semantics) libs_load_attempted.store(1, std::memory_order_release); return 0; } // Log GPU memory usage for diagnostics static void log_gpu_memory(void *log_ctx, const char *label) { if (!p_cudaMemGetInfo) return; size_t free_mem = 0, total_mem = 0; if (p_cudaMemGetInfo(&free_mem, &total_mem) == 0) { size_t used_mb = (total_mem - free_mem) / (1024 * 1024); av_log(log_ctx, AV_LOG_WARNING, "GPU_MEM [%s]: %zu MiB used\n", label, used_mb); } } // TensorRT logger - forward to FFmpeg's logging class TRTLogger : public nvinfer1::ILogger { public: void *log_ctx; TRTLogger(void *ctx = nullptr) : log_ctx(ctx) {} void log(Severity severity, const char *msg) noexcept override { int level; switch (severity) { case Severity::kINTERNAL_ERROR: case Severity::kERROR: level = AV_LOG_ERROR; break; case Severity::kWARNING: level = AV_LOG_WARNING; break; case Severity::kINFO: level = AV_LOG_INFO; break; default: level = AV_LOG_DEBUG; break; } av_log(log_ctx, level, "TensorRT: %s\n", msg); } }; // Supported tensor data types typedef enum TRTDataType { TRT_DT_FLOAT32 = 0, // 4 bytes TRT_DT_FLOAT16 = 1, // 2 bytes TRT_DT_BFLOAT16 = 2, // 2 bytes TRT_DT_INT8 = 3, // 1 byte TRT_DT_UINT8 = 4, // 1 byte TRT_DT_UNKNOWN = -1 } TRTDataType; static const char *trt_dtype_name(TRTDataType dt) { switch (dt) { case TRT_DT_FLOAT32: return "FP32"; case TRT_DT_FLOAT16: return "FP16"; case TRT_DT_BFLOAT16: return "BF16"; case TRT_DT_INT8: return "INT8"; case TRT_DT_UINT8: return "UINT8"; default: return "UNKNOWN"; } } static size_t trt_dtype_size(TRTDataType dt) { switch (dt) { case TRT_DT_FLOAT32: return 4; case TRT_DT_FLOAT16: return 2; case TRT_DT_BFLOAT16: return 2; case TRT_DT_INT8: return 1; case TRT_DT_UINT8: return 1; default: return 0; } } static TRTDataType nvinfer_to_trt_dtype(nvinfer1::DataType dt) { switch (dt) { case nvinfer1::DataType::kFLOAT: return TRT_DT_FLOAT32; case nvinfer1::DataType::kHALF: return TRT_DT_FLOAT16; case nvinfer1::DataType::kBF16: return TRT_DT_BFLOAT16; case nvinfer1::DataType::kINT8: return TRT_DT_INT8; case nvinfer1::DataType::kUINT8: return TRT_DT_UINT8; default: return TRT_DT_UNKNOWN; } } typedef struct TRTModel { DNNModel model; DnnContext *ctx; nvinfer1::IRuntime *runtime; nvinfer1::ICudaEngine *engine; nvinfer1::IExecutionContext *context; // Lazily created on first inference TRTLogger *logger; CUstream stream; // CUDA Graphs for reduced kernel launch overhead CUgraph cuda_graph; CUgraphExec cuda_graph_exec; int cuda_graph_captured; // 1 if graph has been captured and is ready int cuda_graph_failed; // 1 if capture failed, don't retry // Engine cache entry (for refcounting shared engines) CachedEngine *cached_engine; char *engine_path; // Key for cache lookup (av_strdup'd) // CUDA kernel module (loaded from PTX) CUmodule cuda_module; // Kernels for each dtype: [0]=FP32, [1]=FP16, [2]=BF16 CUfunction kernel_hwc_to_nchw[3]; CUfunction kernel_nchw_to_hwc[3]; CUfunction kernel_hwc4_to_nchw[3]; CUfunction kernel_nchw_to_hwc4[3]; // I/O tensor info (TensorRT 10.x API) char *input_name; char *output_name; nvinfer1::Dims input_dims; nvinfer1::Dims output_dims; TRTDataType input_dtype; TRTDataType output_dtype; // CUDA buffers (using CUdeviceptr for Driver API) CUdeviceptr input_buffer; CUdeviceptr output_buffer; size_t input_size; size_t output_size; // Task management (reuse FFmpeg's queue infrastructure) SafeQueue *request_queue; Queue *task_queue; Queue *lltask_queue; } TRTModel; typedef struct TRTInferRequest { float *output_data; // CPU output buffer } TRTInferRequest; typedef struct TRTRequestItem { TRTInferRequest *infer_request; LastLevelTaskItem *lltask; DNNAsyncExecModule exec_module; } TRTRequestItem; #define OFFSET(x) offsetof(TRTOptions, x) #define FLAGS AV_OPT_FLAG_FILTERING_PARAM static const AVOption dnn_trt_options[] = { { "device_id", "CUDA device ID", OFFSET(device_id), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, 16, FLAGS }, { NULL } }; // Check CUDA error and log #define CUDA_CHECK(call, ctx, ret) do { \ CUresult cuda_err = (call); \ if (cuda_err != CUDA_SUCCESS) { \ av_log(ctx, AV_LOG_ERROR, "CUDA error: %s\n", cuda_error_string(cuda_err)); \ return ret; \ } \ } while(0) // Load CUDA kernels from embedded PTX static int load_cuda_kernels(TRTModel *trt_model, void *log_ctx) { CUresult err; // Load PTX module err = p_cuModuleLoadData(&trt_model->cuda_module, ff_dnn_cuda_kernels_ptx); if (err != CUDA_SUCCESS) { av_log(log_ctx, AV_LOG_ERROR, "Failed to load CUDA kernel module: %s\n", cuda_error_string(err)); return AVERROR(ENOSYS); } // Kernel names for each dtype: [0]=FP32, [1]=FP16, [2]=BF16 const char *hwc_to_nchw_names[] = { DNN_CUDA_KERNEL_HWC_UINT8_TO_NCHW_FLOAT32, DNN_CUDA_KERNEL_HWC_UINT8_TO_NCHW_FLOAT16, DNN_CUDA_KERNEL_HWC_UINT8_TO_NCHW_BFLOAT16 }; const char *nchw_to_hwc_names[] = { DNN_CUDA_KERNEL_NCHW_FLOAT32_TO_HWC_UINT8, DNN_CUDA_KERNEL_NCHW_FLOAT16_TO_HWC_UINT8, DNN_CUDA_KERNEL_NCHW_BFLOAT16_TO_HWC_UINT8 }; const char *hwc4_to_nchw_names[] = { DNN_CUDA_KERNEL_HWC4_UINT8_TO_NCHW_FLOAT32, DNN_CUDA_KERNEL_HWC4_UINT8_TO_NCHW_FLOAT16, DNN_CUDA_KERNEL_HWC4_UINT8_TO_NCHW_BFLOAT16 }; const char *nchw_to_hwc4_names[] = { DNN_CUDA_KERNEL_NCHW_FLOAT32_TO_HWC4_UINT8, DNN_CUDA_KERNEL_NCHW_FLOAT16_TO_HWC4_UINT8, DNN_CUDA_KERNEL_NCHW_BFLOAT16_TO_HWC4_UINT8 }; // Load all kernel variants for (int i = 0; i < 3; i++) { err = p_cuModuleGetFunction(&trt_model->kernel_hwc_to_nchw[i], trt_model->cuda_module, hwc_to_nchw_names[i]); if (err != CUDA_SUCCESS) { av_log(log_ctx, AV_LOG_ERROR, "Failed to get kernel %s: %s\n", hwc_to_nchw_names[i], cuda_error_string(err)); p_cuModuleUnload(trt_model->cuda_module); trt_model->cuda_module = NULL; return AVERROR(ENOSYS); } err = p_cuModuleGetFunction(&trt_model->kernel_nchw_to_hwc[i], trt_model->cuda_module, nchw_to_hwc_names[i]); if (err != CUDA_SUCCESS) { av_log(log_ctx, AV_LOG_ERROR, "Failed to get kernel %s: %s\n", nchw_to_hwc_names[i], cuda_error_string(err)); p_cuModuleUnload(trt_model->cuda_module); trt_model->cuda_module = NULL; return AVERROR(ENOSYS); } err = p_cuModuleGetFunction(&trt_model->kernel_hwc4_to_nchw[i], trt_model->cuda_module, hwc4_to_nchw_names[i]); if (err != CUDA_SUCCESS) { av_log(log_ctx, AV_LOG_ERROR, "Failed to get kernel %s: %s\n", hwc4_to_nchw_names[i], cuda_error_string(err)); p_cuModuleUnload(trt_model->cuda_module); trt_model->cuda_module = NULL; return AVERROR(ENOSYS); } err = p_cuModuleGetFunction(&trt_model->kernel_nchw_to_hwc4[i], trt_model->cuda_module, nchw_to_hwc4_names[i]); if (err != CUDA_SUCCESS) { av_log(log_ctx, AV_LOG_ERROR, "Failed to get kernel %s: %s\n", nchw_to_hwc4_names[i], cuda_error_string(err)); p_cuModuleUnload(trt_model->cuda_module); trt_model->cuda_module = NULL; return AVERROR(ENOSYS); } } av_log(log_ctx, AV_LOG_INFO, "CUDA format conversion kernels loaded (FP32/FP16/BF16)\n"); return 0; } // Launch kernel with Driver API static int launch_kernel(CUfunction func, CUstream stream, int width, int height, void **args, void *log_ctx) { // 32x8 thread blocks (better warp utilization for row-major image access) unsigned int block_x = 32, block_y = 8; unsigned int grid_x = (width + block_x - 1) / block_x; unsigned int grid_y = (height + block_y - 1) / block_y; CUresult err = p_cuLaunchKernel(func, grid_x, grid_y, 1, // grid dimensions block_x, block_y, 1, // block dimensions 0, // shared memory stream, // stream args, // kernel arguments NULL); // extra if (err != CUDA_SUCCESS) { av_log(log_ctx, AV_LOG_ERROR, "Kernel launch failed: %s\n", cuda_error_string(err)); return AVERROR(EIO); } return 0; } // Lazily create execution context on first inference // Returns 0 on success, negative AVERROR on failure static int ensure_execution_context(TRTModel *trt_model, void *log_ctx) { if (trt_model->context) return 0; // Already created av_log(log_ctx, AV_LOG_INFO, "Creating TensorRT execution context (lazy init)\n"); trt_model->context = trt_model->engine->createExecutionContext(); if (!trt_model->context) { av_log(log_ctx, AV_LOG_ERROR, "Failed to create execution context\n"); return AVERROR(ENOMEM); } log_gpu_memory(log_ctx, "after createExecutionContext"); // Allocate GPU buffers now that we have context if (!trt_model->input_buffer) { void *input_ptr = NULL, *output_ptr = NULL; cudaError_t cuda_err = p_cudaMalloc(&input_ptr, trt_model->input_size); if (cuda_err != cudaSuccess) { av_log(log_ctx, AV_LOG_ERROR, "cudaMalloc failed for input buffer: %d\n", cuda_err); delete trt_model->context; trt_model->context = nullptr; return AVERROR(ENOMEM); } trt_model->input_buffer = (CUdeviceptr)input_ptr; cuda_err = p_cudaMalloc(&output_ptr, trt_model->output_size); if (cuda_err != cudaSuccess) { av_log(log_ctx, AV_LOG_ERROR, "cudaMalloc failed for output buffer: %d\n", cuda_err); p_cudaFree((void*)trt_model->input_buffer); trt_model->input_buffer = 0; delete trt_model->context; trt_model->context = nullptr; return AVERROR(ENOMEM); } trt_model->output_buffer = (CUdeviceptr)output_ptr; // Set tensor addresses if (!trt_model->context->setTensorAddress(trt_model->input_name, (void*)trt_model->input_buffer)) { av_log(log_ctx, AV_LOG_ERROR, "Failed to set input tensor address\n"); p_cudaFree((void*)trt_model->input_buffer); p_cudaFree((void*)trt_model->output_buffer); trt_model->input_buffer = 0; trt_model->output_buffer = 0; delete trt_model->context; trt_model->context = nullptr; return AVERROR(EINVAL); } if (!trt_model->context->setTensorAddress(trt_model->output_name, (void*)trt_model->output_buffer)) { av_log(log_ctx, AV_LOG_ERROR, "Failed to set output tensor address\n"); p_cudaFree((void*)trt_model->input_buffer); p_cudaFree((void*)trt_model->output_buffer); trt_model->input_buffer = 0; trt_model->output_buffer = 0; delete trt_model->context; trt_model->context = nullptr; return AVERROR(EINVAL); } av_log(log_ctx, AV_LOG_INFO, " Allocated GPU buffers: input=%zuMB output=%zuMB\n", trt_model->input_size / (1024 * 1024), trt_model->output_size / (1024 * 1024)); log_gpu_memory(log_ctx, "after buffer allocation"); } return 0; } static int extract_lltask_from_task(TaskItem *task, Queue *lltask_queue) { TRTModel *trt_model = (TRTModel *)task->model; DnnContext *ctx = trt_model->ctx; LastLevelTaskItem *lltask = (LastLevelTaskItem *)av_malloc(sizeof(*lltask)); if (!lltask) { av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for LastLevelTaskItem\n"); return AVERROR(ENOMEM); } task->inference_todo = 1; task->inference_done = 0; lltask->task = task; if (ff_queue_push_back(lltask_queue, lltask) < 0) { av_log(ctx, AV_LOG_ERROR, "Failed to push back lltask_queue.\n"); av_freep(&lltask); return AVERROR(ENOMEM); } return 0; } static void trt_free_request(TRTInferRequest *request) { if (!request) return; if (request->output_data) { av_freep(&request->output_data); } } static inline void destroy_request_item(TRTRequestItem **arg) { TRTRequestItem *item; if (!arg || !*arg) return; item = *arg; trt_free_request(item->infer_request); av_freep(&item->infer_request); av_freep(&item->lltask); ff_dnn_async_module_cleanup(&item->exec_module); av_freep(arg); } static void dnn_free_model_trt(DNNModel **model) { TRTModel *trt_model; if (!model || !*model) return; trt_model = (TRTModel *)(*model); // Synchronize stream before cleanup to ensure all GPU operations complete if (trt_model->stream && p_cuStreamSynchronize) { p_cuStreamSynchronize(trt_model->stream); } // Free CUDA Graph resources if (trt_model->cuda_graph_exec && p_cuGraphExecDestroy) { p_cuGraphExecDestroy(trt_model->cuda_graph_exec); trt_model->cuda_graph_exec = NULL; } if (trt_model->cuda_graph && p_cuGraphDestroy) { p_cuGraphDestroy(trt_model->cuda_graph); trt_model->cuda_graph = NULL; } // Free CUDA resources (using Runtime API - must match cudaMalloc allocation) if (trt_model->input_buffer && p_cudaFree) { p_cudaFree((void*)trt_model->input_buffer); trt_model->input_buffer = 0; } if (trt_model->output_buffer && p_cudaFree) { p_cudaFree((void*)trt_model->output_buffer); trt_model->output_buffer = 0; } if (trt_model->cuda_module && p_cuModuleUnload) { p_cuModuleUnload(trt_model->cuda_module); trt_model->cuda_module = NULL; } if (trt_model->stream && p_cuStreamDestroy) { p_cuStreamDestroy(trt_model->stream); trt_model->stream = NULL; } // Free tensor names (engine_path freed after cache cleanup) av_freep(&trt_model->input_name); av_freep(&trt_model->output_name); // Free TensorRT resources if (trt_model->context) { delete trt_model->context; trt_model->context = nullptr; } // Handle cached engine - decrement refcount and only free when last reference if (trt_model->cached_engine) { std::lock_guard lock(g_engine_cache_mutex); // Use atomic fetch_sub to avoid race condition - returns OLD value int old_refcount = trt_model->cached_engine->refcount.fetch_sub(1, std::memory_order_acq_rel); int remaining = old_refcount - 1; av_log(trt_model->ctx, AV_LOG_DEBUG, "Engine refcount: %d -> %d (path=%s)\n", old_refcount, remaining, trt_model->engine_path ? trt_model->engine_path : "null"); if (remaining == 0) { // Last reference - remove from cache and delete if (trt_model->engine_path) { std::string path_key(trt_model->engine_path); size_t erased = g_engine_cache.erase(path_key); av_log(trt_model->ctx, AV_LOG_DEBUG, "Erased %zu entries from cache\n", erased); } if (trt_model->cached_engine->engine) { delete trt_model->cached_engine->engine; } if (trt_model->cached_engine->runtime) { delete trt_model->cached_engine->runtime; } delete trt_model->cached_engine; av_log(trt_model->ctx, AV_LOG_DEBUG, "Released last reference to cached engine\n"); } else if (remaining < 0) { av_log(trt_model->ctx, AV_LOG_ERROR, "BUG: Engine refcount went negative (%d)!\n", remaining); } else { av_log(trt_model->ctx, AV_LOG_DEBUG, "Released engine reference (remaining=%d)\n", remaining); } trt_model->cached_engine = nullptr; trt_model->engine = nullptr; trt_model->runtime = nullptr; } else { // Not cached (shouldn't happen normally, but handle gracefully) if (trt_model->engine) { delete trt_model->engine; trt_model->engine = nullptr; } if (trt_model->runtime) { delete trt_model->runtime; trt_model->runtime = nullptr; } } if (trt_model->logger) { delete trt_model->logger; trt_model->logger = nullptr; } // Free engine path (after cache cleanup which uses it) av_freep(&trt_model->engine_path); // Free queues if (trt_model->request_queue) { while (ff_safe_queue_size(trt_model->request_queue) != 0) { TRTRequestItem *item = (TRTRequestItem *)ff_safe_queue_pop_front(trt_model->request_queue); destroy_request_item(&item); } ff_safe_queue_destroy(trt_model->request_queue); } if (trt_model->lltask_queue) { while (ff_queue_size(trt_model->lltask_queue) != 0) { LastLevelTaskItem *item = (LastLevelTaskItem *)ff_queue_pop_front(trt_model->lltask_queue); av_freep(&item); } ff_queue_destroy(trt_model->lltask_queue); } if (trt_model->task_queue) { while (ff_queue_size(trt_model->task_queue) != 0) { TaskItem *item = (TaskItem *)ff_queue_pop_front(trt_model->task_queue); av_frame_free(&item->in_frame); av_frame_free(&item->out_frame); av_freep(&item); } ff_queue_destroy(trt_model->task_queue); } av_freep(&trt_model); *model = NULL; } static int get_input_trt(DNNModel *model, DNNData *input, const char *input_name) { TRTModel *trt_model = (TRTModel *)model; // Validate tensor has expected dimensions (NCHW = 4) if (trt_model->input_dims.nbDims != 4) { av_log(trt_model->ctx, AV_LOG_ERROR, "Expected 4D input tensor (NCHW), got %d dimensions\n", trt_model->input_dims.nbDims); return AVERROR(EINVAL); } input->dt = DNN_FLOAT; input->order = DCO_RGB; input->layout = DL_NCHW; // Get dimensions from engine input->dims[0] = trt_model->input_dims.d[0]; // N (batch) input->dims[1] = trt_model->input_dims.d[1]; // C (channels) input->dims[2] = trt_model->input_dims.d[2]; // H (height) input->dims[3] = trt_model->input_dims.d[3]; // W (width) return 0; } static int fill_model_input_trt(TRTModel *trt_model, TRTRequestItem *request) { LastLevelTaskItem *lltask = NULL; TaskItem *task = NULL; DNNData input = { 0 }; DnnContext *ctx = trt_model->ctx; int ret; // Ensure execution context and buffers are created (lazy initialization) ret = ensure_execution_context(trt_model, ctx); if (ret < 0) { return ret; } lltask = (LastLevelTaskItem *)ff_queue_pop_front(trt_model->lltask_queue); if (!lltask) { return AVERROR(EINVAL); } request->lltask = lltask; task = lltask->task; ret = get_input_trt(&trt_model->model, &input, NULL); if (ret != 0) { return ret; } int height_idx = dnn_get_height_idx_by_layout(input.layout); int width_idx = dnn_get_width_idx_by_layout(input.layout); // Check input dimensions match engine if (task->in_frame->height != input.dims[height_idx] || task->in_frame->width != input.dims[width_idx]) { av_log(ctx, AV_LOG_ERROR, "Input size %dx%d doesn't match engine's expected %dx%d\n", task->in_frame->width, task->in_frame->height, input.dims[width_idx], input.dims[height_idx]); return AVERROR(EINVAL); } int width = task->in_frame->width; int height = task->in_frame->height; // Check for CUDA hardware frames (zero-copy input path) if (task->in_frame->format == AV_PIX_FMT_CUDA && task->in_frame->hw_frames_ctx) { AVHWFramesContext *hw_frames = (AVHWFramesContext *)task->in_frame->hw_frames_ctx->data; int linesize = task->in_frame->linesize[0]; CUdeviceptr cuda_data = (CUdeviceptr)task->in_frame->data[0]; int dtype_idx = (int)trt_model->input_dtype; // Kernel array index: 0=FP32, 1=FP16, 2=BF16 // For RGB24/BGR24: convert uint8 HWC to NCHW on GPU (zero-copy) if (hw_frames->sw_format == AV_PIX_FMT_RGB24 || hw_frames->sw_format == AV_PIX_FMT_BGR24) { void *args[] = {&cuda_data, &trt_model->input_buffer, &height, &width, &linesize}; ret = launch_kernel(trt_model->kernel_hwc_to_nchw[dtype_idx], trt_model->stream, width, height, args, ctx); if (ret != 0) return ret; return 0; } // For 4-channel formats (RGB0, RGBA, BGR0, BGRA) if (hw_frames->sw_format == AV_PIX_FMT_RGB0 || hw_frames->sw_format == AV_PIX_FMT_BGR0 || hw_frames->sw_format == AV_PIX_FMT_RGBA || hw_frames->sw_format == AV_PIX_FMT_BGRA) { int r_off = 0, g_off = 1, b_off = 2; if (hw_frames->sw_format == AV_PIX_FMT_BGR0 || hw_frames->sw_format == AV_PIX_FMT_BGRA) { r_off = 2; b_off = 0; } void *args[] = {&cuda_data, &trt_model->input_buffer, &height, &width, &linesize, &r_off, &g_off, &b_off}; ret = launch_kernel(trt_model->kernel_hwc4_to_nchw[dtype_idx], trt_model->stream, width, height, args, ctx); if (ret != 0) return ret; return 0; } // For 0RGB/ARGB formats (alpha first) if (hw_frames->sw_format == AV_PIX_FMT_0RGB || hw_frames->sw_format == AV_PIX_FMT_0BGR || hw_frames->sw_format == AV_PIX_FMT_ARGB || hw_frames->sw_format == AV_PIX_FMT_ABGR) { int r_off = 1, g_off = 2, b_off = 3; if (hw_frames->sw_format == AV_PIX_FMT_0BGR || hw_frames->sw_format == AV_PIX_FMT_ABGR) { r_off = 3; b_off = 1; } void *args[] = {&cuda_data, &trt_model->input_buffer, &height, &width, &linesize, &r_off, &g_off, &b_off}; ret = launch_kernel(trt_model->kernel_hwc4_to_nchw[dtype_idx], trt_model->stream, width, height, args, ctx); if (ret != 0) return ret; return 0; } av_log(ctx, AV_LOG_WARNING, "CUDA sw_format %s not supported for zero-copy, using CPU path\n", av_get_pix_fmt_name(hw_frames->sw_format)); } // Standard CPU path - only supports FP32 engines // For FP16/BF16, use CUDA hw frames for zero-copy path if (trt_model->input_dtype != TRT_DT_FLOAT32) { av_log(ctx, AV_LOG_ERROR, "CPU input path only supports FP32 engines, got %s. " "Use hwupload to provide CUDA frames for FP16/BF16 zero-copy.\n", trt_dtype_name(trt_model->input_dtype)); return AVERROR(ENOSYS); } size_t input_elements = input.dims[0] * input.dims[1] * input.dims[2] * input.dims[3]; float *input_data = (float *)av_malloc(input_elements * sizeof(float)); if (!input_data) { return AVERROR(ENOMEM); } input.data = input_data; input.scale = 255; switch (trt_model->model.func_type) { case DFT_PROCESS_FRAME: if (task->do_ioproc) { if (trt_model->model.frame_pre_proc != NULL) { trt_model->model.frame_pre_proc(task->in_frame, &input, trt_model->model.filter_ctx); } else { ff_proc_from_frame_to_dnn(task->in_frame, &input, ctx); } } break; default: av_log(ctx, AV_LOG_ERROR, "Unsupported model function type %d\n", trt_model->model.func_type); av_freep(&input_data); return AVERROR(EINVAL); } // Copy input to GPU using CUDA Runtime API (compatible with TensorRT's Runtime API context) cudaError_t cuda_err = p_cudaMemcpy((void*)trt_model->input_buffer, input_data, trt_model->input_size, cudaMemcpyHostToDevice); if (cuda_err != cudaSuccess) { av_log(ctx, AV_LOG_ERROR, "cudaMemcpy failed for input: %d\n", cuda_err); av_freep(&input_data); return AVERROR(EIO); } av_freep(&input_data); return 0; } // Capture TensorRT inference into a CUDA Graph for reduced kernel launch overhead // Returns 0 on success, negative on failure (non-fatal, falls back to regular enqueue) static int trt_capture_cuda_graph(TRTModel *trt_model, void *log_ctx) { CUresult err; if (!cuda_graphs_available || trt_model->cuda_graph_failed) { return -1; } av_log(log_ctx, AV_LOG_INFO, "Capturing TensorRT inference into CUDA Graph...\n"); // Synchronize stream before capture to ensure any pending work (e.g., input kernel) completes // This prevents undefined behavior from capturing a stream with pending async operations err = p_cuStreamSynchronize(trt_model->stream); if (err != CUDA_SUCCESS) { av_log(log_ctx, AV_LOG_WARNING, "Stream sync before graph capture failed: %s\n", cuda_error_string(err)); trt_model->cuda_graph_failed = 1; return -1; } // Begin stream capture err = p_cuStreamBeginCapture(trt_model->stream, CU_STREAM_CAPTURE_MODE_GLOBAL); if (err != CUDA_SUCCESS) { av_log(log_ctx, AV_LOG_WARNING, "CUDA Graph capture begin failed: %s\n", cuda_error_string(err)); trt_model->cuda_graph_failed = 1; return -1; } // Execute TensorRT inference (this gets captured into the graph) bool success = trt_model->context->enqueueV3((cudaStream_t)trt_model->stream); if (!success) { // End capture to clean up, ignore the graph CUgraph temp_graph = NULL; p_cuStreamEndCapture(trt_model->stream, &temp_graph); if (temp_graph) p_cuGraphDestroy(temp_graph); av_log(log_ctx, AV_LOG_WARNING, "TensorRT inference failed during graph capture\n"); trt_model->cuda_graph_failed = 1; return -1; } // End stream capture err = p_cuStreamEndCapture(trt_model->stream, &trt_model->cuda_graph); if (err != CUDA_SUCCESS || !trt_model->cuda_graph) { av_log(log_ctx, AV_LOG_WARNING, "CUDA Graph capture end failed: %s\n", cuda_error_string(err)); trt_model->cuda_graph_failed = 1; return -1; } // Instantiate the graph for execution err = p_cuGraphInstantiate(&trt_model->cuda_graph_exec, trt_model->cuda_graph, 0); if (err != CUDA_SUCCESS || !trt_model->cuda_graph_exec) { av_log(log_ctx, AV_LOG_WARNING, "CUDA Graph instantiate failed: %s\n", cuda_error_string(err)); p_cuGraphDestroy(trt_model->cuda_graph); trt_model->cuda_graph = NULL; trt_model->cuda_graph_failed = 1; return -1; } trt_model->cuda_graph_captured = 1; av_log(log_ctx, AV_LOG_INFO, "CUDA Graph captured successfully (reduced kernel launch overhead)\n"); return 0; } static int trt_start_inference(void *args) { TRTRequestItem *request = (TRTRequestItem *)args; LastLevelTaskItem *lltask; TaskItem *task; TRTModel *trt_model; DnnContext *ctx; if (!request || !request->lltask) { av_log(NULL, AV_LOG_ERROR, "TRTRequestItem or lltask is NULL\n"); return AVERROR(EINVAL); } lltask = request->lltask; task = lltask->task; trt_model = (TRTModel *)task->model; ctx = trt_model->ctx; // Validate required resources exist if (!trt_model->context || !trt_model->stream) { av_log(ctx, AV_LOG_ERROR, "TensorRT context or CUDA stream not initialized\n"); return DNN_GENERIC_ERROR; } // NOTE: Tensor addresses are set once during model load (not per-frame) // since input/output buffers are persistent // Try to use CUDA Graph for reduced kernel launch overhead // First frame: capture the graph; subsequent frames: launch the captured graph if (cuda_graphs_available && !trt_model->cuda_graph_captured && !trt_model->cuda_graph_failed) { // Capture on first inference if (trt_capture_cuda_graph(trt_model, ctx) == 0) { // Graph captured - we already ran inference during capture, so return return 0; } // Capture failed - fall through to regular execution } if (trt_model->cuda_graph_captured && trt_model->cuda_graph_exec) { // Execute the captured graph (much lower overhead than enqueueV3) CUresult err = p_cuGraphLaunch(trt_model->cuda_graph_exec, trt_model->stream); if (err != CUDA_SUCCESS) { av_log(ctx, AV_LOG_ERROR, "CUDA Graph launch failed: %s\n", cuda_error_string(err)); return DNN_GENERIC_ERROR; } } else { // Regular execution path (fallback if graphs not available or capture failed) bool success = trt_model->context->enqueueV3((cudaStream_t)trt_model->stream); if (!success) { av_log(ctx, AV_LOG_ERROR, "TensorRT inference failed\n"); return DNN_GENERIC_ERROR; } } // NOTE: No sync here - for zero-copy paths, we sync once after the output kernel // For CPU paths, we sync before cudaMemcpy DtoH in infer_completion_callback return 0; } static void infer_completion_callback(void *args) { TRTRequestItem *request = (TRTRequestItem *)args; LastLevelTaskItem *lltask = request->lltask; TaskItem *task = lltask->task; TRTModel *trt_model = (TRTModel *)task->model; DnnContext *ctx = trt_model->ctx; DNNData outputs = { 0 }; float *output_data = NULL; size_t output_elements; int ret; // Output dimensions are validated during model loading, safe to access outputs.order = DCO_RGB; outputs.layout = DL_NCHW; outputs.dt = DNN_FLOAT; outputs.dims[0] = trt_model->output_dims.d[0]; // N outputs.dims[1] = trt_model->output_dims.d[1]; // C outputs.dims[2] = trt_model->output_dims.d[2]; // H outputs.dims[3] = trt_model->output_dims.d[3]; // W int out_height = outputs.dims[2]; int out_width = outputs.dims[3]; // Validate stream exists (should always be true if model loaded successfully) if (!trt_model->stream) { av_log(ctx, AV_LOG_ERROR, "CUDA stream is NULL\n"); goto err; } // Check for CUDA output frames (zero-copy output path) if (task->out_frame->format == AV_PIX_FMT_CUDA && task->out_frame->hw_frames_ctx) { AVHWFramesContext *hw_frames = (AVHWFramesContext *)task->out_frame->hw_frames_ctx->data; int out_linesize = task->out_frame->linesize[0]; CUdeviceptr cuda_out = (CUdeviceptr)task->out_frame->data[0]; int dtype_idx = (int)trt_model->output_dtype; // Kernel array index: 0=FP32, 1=FP16, 2=BF16 // For RGB24/BGR24: convert NCHW to uint8 HWC on GPU (zero-copy) if (hw_frames->sw_format == AV_PIX_FMT_RGB24 || hw_frames->sw_format == AV_PIX_FMT_BGR24) { void *args[] = {&trt_model->output_buffer, &cuda_out, &out_height, &out_width, &out_linesize}; ret = launch_kernel(trt_model->kernel_nchw_to_hwc[dtype_idx], trt_model->stream, out_width, out_height, args, ctx); if (ret != 0) goto err; if (p_cuStreamSynchronize(trt_model->stream) != CUDA_SUCCESS) { av_log(ctx, AV_LOG_ERROR, "CUDA stream sync failed\n"); goto err; } task->out_frame->width = out_width; task->out_frame->height = out_height; task->inference_done++; goto done; } // For 4-channel formats (RGB0, RGBA, BGR0, BGRA) if (hw_frames->sw_format == AV_PIX_FMT_RGB0 || hw_frames->sw_format == AV_PIX_FMT_BGR0 || hw_frames->sw_format == AV_PIX_FMT_RGBA || hw_frames->sw_format == AV_PIX_FMT_BGRA) { int r_off = 0, g_off = 1, b_off = 2, a_off = 3; if (hw_frames->sw_format == AV_PIX_FMT_BGR0 || hw_frames->sw_format == AV_PIX_FMT_BGRA) { r_off = 2; b_off = 0; } void *args[] = {&trt_model->output_buffer, &cuda_out, &out_height, &out_width, &out_linesize, &r_off, &g_off, &b_off, &a_off}; ret = launch_kernel(trt_model->kernel_nchw_to_hwc4[dtype_idx], trt_model->stream, out_width, out_height, args, ctx); if (ret != 0) goto err; if (p_cuStreamSynchronize(trt_model->stream) != CUDA_SUCCESS) { av_log(ctx, AV_LOG_ERROR, "CUDA stream sync failed\n"); goto err; } task->out_frame->width = out_width; task->out_frame->height = out_height; task->inference_done++; goto done; } // For 0RGB/ARGB formats (alpha first) if (hw_frames->sw_format == AV_PIX_FMT_0RGB || hw_frames->sw_format == AV_PIX_FMT_0BGR || hw_frames->sw_format == AV_PIX_FMT_ARGB || hw_frames->sw_format == AV_PIX_FMT_ABGR) { int r_off = 1, g_off = 2, b_off = 3, a_off = 0; if (hw_frames->sw_format == AV_PIX_FMT_0BGR || hw_frames->sw_format == AV_PIX_FMT_ABGR) { r_off = 3; b_off = 1; } void *args[] = {&trt_model->output_buffer, &cuda_out, &out_height, &out_width, &out_linesize, &r_off, &g_off, &b_off, &a_off}; ret = launch_kernel(trt_model->kernel_nchw_to_hwc4[dtype_idx], trt_model->stream, out_width, out_height, args, ctx); if (ret != 0) goto err; if (p_cuStreamSynchronize(trt_model->stream) != CUDA_SUCCESS) { av_log(ctx, AV_LOG_ERROR, "CUDA stream sync failed\n"); goto err; } task->out_frame->width = out_width; task->out_frame->height = out_height; task->inference_done++; goto done; } av_log(ctx, AV_LOG_WARNING, "CUDA output sw_format %s not supported for zero-copy, using CPU path\n", av_get_pix_fmt_name(hw_frames->sw_format)); } // Standard CPU path - only supports FP32 engines // For FP16/BF16, use CUDA hw frames for zero-copy path if (trt_model->output_dtype != TRT_DT_FLOAT32) { av_log(ctx, AV_LOG_ERROR, "CPU output path only supports FP32 engines, got %s. " "Use hwupload to provide CUDA frames for FP16/BF16 zero-copy.\n", trt_dtype_name(trt_model->output_dtype)); goto err; } output_elements = outputs.dims[0] * outputs.dims[1] * outputs.dims[2] * outputs.dims[3]; output_data = (float *)av_malloc(output_elements * sizeof(float)); if (!output_data) { av_log(ctx, AV_LOG_ERROR, "Failed to allocate output buffer\n"); goto err; } // Sync stream before copying (inference runs async on stream) if (p_cudaStreamSynchronize_rt) { cudaError_t sync_err = p_cudaStreamSynchronize_rt((cudaStream_t)trt_model->stream); if (sync_err != cudaSuccess) { av_log(ctx, AV_LOG_ERROR, "Stream sync failed before output copy: %d\n", sync_err); av_freep(&output_data); goto err; } } else { // Fallback to Driver API sync CUresult err = p_cuStreamSynchronize(trt_model->stream); if (err != CUDA_SUCCESS) { av_log(ctx, AV_LOG_ERROR, "Stream sync failed: %s\n", cuda_error_string(err)); av_freep(&output_data); goto err; } } // Copy output from GPU using CUDA Runtime API { cudaError_t cuda_err = p_cudaMemcpy(output_data, (void*)trt_model->output_buffer, trt_model->output_size, cudaMemcpyDeviceToHost); if (cuda_err != cudaSuccess) { av_log(ctx, AV_LOG_ERROR, "cudaMemcpy failed for output: %d\n", cuda_err); av_freep(&output_data); goto err; } } switch (trt_model->model.func_type) { case DFT_PROCESS_FRAME: if (task->do_ioproc) { outputs.scale = 255; outputs.data = output_data; if (trt_model->model.frame_post_proc != NULL) { trt_model->model.frame_post_proc(task->out_frame, &outputs, trt_model->model.filter_ctx); } else { ff_proc_from_dnn_to_frame(task->out_frame, &outputs, ctx); } } else { task->out_frame->width = out_width; task->out_frame->height = out_height; } break; default: av_log(ctx, AV_LOG_ERROR, "Unsupported model function type %d\n", trt_model->model.func_type); av_freep(&output_data); goto err; } av_freep(&output_data); task->inference_done++; goto done; err: // Increment inference_done even on error so task completion tracking works // The caller can detect failure through other means (e.g., frame validation) task->inference_done++; done: av_freep(&request->lltask); if (ff_safe_queue_push_back(trt_model->request_queue, request) < 0) { destroy_request_item(&request); av_log(ctx, AV_LOG_ERROR, "Unable to push back request_queue.\n"); } } static int execute_model_trt(TRTRequestItem *request, Queue *lltask_queue) { TRTModel *trt_model = NULL; LastLevelTaskItem *lltask; TaskItem *task = NULL; int ret = 0; if (ff_queue_size(lltask_queue) == 0) { destroy_request_item(&request); return 0; } lltask = (LastLevelTaskItem *)ff_queue_peek_front(lltask_queue); if (lltask == NULL) { av_log(NULL, AV_LOG_ERROR, "Failed to get LastLevelTaskItem\n"); ret = AVERROR(EINVAL); goto err; } task = lltask->task; trt_model = (TRTModel *)task->model; ret = fill_model_input_trt(trt_model, request); if (ret != 0) { goto err; } // Synchronous execution (TensorRT is fast, async adds complexity) ret = trt_start_inference((void *)request); if (ret != 0) { goto err; } infer_completion_callback(request); return (task->inference_done == task->inference_todo) ? 0 : DNN_GENERIC_ERROR; err: trt_free_request(request->infer_request); av_freep(&request->lltask); // Free lltask that was popped from queue // Clean up the task that was left in task_queue to prevent memory leak // The task is at the back since we just pushed it in dnn_execute_model_trt if (trt_model && task) { // Remove task from queue - it should be the last one we added // Iterate to find and remove it (safer than assuming position) Queue *tq = trt_model->task_queue; size_t queue_size = ff_queue_size(tq); for (size_t i = 0; i < queue_size; i++) { TaskItem *queued_task = (TaskItem *)ff_queue_peek_front(tq); if (queued_task == task) { ff_queue_pop_front(tq); av_frame_free(&task->in_frame); av_frame_free(&task->out_frame); av_freep(&task); break; } // Move to next by popping and re-pushing (rotate queue) ff_queue_pop_front(tq); ff_queue_push_back(tq, queued_task); } } if (!trt_model || ff_safe_queue_push_back(trt_model->request_queue, request) < 0) { destroy_request_item(&request); } return ret; } static int get_output_trt(DNNModel *model, const char *input_name, int input_width, int input_height, const char *output_name, int *output_width, int *output_height) { TRTModel *trt_model = (TRTModel *)model; // Get from engine's output dimensions *output_width = trt_model->output_dims.d[3]; *output_height = trt_model->output_dims.d[2]; return 0; } static TRTInferRequest *trt_create_inference_request(void) { TRTInferRequest *request = (TRTInferRequest *)av_mallocz(sizeof(TRTInferRequest)); return request; } static DNNModel *dnn_load_model_trt(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *filter_ctx) { TRTModel *trt_model = NULL; TRTRequestItem *item = NULL; CUresult err; trt_model = (TRTModel *)av_mallocz(sizeof(TRTModel)); if (!trt_model) return NULL; trt_model->ctx = ctx; // Load CUDA and TensorRT libraries via dlopen if (load_libs(ctx) < 0) { goto fail; } log_gpu_memory(ctx, "after load_libs"); // Set CUDA device using Runtime API for TensorRT compatibility if (p_cudaSetDevice) { int device_id = ctx->trt_option.device_id; cudaError_t cuda_err = p_cudaSetDevice(device_id); if (cuda_err != cudaSuccess) { av_log(ctx, AV_LOG_ERROR, "cudaSetDevice(%d) failed: %d\n", device_id, cuda_err); goto fail; } av_log(ctx, AV_LOG_DEBUG, "Set CUDA device %d for TensorRT\n", device_id); } // Create TensorRT logger (cleaned up by dnn_free_model_trt on any failure path) trt_model->logger = new TRTLogger(ctx); // Check engine cache first (avoid reloading same engine file) trt_model->engine_path = av_strdup(ctx->model_filename); if (!trt_model->engine_path) { av_log(ctx, AV_LOG_ERROR, "Failed to allocate engine path\n"); goto fail; } { std::lock_guard lock(g_engine_cache_mutex); std::string path_key(trt_model->engine_path); av_log(ctx, AV_LOG_DEBUG, "Checking engine cache for: %s (cache size=%zu)\n", trt_model->engine_path, g_engine_cache.size()); auto it = g_engine_cache.find(path_key); if (it != g_engine_cache.end()) { // Found in cache - reuse existing engine trt_model->cached_engine = it->second; if (!trt_model->cached_engine->engine || !trt_model->cached_engine->runtime) { av_log(ctx, AV_LOG_ERROR, "BUG: Cached engine has NULL pointers! Removing stale entry.\n"); g_engine_cache.erase(it); trt_model->cached_engine = nullptr; } else { // Use atomic fetch_add to avoid race condition int new_refcount = trt_model->cached_engine->refcount.fetch_add(1, std::memory_order_acq_rel) + 1; trt_model->engine = trt_model->cached_engine->engine; trt_model->runtime = trt_model->cached_engine->runtime; av_log(ctx, AV_LOG_INFO, "Reusing cached TensorRT engine (refcount=%d, engine=%p)\n", new_refcount, (void*)trt_model->engine); } } if (!trt_model->cached_engine) { av_log(ctx, AV_LOG_DEBUG, "Engine not in cache, will load from file\n"); } } // If not in cache, load engine from file if (!trt_model->engine) { // Create runtime using dynamically loaded function trt_model->runtime = p_createInferRuntime(*trt_model->logger); if (!trt_model->runtime) { av_log(ctx, AV_LOG_ERROR, "Failed to create TensorRT runtime\n"); goto fail; } // Load engine from file { std::ifstream file(ctx->model_filename, std::ios::binary | std::ios::ate); if (!file.is_open()) { av_log(ctx, AV_LOG_ERROR, "Failed to open engine file: %s\n", ctx->model_filename); goto fail; } std::streampos pos = file.tellg(); if (pos == std::streampos(-1) || pos <= 0) { av_log(ctx, AV_LOG_ERROR, "Engine file is empty or unreadable: %s\n", ctx->model_filename); goto fail; } size_t size = static_cast(pos); file.seekg(0, std::ios::beg); std::vector buffer(size); if (!file.read(buffer.data(), size)) { av_log(ctx, AV_LOG_ERROR, "Failed to read engine file\n"); goto fail; } trt_model->engine = trt_model->runtime->deserializeCudaEngine(buffer.data(), size); if (!trt_model->engine) { av_log(ctx, AV_LOG_ERROR, "Failed to deserialize CUDA engine\n"); goto fail; } log_gpu_memory(ctx, "after engine deserialize"); } // Add to cache { std::lock_guard lock(g_engine_cache_mutex); std::string path_key(trt_model->engine_path); // Double-check another thread didn't add it while we were loading auto it = g_engine_cache.find(path_key); if (it == g_engine_cache.end()) { trt_model->cached_engine = new CachedEngine(trt_model->engine, trt_model->runtime); g_engine_cache[path_key] = trt_model->cached_engine; av_log(ctx, AV_LOG_INFO, "Added TensorRT engine to cache\n"); } else { // Another thread added it - use theirs, discard ours delete trt_model->engine; delete trt_model->runtime; trt_model->cached_engine = it->second; // Use atomic fetch_add to avoid race condition int new_refcount = trt_model->cached_engine->refcount.fetch_add(1, std::memory_order_acq_rel) + 1; trt_model->engine = trt_model->cached_engine->engine; trt_model->runtime = trt_model->cached_engine->runtime; av_log(ctx, AV_LOG_INFO, "Using engine added by another thread (refcount=%d)\n", new_refcount); } } } // NOTE: Execution context is created lazily on first inference (saves ~720MB for probe instances) // FFmpeg creates two filter instances: one for probing (never runs inference) and one for execution trt_model->context = nullptr; // CUDA Graph state (captured lazily on first inference) trt_model->cuda_graph = NULL; trt_model->cuda_graph_exec = NULL; trt_model->cuda_graph_captured = 0; trt_model->cuda_graph_failed = 0; // Create CUDA stream for TensorRT operations err = p_cuStreamCreate(&trt_model->stream, 0); if (err != CUDA_SUCCESS) { av_log(ctx, AV_LOG_ERROR, "Failed to create CUDA stream: %s\n", cuda_error_string(err)); goto fail; } // Load CUDA kernels from embedded PTX if (load_cuda_kernels(trt_model, ctx) < 0) { goto fail; } log_gpu_memory(ctx, "after load_cuda_kernels"); // Get I/O tensor info (TensorRT 10.x API) { int nb_io_tensors = trt_model->engine->getNbIOTensors(); if (nb_io_tensors < 2) { av_log(ctx, AV_LOG_ERROR, "Engine must have at least 2 tensors (input and output), got %d\n", nb_io_tensors); goto fail; } // Find input and output tensors for (int i = 0; i < nb_io_tensors; i++) { const char *name = trt_model->engine->getIOTensorName(i); nvinfer1::TensorIOMode mode = trt_model->engine->getTensorIOMode(name); if (mode == nvinfer1::TensorIOMode::kINPUT && !trt_model->input_name) { trt_model->input_name = av_strdup(name); trt_model->input_dims = trt_model->engine->getTensorShape(name); trt_model->input_dtype = nvinfer_to_trt_dtype(trt_model->engine->getTensorDataType(name)); } else if (mode == nvinfer1::TensorIOMode::kOUTPUT && !trt_model->output_name) { trt_model->output_name = av_strdup(name); trt_model->output_dims = trt_model->engine->getTensorShape(name); trt_model->output_dtype = nvinfer_to_trt_dtype(trt_model->engine->getTensorDataType(name)); } } if (!trt_model->input_name || !trt_model->output_name) { av_log(ctx, AV_LOG_ERROR, "Could not find input/output tensors\n"); goto fail; } // Validate dtypes are supported if (trt_model->input_dtype == TRT_DT_UNKNOWN) { av_log(ctx, AV_LOG_ERROR, "Unsupported input tensor data type\n"); goto fail; } if (trt_model->output_dtype == TRT_DT_UNKNOWN) { av_log(ctx, AV_LOG_ERROR, "Unsupported output tensor data type\n"); goto fail; } // For now, we only support FP32/FP16/BF16 for zero-copy kernels if (trt_model->input_dtype > TRT_DT_BFLOAT16 || trt_model->output_dtype > TRT_DT_BFLOAT16) { av_log(ctx, AV_LOG_ERROR, "Only FP32/FP16/BF16 tensors supported for zero-copy, got input=%s output=%s\n", trt_dtype_name(trt_model->input_dtype), trt_dtype_name(trt_model->output_dtype)); goto fail; } // Validate tensor dimensions (must be 4D for NCHW format) if (trt_model->input_dims.nbDims != 4) { av_log(ctx, AV_LOG_ERROR, "Input tensor must be 4D (NCHW), got %d dimensions\n", trt_model->input_dims.nbDims); goto fail; } if (trt_model->output_dims.nbDims != 4) { av_log(ctx, AV_LOG_ERROR, "Output tensor must be 4D (NCHW), got %d dimensions\n", trt_model->output_dims.nbDims); goto fail; } // Validate all dimensions are positive for (int i = 0; i < 4; i++) { if (trt_model->input_dims.d[i] <= 0) { av_log(ctx, AV_LOG_ERROR, "Invalid input dimension[%d] = %ld\n", i, (long)trt_model->input_dims.d[i]); goto fail; } if (trt_model->output_dims.d[i] <= 0) { av_log(ctx, AV_LOG_ERROR, "Invalid output dimension[%d] = %ld\n", i, (long)trt_model->output_dims.d[i]); goto fail; } } // Log dimensions and dtypes av_log(ctx, AV_LOG_INFO, "TensorRT engine loaded:\n"); av_log(ctx, AV_LOG_INFO, " Input '%s': %ldx%ldx%ldx%ld (%s)\n", trt_model->input_name, (long)trt_model->input_dims.d[0], (long)trt_model->input_dims.d[1], (long)trt_model->input_dims.d[2], (long)trt_model->input_dims.d[3], trt_dtype_name(trt_model->input_dtype)); av_log(ctx, AV_LOG_INFO, " Output '%s': %ldx%ldx%ldx%ld (%s)\n", trt_model->output_name, (long)trt_model->output_dims.d[0], (long)trt_model->output_dims.d[1], (long)trt_model->output_dims.d[2], (long)trt_model->output_dims.d[3], trt_dtype_name(trt_model->output_dtype)); // Calculate buffer sizes (allocation deferred to first inference via ensure_execution_context) { // Cast each factor to int64_t to prevent overflow during multiplication int64_t in_elems = (int64_t)trt_model->input_dims.d[0] * (int64_t)trt_model->input_dims.d[1] * (int64_t)trt_model->input_dims.d[2] * (int64_t)trt_model->input_dims.d[3]; int64_t out_elems = (int64_t)trt_model->output_dims.d[0] * (int64_t)trt_model->output_dims.d[1] * (int64_t)trt_model->output_dims.d[2] * (int64_t)trt_model->output_dims.d[3]; size_t in_elem_size = trt_dtype_size(trt_model->input_dtype); size_t out_elem_size = trt_dtype_size(trt_model->output_dtype); // Check for overflow (max reasonable GPU buffer ~16GB) const int64_t max_bytes = (int64_t)16 * 1024 * 1024 * 1024; if (in_elems * (int64_t)in_elem_size > max_bytes || out_elems * (int64_t)out_elem_size > max_bytes) { av_log(ctx, AV_LOG_ERROR, "Tensor size exceeds maximum supported (16GB)\n"); goto fail; } trt_model->input_size = (size_t)in_elems * in_elem_size; trt_model->output_size = (size_t)out_elems * out_elem_size; av_log(ctx, AV_LOG_INFO, " Buffer sizes (deferred): input=%zuMB output=%zuMB\n", trt_model->input_size / (1024 * 1024), trt_model->output_size / (1024 * 1024)); } // NOTE: GPU buffers and execution context are allocated lazily on first inference // This saves ~720MB+ for FFmpeg's probe filter instance that never runs inference } // Initialize queues trt_model->request_queue = ff_safe_queue_create(); if (!trt_model->request_queue) goto fail; item = (TRTRequestItem *)av_mallocz(sizeof(TRTRequestItem)); if (!item) goto fail; item->infer_request = trt_create_inference_request(); if (!item->infer_request) goto fail; item->exec_module.start_inference = &trt_start_inference; item->exec_module.callback = &infer_completion_callback; item->exec_module.args = item; if (ff_safe_queue_push_back(trt_model->request_queue, item) < 0) goto fail; item = NULL; trt_model->task_queue = ff_queue_create(); if (!trt_model->task_queue) goto fail; trt_model->lltask_queue = ff_queue_create(); if (!trt_model->lltask_queue) goto fail; // Set up model interface trt_model->model.get_input = &get_input_trt; trt_model->model.get_output = &get_output_trt; trt_model->model.filter_ctx = filter_ctx; trt_model->model.func_type = func_type; return &trt_model->model; fail: if (item) { destroy_request_item(&item); } dnn_free_model_trt((DNNModel **)&trt_model); return NULL; } static int dnn_execute_model_trt(const DNNModel *model, DNNExecBaseParams *exec_params) { TRTModel *trt_model = (TRTModel *)model; DnnContext *ctx = trt_model->ctx; TaskItem *task; TRTRequestItem *request; int ret = 0; ret = ff_check_exec_params(ctx, DNN_TRT, model->func_type, exec_params); if (ret != 0) { av_log(ctx, AV_LOG_ERROR, "exec parameter checking fail.\n"); return ret; } task = (TaskItem *)av_malloc(sizeof(TaskItem)); if (!task) { av_log(ctx, AV_LOG_ERROR, "unable to alloc memory for task item.\n"); return AVERROR(ENOMEM); } ret = ff_dnn_fill_task(task, exec_params, trt_model, 0, 1); if (ret != 0) { av_log(ctx, AV_LOG_ERROR, "unable to fill task.\n"); av_frame_free(&task->in_frame); av_frame_free(&task->out_frame); av_freep(&task); return ret; } ret = ff_queue_push_back(trt_model->task_queue, task); if (ret < 0) { av_log(ctx, AV_LOG_ERROR, "unable to push back task_queue.\n"); av_frame_free(&task->in_frame); av_frame_free(&task->out_frame); av_freep(&task); return ret; } ret = extract_lltask_from_task(task, trt_model->lltask_queue); if (ret != 0) { av_log(ctx, AV_LOG_ERROR, "unable to extract last level task from task.\n"); // Remove task from queue since extraction failed ff_queue_pop_back(trt_model->task_queue); av_frame_free(&task->in_frame); av_frame_free(&task->out_frame); av_freep(&task); return ret; } request = (TRTRequestItem *)ff_safe_queue_pop_front(trt_model->request_queue); if (!request) { av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n"); // Clean up: remove lltask and task we just added LastLevelTaskItem *lltask = (LastLevelTaskItem *)ff_queue_pop_back(trt_model->lltask_queue); av_freep(&lltask); ff_queue_pop_back(trt_model->task_queue); av_frame_free(&task->in_frame); av_frame_free(&task->out_frame); av_freep(&task); return AVERROR(EINVAL); } return execute_model_trt(request, trt_model->lltask_queue); } static DNNAsyncStatusType dnn_get_result_trt(const DNNModel *model, AVFrame **in, AVFrame **out) { TRTModel *trt_model = (TRTModel *)model; return ff_dnn_get_result_common(trt_model->task_queue, in, out); } static int dnn_flush_trt(const DNNModel *model) { TRTModel *trt_model = (TRTModel *)model; TRTRequestItem *request; if (ff_queue_size(trt_model->lltask_queue) == 0) return 0; request = (TRTRequestItem *)ff_safe_queue_pop_front(trt_model->request_queue); if (!request) { av_log(trt_model->ctx, AV_LOG_ERROR, "unable to get infer request.\n"); return AVERROR(EINVAL); } return execute_model_trt(request, trt_model->lltask_queue); } extern const DNNModule ff_dnn_backend_tensorrt = { .clazz = DNN_DEFINE_CLASS(dnn_trt), .type = DNN_TRT, .load_model = dnn_load_model_trt, .execute_model = dnn_execute_model_trt, .get_result = dnn_get_result_trt, .flush = dnn_flush_trt, .free_model = dnn_free_model_trt, }; ================================================ FILE: tools/patches/dnn_backend_torch.cpp ================================================ /* * Copyright 2026 Joshua V. Dillon * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * Based on FFmpeg's dnn_backend_torch.cpp with extensive modifications * for CUDA zero-copy support and hardware frame integration. */ /** * @file * DNN Torch backend implementation. */ #include #include #include #include #include #include #include #include extern "C" { #include "dnn_io_proc.h" #include "dnn_backend_common.h" #include "libavutil/opt.h" #include "libavutil/mem.h" #include "libavutil/hwcontext.h" #include "libavutil/hwcontext_cuda.h" #include "libavutil/pixfmt.h" #include "libavutil/pixdesc.h" #include "queue.h" #include "safe_queue.h" } #include typedef struct THModel { DNNModel model; DnnContext *ctx; torch::jit::Module *jit_model; SafeQueue *request_queue; Queue *task_queue; Queue *lltask_queue; SafeQueue *pending_queue; ///< requests waiting for inference std::thread *worker_thread; ///< background worker thread std::mutex *mutex; ///< mutex for the condition variable std::condition_variable *cond; ///< condition variable for worker wakeup std::atomic worker_stop; ///< signal for thread exit } THModel; typedef struct THInferRequest { torch::Tensor *output; torch::Tensor *input_tensor; } THInferRequest; typedef struct THRequestItem { THInferRequest *infer_request; LastLevelTaskItem *lltask; DNNAsyncExecModule exec_module; } THRequestItem; #define OFFSET(x) offsetof(THOptions, x) #define FLAGS AV_OPT_FLAG_FILTERING_PARAM static const AVOption dnn_th_options[] = { { "optimize", "turn on graph executor optimization", OFFSET(optimize), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, 1, FLAGS}, { NULL } }; static int extract_lltask_from_task(TaskItem *task, Queue *lltask_queue) { THModel *th_model = (THModel *)task->model; DnnContext *ctx = th_model->ctx; LastLevelTaskItem *lltask = (LastLevelTaskItem *)av_malloc(sizeof(*lltask)); if (!lltask) { av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for LastLevelTaskItem\n"); return AVERROR(ENOMEM); } task->inference_todo = 1; task->inference_done = 0; lltask->task = task; if (ff_queue_push_back(lltask_queue, lltask) < 0) { av_log(ctx, AV_LOG_ERROR, "Failed to push back lltask_queue.\n"); av_freep(&lltask); return AVERROR(ENOMEM); } return 0; } static void th_free_request(THInferRequest *request) { if (!request) return; if (request->output) { delete(request->output); request->output = NULL; } if (request->input_tensor) { delete(request->input_tensor); request->input_tensor = NULL; } return; } static inline void destroy_request_item(THRequestItem **arg) { THRequestItem *item; if (!arg || !*arg) { return; } item = *arg; th_free_request(item->infer_request); av_freep(&item->infer_request); av_freep(&item->lltask); ff_dnn_async_module_cleanup(&item->exec_module); av_freep(arg); } static void dnn_free_model_th(DNNModel **model) { THModel *th_model; if (!model || !*model) return; th_model = (THModel *)(*model); /* 1. Stop and join the worker thread if it exists */ if (th_model->worker_thread) { { std::lock_guard lock(*th_model->mutex); th_model->worker_stop = true; } th_model->cond->notify_all(); th_model->worker_thread->join(); delete th_model->worker_thread; th_model->worker_thread = NULL; } /* 2. Safely delete C++ synchronization objects */ if (th_model->mutex) { delete th_model->mutex; th_model->mutex = NULL; } if (th_model->cond) { delete th_model->cond; th_model->cond = NULL; } /* 3. Clean up the pending queue */ if (th_model->pending_queue) { while (ff_safe_queue_size(th_model->pending_queue) > 0) { THRequestItem *item = (THRequestItem *)ff_safe_queue_pop_front(th_model->pending_queue); destroy_request_item(&item); } ff_safe_queue_destroy(th_model->pending_queue); } /* 4. Clean up standard backend queues */ if (th_model->request_queue) { while (ff_safe_queue_size(th_model->request_queue) != 0) { THRequestItem *item = (THRequestItem *)ff_safe_queue_pop_front(th_model->request_queue); destroy_request_item(&item); } ff_safe_queue_destroy(th_model->request_queue); } if (th_model->lltask_queue) { while (ff_queue_size(th_model->lltask_queue) != 0) { LastLevelTaskItem *item = (LastLevelTaskItem *)ff_queue_pop_front(th_model->lltask_queue); av_freep(&item); } ff_queue_destroy(th_model->lltask_queue); } if (th_model->task_queue) { while (ff_queue_size(th_model->task_queue) != 0) { TaskItem *item = (TaskItem *)ff_queue_pop_front(th_model->task_queue); av_frame_free(&item->in_frame); av_frame_free(&item->out_frame); av_freep(&item); } ff_queue_destroy(th_model->task_queue); } /* 5. Final model cleanup */ if (th_model->jit_model) delete th_model->jit_model; av_freep(&th_model); *model = NULL; } static int get_input_th(DNNModel *model, DNNData *input, const char *input_name) { input->dt = DNN_FLOAT; input->order = DCO_RGB; input->layout = DL_NCHW; input->dims[0] = 1; input->dims[1] = 3; input->dims[2] = -1; input->dims[3] = -1; return 0; } static void deleter(void *arg) { av_freep(&arg); } static int fill_model_input_th(THModel *th_model, THRequestItem *request) { LastLevelTaskItem *lltask = NULL; TaskItem *task = NULL; THInferRequest *infer_request = NULL; DNNData input = { 0 }; DnnContext *ctx = th_model->ctx; int ret, width_idx, height_idx, channel_idx; lltask = (LastLevelTaskItem *)ff_queue_pop_front(th_model->lltask_queue); if (!lltask) { ret = AVERROR(EINVAL); goto err; } request->lltask = lltask; task = lltask->task; infer_request = request->infer_request; ret = get_input_th(&th_model->model, &input, NULL); if (ret != 0) { goto err; } width_idx = dnn_get_width_idx_by_layout(input.layout); height_idx = dnn_get_height_idx_by_layout(input.layout); channel_idx = dnn_get_channel_idx_by_layout(input.layout); input.dims[height_idx] = task->in_frame->height; input.dims[width_idx] = task->in_frame->width; // Allocate tensors. Note: th_create_inference_request() NULL-initializes both pointers, // so th_free_request() in the err path safely handles partial allocation if second new throws. try { infer_request->input_tensor = new torch::Tensor(); infer_request->output = new torch::Tensor(); } catch (const std::exception& e) { av_log(ctx, AV_LOG_ERROR, "Failed to allocate torch tensors: %s\n", e.what()); ret = AVERROR(ENOMEM); goto err; } // Check for CUDA hardware frames (zero-copy input path) if (task->in_frame->format == AV_PIX_FMT_CUDA && task->in_frame->hw_frames_ctx) { AVHWFramesContext *hw_frames = (AVHWFramesContext *)task->in_frame->hw_frames_ctx->data; int width = task->in_frame->width; int height = task->in_frame->height; int linesize = task->in_frame->linesize[0]; uint8_t *cuda_data = task->in_frame->data[0]; av_log(ctx, AV_LOG_DEBUG, "CUDA frame input: %dx%d, sw_format=%s, linesize=%d\n", width, height, av_get_pix_fmt_name(hw_frames->sw_format), linesize); try { // Handle RGB24/BGR24 sw_format - zero-copy path (3 bytes per pixel) if (hw_frames->sw_format == AV_PIX_FMT_RGB24 || hw_frames->sw_format == AV_PIX_FMT_BGR24) { auto options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); // Create tensor from CUDA memory (HWC format, uint8) torch::Tensor input_hwc = torch::from_blob( cuda_data, {height, width, 3}, {linesize, 3, 1}, // strides for row-major with padding options ); // Convert: HWC uint8 [0,255] -> NCHW float32 [0,1] *infer_request->input_tensor = input_hwc.permute({2, 0, 1}) // HWC -> CHW .unsqueeze(0) // CHW -> NCHW .to(torch::kFloat32) .div(255.0f) .contiguous(); av_log(ctx, AV_LOG_DEBUG, "Zero-copy CUDA input created (RGB24/BGR24)\n"); return 0; } // Handle RGB0/BGR0/0RGB/0BGR sw_format - zero-copy path (4 bytes per pixel, ignore alpha) if (hw_frames->sw_format == AV_PIX_FMT_RGB0 || hw_frames->sw_format == AV_PIX_FMT_BGR0 || hw_frames->sw_format == AV_PIX_FMT_0RGB || hw_frames->sw_format == AV_PIX_FMT_0BGR || hw_frames->sw_format == AV_PIX_FMT_RGBA || hw_frames->sw_format == AV_PIX_FMT_BGRA || hw_frames->sw_format == AV_PIX_FMT_ARGB || hw_frames->sw_format == AV_PIX_FMT_ABGR) { auto options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); // Create tensor from CUDA memory (4 channels, uint8) torch::Tensor input_hwc4 = torch::from_blob( cuda_data, {height, width, 4}, {linesize, 4, 1}, // strides for row-major with padding options ); // Extract RGB channels based on format torch::Tensor input_hwc; if (hw_frames->sw_format == AV_PIX_FMT_RGB0 || hw_frames->sw_format == AV_PIX_FMT_BGR0 || hw_frames->sw_format == AV_PIX_FMT_RGBA || hw_frames->sw_format == AV_PIX_FMT_BGRA) { // RGB(A) format: first 3 channels are R, G, B input_hwc = input_hwc4.slice(2, 0, 3); // slice along channel dim } else { // (A)RGB format: last 3 channels are R, G, B input_hwc = input_hwc4.slice(2, 1, 4); // slice along channel dim } // Convert: HWC uint8 [0,255] -> NCHW float32 [0,1] *infer_request->input_tensor = input_hwc.permute({2, 0, 1}) // HWC -> CHW .unsqueeze(0) // CHW -> NCHW .to(torch::kFloat32) .div(255.0f) .contiguous(); av_log(ctx, AV_LOG_DEBUG, "Zero-copy CUDA input created (4-channel format)\n"); return 0; } } catch (const std::exception& e) { av_log(ctx, AV_LOG_ERROR, "Torch exception in zero-copy input: %s\n", e.what()); ret = AVERROR(ENOSYS); goto err; } av_log(ctx, AV_LOG_WARNING, "CUDA sw_format %s not supported for zero-copy, falling back to CPU\n", av_get_pix_fmt_name(hw_frames->sw_format)); } // Standard CPU path - allocate memory for input data input.data = av_malloc(input.dims[height_idx] * input.dims[width_idx] * input.dims[channel_idx] * sizeof(float)); if (!input.data) { ret = AVERROR(ENOMEM); goto err; } switch (th_model->model.func_type) { case DFT_PROCESS_FRAME: input.scale = 255; if (task->do_ioproc) { if (th_model->model.frame_pre_proc != NULL) { th_model->model.frame_pre_proc(task->in_frame, &input, th_model->model.filter_ctx); } else { ff_proc_from_frame_to_dnn(task->in_frame, &input, ctx); } } break; default: avpriv_report_missing_feature(NULL, "model function type %d", th_model->model.func_type); av_freep(&input.data); ret = AVERROR(ENOSYS); goto err; } *infer_request->input_tensor = torch::from_blob(input.data, {1, input.dims[channel_idx], input.dims[height_idx], input.dims[width_idx]}, deleter, torch::kFloat32); return 0; err: th_free_request(infer_request); return ret; } static int th_start_inference(void *args) { THRequestItem *request = (THRequestItem *)args; THInferRequest *infer_request = NULL; LastLevelTaskItem *lltask = NULL; TaskItem *task = NULL; THModel *th_model = NULL; DnnContext *ctx = NULL; std::vector inputs; torch::NoGradGuard no_grad; if (!request) { av_log(NULL, AV_LOG_ERROR, "THRequestItem is NULL\n"); return AVERROR(EINVAL); } infer_request = request->infer_request; lltask = request->lltask; if (!lltask) { av_log(NULL, AV_LOG_ERROR, "THRequestItem lltask is NULL\n"); return AVERROR(EINVAL); } task = lltask->task; th_model = (THModel *)task->model; ctx = th_model->ctx; if (!infer_request->input_tensor || !infer_request->output) { av_log(ctx, AV_LOG_ERROR, "input or output tensor is NULL\n"); return DNN_GENERIC_ERROR; } try { // Transfer tensor to the same device as model c10::Device device = torch::kCUDA; auto params = th_model->jit_model->parameters(); if (params.begin() != params.end()) { device = (*params.begin()).device(); } if (infer_request->input_tensor->device() != device) *infer_request->input_tensor = infer_request->input_tensor->to(device); inputs.push_back(*infer_request->input_tensor); auto _fwd_out = th_model->jit_model->forward(inputs); if (_fwd_out.isTuple()) { *infer_request->output = _fwd_out.toTuple()->elements()[0].toTensor(); } else { *infer_request->output = _fwd_out.toTensor(); } } catch (const std::exception& e) { av_log(ctx, AV_LOG_ERROR, "Torch inference failed: %s\n", e.what()); return DNN_GENERIC_ERROR; } return 0; } static void infer_completion_callback(void *args) { THRequestItem *request = (THRequestItem*)args; LastLevelTaskItem *lltask = request->lltask; TaskItem *task = lltask->task; DNNData outputs = { 0 }; THInferRequest *infer_request = request->infer_request; THModel *th_model = (THModel *)task->model; torch::Tensor *output = infer_request->output; c10::IntArrayRef sizes; try { sizes = output->sizes(); outputs.order = DCO_RGB; outputs.layout = DL_NCHW; outputs.dt = DNN_FLOAT; if (sizes.size() == 4) { // 4 dimensions: [batch_size, channel, height, width] // this format of data is normally used for video frame SR outputs.dims[0] = sizes.at(0); // N outputs.dims[1] = sizes.at(1); // C outputs.dims[2] = sizes.at(2); // H outputs.dims[3] = sizes.at(3); // W } else { avpriv_report_missing_feature(th_model->ctx, "Support of this kind of model"); goto err; } switch (th_model->model.func_type) { case DFT_PROCESS_FRAME: // Check for CUDA output frames (zero-copy output path) if (task->out_frame->format == AV_PIX_FMT_CUDA && task->out_frame->hw_frames_ctx) { AVHWFramesContext *hw_frames = (AVHWFramesContext *)task->out_frame->hw_frames_ctx->data; int out_height = outputs.dims[dnn_get_height_idx_by_layout(outputs.layout)]; int out_width = outputs.dims[dnn_get_width_idx_by_layout(outputs.layout)]; int out_linesize = task->out_frame->linesize[0]; uint8_t *cuda_out = task->out_frame->data[0]; av_log(th_model->ctx, AV_LOG_DEBUG, "CUDA frame output: %dx%d, sw_format=%s\n", out_width, out_height, av_get_pix_fmt_name(hw_frames->sw_format)); if (hw_frames->sw_format == AV_PIX_FMT_RGB24 || hw_frames->sw_format == AV_PIX_FMT_BGR24) { // Ensure output is on CUDA if (!output->is_cuda()) { *output = output->to(torch::kCUDA); } // Convert: NCHW float32 [0,1] -> HWC uint8 [0,255] torch::Tensor output_hwc = output->squeeze(0) // NCHW -> CHW .permute({1, 2, 0}) // CHW -> HWC .mul(255.0f) .clamp(0.0f, 255.0f) .to(torch::kUInt8) .contiguous(); // Copy to output CUDA frame cudaError_t cuda_err; if (out_linesize == out_width * 3) { // Contiguous - single copy cuda_err = cudaMemcpy(cuda_out, output_hwc.data_ptr(), out_height * out_width * 3, cudaMemcpyDeviceToDevice); if (cuda_err != cudaSuccess) { av_log(th_model->ctx, AV_LOG_ERROR, "cudaMemcpy failed: %s\n", cudaGetErrorString(cuda_err)); goto err; } } else { // Padded rows - copy row by row for (int y = 0; y < out_height; y++) { cuda_err = cudaMemcpy(cuda_out + y * out_linesize, (uint8_t*)output_hwc.data_ptr() + y * out_width * 3, out_width * 3, cudaMemcpyDeviceToDevice); if (cuda_err != cudaSuccess) { av_log(th_model->ctx, AV_LOG_ERROR, "cudaMemcpy row %d failed: %s\n", y, cudaGetErrorString(cuda_err)); goto err; } } } task->out_frame->width = out_width; task->out_frame->height = out_height; av_log(th_model->ctx, AV_LOG_DEBUG, "Zero-copy CUDA output done (RGB24/BGR24)\n"); break; } // Handle 4-channel output formats (RGB0, BGR0, RGBA, etc.) if (hw_frames->sw_format == AV_PIX_FMT_RGB0 || hw_frames->sw_format == AV_PIX_FMT_BGR0 || hw_frames->sw_format == AV_PIX_FMT_0RGB || hw_frames->sw_format == AV_PIX_FMT_0BGR || hw_frames->sw_format == AV_PIX_FMT_RGBA || hw_frames->sw_format == AV_PIX_FMT_BGRA || hw_frames->sw_format == AV_PIX_FMT_ARGB || hw_frames->sw_format == AV_PIX_FMT_ABGR) { // Ensure output is on CUDA if (!output->is_cuda()) { *output = output->to(torch::kCUDA); } // Convert: NCHW float32 [0,1] -> HWC uint8 [0,255] torch::Tensor output_hwc = output->squeeze(0) // NCHW -> CHW .permute({1, 2, 0}) // CHW -> HWC .mul(255.0f) .clamp(0.0f, 255.0f) .to(torch::kUInt8) .contiguous(); // Create 4-channel output with alpha=255 auto options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); torch::Tensor alpha = torch::full({out_height, out_width, 1}, 255, options); torch::Tensor output_hwc4; if (hw_frames->sw_format == AV_PIX_FMT_RGB0 || hw_frames->sw_format == AV_PIX_FMT_BGR0 || hw_frames->sw_format == AV_PIX_FMT_RGBA || hw_frames->sw_format == AV_PIX_FMT_BGRA) { // RGB(A) format: R, G, B, A output_hwc4 = torch::cat({output_hwc, alpha}, 2).contiguous(); } else { // (A)RGB format: A, R, G, B output_hwc4 = torch::cat({alpha, output_hwc}, 2).contiguous(); } // Copy to output CUDA frame cudaError_t cuda_err4; if (out_linesize == out_width * 4) { // Contiguous - single copy cuda_err4 = cudaMemcpy(cuda_out, output_hwc4.data_ptr(), out_height * out_width * 4, cudaMemcpyDeviceToDevice); if (cuda_err4 != cudaSuccess) { av_log(th_model->ctx, AV_LOG_ERROR, "cudaMemcpy failed: %s\n", cudaGetErrorString(cuda_err4)); goto err; } } else { // Padded rows - copy row by row for (int y = 0; y < out_height; y++) { cuda_err4 = cudaMemcpy(cuda_out + y * out_linesize, (uint8_t*)output_hwc4.data_ptr() + y * out_width * 4, out_width * 4, cudaMemcpyDeviceToDevice); if (cuda_err4 != cudaSuccess) { av_log(th_model->ctx, AV_LOG_ERROR, "cudaMemcpy row %d failed: %s\n", y, cudaGetErrorString(cuda_err4)); goto err; } } } task->out_frame->width = out_width; task->out_frame->height = out_height; av_log(th_model->ctx, AV_LOG_DEBUG, "Zero-copy CUDA output done (4-channel format)\n"); break; } av_log(th_model->ctx, AV_LOG_WARNING, "CUDA output sw_format %s not supported, falling back to CPU\n", av_get_pix_fmt_name(hw_frames->sw_format)); } // Standard CPU output path if (task->do_ioproc) { // Post process can only deal with CPU memory. if (output->device() != torch::kCPU) *output = output->to(torch::kCPU); // Expensive GPU->CPU copy! outputs.scale = 255; outputs.data = output->data_ptr(); if (th_model->model.frame_post_proc != NULL) { th_model->model.frame_post_proc(task->out_frame, &outputs, th_model->model.filter_ctx); } else { ff_proc_from_dnn_to_frame(task->out_frame, &outputs, th_model->ctx); } } else { task->out_frame->width = outputs.dims[dnn_get_width_idx_by_layout(outputs.layout)]; task->out_frame->height = outputs.dims[dnn_get_height_idx_by_layout(outputs.layout)]; } break; default: avpriv_report_missing_feature(th_model->ctx, "model function type %d", th_model->model.func_type); goto err; } task->inference_done++; goto done; } catch (const std::exception& e) { av_log(th_model->ctx, AV_LOG_ERROR, "Torch exception in completion callback: %s\n", e.what()); } err: // Increment inference_done even on error so task completion tracking works task->inference_done++; done: // Free lltask - it was popped from the queue in fill_model_input_th av_freep(&request->lltask); // Free the inference request data (tensors) th_free_request(infer_request); // Don't free infer_request struct here - it's reused when pushed back to request_queue // The struct will be freed when the model is destroyed if (ff_safe_queue_push_back(th_model->request_queue, request) < 0) { // Only destroy if we can't push back - this will free the struct av_freep(&request->infer_request); av_freep(&request); av_log(th_model->ctx, AV_LOG_ERROR, "Unable to push back request_queue when failed to start inference.\n"); } } static void th_worker_thread(THModel *th_model) { while (true) { THRequestItem *request = NULL; { std::unique_lock lock(*th_model->mutex); th_model->cond->wait(lock, [&]{ return th_model->worker_stop || ff_safe_queue_size(th_model->pending_queue) > 0; }); if (th_model->worker_stop && ff_safe_queue_size(th_model->pending_queue) == 0) break; request = (THRequestItem *)ff_safe_queue_pop_front(th_model->pending_queue); } if (request) { int ret = th_start_inference(request); if (ret < 0) { av_log(NULL, AV_LOG_ERROR, "Async inference failed: %d\n", ret); } infer_completion_callback(request); } } } static int execute_model_th(THRequestItem *request, Queue *lltask_queue) { THModel *th_model = NULL; LastLevelTaskItem *lltask; TaskItem *task = NULL; int ret = 0; if (ff_queue_size(lltask_queue) == 0) { destroy_request_item(&request); return 0; } lltask = (LastLevelTaskItem *)ff_queue_peek_front(lltask_queue); if (lltask == NULL) { av_log(NULL, AV_LOG_ERROR, "Failed to get LastLevelTaskItem\n"); ret = AVERROR(EINVAL); goto err; } task = lltask->task; th_model = (THModel *)task->model; ret = fill_model_input_th(th_model, request); if ( ret != 0) { goto err; } if (task->async) { std::lock_guard lock(*th_model->mutex); if (ff_safe_queue_push_back(th_model->pending_queue, request) < 0) { th_free_request(request->infer_request); av_freep(&request->lltask); if (ff_safe_queue_push_back(th_model->request_queue, request) < 0) { destroy_request_item(&request); } return AVERROR(ENOMEM); } th_model->cond->notify_one(); return 0; } else { // Synchronous execution path ret = th_start_inference((void *)request); if (ret != 0) { goto err; } infer_completion_callback(request); return (task->inference_done == task->inference_todo) ? 0 : DNN_GENERIC_ERROR; } err: th_free_request(request->infer_request); av_freep(&request->lltask); // Free lltask to avoid leak and dangling pointer if (!th_model || ff_safe_queue_push_back(th_model->request_queue, request) < 0) { destroy_request_item(&request); } return ret; } static int get_output_th(DNNModel *model, const char *input_name, int input_width, int input_height, const char *output_name, int *output_width, int *output_height) { int ret = 0; THModel *th_model = (THModel*) model; DnnContext *ctx = th_model->ctx; TaskItem task = { 0 }; THRequestItem *request = NULL; DNNExecBaseParams exec_params = { .input_name = input_name, .output_names = &output_name, .nb_output = 1, .in_frame = NULL, .out_frame = NULL, }; ret = ff_dnn_fill_gettingoutput_task(&task, &exec_params, th_model, input_height, input_width, ctx); if ( ret != 0) { goto err; } ret = extract_lltask_from_task(&task, th_model->lltask_queue); if ( ret != 0) { av_log(ctx, AV_LOG_ERROR, "unable to extract last level task from task.\n"); goto err; } request = (THRequestItem*) ff_safe_queue_pop_front(th_model->request_queue); if (!request) { av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n"); ret = AVERROR(EINVAL); // Clean up lltask that was pushed by extract_lltask_from_task LastLevelTaskItem *lltask = (LastLevelTaskItem *)ff_queue_pop_back(th_model->lltask_queue); av_freep(&lltask); goto err; } ret = execute_model_th(request, th_model->lltask_queue); *output_width = task.out_frame->width; *output_height = task.out_frame->height; err: av_frame_free(&task.out_frame); av_frame_free(&task.in_frame); return ret; } static THInferRequest *th_create_inference_request(void) { THInferRequest *request = (THInferRequest *)av_malloc(sizeof(THInferRequest)); if (!request) { return NULL; } request->input_tensor = NULL; request->output = NULL; return request; } static DNNModel *dnn_load_model_th(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *filter_ctx) { DNNModel *model = NULL; THModel *th_model = NULL; THRequestItem *item = NULL; const char *device_name = ctx->device ? ctx->device : "cpu"; th_model = (THModel *)av_mallocz(sizeof(THModel)); if (!th_model) return NULL; model = &th_model->model; th_model->ctx = ctx; c10::Device device = c10::Device(device_name); if (device.is_xpu()) { if (!at::hasXPU()) { av_log(ctx, AV_LOG_ERROR, "No XPU device found\n"); goto fail; } at::detail::getXPUHooks().init(); } else if (device.is_cuda()) { if (!at::cuda::is_available()) { av_log(ctx, AV_LOG_ERROR, "No CUDA device found\n"); goto fail; } // Load CUDA kernels - required for libtorch CUDA ops // Thread-safe initialization using call_once // NOTE: These handles are intentionally never dlclose'd. CUDA/TensorRT libraries // have complex cleanup requirements and calling dlclose can cause crashes. // The OS reclaims resources on process exit. static std::once_flag cuda_lib_once; static void *cuda_lib_handle = NULL; std::call_once(cuda_lib_once, [ctx]() { cuda_lib_handle = dlopen("libtorch_cuda.so", RTLD_NOW | RTLD_GLOBAL); if (cuda_lib_handle) { av_log(ctx, AV_LOG_DEBUG, "libtorch_cuda.so loaded\n"); } else { av_log(ctx, AV_LOG_WARNING, "Failed to load libtorch_cuda.so: %s\n", dlerror()); } }); } else if (!device.is_cpu()) { av_log(ctx, AV_LOG_ERROR, "Not supported device:\"%s\"\n", device_name); goto fail; } try { th_model->jit_model = new torch::jit::Module; // Load TensorRT runtime if available (enables TRT-compiled models) // Thread-safe initialization using call_once static std::once_flag trt_once; static void *trt_lib_handle = NULL; std::call_once(trt_once, [ctx]() { trt_lib_handle = dlopen("libtorchtrt_runtime.so", RTLD_NOW | RTLD_GLOBAL); if (trt_lib_handle) { av_log(ctx, AV_LOG_INFO, "TensorRT runtime loaded\n"); } }); (*th_model->jit_model) = torch::jit::load(ctx->model_filename); th_model->jit_model->to(device); // Set JIT optimization once at model load time (thread-safe) torch::jit::setGraphExecutorOptimize(ctx->torch_option.optimize ? true : false); av_log(ctx, AV_LOG_INFO, "Model loaded to device: %s (JIT optimize=%d)\n", device_name, ctx->torch_option.optimize); if (device.is_cuda()) { av_log(ctx, AV_LOG_INFO, "CUDA available: %s, device count: %d\n", at::cuda::is_available() ? "yes" : "no", at::cuda::device_count()); } } catch (const c10::Error& e) { av_log(ctx, AV_LOG_ERROR, "Failed to load torch model: %s\n", e.what()); goto fail; } catch (const std::exception& e) { av_log(ctx, AV_LOG_ERROR, "Failed to load torch model: %s\n", e.what()); goto fail; } th_model->request_queue = ff_safe_queue_create(); if (!th_model->request_queue) { goto fail; } item = (THRequestItem *)av_mallocz(sizeof(THRequestItem)); if (!item) { goto fail; } item->lltask = NULL; item->infer_request = th_create_inference_request(); if (!item->infer_request) { av_log(NULL, AV_LOG_ERROR, "Failed to allocate memory for Torch inference request\n"); goto fail; } item->exec_module.start_inference = &th_start_inference; item->exec_module.callback = &infer_completion_callback; item->exec_module.args = item; if (ff_safe_queue_push_back(th_model->request_queue, item) < 0) { goto fail; } item = NULL; th_model->task_queue = ff_queue_create(); if (!th_model->task_queue) { goto fail; } th_model->lltask_queue = ff_queue_create(); if (!th_model->lltask_queue) { goto fail; } th_model->pending_queue = ff_safe_queue_create(); if (!th_model->pending_queue) { goto fail; } try { th_model->mutex = new std::mutex(); th_model->cond = new std::condition_variable(); th_model->worker_stop = false; th_model->worker_thread = new std::thread(th_worker_thread, th_model); } catch (const std::exception& e) { av_log(ctx, AV_LOG_ERROR, "Failed to create worker thread: %s\n", e.what()); goto fail; } model->get_input = &get_input_th; model->get_output = &get_output_th; model->filter_ctx = filter_ctx; model->func_type = func_type; return model; fail: if (item) { destroy_request_item(&item); // Note: destroy_request_item already calls av_freep(arg), so item is now NULL } dnn_free_model_th(&model); return NULL; } static int dnn_execute_model_th(const DNNModel *model, DNNExecBaseParams *exec_params) { THModel *th_model = (THModel *)model; DnnContext *ctx = th_model->ctx; TaskItem *task; THRequestItem *request; int ret = 0; ret = ff_check_exec_params(ctx, DNN_TH, model->func_type, exec_params); if (ret != 0) { av_log(ctx, AV_LOG_ERROR, "exec parameter checking fail.\n"); return ret; } task = (TaskItem *)av_malloc(sizeof(TaskItem)); if (!task) { av_log(ctx, AV_LOG_ERROR, "unable to alloc memory for task item.\n"); return AVERROR(ENOMEM); } ret = ff_dnn_fill_task(task, exec_params, th_model, 0, 1); if (ret != 0) { av_log(ctx, AV_LOG_ERROR, "unable to fill task.\n"); av_frame_free(&task->in_frame); av_frame_free(&task->out_frame); av_freep(&task); return ret; } ret = ff_queue_push_back(th_model->task_queue, task); if (ret < 0) { av_log(ctx, AV_LOG_ERROR, "unable to push back task_queue.\n"); av_frame_free(&task->in_frame); av_frame_free(&task->out_frame); av_freep(&task); return ret; } ret = extract_lltask_from_task(task, th_model->lltask_queue); if (ret != 0) { av_log(ctx, AV_LOG_ERROR, "unable to extract last level task from task.\n"); ff_queue_pop_back(th_model->task_queue); av_frame_free(&task->in_frame); av_frame_free(&task->out_frame); av_freep(&task); return ret; } request = (THRequestItem *)ff_safe_queue_pop_front(th_model->request_queue); if (!request) { av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n"); LastLevelTaskItem *lltask = (LastLevelTaskItem *)ff_queue_pop_back(th_model->lltask_queue); av_freep(&lltask); ff_queue_pop_back(th_model->task_queue); av_frame_free(&task->in_frame); av_frame_free(&task->out_frame); av_freep(&task); return AVERROR(EINVAL); } return execute_model_th(request, th_model->lltask_queue); } static DNNAsyncStatusType dnn_get_result_th(const DNNModel *model, AVFrame **in, AVFrame **out) { THModel *th_model = (THModel *)model; return ff_dnn_get_result_common(th_model->task_queue, in, out); } static int dnn_flush_th(const DNNModel *model) { THModel *th_model = (THModel *)model; THRequestItem *request; if (ff_queue_size(th_model->lltask_queue) == 0) // no pending task need to flush return 0; request = (THRequestItem *)ff_safe_queue_pop_front(th_model->request_queue); if (!request) { av_log(th_model->ctx, AV_LOG_ERROR, "unable to get infer request.\n"); return AVERROR(EINVAL); } return execute_model_th(request, th_model->lltask_queue); } extern const DNNModule ff_dnn_backend_torch = { .clazz = DNN_DEFINE_CLASS(dnn_th), .type = DNN_TH, .load_model = dnn_load_model_th, .execute_model = dnn_execute_model_th, .get_result = dnn_get_result_th, .flush = dnn_flush_th, .free_model = dnn_free_model_th, }; ================================================ FILE: tools/patches/dnn_cuda_kernels.cu ================================================ /* * Copyright 2026 Joshua V. Dillon * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * CUDA kernels for DNN backend format conversion. * Compiled to PTX at build time, loaded dynamically at runtime. * No cudart dependency - uses CUDA Driver API via FFmpeg's dynlink. * * WHY CUSTOM CUDA KERNELS FOR SUCH A PRIMITIVE OPERATION? * ======================================================== * These kernels convert FFmpeg frames (HWC uint8) to neural network input (NCHW float). * This seems like it should be a one-liner, but no standard library handles our needs: * * What we need (all in one pass, zero-copy from FFmpeg CUDA frames): * 1. Handle FFmpeg's linesize padding (row stride != width * channels) * 2. Convert uint8 [0,255] to float [0,1] (type conversion + scaling) * 3. Transpose HWC to NCHW (layout conversion) * 4. Support multiple pixel formats (RGB24, BGR24, RGBA, BGRA, ARGB, etc.) * 5. Support multiple tensor types (FP32, FP16, BF16) * * Why existing libraries don't work: * - cuDNN cudnnTransformTensor: float-to-float layout changes only, no uint8 * - NPP (nppiConvert_8u32f): type conversion but no transpose, separate calls * - TensorRT IReformatLayer: layout changes for float, not uint8 ingestion * - Chaining these: multiple kernel launches + intermediate allocations * * Alternatives considered: * - Preprocessing in ONNX model: Can't handle variable linesize padding * - TensorRT custom plugin: Adds export complexity, less flexible * - NPP + cuBLAS transpose: 2 passes, intermediate buffer, slower * * The fused kernel approach: one read, one write, no intermediate buffers. * It's unfortunate that we need 400 lines of CUDA for "pixels to floats", * but this is the reality of bridging FFmpeg's frame formats to ML frameworks. * * MEMORY ACCESS PATTERN NOTES: * ============================ * The HWC->NCHW conversion writes R,G,B to memory locations separated by H*W elements. * This looks like poor cache behavior, but: * - Writes within each channel plane ARE coalesced (adjacent threads write adjacent addresses) * - Modern GPU L2 caches (4-6MB) can hold multiple planes for typical frame sizes * - The strided HWC reads (3 bytes apart) are actually the slower part * - Shared memory staging for better write coalescing was tested but didn't help * * An elementwise approach (3 separate kernels, one per channel) would: * - Triple kernel launch overhead * - Read the input 3x instead of 1x * - Not improve coalescing (still stride-3 reads) */ #include #include extern "C" { // Precomputed reciprocal for [0,255] -> [0,1] conversion // Using multiplication is faster than division __device__ __constant__ float kScale255Inv = 1.0f / 255.0f; // Kernel: HWC uint8 [0,255] -> NCHW float32 [0,1] // Input: uint8 buffer in HWC format (height, width, 3) with possible row padding // Output: float32 buffer in NCHW format (1, 3, height, width) __global__ void hwc_uint8_to_nchw_float32_kernel( const unsigned char* __restrict__ input, float* __restrict__ output, int height, int width, int input_linesize) { int x = blockIdx.x * blockDim.x + threadIdx.x; int y = blockIdx.y * blockDim.y + threadIdx.y; if (x >= width || y >= height) return; // Input: HWC with potential row padding const unsigned char* row = input + y * input_linesize; unsigned char r = row[x * 3 + 0]; unsigned char g = row[x * 3 + 1]; unsigned char b = row[x * 3 + 2]; // Output: NCHW (batch=1), scale to [0,1] using multiplication (faster than division) int hw = height * width; int offset = y * width + x; output[0 * hw + offset] = r * kScale255Inv; // R channel output[1 * hw + offset] = g * kScale255Inv; // G channel output[2 * hw + offset] = b * kScale255Inv; // B channel } // Helper: Clamp float to [0,255] with NaN handling and proper rounding __device__ __forceinline__ unsigned char float_to_uint8_safe(float val) { // Handle NaN and Inf: NaN comparisons return false, so we check explicitly // isfinite() returns false for NaN and Inf if (!isfinite(val)) { return 0; // Default to black for corrupted values } // Scale, clamp, and round to nearest integer val = val * 255.0f + 0.5f; // Add 0.5 for proper rounding val = fminf(fmaxf(val, 0.0f), 255.0f); return (unsigned char)val; } // Kernel: NCHW float32 [0,1] -> HWC uint8 [0,255] // Input: float32 buffer in NCHW format (1, 3, height, width) // Output: uint8 buffer in HWC format (height, width, 3) with possible row padding __global__ void nchw_float32_to_hwc_uint8_kernel( const float* __restrict__ input, unsigned char* __restrict__ output, int height, int width, int output_linesize) { int x = blockIdx.x * blockDim.x + threadIdx.x; int y = blockIdx.y * blockDim.y + threadIdx.y; if (x >= width || y >= height) return; int hw = height * width; int offset = y * width + x; // Input: NCHW (batch=1), values in [0,1] float r = input[0 * hw + offset]; float g = input[1 * hw + offset]; float b = input[2 * hw + offset]; // Output: HWC with potential row padding // Using safe conversion with NaN handling and proper rounding unsigned char* row = output + y * output_linesize; row[x * 3 + 0] = float_to_uint8_safe(r); row[x * 3 + 1] = float_to_uint8_safe(g); row[x * 3 + 2] = float_to_uint8_safe(b); } // Kernel: 4-channel HWC uint8 -> NCHW float32 (extract RGB, ignore alpha) // NOTE: r_offset, g_offset, b_offset must be validated by host before launch (range [0,3]). // Device-side bounds checking would add branching overhead to every pixel - not worth it. __global__ void hwc4_uint8_to_nchw_float32_kernel( const unsigned char* __restrict__ input, float* __restrict__ output, int height, int width, int input_linesize, int r_offset, int g_offset, int b_offset) { int x = blockIdx.x * blockDim.x + threadIdx.x; int y = blockIdx.y * blockDim.y + threadIdx.y; if (x >= width || y >= height) return; // Host is responsible for validating offsets are in [0,3] const unsigned char* row = input + y * input_linesize; unsigned char r = row[x * 4 + r_offset]; unsigned char g = row[x * 4 + g_offset]; unsigned char b = row[x * 4 + b_offset]; int hw = height * width; int offset = y * width + x; output[0 * hw + offset] = r * kScale255Inv; output[1 * hw + offset] = g * kScale255Inv; output[2 * hw + offset] = b * kScale255Inv; } // Kernel: NCHW float32 -> 4-channel HWC uint8 (add alpha=255) // NOTE: r_offset, g_offset, b_offset, a_offset must be validated by host before launch (range [0,3]). // Device-side bounds checking would add branching overhead to every pixel - not worth it. __global__ void nchw_float32_to_hwc4_uint8_kernel( const float* __restrict__ input, unsigned char* __restrict__ output, int height, int width, int output_linesize, int r_offset, int g_offset, int b_offset, int a_offset) { int x = blockIdx.x * blockDim.x + threadIdx.x; int y = blockIdx.y * blockDim.y + threadIdx.y; if (x >= width || y >= height) return; // Host is responsible for validating offsets are in [0,3] int hw = height * width; int offset = y * width + x; float r = input[0 * hw + offset]; float g = input[1 * hw + offset]; float b = input[2 * hw + offset]; // Using safe conversion with NaN handling and proper rounding unsigned char* row = output + y * output_linesize; row[x * 4 + r_offset] = float_to_uint8_safe(r); row[x * 4 + g_offset] = float_to_uint8_safe(g); row[x * 4 + b_offset] = float_to_uint8_safe(b); row[x * 4 + a_offset] = 255; // Alpha = opaque } // ============================================================================ // FP16 (half precision) variants // ============================================================================ // Kernel: HWC uint8 [0,255] -> NCHW float16 [0,1] __global__ void hwc_uint8_to_nchw_float16_kernel( const unsigned char* __restrict__ input, __half* __restrict__ output, int height, int width, int input_linesize) { int x = blockIdx.x * blockDim.x + threadIdx.x; int y = blockIdx.y * blockDim.y + threadIdx.y; if (x >= width || y >= height) return; const unsigned char* row = input + y * input_linesize; unsigned char r = row[x * 3 + 0]; unsigned char g = row[x * 3 + 1]; unsigned char b = row[x * 3 + 2]; int hw = height * width; int offset = y * width + x; output[0 * hw + offset] = __float2half(r * kScale255Inv); output[1 * hw + offset] = __float2half(g * kScale255Inv); output[2 * hw + offset] = __float2half(b * kScale255Inv); } // Helper: Convert half to uint8 safely __device__ __forceinline__ unsigned char half_to_uint8_safe(__half val) { float f = __half2float(val); if (!isfinite(f)) return 0; f = f * 255.0f + 0.5f; f = fminf(fmaxf(f, 0.0f), 255.0f); return (unsigned char)f; } // Kernel: NCHW float16 [0,1] -> HWC uint8 [0,255] __global__ void nchw_float16_to_hwc_uint8_kernel( const __half* __restrict__ input, unsigned char* __restrict__ output, int height, int width, int output_linesize) { int x = blockIdx.x * blockDim.x + threadIdx.x; int y = blockIdx.y * blockDim.y + threadIdx.y; if (x >= width || y >= height) return; int hw = height * width; int offset = y * width + x; __half r = input[0 * hw + offset]; __half g = input[1 * hw + offset]; __half b = input[2 * hw + offset]; unsigned char* row = output + y * output_linesize; row[x * 3 + 0] = half_to_uint8_safe(r); row[x * 3 + 1] = half_to_uint8_safe(g); row[x * 3 + 2] = half_to_uint8_safe(b); } // Kernel: 4-channel HWC uint8 -> NCHW float16 (extract RGB, ignore alpha) // NOTE: r_offset, g_offset, b_offset must be validated by host before launch (range [0,3]). // Device-side bounds checking would add branching overhead to every pixel - not worth it. __global__ void hwc4_uint8_to_nchw_float16_kernel( const unsigned char* __restrict__ input, __half* __restrict__ output, int height, int width, int input_linesize, int r_offset, int g_offset, int b_offset) { int x = blockIdx.x * blockDim.x + threadIdx.x; int y = blockIdx.y * blockDim.y + threadIdx.y; if (x >= width || y >= height) return; // Host is responsible for validating offsets are in [0,3] const unsigned char* row = input + y * input_linesize; unsigned char r = row[x * 4 + r_offset]; unsigned char g = row[x * 4 + g_offset]; unsigned char b = row[x * 4 + b_offset]; int hw = height * width; int offset = y * width + x; output[0 * hw + offset] = __float2half(r * kScale255Inv); output[1 * hw + offset] = __float2half(g * kScale255Inv); output[2 * hw + offset] = __float2half(b * kScale255Inv); } // Kernel: NCHW float16 -> 4-channel HWC uint8 (set alpha to 255) // NOTE: r_offset, g_offset, b_offset, a_offset must be validated by host before launch (range [0,3]). // Device-side bounds checking would add branching overhead to every pixel - not worth it. __global__ void nchw_float16_to_hwc4_uint8_kernel( const __half* __restrict__ input, unsigned char* __restrict__ output, int height, int width, int output_linesize, int r_offset, int g_offset, int b_offset, int a_offset) { int x = blockIdx.x * blockDim.x + threadIdx.x; int y = blockIdx.y * blockDim.y + threadIdx.y; if (x >= width || y >= height) return; // Host is responsible for validating offsets are in [0,3] int hw = height * width; int offset = y * width + x; __half r = input[0 * hw + offset]; __half g = input[1 * hw + offset]; __half b = input[2 * hw + offset]; unsigned char* row = output + y * output_linesize; row[x * 4 + r_offset] = half_to_uint8_safe(r); row[x * 4 + g_offset] = half_to_uint8_safe(g); row[x * 4 + b_offset] = half_to_uint8_safe(b); row[x * 4 + a_offset] = 255; } // ============================================================================ // BF16 (bfloat16) variants // ============================================================================ // Kernel: HWC uint8 [0,255] -> NCHW bfloat16 [0,1] __global__ void hwc_uint8_to_nchw_bfloat16_kernel( const unsigned char* __restrict__ input, __nv_bfloat16* __restrict__ output, int height, int width, int input_linesize) { int x = blockIdx.x * blockDim.x + threadIdx.x; int y = blockIdx.y * blockDim.y + threadIdx.y; if (x >= width || y >= height) return; const unsigned char* row = input + y * input_linesize; unsigned char r = row[x * 3 + 0]; unsigned char g = row[x * 3 + 1]; unsigned char b = row[x * 3 + 2]; int hw = height * width; int offset = y * width + x; output[0 * hw + offset] = __float2bfloat16(r * kScale255Inv); output[1 * hw + offset] = __float2bfloat16(g * kScale255Inv); output[2 * hw + offset] = __float2bfloat16(b * kScale255Inv); } // Helper: Convert bfloat16 to uint8 safely __device__ __forceinline__ unsigned char bfloat16_to_uint8_safe(__nv_bfloat16 val) { float f = __bfloat162float(val); if (!isfinite(f)) return 0; f = f * 255.0f + 0.5f; f = fminf(fmaxf(f, 0.0f), 255.0f); return (unsigned char)f; } // Kernel: NCHW bfloat16 [0,1] -> HWC uint8 [0,255] __global__ void nchw_bfloat16_to_hwc_uint8_kernel( const __nv_bfloat16* __restrict__ input, unsigned char* __restrict__ output, int height, int width, int output_linesize) { int x = blockIdx.x * blockDim.x + threadIdx.x; int y = blockIdx.y * blockDim.y + threadIdx.y; if (x >= width || y >= height) return; int hw = height * width; int offset = y * width + x; __nv_bfloat16 r = input[0 * hw + offset]; __nv_bfloat16 g = input[1 * hw + offset]; __nv_bfloat16 b = input[2 * hw + offset]; unsigned char* row = output + y * output_linesize; row[x * 3 + 0] = bfloat16_to_uint8_safe(r); row[x * 3 + 1] = bfloat16_to_uint8_safe(g); row[x * 3 + 2] = bfloat16_to_uint8_safe(b); } // Kernel: 4-channel HWC uint8 -> NCHW bfloat16 (extract RGB, ignore alpha) // NOTE: r_offset, g_offset, b_offset must be validated by host before launch (range [0,3]). // Device-side bounds checking would add branching overhead to every pixel - not worth it. __global__ void hwc4_uint8_to_nchw_bfloat16_kernel( const unsigned char* __restrict__ input, __nv_bfloat16* __restrict__ output, int height, int width, int input_linesize, int r_offset, int g_offset, int b_offset) { int x = blockIdx.x * blockDim.x + threadIdx.x; int y = blockIdx.y * blockDim.y + threadIdx.y; if (x >= width || y >= height) return; // Host is responsible for validating offsets are in [0,3] const unsigned char* row = input + y * input_linesize; unsigned char r = row[x * 4 + r_offset]; unsigned char g = row[x * 4 + g_offset]; unsigned char b = row[x * 4 + b_offset]; int hw = height * width; int offset = y * width + x; output[0 * hw + offset] = __float2bfloat16(r * kScale255Inv); output[1 * hw + offset] = __float2bfloat16(g * kScale255Inv); output[2 * hw + offset] = __float2bfloat16(b * kScale255Inv); } // Kernel: NCHW bfloat16 -> 4-channel HWC uint8 (set alpha to 255) // NOTE: r_offset, g_offset, b_offset, a_offset must be validated by host before launch (range [0,3]). // Device-side bounds checking would add branching overhead to every pixel - not worth it. __global__ void nchw_bfloat16_to_hwc4_uint8_kernel( const __nv_bfloat16* __restrict__ input, unsigned char* __restrict__ output, int height, int width, int output_linesize, int r_offset, int g_offset, int b_offset, int a_offset) { int x = blockIdx.x * blockDim.x + threadIdx.x; int y = blockIdx.y * blockDim.y + threadIdx.y; if (x >= width || y >= height) return; // Host is responsible for validating offsets are in [0,3] int hw = height * width; int offset = y * width + x; __nv_bfloat16 r = input[0 * hw + offset]; __nv_bfloat16 g = input[1 * hw + offset]; __nv_bfloat16 b = input[2 * hw + offset]; unsigned char* row = output + y * output_linesize; row[x * 4 + r_offset] = bfloat16_to_uint8_safe(r); row[x * 4 + g_offset] = bfloat16_to_uint8_safe(g); row[x * 4 + b_offset] = bfloat16_to_uint8_safe(b); row[x * 4 + a_offset] = 255; } } // extern "C" ================================================ FILE: tools/patches/dnn_cuda_kernels.h ================================================ /* * Copyright 2026 Joshua V. Dillon * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * CUDA kernel PTX declarations for DNN backend format conversion. * Kernels are compiled to PTX at build time and loaded via Driver API at runtime. * This avoids any CUDA runtime (cudart) dependency. */ #ifndef AVFILTER_DNN_CUDA_KERNELS_H #define AVFILTER_DNN_CUDA_KERNELS_H #include /* PTX bytecode embedded at compile time via bin2c */ extern const unsigned char ff_dnn_cuda_kernels_ptx[]; extern const unsigned int ff_dnn_cuda_kernels_ptx_len; /* Kernel names within the PTX module */ /* FP32 variants */ #define DNN_CUDA_KERNEL_HWC_UINT8_TO_NCHW_FLOAT32 "hwc_uint8_to_nchw_float32_kernel" #define DNN_CUDA_KERNEL_NCHW_FLOAT32_TO_HWC_UINT8 "nchw_float32_to_hwc_uint8_kernel" #define DNN_CUDA_KERNEL_HWC4_UINT8_TO_NCHW_FLOAT32 "hwc4_uint8_to_nchw_float32_kernel" #define DNN_CUDA_KERNEL_NCHW_FLOAT32_TO_HWC4_UINT8 "nchw_float32_to_hwc4_uint8_kernel" /* FP16 variants */ #define DNN_CUDA_KERNEL_HWC_UINT8_TO_NCHW_FLOAT16 "hwc_uint8_to_nchw_float16_kernel" #define DNN_CUDA_KERNEL_NCHW_FLOAT16_TO_HWC_UINT8 "nchw_float16_to_hwc_uint8_kernel" #define DNN_CUDA_KERNEL_HWC4_UINT8_TO_NCHW_FLOAT16 "hwc4_uint8_to_nchw_float16_kernel" #define DNN_CUDA_KERNEL_NCHW_FLOAT16_TO_HWC4_UINT8 "nchw_float16_to_hwc4_uint8_kernel" /* BF16 variants */ #define DNN_CUDA_KERNEL_HWC_UINT8_TO_NCHW_BFLOAT16 "hwc_uint8_to_nchw_bfloat16_kernel" #define DNN_CUDA_KERNEL_NCHW_BFLOAT16_TO_HWC_UINT8 "nchw_bfloat16_to_hwc_uint8_kernel" #define DNN_CUDA_KERNEL_HWC4_UINT8_TO_NCHW_BFLOAT16 "hwc4_uint8_to_nchw_bfloat16_kernel" #define DNN_CUDA_KERNEL_NCHW_BFLOAT16_TO_HWC4_UINT8 "nchw_bfloat16_to_hwc4_uint8_kernel" #endif /* AVFILTER_DNN_CUDA_KERNELS_H */ ================================================ FILE: tools/patches/vf_dnn_processing.c ================================================ /* * Copyright (c) 2019 Guo Yejun * Copyright (c) 2026 Joshua V. Dillon (TensorRT/Torch backend integration, CUDA hw frame support) * * This file is part of FFmpeg. * * FFmpeg is free software; you can redistribute it and/or * modify it under the terms of the GNU Lesser General Public * License as published by the Free Software Foundation; either * version 2.1 of the License, or (at your option) any later version. * * FFmpeg is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with FFmpeg; if not, write to the Free Software * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA */ /** * @file * implementing a generic image processing filter using deep learning networks. */ #include "libavutil/opt.h" #include "libavutil/pixdesc.h" #include "libavutil/avassert.h" #include "libavutil/imgutils.h" #include "libavutil/hwcontext.h" #include "libavutil/hwcontext_cuda.h" #include "filters.h" #include "dnn_filter_common.h" #include "video.h" #include "libswscale/swscale.h" #include "libavutil/time.h" typedef struct DnnProcessingContext { const AVClass *class; DnnContext dnnctx; struct SwsContext *sws_uv_scale; int sws_uv_height; AVBufferRef *hw_frames_ctx; // For CUDA output frames } DnnProcessingContext; #define OFFSET(x) offsetof(DnnProcessingContext, dnnctx.x) #define FLAGS AV_OPT_FLAG_FILTERING_PARAM | AV_OPT_FLAG_VIDEO_PARAM static const AVOption dnn_processing_options[] = { { "dnn_backend", "DNN backend", OFFSET(backend_type), AV_OPT_TYPE_INT, { .i64 = DNN_TF }, INT_MIN, INT_MAX, FLAGS, .unit = "backend" }, #if (CONFIG_LIBTENSORFLOW == 1) { "tensorflow", "tensorflow backend flag", 0, AV_OPT_TYPE_CONST, { .i64 = DNN_TF }, 0, 0, FLAGS, .unit = "backend" }, #endif #if (CONFIG_LIBOPENVINO == 1) { "openvino", "openvino backend flag", 0, AV_OPT_TYPE_CONST, { .i64 = DNN_OV }, 0, 0, FLAGS, .unit = "backend" }, #endif #if (CONFIG_LIBTORCH == 1) { "torch", "torch backend flag", 0, AV_OPT_TYPE_CONST, { .i64 = DNN_TH }, 0, 0, FLAGS, .unit = "backend" }, #endif #if (CONFIG_LIBTENSORRT == 1) { "tensorrt", "tensorrt backend flag", 0, AV_OPT_TYPE_CONST, { .i64 = DNN_TRT }, 0, 0, FLAGS, .unit = "backend" }, #endif { NULL } }; AVFILTER_DNN_DEFINE_CLASS(dnn_processing, DNN_TF | DNN_OV | DNN_TH | DNN_TRT); static av_cold int init(AVFilterContext *context) { DnnProcessingContext *ctx = context->priv; return ff_dnn_init(&ctx->dnnctx, DFT_PROCESS_FRAME, context); } static const enum AVPixelFormat pix_fmts[] = { AV_PIX_FMT_RGB24, AV_PIX_FMT_BGR24, AV_PIX_FMT_GRAY8, AV_PIX_FMT_GRAYF32, AV_PIX_FMT_YUV420P, AV_PIX_FMT_YUV422P, AV_PIX_FMT_YUV444P, AV_PIX_FMT_YUV410P, AV_PIX_FMT_YUV411P, AV_PIX_FMT_NV12, AV_PIX_FMT_CUDA, // CUDA hardware frames for zero-copy GPU inference AV_PIX_FMT_NONE }; #define LOG_FORMAT_CHANNEL_MISMATCH() \ av_log(ctx, AV_LOG_ERROR, \ "the frame's format %s does not match " \ "the model input channel %d\n", \ av_get_pix_fmt_name(fmt), \ model_input->dims[dnn_get_channel_idx_by_layout(model_input->layout)]); static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLink *inlink) { AVFilterContext *ctx = inlink->dst; enum AVPixelFormat fmt = inlink->format; int width_idx, height_idx; width_idx = dnn_get_width_idx_by_layout(model_input->layout); height_idx = dnn_get_height_idx_by_layout(model_input->layout); // the design is to add explicit scale filter before this filter if (model_input->dims[height_idx] != -1 && model_input->dims[height_idx] != inlink->h) { av_log(ctx, AV_LOG_ERROR, "the model requires frame height %d but got %d\n", model_input->dims[height_idx], inlink->h); return AVERROR(EIO); } if (model_input->dims[width_idx] != -1 && model_input->dims[width_idx] != inlink->w) { av_log(ctx, AV_LOG_ERROR, "the model requires frame width %d but got %d\n", model_input->dims[width_idx], inlink->w); return AVERROR(EIO); } if (model_input->dt != DNN_FLOAT) { avpriv_report_missing_feature(ctx, "data type rather than DNN_FLOAT"); return AVERROR(EIO); } switch (fmt) { case AV_PIX_FMT_RGB24: case AV_PIX_FMT_BGR24: if (model_input->dims[dnn_get_channel_idx_by_layout(model_input->layout)] != 3) { LOG_FORMAT_CHANNEL_MISMATCH(); return AVERROR(EIO); } return 0; case AV_PIX_FMT_GRAY8: case AV_PIX_FMT_GRAYF32: case AV_PIX_FMT_YUV420P: case AV_PIX_FMT_YUV422P: case AV_PIX_FMT_YUV444P: case AV_PIX_FMT_YUV410P: case AV_PIX_FMT_YUV411P: case AV_PIX_FMT_NV12: if (model_input->dims[dnn_get_channel_idx_by_layout(model_input->layout)] != 1) { LOG_FORMAT_CHANNEL_MISMATCH(); return AVERROR(EIO); } return 0; case AV_PIX_FMT_CUDA: // CUDA frames: torch backend handles conversion internally // Model expects 3 channels (RGB) if (model_input->dims[dnn_get_channel_idx_by_layout(model_input->layout)] != 3) { LOG_FORMAT_CHANNEL_MISMATCH(); return AVERROR(EIO); } return 0; default: avpriv_report_missing_feature(ctx, "%s", av_get_pix_fmt_name(fmt)); return AVERROR(EIO); } } static int config_input(AVFilterLink *inlink) { AVFilterContext *context = inlink->dst; DnnProcessingContext *ctx = context->priv; int result; DNNData model_input; int check; result = ff_dnn_get_input(&ctx->dnnctx, &model_input); if (result != 0) { av_log(ctx, AV_LOG_ERROR, "could not get input from the model\n"); return result; } check = check_modelinput_inlink(&model_input, inlink); if (check != 0) { return check; } return 0; } static av_always_inline int isPlanarYUV(enum AVPixelFormat pix_fmt) { const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(pix_fmt); if (!desc) return 0; return !(desc->flags & AV_PIX_FMT_FLAG_RGB) && desc->nb_components == 3; } static int prepare_uv_scale(AVFilterLink *outlink) { AVFilterContext *context = outlink->src; DnnProcessingContext *ctx = context->priv; AVFilterLink *inlink = context->inputs[0]; enum AVPixelFormat fmt = inlink->format; if (isPlanarYUV(fmt)) { if (inlink->w != outlink->w || inlink->h != outlink->h) { if (fmt == AV_PIX_FMT_NV12) { ctx->sws_uv_scale = sws_getContext(inlink->w >> 1, inlink->h >> 1, AV_PIX_FMT_YA8, outlink->w >> 1, outlink->h >> 1, AV_PIX_FMT_YA8, SWS_BICUBIC, NULL, NULL, NULL); if (!ctx->sws_uv_scale) { av_log(context, AV_LOG_ERROR, "Failed to create UV scale context for NV12\n"); return AVERROR(ENOMEM); } ctx->sws_uv_height = inlink->h >> 1; } else { const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(fmt); int sws_src_h, sws_src_w, sws_dst_h, sws_dst_w; if (!desc) { av_log(context, AV_LOG_ERROR, "Unknown pixel format %d\n", fmt); return AVERROR(EINVAL); } sws_src_h = AV_CEIL_RSHIFT(inlink->h, desc->log2_chroma_h); sws_src_w = AV_CEIL_RSHIFT(inlink->w, desc->log2_chroma_w); sws_dst_h = AV_CEIL_RSHIFT(outlink->h, desc->log2_chroma_h); sws_dst_w = AV_CEIL_RSHIFT(outlink->w, desc->log2_chroma_w); ctx->sws_uv_scale = sws_getContext(sws_src_w, sws_src_h, AV_PIX_FMT_GRAY8, sws_dst_w, sws_dst_h, AV_PIX_FMT_GRAY8, SWS_BICUBIC, NULL, NULL, NULL); if (!ctx->sws_uv_scale) { av_log(context, AV_LOG_ERROR, "Failed to create UV scale context\n"); return AVERROR(ENOMEM); } ctx->sws_uv_height = sws_src_h; } } } return 0; } static int config_output(AVFilterLink *outlink) { AVFilterContext *context = outlink->src; DnnProcessingContext *ctx = context->priv; FilterLink *ol = ff_filter_link(outlink); FilterLink *il = ff_filter_link(context->inputs[0]); int result; AVFilterLink *inlink = context->inputs[0]; // have a try run in case that the dnn model resize the frame result = ff_dnn_get_output(&ctx->dnnctx, inlink->w, inlink->h, &outlink->w, &outlink->h); if (result != 0) { av_log(ctx, AV_LOG_ERROR, "could not get output from the model\n"); return result; } // Handle CUDA frames - set up output hw_frames_ctx if (inlink->format == AV_PIX_FMT_CUDA && il->hw_frames_ctx) { AVHWFramesContext *in_frames_ctx = (AVHWFramesContext *)il->hw_frames_ctx->data; AVHWFramesContext *out_frames_ctx; ctx->hw_frames_ctx = av_hwframe_ctx_alloc(in_frames_ctx->device_ref); if (!ctx->hw_frames_ctx) return AVERROR(ENOMEM); out_frames_ctx = (AVHWFramesContext *)ctx->hw_frames_ctx->data; out_frames_ctx->format = AV_PIX_FMT_CUDA; out_frames_ctx->sw_format = in_frames_ctx->sw_format; out_frames_ctx->width = outlink->w; out_frames_ctx->height = outlink->h; result = av_hwframe_ctx_init(ctx->hw_frames_ctx); if (result < 0) { av_buffer_unref(&ctx->hw_frames_ctx); return result; } ol->hw_frames_ctx = av_buffer_ref(ctx->hw_frames_ctx); if (!ol->hw_frames_ctx) { av_buffer_unref(&ctx->hw_frames_ctx); return AVERROR(ENOMEM); } av_log(context, AV_LOG_INFO, "CUDA output frames: %dx%d\n", outlink->w, outlink->h); return 0; } result = prepare_uv_scale(outlink); if (result < 0) return result; return 0; } static int copy_uv_planes(DnnProcessingContext *ctx, AVFrame *out, const AVFrame *in) { const AVPixFmtDescriptor *desc; int uv_height; if (!ctx->sws_uv_scale) { av_assert0(in->height == out->height && in->width == out->width); desc = av_pix_fmt_desc_get(in->format); if (!desc) return AVERROR(EINVAL); uv_height = AV_CEIL_RSHIFT(in->height, desc->log2_chroma_h); for (int i = 1; i < 3; ++i) { int bytewidth = av_image_get_linesize(in->format, in->width, i); if (bytewidth < 0) { return AVERROR(EINVAL); } av_image_copy_plane(out->data[i], out->linesize[i], in->data[i], in->linesize[i], bytewidth, uv_height); } } else if (in->format == AV_PIX_FMT_NV12) { int ret = sws_scale(ctx->sws_uv_scale, (const uint8_t **)(in->data + 1), in->linesize + 1, 0, ctx->sws_uv_height, out->data + 1, out->linesize + 1); if (ret < 0) return ret; } else { int ret = sws_scale(ctx->sws_uv_scale, (const uint8_t **)(in->data + 1), in->linesize + 1, 0, ctx->sws_uv_height, out->data + 1, out->linesize + 1); if (ret < 0) return ret; ret = sws_scale(ctx->sws_uv_scale, (const uint8_t **)(in->data + 2), in->linesize + 2, 0, ctx->sws_uv_height, out->data + 2, out->linesize + 2); if (ret < 0) return ret; } return 0; } static int flush_frame(AVFilterLink *outlink, int64_t pts, int64_t *out_pts) { DnnProcessingContext *ctx = outlink->src->priv; int ret; DNNAsyncStatusType async_state; ret = ff_dnn_flush(&ctx->dnnctx); if (ret != 0) { return ret; } do { AVFrame *in_frame = NULL; AVFrame *out_frame = NULL; async_state = ff_dnn_get_result(&ctx->dnnctx, &in_frame, &out_frame); if (out_frame) { int64_t frame_pts = out_frame->pts; // Save before ff_filter_frame may free if (in_frame && isPlanarYUV(in_frame->format)) { ret = copy_uv_planes(ctx, out_frame, in_frame); if (ret < 0) { av_frame_free(&in_frame); av_frame_free(&out_frame); return ret; } } av_frame_free(&in_frame); ret = ff_filter_frame(outlink, out_frame); if (ret < 0) return ret; if (out_pts) *out_pts = frame_pts + pts; } av_usleep(5000); } while (async_state >= DAST_NOT_READY); return 0; } static int activate(AVFilterContext *filter_ctx) { AVFilterLink *inlink = filter_ctx->inputs[0]; AVFilterLink *outlink = filter_ctx->outputs[0]; DnnProcessingContext *ctx = filter_ctx->priv; AVFrame *in = NULL, *out = NULL; int64_t pts; int ret, status; int got_frame = 0; int async_state; FF_FILTER_FORWARD_STATUS_BACK(outlink, inlink); do { // drain all input frames ret = ff_inlink_consume_frame(inlink, &in); if (ret < 0) return ret; if (ret > 0) { // Allocate CUDA frames for CUDA input, CPU frames otherwise if (in->format == AV_PIX_FMT_CUDA && ctx->hw_frames_ctx) { out = av_frame_alloc(); if (!out) { av_frame_free(&in); return AVERROR(ENOMEM); } ret = av_hwframe_get_buffer(ctx->hw_frames_ctx, out, 0); if (ret < 0) { av_frame_free(&out); av_frame_free(&in); return ret; } } else { out = ff_get_video_buffer(outlink, outlink->w, outlink->h); if (!out) { av_frame_free(&in); return AVERROR(ENOMEM); } } ret = av_frame_copy_props(out, in); if (ret < 0) { av_frame_free(&in); av_frame_free(&out); return ret; } if (ff_dnn_execute_model(&ctx->dnnctx, in, out) != 0) { av_log(ctx, AV_LOG_ERROR, "DNN model execution failed\n"); av_frame_free(&in); av_frame_free(&out); return AVERROR(EIO); } } } while (ret > 0); // drain all processed frames do { AVFrame *in_frame = NULL; AVFrame *out_frame = NULL; async_state = ff_dnn_get_result(&ctx->dnnctx, &in_frame, &out_frame); if (out_frame) { if (in_frame && isPlanarYUV(in_frame->format)) { ret = copy_uv_planes(ctx, out_frame, in_frame); if (ret < 0) { av_frame_free(&in_frame); av_frame_free(&out_frame); return ret; } } av_frame_free(&in_frame); ret = ff_filter_frame(outlink, out_frame); if (ret < 0) return ret; got_frame = 1; } } while (async_state == DAST_SUCCESS); // if frame got, schedule to next filter if (got_frame) return 0; if (ff_inlink_acknowledge_status(inlink, &status, &pts)) { if (status == AVERROR_EOF) { int64_t out_pts = pts; ret = flush_frame(outlink, pts, &out_pts); ff_outlink_set_status(outlink, status, out_pts); return ret; } } FF_FILTER_FORWARD_WANTED(outlink, inlink); return 0; } static av_cold void uninit(AVFilterContext *ctx) { DnnProcessingContext *context = ctx->priv; sws_freeContext(context->sws_uv_scale); av_buffer_unref(&context->hw_frames_ctx); ff_dnn_uninit(&context->dnnctx); } static const AVFilterPad dnn_processing_inputs[] = { { .name = "default", .type = AVMEDIA_TYPE_VIDEO, .config_props = config_input, }, }; static const AVFilterPad dnn_processing_outputs[] = { { .name = "default", .type = AVMEDIA_TYPE_VIDEO, .config_props = config_output, }, }; const FFFilter ff_vf_dnn_processing = { .p.name = "dnn_processing", .p.description = NULL_IF_CONFIG_SMALL("Apply DNN processing filter to the input."), .p.priv_class = &dnn_processing_class, .priv_size = sizeof(DnnProcessingContext), .preinit = ff_dnn_filter_init_child_class, .init = init, .uninit = uninit, FILTER_INPUTS(dnn_processing_inputs), FILTER_OUTPUTS(dnn_processing_outputs), FILTER_PIXFMTS_ARRAY(pix_fmts), .activate = activate, .flags_internal = FF_FILTER_FLAG_HWFRAME_AWARE, }; ================================================ FILE: tools/uninstall-netv.sh ================================================ #!/bin/bash # Uninstall netv systemd service # # Usage: sudo ./uninstall-netv.sh set -e if [ "$EUID" -ne 0 ]; then echo "Error: Run with sudo" echo "Usage: sudo $0" exit 1 fi echo "=== Uninstalling netv ===" if systemctl is-active --quiet netv 2>/dev/null; then echo "Stopping netv service..." systemctl stop netv fi if systemctl is-enabled --quiet netv 2>/dev/null; then echo "Disabling netv service..." systemctl disable netv fi if [ -f /etc/systemd/system/netv.service ]; then echo "Removing service file..." rm /etc/systemd/system/netv.service systemctl daemon-reload fi if [ -f /etc/letsencrypt/renewal-hooks/deploy/netv ]; then echo "Removing certbot hook..." rm /etc/letsencrypt/renewal-hooks/deploy/netv fi echo "" echo "=== Done ===" echo "" echo "The netv service has been removed." echo "Project files and cache remain in place - delete manually if desired." ================================================ FILE: tools/xtream2m3u.py ================================================ #!/usr/bin/env python3 # pyright: reportUnknownVariableType=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportImplicitStringConcatenation=false, reportUnknownParameterType=false # M3U Details: # https://github.com/HamzaBhf00/m3u-tags-iptv # Xtream Codes: # https://github.com/engenex/xtream-codes-api-v2/blob/main/%5BHow-To%5D%20Player%20API%20v2%20-%20Tutorials%20-%20Xtream%20Codes.pdf from __future__ import annotations from typing import Any, Protocol import collections import concurrent.futures import functools import gzip # lzma(80%), bz2(78%), gzip(75%) but gzip was fastest. import json import math import pathlib import pickle import shutil import threading import time import urllib import urllib.error import urllib.parse import urllib.request class RetryableError(Exception): pass TOOLS_DIR = pathlib.Path(__file__).parent.resolve() CONFIG_FILE = TOOLS_DIR / "xtream.json" TEMPDIR = TOOLS_DIR DESTDIR = TOOLS_DIR def _load_config() -> dict: """Load config from xtream.json, creating template if missing.""" if not CONFIG_FILE.exists(): template = { "url": "https://your-provider.com", "username": "your_username", "password": "your_password", "live_filter": {}, "locals_group": "", "locals_filter": [], } CONFIG_FILE.write_text(json.dumps(template, indent=2)) raise SystemExit(f"Created {CONFIG_FILE} - edit with your credentials and re-run") return json.loads(CONFIG_FILE.read_text()) def _get_urls() -> tuple[str, str, str]: """Return (api_url, get_url, epg_url) from config.""" cfg = _load_config() base = cfg["url"].rstrip("/") user, passwd = cfg["username"], cfg["password"] api = f"{base}/player_api.php?username={user}&password={passwd}" get = f"{base}/get.php?username={user}&password={passwd}" epg = f"{base}/xmltv.php?username={user}&password={passwd}" return api, get, epg def _get_filters() -> tuple[dict[int, str], str, set[str]]: """Return (live_filter, locals_group, locals_filter) from config.""" cfg = _load_config() live_filter = {int(k): v for k, v in cfg.get("live_filter", {}).items()} locals_group = cfg.get("locals_group", "") locals_filter = set(cfg.get("locals_filter", [])) return live_filter, locals_group, locals_filter def main(cached_only: bool = False) -> None: api_url, _, epg_url = _get_urls() live_filter, locals_group, locals_filter = _get_filters() if not cached_only: fetch_all_data(api_url) auth = load_dict("authentication.json") iptv_url = process_iptv_url(auth) live, live_categories = process( load_list("get_live_stream.json"), load_list("get_live_categories.json"), iptv_url, ) del live_categories live = filter_live(live, live_filter, locals_group, locals_filter) write_m3u_live(live, auth, epg_url) vod_url = list(iptv_url) vod_url.insert(2, "movie") vod, vod_categories = process( load_list("get_vod_streams.json"), load_list("get_vod_categories.json"), vod_url, ) del vod_categories write_m3u_vod(vod, auth) series_url = list(iptv_url) series_url.insert(2, "series") series, series_categories = process( load_list("get_series.json"), load_list("get_series_categories.json"), series_url, ) del series_categories series_info = fetch_series_info(series, api_url, cached_only=cached_only) write_m3u_series(series, series_info, auth, series_url) ############################################################################### # ____ __ _ # # | _ \ ___ / _| _ __ ___ ___ | |__ # # | |_) | / _ \ | |_ | '__| / _ \ / __| | '_ \ # # | _ < | __/ | _| | | | __/ \__ \ | | | | # # |_| \_\ \___| |_| |_| \___| |___/ |_| |_| # # # ############################################################################### def fetch_all_data(api_url: str) -> None: if False: # Intentionally disabled debug code r = fetch_text(api_url + "&type=m3u_plus").encode("utf-8") # pyright: ignore[reportUnreachable] with gzip.open(TEMPDIR / "xtream.m3u.gz", "wb") as f: f.write(r) print("Fetching authentication...", end=" ", flush=True) t0 = time.perf_counter() r = fetch_text(api_url) with open(TEMPDIR / "authentication.json", "w") as f: f.write(r) print(f"({time.perf_counter() - t0:.1f}s)") print("Fetching live streams...", end=" ", flush=True) t0 = time.perf_counter() r = fetch_text(api_url + "&action=get_live_streams", timeout=120) with open(TEMPDIR / "get_live_stream.json", "w") as f: f.write(r) print(f"({time.perf_counter() - t0:.1f}s)") print("Fetching live categories...", end=" ", flush=True) t0 = time.perf_counter() r = fetch_text(api_url + "&action=get_live_categories") with open(TEMPDIR / "get_live_categories.json", "w") as f: f.write(r) print(f"({time.perf_counter() - t0:.1f}s)") print("Fetching series...", end=" ", flush=True) t0 = time.perf_counter() r = fetch_text(api_url + "&action=get_series", timeout=120) with open(TEMPDIR / "get_series.json", "w") as f: f.write(r) print(f"({time.perf_counter() - t0:.1f}s)") print("Fetching series categories...", end=" ", flush=True) t0 = time.perf_counter() r = fetch_text(api_url + "&action=get_series_categories") with open(TEMPDIR / "get_series_categories.json", "w") as f: f.write(r) print(f"({time.perf_counter() - t0:.1f}s)") print("Fetching VOD streams...", end=" ", flush=True) t0 = time.perf_counter() r = fetch_text(api_url + "&action=get_vod_streams", timeout=120) with open(TEMPDIR / "get_vod_streams.json", "w") as f: f.write(r) print(f"({time.perf_counter() - t0:.1f}s)") print("Fetching VOD categories...", end=" ", flush=True) t0 = time.perf_counter() r = fetch_text(api_url + "&action=get_vod_categories") with open(TEMPDIR / "get_vod_categories.json", "w") as f: f.write(r) print(f"({time.perf_counter() - t0:.1f}s)") def fetch_text(url: str, timeout: int = 5) -> str: parsed = urllib.parse.urlparse(url) if parsed.scheme not in ("http", "https"): raise ValueError(f"Unsupported URL scheme: {parsed.scheme}") try: with urllib.request.urlopen(url, timeout=timeout) as response: return response.read().decode("utf-8") except urllib.error.HTTPError as e: if e.code == 429: raise RetryableError(f"Unable to get {url}; http error {e.code}.") from e raise ValueError(f"Unable to get {url}; http error {e.code}.") from e except (urllib.error.URLError, TimeoutError) as e: reason = e.reason if isinstance(e, urllib.error.URLError) else str(e) raise RetryableError(f"Unable to get {url}; timeout {reason}.") from e def fetch_series_info( series: dict[int, dict[str, Any]], api_url: str, cached_only: bool = False, ) -> dict[int, Any]: series_info: dict[int, None | Any] = {} try: with gzip.open(TEMPDIR / "series_info.pickle.gz", "rb") as f: series_info = pickle.load(f) except Exception as e: print(f"Cache miss: {e}") series_info = dict.fromkeys(series.keys()) if cached_only: return series_info changed = False refetch_count = 0 for k in series: series_info.setdefault(k, None) try: t = int(series_info[k]["info"]["last_modified"]) # pyright: ignore[reportOptionalSubscript] except (KeyError, TypeError, ValueError): t = -1 if series[k]["last_modified"] > t: refetch_count += 1 series_info[k] = None print(f"Marked {refetch_count}/{len(series)} series for re/fetch.") for k in tuple(series_info.keys()): if k in series: continue changed = True del series_info[k] progress_lock = threading.Lock() progress_count = sum(v is not None for v in series_info.values()) limiter = SlidingRateLimiter(max_calls=4, per_seconds=1) task_ = functools.partial( _task, limiter=limiter, series_info=series_info, api_url=api_url, progress_lock=progress_lock, progress_count_ref=[progress_count], ) retries = -1 max_retries = 3 max_workers = math.ceil(1.5 * limiter.max_calls) while (retries := retries + 1) < max_retries and ( ids := [k for k, v in series_info.items() if v is None] ): changed = True executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) try: list(executor.map(task_, ids)) executor.shutdown(wait=True) except KeyboardInterrupt: print("\nCancelling...") executor.shutdown(wait=False, cancel_futures=True) raise if changed: pickle_filename = TEMPDIR / "series_info.pickle.gz" with gzip.open(pickle_filename.with_suffix(".tmp"), "wb") as f: pickle.dump(series_info, f) shutil.move(pickle_filename.with_suffix(".tmp"), pickle_filename) return series_info class RateLimiter(Protocol): def __init__(self, max_calls: int, per_seconds: float = 1): ... def acquire(self) -> None: ... def _task( id_: int, limiter: RateLimiter, series_info: dict[int, Any], api_url: str, progress_lock: threading.Lock, progress_count_ref: list[int], ) -> None: try: limiter.acquire() result = json.loads( fetch_text( url=f"{api_url}&action=get_series_info&series_id={id_}", timeout=60, ) ) series_info[id_] = result with progress_lock: progress_count_ref[0] += 1 print_progress_bar( iteration=progress_count_ref[0], total=len(series_info), ) except RetryableError as e: print(e) except (json.JSONDecodeError, ValueError, KeyError) as e: print(e) series_info[id_] = {} def filter_live( live: dict[int, dict[str, Any]], live_filter: dict[int, str], locals_group: str, locals_filter: set[str], ) -> dict[int, dict[str, Any]]: if live_filter: live_ = collections.defaultdict(dict) for k, v in live.items(): if not any(c in live_filter for c in v["category_ids"]): continue if len(v["group-title"]) != 1: raise ValueError(f"Expected single group-title, got {v['group-title']}") live_[v["group-title"][0]][k] = v live = {} for v in live_filter.values(): live.update(live_[v]) if locals_group and locals_filter: live = { k: v for k, v in live.items() if ( locals_group not in v["group-title"] or any(c in v["tvg-name"] for c in locals_filter) ) } return live ############################################################################### # ____ # # | _ \ __ _ _ __ ___ ___ # # | |_) | / _` | | '__| / __| / _ \ # # | __/ | (_| | | | \__ \ | __/ # # |_| \__,_| |_| |___/ \___| # # # ############################################################################### def process( elements: list[dict[str, Any]], categories: list[dict[str, Any]], iptv_url: list[str], ) -> tuple[dict[int, dict[str, Any]], dict[int, Any]]: categories_dict: dict[int, str] = { int(c["category_id"]): c["category_name"] for c in categories } elements_dict: dict[int, dict[str, None | int | str | list[str]]] = {} for s in elements: stream_type = s.get("stream_type") if stream_type in ("live", "radio_streams"): id_ = int(s["stream_id"]) attr = { "tvg-name": s["name"] or s["title"], "tvg-logo": s["stream_icon"], "group-title": [categories_dict[c] for c in s["category_ids"]], "tvg-id": "" if s["epg_channel_id"] is None else s["epg_channel_id"], "url": "/".join([*iptv_url, str(id_)]), "category_ids": s["category_ids"], "year": None, "rating": None, "num": s["num"], "last_modified": None, # 'timeshift': None, ??? } if s["tv_archive"] not in (0, 1): raise ValueError(f"Invalid tv_archive value: {s}") # assert not s["direct_source"], s elif stream_type == "series" or s.get("series_id") is not None: id_ = int(s["series_id"]) attr = { "tvg-name": s["name"] or s["title"], "tvg-logo": s["cover"], "group-title": [categories_dict[c] for c in s["category_ids"]], "tvg-id": None, "url": None, "category_ids": s["category_ids"], "year": toint(s.get("year")), "rating": tofloat(s.get("rating")), "num": s["num"], "last_modified": int(s["last_modified"]), } elif stream_type == "movie": id_ = int(s["stream_id"]) attr = { "tvg-name": s["name"] or s["title"], "tvg-logo": s["stream_icon"], "group-title": [categories_dict[c] for c in s["category_ids"]], "tvg-id": None, "url": "/".join([*iptv_url, f"{id_}.{s['container_extension']}"]), "category_ids": s["category_ids"], "year": toint(s.get("year")), "rating": tofloat(s.get("rating")), "num": s["num"], "last_modified": None, } else: print(f"Unrecognized {stream_type=}: {s}") continue if id_ in elements_dict: raise ValueError(f"Duplicate id {id_}: {attr}") elements_dict[id_] = attr return elements_dict, categories_dict def process_iptv_url(auth: dict[str, dict[str, Any]]) -> list[str]: if (status := auth["user_info"]["status"]) != "Active": raise ValueError(f"Unsupported {status=}.") if (max_connections := int(auth["user_info"]["max_connections"])) < 1: raise ValueError(f"Insufficient {max_connections=}.") if (server_protocol := auth["server_info"]["server_protocol"]) not in ( "http", "https", ): raise ValueError(f"Unrecognized {server_protocol=}.") # We used to respect server protocol but now we just force HTTPS. server_protocol = "https" port_key = server_protocol + "_port" port = auth["server_info"].get(port_key, auth["server_info"][port_key]) return [ f"{server_protocol}:/", # We'll join everything with slashes later. f"{auth['server_info']['url']}:{port}", auth["user_info"]["username"], auth["user_info"]["password"], ] def toint(x: str | None) -> int | None: return int(x) if x else None def tofloat(x: str | None) -> float | None: return float(x) if x else None def load(filename: str) -> Any: with open(TEMPDIR / filename) as f: return json.load(f) def load_dict(filename: str) -> dict[str, Any]: result = load(filename) if not isinstance(result, dict): raise TypeError(f"Expected dict from {filename}, got {type(result)}") return result def load_list(filename: str) -> list[dict[str, Any]]: result = load(filename) if not isinstance(result, list): raise TypeError(f"Expected list from {filename}, got {type(result)}") return result class SlidingRateLimiter: def __init__(self, max_calls: int, per_seconds: float = 1): self.max_calls = max_calls self.per_seconds = per_seconds self.lock = threading.Lock() self.requests = collections.deque() def acquire(self) -> None: while True: with self.lock: cutoff = time.perf_counter() - self.per_seconds while self.requests and self.requests[0] <= cutoff: self.requests.popleft() if len(self.requests) < self.max_calls: self.requests.append(time.perf_counter()) return sleep_time = max(0, self.requests[0] - cutoff) time.sleep(sleep_time) class ChunkingRateLimiter: def __init__(self, max_calls: int, per_seconds: float = 1): self.max_calls = max_calls self.per_seconds = per_seconds self.condition = threading.Condition() self.calls = 0 self.last_reset = time.perf_counter() def acquire(self) -> None: with self.condition: if self.calls >= self.max_calls: now = time.perf_counter() elapsed = now - self.last_reset if elapsed < self.per_seconds: sleep_time = self.per_seconds - elapsed self.condition.wait(timeout=sleep_time) # Basically just, # self.lock.release() # time.sleep(sleep_time) # self.lock.acquire() self.calls = 0 self.last_reset = time.perf_counter() self.calls += 1 def print_progress_bar( iteration: int, total: int, prefix: str = "", suffix: str = "", decimals: int = 1, length: int = 50, fill: str = "█", printEnd: str = "\r", ) -> None: r"""Call in a loop to create terminal progress bar @params: iteration - Required : current iteration (Int) total - Required : total iterations (Int) prefix - Optional : prefix string (Str) suffix - Optional : suffix string (Str) decimals - Optional : positive number of decimals in percent complete (Int) length - Optional : character length of bar (Int) fill - Optional : bar fill character (Str) printEnd - Optional : end character (e.g. "\r", "\r\n") (Str) """ if total == 0: return percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total))) filledLength = int(length * iteration // total) bar = fill * filledLength + "-" * (length - filledLength) print(f"\r{prefix} |{bar}| {percent}% {suffix}", end=printEnd) # Print New Line on Complete if iteration == total: print() def write_m3u_live( live: dict[int, dict[str, Any]], auth: dict[str, dict[str, Any]], epg_url: str, ) -> None: with open(DESTDIR / "live.m3u", "w") as f: print(f'#EXTM3U url-tvg="{epg_url}" x-tvg-url="{epg_url}"', file=f) if auth["server_info"].get("xui") is not None: version = auth["server_info"]["version"] print(f'#EXT-X-SESSION-DATA:DATA-ID="com.xui.{version}"', file=f) for v in live.values(): tvg_name = v["tvg-name"] tvg_logo = v["tvg-logo"] group_title = "TV | " + v["group-title"][0] url = v["url"] tvg_id = v["tvg-id"] print( f'#EXTINF:-1 tvg-id="{tvg_id}" tvg-name="{tvg_name}" ' f'tvg-logo="{tvg_logo}" group-title="{group_title}",{tvg_name}', file=f, ) print(url, file=f) def write_m3u_vod( vod: dict[int, dict[str, Any]], auth: dict[str, dict[str, Any]], ) -> None: with open(DESTDIR / "vod.m3u", "w") as f: print("#EXTM3U", file=f) if auth["server_info"].get("xui") is not None: version = auth["server_info"]["version"] print(f'#EXT-X-SESSION-DATA:DATA-ID="com.xui.{version}"', file=f) for v in vod.values(): tvg_name = v["tvg-name"] tvg_logo = v["tvg-logo"] group_title = "VOD | " + v["group-title"][0] url = v["url"] print( f'#EXTINF:-1 tvg-name="{tvg_name}" tvg-logo="{tvg_logo}" ' f'group-title="{group_title}",{tvg_name}', file=f, ) print(url, file=f) def write_m3u_series( series: dict[int, dict[str, Any]], series_info: dict[int, None | Any], auth: dict[str, dict[str, Any]], series_url: list[str], ) -> None: series_episodes = {} for k in series: info = series_info.get(k) if not info or "episodes" not in info: continue try: series_episodes[k] = list(_descend(info["episodes"])) except Exception as e: print(f"Series {k}: {e}") with open(DESTDIR / "series.m3u", "w") as f: print("#EXTM3U", file=f) if auth["server_info"].get("xui") is not None: version = auth["server_info"]["version"] print(f'#EXT-X-SESSION-DATA:DATA-ID="com.xui.{version}"', file=f) for k, vv in series_episodes.items(): v = series[k] tvg_logo = v["tvg-logo"] group_title = "Series | " + v["group-title"][0] for e in vv: tvg_name = e["title"] url = "/".join([*series_url, f"{e['id']}.{e['container_extension']}"]) print( f'#EXTINF:-1 tvg-name="{tvg_name}" tvg-logo="{tvg_logo}" ' f'group-title="{group_title}",{tvg_name}', file=f, ) print(url, file=f) def _descend(x: Any): if isinstance(x, dict): if "id" in x: yield x else: for x_ in x.values(): yield from _descend(x_) elif isinstance(x, list): for x_ in x: yield from _descend(x_) if __name__ == "__main__": main(cached_only=False) ================================================ FILE: tools/zap2xml.py ================================================ #!/usr/bin/env python3 # pyright: reportUnknownVariableType=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownLambdaType=false """zap2xml.py -- Fetch TV guide data from zap2it/gracenote in XMLTV format. Scrapes the internal JSON feed from zap2it/gracenote to generate XMLTV guide data. The site occasionally returns 400 errors for certain time windows; this tool ignores those and continues fetching available data. Written with only standard library dependencies. Usage: ./zap2xml.py --zip 90210 --days 7 Cron example: 0 0 * * * cd /path/to/tools && ./zap2xml.py --zip 90210 """ from __future__ import annotations from collections.abc import Mapping from typing import Any, ClassVar import argparse import datetime import gzip # lzma(80%), bz2(78%), gzip(75%) but gzip was fastest. import json import math import pathlib import re import sys import time import urllib.error import urllib.parse import urllib.request import xml.etree.ElementTree as xml SECONDS_PER_HOUR = 3_600 SECONDS_PER_DAY = 86_400 # https://en.wikipedia.org/wiki/Call_signs_in_the_United_States#Suffixes # Note: Doesn't correctly handly boosters. _CALLSIGN_REGEX = re.compile(r"^([A-Z]+?)(LD|DT|CD|CA|LP|TV|FM|D)(\d*)$") class Namespace(dict): # pyright: ignore[reportMissingTypeArgument] """Allows a dictionary to be accessed as `x.item` vs. `x['item']`.""" __slots__: ClassVar[tuple[str, ...]] = () __getattr__ = dict.__getitem__ __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ def main() -> None: args = parse_args() working_dir = pathlib.Path(args.path) cache_dir = working_dir / ".zap2xml" if not cache_dir.is_dir(): cache_dir.mkdir() url_flags = {k[len("zap_") :]: v for k, v in vars(args).items() if k.startswith("zap_")} url_flags["lineupId"] = f"{args.zap_country}-{args.zap_headendId}-DEFAULT" # Start time parameter is now rounded down to nearest `zap_timespan`, in s. zap_time = int(datetime.datetime.now().timestamp()) print(f"Local time: {zap_time} {strf_time_int(zap_time)}") zap_time_window = args.zap_timespan * SECONDS_PER_HOUR zap_time = (zap_time // zap_time_window) * zap_time_window print(f"First zap time: {zap_time} {strf_time_int(zap_time)}") remove_stale_cache(cache_dir, zap_time) # https://wiki.xmltv.org/index.php/XMLTVFormat # https://github.com/XMLTV/xmltv/blob/master/xmltv.dtd#L529 out = add_xml_child( parent=None, tag="tv", attrib={ "source-info-url": f"https://{args.base_url}/grid-affiliates.html?aid=gapzap", "source-info-name": "zap2it", "generator-info-name": "zap2xml.py", "generator-info-url": "https://github.com/jvdillon/netv", }, ) channel_map = {} # Only used for debugging. done_channels = False # Fetch data in `zap_timespan` chunks. if args.days > 15: raise ValueError(f"Can only collect at most 15 days; {args.days} too large.") num_fetch = math.ceil(args.days * 24 / args.zap_timespan) for i in range(num_fetch): i_time = zap_time + (i * zap_time_window) print(f"Getting data: {i_time} {strf_time_int(i_time)}") url = f"https://{args.base_url}/api/grid?" url += urllib.parse.urlencode({**url_flags, "time": i_time}) result = get_cached(cache_dir, i_time, args.delay, url) json_result = json.loads(result) if not done_channels: done_channels = True for c_in in json_result["channels"]: # {'affiliateCallSign': 'null', # 'affiliateName': 'AMERICAN BROADCASTING COMPANY', # 'callSign': 'KXTVDT', # 'channelId': '20775', # 'channelNo': '10.1', # 'id': '2077555', # 'stationFilters': ['filter-sports'], # 'stationGenres': [False], # 'thumbnail': '//zap2it.tmsimg.com/h3/NowShowing/20775/s28708_ll_h15_ac.png?w=55'} channel_key = get_channel_key(c_in) channel_display_name = " - ".join( [ c_in["affiliateName"].title(), # Eg, "CATCHY COMEDY" parse_callsign(c_in["callSign"]), # Eg, "KOVR-DT-5" c_in["channelNo"], # Eg., "13.5" ] ) channel_map[channel_key] = channel_display_name c_out = add_xml_child( parent=out, tag="channel", id=channel_key, ) _ = add_xml_child( parent=c_out, tag="display-name", text=channel_display_name, ) _ = add_xml_child( parent=c_out, tag="icon", src=f"https:{c_in['thumbnail'].split('?')[0]}", ) channel_map = dict(sorted(channel_map.items(), key=lambda kv: kv[0])) f = add_programme_tvimate if args.tvimate else add_programme for c_in in json_result["channels"]: channel_key = get_channel_key(c_in) for event in c_in["events"]: f(out, event, channel_key) # https://docs.python.org/3/library/xml.etree.elementtree.html#xml.etree.ElementTree.indent # Note: xml.indent must be done last. xml.indent(out, space="\t", level=0) with pathlib.Path.open((working_dir / "xmltv.xml").resolve(), "wb") as f: f.write(b'\n') f.write(xml.tostring(out, encoding="UTF-8")) sys.exit(0) def get_cached( cache_dir: pathlib.Path, timestamp: int, delay: int, url: str, ) -> bytes: cache_path = (cache_dir / str(timestamp)).with_suffix(".json.gz") if cache_path.is_file(): print(f"Cached: {url}") with gzip.open(cache_path, "rb") as f: return f.read() print(f"Fetching: '{url}'.") if not url.startswith(("http:", "https:")): raise ValueError(f"URL '{url}' must start with 'http:' or 'https:'") from None request = urllib.request.Request( url, headers={ "User-Agent": ( "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 " "(KHTML, like Gecko) Chrome/140.0.0.0 Safari/537.36" ), "Accept": "*/*", "Sec-Ch-Ua": ('"Chromium";v="140", "Not=A?Brand";v="24", "Google Chrome";v="140"'), "Sec-Ch-Ua-Mobile": "?0", "Sec-Ch-Ua-Platform": '"Linux"', # "Accept-Encoding": "br, gzip, deflate, zstd, identity", "Accept-Language": "en-US,en;q=0.9", "Content-Type": "text/plain;charset=UTF-8", "Priority": "u=1, i", }, ) try: response = urllib.request.urlopen(request) result = response.read() except urllib.error.HTTPError as e: if e.code != 400: e.add_note(f'Url is "{url}".') raise print("Got a 400 error! Ignoring it.") result = b'{"note": "Got a 400 error at this time, skipping.","channels": []}' with gzip.open(cache_path, "wb") as f: f.write(result) time.sleep(delay) return result def remove_stale_cache(cache_dir: pathlib.Path, zap_time: int) -> None: for p in sorted(cache_dir.glob("*")): x = Namespace() x.name = p.name x.zap_time = zap_time x.data_time = int(str(p.name).removesuffix("".join(p.suffixes))) x.file_time = int(p.stat().st_mtime) x.is_irrelevant = x.data_time < zap_time x.is_1day_expired = _expired(3, 1, x.data_time, x.file_time, zap_time) x.is_7day_expired = _expired(7, 7, x.data_time, x.file_time, zap_time) if any(v for k, v in x.items() if k.startswith("is_")): x.file_time_str = strf_time_int(x.file_time) x.data_time_str = strf_time_int(x.data_time) s = " ".join(f"{k}={v}" for k, v in x.items()) print(f"Removing stale cache file: {s}") p.unlink() def _expired( data_days: float, file_days: float, data_time: int, file_time: int, zap_time: int, ) -> bool: data_time_within_limit = data_time < zap_time + data_days * SECONDS_PER_DAY file_time_within_limit = zap_time < file_time + file_days * SECONDS_PER_DAY return data_time_within_limit and not file_time_within_limit def add_programme( out: xml.Element, event: Mapping[str, Any], channel_key: str, ) -> None: # {'callSign': 'KCRADT2', 'duration': '30', 'startTime': '2025-04-20T18:00:00Z', 'endTime': '2025-04-20T18:30:00Z', # 'thumbnail': 'p1119901_e_v9_ab', 'channelNo': '3.2', 'filter': [], 'seriesId': 'SH00001996', 'rating': 'TV-G', 'flag': [], 'tags': ['CC'], # 'program': { # 'title': 'Happy Days', # 'id': 'EP000019960180', # 'tmsId': 'EP000019960180', # 'shortDesc': 'Richie is selected to become a contestant on a popular game show with a chance to win $3,200.', # 'season': '2', 'releaseYear': None, 'episode': '9', 'episodeTitle': 'Big Money', 'seriesId': 'SH00001996', 'isGeneric': '0'} # } # https://tvlistings.gracenote.com/overview-affiliates.html?programSeriesId=SH00001996&tmsId=EP000019960180&aid=lat prog_out = add_xml_child( parent=out, tag="programme", start=strf_time_str(event["startTime"]), stop=strf_time_str(event["endTime"]), channel=channel_key, ) prog_in = event["program"] if prog_in["title"]: _ = add_xml_child( parent=prog_out, tag="title", # lang="en", text=prog_in["title"], ) year = toint(prog_in["releaseYear"]) if prog_in["episodeTitle"]: _ = add_xml_child( parent=prog_out, tag="sub-title", # lang="en", text=prog_in["episodeTitle"], ) elif "filter-movie" in event["filter"]: if prog_in["title"] == "Movie": text = "TBD" elif year: text = f"Movie ({year})" else: text = "Movie" _ = add_xml_child( parent=prog_out, tag="sub-title", # lang="en", text=text, ) if prog_in["shortDesc"]: _ = add_xml_child( parent=prog_out, tag="desc", # lang="en", text=prog_in["shortDesc"], ) if prog_in["season"] and prog_in["episode"]: # Format: # season_num/season_total.episode_num/episode_total.part_num/part_total # where "num" is zero indexed and "/total" is optional # and "num/total" is also optional. _ = add_xml_child( parent=prog_out, tag="episode-num", system="xmltv_ns", text=f"{int(prog_in['season']) - 1}.{int(prog_in['episode']) - 1}.", ) if event["rating"]: r = add_xml_child( parent=prog_out, tag="rating", system="VCHIP", ) _ = add_xml_child( parent=r, tag="value", text=event["rating"], ) _ = add_xml_child( parent=prog_out, tag="length", units="minutes", text=event["duration"], ) if year: _ = add_xml_child( parent=prog_out, tag="date", text=str(year), ) if event["thumbnail"]: # Not part of xmltv spec but we're including it anyway. _ = add_xml_child( parent=prog_out, tag="icon", src=f"https://zap2it.tmsimg.com/assets/{event['thumbnail']}.jpg", ) for f in event["filter"]: if f not in { "filter-family", "filter-movie", "filter-news", "filter-sports", "filter-talk", }: print(f"Novel filter '{f}'.") if not f.startswith("filter-"): continue _ = add_xml_child( parent=prog_out, tag="category", # Was: "genre" # lang="en", text=f[len("filter-") :].title(), ) if "Dolby Digital" in event["tags"]: audio = "dolby digital" elif "Dolby" in event["tags"]: audio = "dolby" elif "Surround" in event["tags"]: audio = "surround" elif "Stereo" in event["tags"]: audio = "stereo" elif "Mono" in event["tags"]: audio = "mono" else: audio = "stereo" r = add_xml_child( parent=prog_out, tag="audio", ) _ = add_xml_child( parent=r, tag="present", text="yes", ) _ = add_xml_child( parent=r, tag="stereo", text=audio, ) if "DVS" in event["tags"]: _ = add_xml_child( parent=r, tag="stereo", text="bilingual", ) # if False: # a = strf_time_str( # event["startTime"], # format_str="%Y-%b-%d %_I:%M%P", # ) # t = prog_in["title"] # e = prog_in["episodeTitle"] if prog_in["episodeTitle"] else "" # c = channel_map[channel_key] # print(f"### {a:30s} {t:40s} {e:50s} {c:20s}") if "CC" in event["tags"]: r = add_xml_child( parent=prog_out, tag="subtitles", type="teletext", ) _ = add_xml_child( parent=r, tag="language", text="English", ) if "New" in event["flag"]: # and "Live" not in event["flag"]: _ = add_xml_child( parent=prog_out, tag="new", ) def add_programme_tvimate( out: xml.Element, event: Mapping[str, Any], channel_key: str, ) -> None: prog_out = add_xml_child( parent=out, tag="programme", start=strf_time_str(event["startTime"]), stop=strf_time_str(event["endTime"]), channel=channel_key, ) prog_in = event["program"] title = prog_in["title"] subtitle = prog_in["episodeTitle"] year = toint(prog_in["releaseYear"]) season = toint(prog_in["season"]) episode = toint(prog_in["episode"]) description = prog_in["shortDesc"] if title and subtitle and "filter-sports" in event["filter"]: title = f"{title}: {subtitle}" subtitle = None elif not subtitle and "filter-movie" in event["filter"]: if title == "Movie": subtitle = None elif year: subtitle = f"Movie ({year})" else: subtitle = "Movie" if title: if "Live" in event["flag"]: if "filter-news" not in event["filter"]: title += " ᴸⁱᵛᵉ" elif "New" in event["flag"]: title += " ᴺᵉʷ" _ = add_xml_child( parent=prog_out, tag="title", # lang="en", text=title, ) if season and episode: season_episode = f"S{season:02d}E{episode:02d}" else: season_episode = None short = " ".join([a_ for a_ in [season_episode, subtitle] if a_]) description = "\n".join([a_ for a_ in [short, description] if a_]) if description: _ = add_xml_child( parent=prog_out, tag="desc", # lang="en", text=description, ) # if event["rating"]: # r = add_xml_child( # parent=prog_out, # tag="rating", # system="VCHIP", # ) # _ = add_xml_child( # parent=r, # tag="value", # text=event["rating"], # ) for f in event["filter"]: if f not in { "filter-family", "filter-movie", "filter-news", "filter-sports", "filter-talk", }: print(f"Novel filter '{f}'.") if not f.startswith("filter-"): continue _ = add_xml_child( parent=prog_out, tag="category", # Was: "genre" # lang="en", text=f[len("filter-") :].title(), ) def get_channel_key(c: Mapping[str, Any]) -> str: # old way: # return f"I{c['channelNo']}.{c['channelId']}.zap2it.com" return c["callSign"] def parse_callsign(coded_callsign: str) -> str: result = _CALLSIGN_REGEX.search(coded_callsign.upper()) assert result call, suffix, num = result.groups() assert suffix assert num != "1" if call == "KQS" and suffix == "LD": # Appears to be a bug in their coded callsign. call = "KQSL" suffix = "LD" if not num: num = "1" return f"{call}-{suffix}-{num}" def strf_time_str(tm: str, format_str: str = "%Y%m%d%H%M%S %z") -> str: tm = tm.replace("Z", "+00:00") return parse_time_iso(tm).strftime(format_str) def strf_time_int(timestamp: int, format_str: str = "%Y-%b-%d %_I:%M%P %z") -> str: return parse_time_int(timestamp).strftime(format_str) def parse_time_iso(tm: str) -> datetime.datetime: tm = tm.replace("Z", "+00:00") return datetime.datetime.fromisoformat(tm).astimezone() def parse_time_int(timestamp: int) -> datetime.datetime: return datetime.datetime.fromtimestamp(timestamp).astimezone() def add_xml_child( parent: xml.Element | None, tag: str, text: str | None = None, attrib: Mapping[str, str] | None = None, **extra: Any, ) -> xml.Element: attrib = {} if attrib is None else dict(attrib) if parent is None: # https://docs.python.org/3/library/xml.etree.elementtree.html#xml.etree.ElementTree.Element el = xml.Element(tag, attrib, **extra) else: # https://docs.python.org/3/library/xml.etree.elementtree.html#xml.etree.ElementTree.SubElement el = xml.SubElement(parent, tag, attrib, **extra) if text is not None: el.text = text return el def toint(x: str | None, fail: int = 0) -> int: if x is None: return fail return int(x) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Fetch TV data from zap2it.", epilog="This tool is noisy to stdout; with cron use chronic from moreutils.", ) _ = parser.add_argument( "--delay", dest="delay", type=int, default=5, help="Delay, in seconds, between server fetches.", ) _ = parser.add_argument( "--url", dest="base_url", type=str, default="tvlistings.gracenote.com", # default="tvlistings.zap2it.com", help="Source url without http prefix.", ) _ = parser.add_argument( "--days", dest="days", type=float, default=15, help="Num days to fetch.", ) _ = parser.add_argument( "--path", dest="path", type=str, default=str(pathlib.Path(__file__).parent.resolve()), help="Path to store files.", ) _ = parser.add_argument( "--aid", dest="zap_aid", type=str, # Previously we used "gapzap" but redditors seem to have found this one. # https://www.reddit.com/r/cordcutters/comments/1m1iba0/zap2it_and_gracenote_listings_are_gone_again/ default="orbebb", help="Raw zap2it input parameter. (Affiliate ID?)", ) _ = parser.add_argument( "--country", dest="zap_country", type=str, default="USA", help="Country identifying the listings to fetch.", ) _ = parser.add_argument( "--device", dest="zap_device", type=str, default="-", help="Raw zap2it input parameter. (?)", ) _ = parser.add_argument( "--headend-id", dest="zap_headendId", type=str, default="lineupId", help="Raw zap2it input parameter. (?)", ) _ = parser.add_argument( "--is-override", dest="zap_isOverride", type=bool, default=True, help="Raw zap2it input parameter. (?)", ) _ = parser.add_argument( "--language", dest="zap_languagecode", type=str, default="en", help="Raw zap2it input parameter. (Language.)", ) _ = parser.add_argument( "--pref", dest="zap_pref", type=str, default="", help="Raw zap2it input parameter. (Preferences?)", ) _ = parser.add_argument( "--timespan", dest="zap_timespan", type=int, default=3, help="Raw zap2it input parameter. (Hours of data per fetch?)", ) _ = parser.add_argument( "--timezone", dest="zap_timezone", type=str, default="", help="Raw zap2it input parameter. (Time zone?)", ) _ = parser.add_argument( "--user-id", dest="zap_userId", type=str, default="-", help="Raw zap2it input parameter. (?)", ) _ = parser.add_argument( "--zip", dest="zap_postalCode", type=str, required=True, help="The zip/postal code identifying the listings to fetch.", ) _ = parser.add_argument( "--tvimate", dest="tvimate", type=bool, default=True, action=argparse.BooleanOptionalAction, help="Guide formatted specifically for TViMate.", ) return parser.parse_args() if __name__ == "__main__": main() ================================================ FILE: util.py ================================================ """Shared utilities.""" from __future__ import annotations from typing import Any import urllib.error import urllib.parse import urllib.request class _SafeRedirectHandler(urllib.request.HTTPRedirectHandler): """Redirect handler that only allows http/https schemes.""" def redirect_request( self, req: urllib.request.Request, fp: Any, code: int, msg: str, headers: Any, newurl: str, ) -> urllib.request.Request | None: parsed = urllib.parse.urlparse(newurl) if parsed.scheme not in ("http", "https"): raise urllib.error.URLError(f"Unsafe redirect scheme: {parsed.scheme}") return super().redirect_request(req, fp, code, msg, headers, newurl) _DEFAULT_USER_AGENT = "VLC/3.0.20 LibVLC/3.0.20" def safe_urlopen(url: str, timeout: int = 30, user_agent: str | None = None) -> Any: """Open URL with safe redirect handling. Args: url: URL to open timeout: Request timeout in seconds user_agent: User-Agent header to send. If None, uses a default VLC User-Agent to avoid being blocked by providers that reject Python's default. """ parsed = urllib.parse.urlparse(url) if parsed.scheme not in ("http", "https"): raise urllib.error.URLError(f"Unsafe URL scheme: {parsed.scheme}") ua = user_agent if user_agent else _DEFAULT_USER_AGENT req = urllib.request.Request(url, headers={"User-Agent": ua}) opener = urllib.request.build_opener(_SafeRedirectHandler()) return opener.open(req, timeout=timeout) ================================================ FILE: util_test.py ================================================ """Tests for util.py.""" from __future__ import annotations from typing import Any import urllib.error import pytest import util def _fake_request(url: str) -> Any: """Create a minimal request object for testing.""" class _Req: full_url = url headers: dict[str, str] = {} data = None origin_req_host = "original.com" def get_method(self) -> str: return "GET" return _Req() class TestSafeRedirectHandler: def test_handler_allows_http(self): handler = util._SafeRedirectHandler() req = _fake_request("http://original.com") result = handler.redirect_request( req, fp=None, code=302, msg="Found", headers={}, newurl="http://redirect.com/path", ) assert result is not None def test_handler_allows_https(self): handler = util._SafeRedirectHandler() req = _fake_request("https://original.com") result = handler.redirect_request( req, fp=None, code=302, msg="Found", headers={}, newurl="https://secure.com/path", ) assert result is not None def test_handler_rejects_file_scheme(self): handler = util._SafeRedirectHandler() req = _fake_request("http://original.com") with pytest.raises(urllib.error.URLError, match="Unsafe redirect scheme"): handler.redirect_request( req, fp=None, code=302, msg="Found", headers={}, newurl="file:///etc/passwd", ) def test_handler_rejects_data_scheme(self): handler = util._SafeRedirectHandler() req = _fake_request("http://original.com") with pytest.raises(urllib.error.URLError, match="Unsafe redirect scheme"): handler.redirect_request( req, fp=None, code=302, msg="Found", headers={}, newurl="data:text/html,", ) def test_handler_rejects_javascript_scheme(self): handler = util._SafeRedirectHandler() req = _fake_request("http://original.com") with pytest.raises(urllib.error.URLError, match="Unsafe redirect scheme"): handler.redirect_request( req, fp=None, code=302, msg="Found", headers={}, newurl="javascript:alert(1)", ) if __name__ == "__main__": from testing import run_tests run_tests(__file__) ================================================ FILE: xtream.py ================================================ """Xtream Codes API client.""" from __future__ import annotations from dataclasses import dataclass from typing import Any import json import urllib.parse from util import safe_urlopen @dataclass(slots=True) class XtreamClient: """Client for Xtream Codes API. Handles authentication and API calls to Xtream-compatible IPTV providers. """ base_url: str username: str password: str def __post_init__(self) -> None: # Normalize URL: strip trailing slashes self.base_url = self.base_url.rstrip("/") @property def _base_params(self) -> dict[str, str]: return {"username": self.username, "password": self.password} @property def api_url(self) -> str: params = urllib.parse.urlencode(self._base_params) return f"{self.base_url}/player_api.php?{params}" def _fetch(self, url: str, timeout: int = 30) -> str: with safe_urlopen(url, timeout=timeout) as resp: return resp.read().decode("utf-8") def _api(self, action: str | None = None, timeout: int = 30, **params: Any) -> Any: query = dict(self._base_params) if action: query["action"] = action query.update(params) url = f"{self.base_url}/player_api.php?{urllib.parse.urlencode(query)}" return json.loads(self._fetch(url, timeout=timeout)) def get_server_info(self, timeout: int = 15) -> dict[str, Any]: """Returns user_info and server_info; check user_info['auth'] == 1.""" return self._api(timeout=timeout) def get_live_categories(self) -> list[dict[str, Any]]: return self._api("get_live_categories") def get_live_streams(self, category_id: int | None = None) -> list[dict[str, Any]]: if category_id: return self._api("get_live_streams", category_id=category_id) return self._api("get_live_streams") def get_vod_categories(self) -> list[dict[str, Any]]: return self._api("get_vod_categories") def get_vod_streams(self, category_id: int | None = None) -> list[dict[str, Any]]: if category_id: return self._api("get_vod_streams", category_id=category_id) return self._api("get_vod_streams") def get_series_categories(self) -> list[dict[str, Any]]: return self._api("get_series_categories") def get_series(self, category_id: int | None = None) -> list[dict[str, Any]]: if category_id: return self._api("get_series", category_id=category_id) return self._api("get_series") def get_series_info(self, series_id: int) -> dict[str, Any]: return self._api("get_series_info", series_id=series_id) def get_vod_info(self, vod_id: int) -> dict[str, Any]: return self._api("get_vod_info", vod_id=vod_id) def get_short_epg(self, stream_id: int, limit: int = 10) -> dict[str, Any]: """Returns epg_listings for stream; some providers ignore limit.""" return self._api("get_short_epg", stream_id=stream_id, limit=limit) def build_stream_url(self, stream_type: str, stream_id: int, ext: str = "") -> str: # URL-encode username/password to handle special chars like # in passwords user = urllib.parse.quote(self.username, safe="") pwd = urllib.parse.quote(self.password, safe="") base = f"{self.base_url}/{stream_type}/{user}/{pwd}/{stream_id}" return f"{base}.{ext}" if ext else base def build_timeshift_url( self, stream_id: int, duration: int, start: str, ext: str = "ts", ) -> str: """For streams with tv_archive=1. start format: YYYY-MM-DD:HH-MM.""" # URL-encode username/password to handle special chars like # in passwords user = urllib.parse.quote(self.username, safe="") pwd = urllib.parse.quote(self.password, safe="") return f"{self.base_url}/timeshift/{user}/{pwd}/{duration}/{start}/{stream_id}.{ext}" @property def epg_url(self) -> str: params = urllib.parse.urlencode(self._base_params) return f"{self.base_url}/xmltv.php?{params}" ================================================ FILE: xtream_test.py ================================================ """Tests for xtream.py - Xtream Codes API client.""" from __future__ import annotations from unittest.mock import MagicMock, patch import json import pytest from xtream import XtreamClient class TestXtreamClient: """Tests for XtreamClient.""" def test_api_url_property(self): client = XtreamClient("http://example.com", "user", "pass") assert client.api_url == "http://example.com/player_api.php?username=user&password=pass" def test_epg_url_property(self): client = XtreamClient("http://example.com", "user", "pass") assert client.epg_url == "http://example.com/xmltv.php?username=user&password=pass" def test_url_normalization_strips_trailing_slash(self): client = XtreamClient("http://example.com/", "user", "pass") assert client.base_url == "http://example.com" # No double slashes after the domain assert "example.com/player_api" in client.api_url def test_url_normalization_strips_multiple_trailing_slashes(self): client = XtreamClient("http://example.com///", "user", "pass") assert client.base_url == "http://example.com" def test_special_chars_in_credentials_are_encoded(self): client = XtreamClient("http://example.com", "user@test", "p&ss=word") # Check that special chars are URL-encoded in api_url assert "user%40test" in client.api_url assert "p%26ss%3Dword" in client.api_url # Same for epg_url assert "user%40test" in client.epg_url assert "p%26ss%3Dword" in client.epg_url def test_build_stream_url_live_no_ext(self): client = XtreamClient("http://example.com", "user", "pass") url = client.build_stream_url("live", 123) assert url == "http://example.com/live/user/pass/123" def test_build_stream_url_live_with_ext(self): client = XtreamClient("http://example.com", "user", "pass") url = client.build_stream_url("live", 123, "m3u8") assert url == "http://example.com/live/user/pass/123.m3u8" def test_build_stream_url_movie(self): client = XtreamClient("http://example.com", "user", "pass") url = client.build_stream_url("movie", 456, "mkv") assert url == "http://example.com/movie/user/pass/456.mkv" def test_build_stream_url_series(self): client = XtreamClient("http://example.com", "user", "pass") url = client.build_stream_url("series", 789, "mp4") assert url == "http://example.com/series/user/pass/789.mp4" def test_build_timeshift_url(self): client = XtreamClient("http://example.com", "user", "pass") url = client.build_timeshift_url(123, 60, "2024-01-15:14-30") assert url == "http://example.com/timeshift/user/pass/60/2024-01-15:14-30/123.ts" def test_build_timeshift_url_custom_ext(self): client = XtreamClient("http://example.com", "user", "pass") url = client.build_timeshift_url(123, 30, "2024-01-15:10-00", ext="m3u8") assert url == "http://example.com/timeshift/user/pass/30/2024-01-15:10-00/123.m3u8" class TestXtreamClientApi: """Tests for XtreamClient API methods with mocked network.""" @pytest.fixture def client(self): return XtreamClient("http://example.com", "user", "pass") @pytest.fixture def mock_urlopen(self): with patch("xtream.safe_urlopen") as mock: yield mock def _setup_response(self, mock_urlopen, data): """Helper to setup mock response.""" mock_resp = MagicMock() mock_resp.read.return_value = json.dumps(data).encode("utf-8") mock_resp.__enter__ = MagicMock(return_value=mock_resp) mock_resp.__exit__ = MagicMock(return_value=False) mock_urlopen.return_value = mock_resp def test_get_live_categories(self, client, mock_urlopen): categories = [{"category_id": "1", "category_name": "News"}] self._setup_response(mock_urlopen, categories) result = client.get_live_categories() assert result == categories mock_urlopen.assert_called_once() url = mock_urlopen.call_args[0][0] assert "action=get_live_categories" in url def test_get_live_streams(self, client, mock_urlopen): streams = [{"stream_id": 1, "name": "CNN"}] self._setup_response(mock_urlopen, streams) result = client.get_live_streams() assert result == streams url = mock_urlopen.call_args[0][0] assert "action=get_live_streams" in url def test_get_live_streams_with_category(self, client, mock_urlopen): streams = [{"stream_id": 1, "name": "CNN"}] self._setup_response(mock_urlopen, streams) result = client.get_live_streams(category_id=5) assert result == streams url = mock_urlopen.call_args[0][0] assert "action=get_live_streams" in url assert "category_id=5" in url def test_get_vod_categories(self, client, mock_urlopen): categories = [{"category_id": "10", "category_name": "Movies"}] self._setup_response(mock_urlopen, categories) result = client.get_vod_categories() assert result == categories url = mock_urlopen.call_args[0][0] assert "action=get_vod_categories" in url def test_get_vod_streams(self, client, mock_urlopen): streams = [{"stream_id": 100, "name": "Movie 1"}] self._setup_response(mock_urlopen, streams) result = client.get_vod_streams() assert result == streams url = mock_urlopen.call_args[0][0] assert "action=get_vod_streams" in url def test_get_vod_streams_with_category(self, client, mock_urlopen): streams = [{"stream_id": 100, "name": "Movie 1"}] self._setup_response(mock_urlopen, streams) result = client.get_vod_streams(category_id=10) assert result == streams url = mock_urlopen.call_args[0][0] assert "category_id=10" in url def test_get_series_categories(self, client, mock_urlopen): categories = [{"category_id": "20", "category_name": "Drama"}] self._setup_response(mock_urlopen, categories) result = client.get_series_categories() assert result == categories url = mock_urlopen.call_args[0][0] assert "action=get_series_categories" in url def test_get_series(self, client, mock_urlopen): series = [{"series_id": 200, "name": "Show 1"}] self._setup_response(mock_urlopen, series) result = client.get_series() assert result == series url = mock_urlopen.call_args[0][0] assert "action=get_series" in url def test_get_series_with_category(self, client, mock_urlopen): series = [{"series_id": 200, "name": "Show 1"}] self._setup_response(mock_urlopen, series) result = client.get_series(category_id=20) assert result == series url = mock_urlopen.call_args[0][0] assert "category_id=20" in url def test_get_series_info(self, client, mock_urlopen): info = {"info": {"name": "Show 1"}, "episodes": {"1": []}} self._setup_response(mock_urlopen, info) result = client.get_series_info(series_id=200) assert result == info url = mock_urlopen.call_args[0][0] assert "action=get_series_info" in url assert "series_id=200" in url def test_get_vod_info(self, client, mock_urlopen): info = {"info": {"name": "Movie 1", "plot": "A story"}} self._setup_response(mock_urlopen, info) result = client.get_vod_info(vod_id=100) assert result == info url = mock_urlopen.call_args[0][0] assert "action=get_vod_info" in url assert "vod_id=100" in url def test_get_server_info(self, client, mock_urlopen): server_info = { "user_info": { "auth": 1, "username": "user", "status": "Active", "exp_date": "1735689600", "max_connections": "2", }, "server_info": { "url": "example.com", "port": "80", "https_port": "443", "server_protocol": "http", }, } self._setup_response(mock_urlopen, server_info) result = client.get_server_info() assert result == server_info assert result["user_info"]["auth"] == 1 url = mock_urlopen.call_args[0][0] # get_server_info calls API with no action assert "action" not in url def test_get_server_info_auth_failed(self, client, mock_urlopen): server_info = {"user_info": {"auth": 0}} self._setup_response(mock_urlopen, server_info) result = client.get_server_info() assert result["user_info"]["auth"] == 0 def test_get_server_info_uses_shorter_timeout(self, client, mock_urlopen): self._setup_response(mock_urlopen, {"user_info": {"auth": 1}}) client.get_server_info() # get_server_info uses 15s timeout by default (vs 30s for other calls) _, kwargs = mock_urlopen.call_args assert kwargs["timeout"] == 15 def test_custom_timeout_passed_through(self, client, mock_urlopen): self._setup_response(mock_urlopen, []) client.get_live_categories() # Default timeout is 30s _, kwargs = mock_urlopen.call_args assert kwargs["timeout"] == 30 def test_api_encodes_special_chars_in_params(self, client, mock_urlopen): self._setup_response(mock_urlopen, {}) client._api("test_action", foo="bar&baz", key="val=ue") url = mock_urlopen.call_args[0][0] assert "foo=bar%26baz" in url assert "key=val%3Due" in url def test_get_short_epg(self, client, mock_urlopen): epg_data = { "epg_listings": [ {"title": "Show 1", "start": "2024-01-15 14:00:00"}, {"title": "Show 2", "start": "2024-01-15 15:00:00"}, ] } self._setup_response(mock_urlopen, epg_data) result = client.get_short_epg(stream_id=123) assert result == epg_data url = mock_urlopen.call_args[0][0] assert "action=get_short_epg" in url assert "stream_id=123" in url assert "limit=10" in url # default limit def test_get_short_epg_custom_limit(self, client, mock_urlopen): self._setup_response(mock_urlopen, {"epg_listings": []}) client.get_short_epg(stream_id=456, limit=5) url = mock_urlopen.call_args[0][0] assert "stream_id=456" in url assert "limit=5" in url if __name__ == "__main__": from testing import run_tests run_tests(__file__)