[
  {
    "path": ".gitignore",
    "content": ".cog/\n__pycache__/\ndiffusers-cache/"
  },
  {
    "path": "README.md",
    "content": "# Stable Diffusion Cog model\n\nThis is an implementation of the [Diffusers Stable Diffusion 1.4](https://huggingface.co/CompVis/stable-diffusion-v1-4) as a Cog model. [Cog packages machine learning models as standard containers.](https://github.com/replicate/cog)\n\nFirst, download the pre-trained weights [with your Hugging Face auth token](https://huggingface.co/settings/tokens):\n\n    cog run script/download-weights <your-hugging-face-auth-token>\n\nThen, you can run predictions:\n\n    cog predict -i prompt=\"monkey scuba diving\"\n\nOr, build a Docker image:\n\n    cog build\n\nOr, [push it to Replicate](https://replicate.com/docs/guides/push-a-model):\n\n    cog push r8.im/...\n"
  },
  {
    "path": "cog.yaml",
    "content": "build:\n  gpu: true\n  cuda: \"11.6.2\"\n  python_version: \"3.10\"\n  python_packages:\n    - \"diffusers==0.2.4\"\n    - \"torch==1.12.1 --extra-index-url=https://download.pytorch.org/whl/cu116\"\n    - \"ftfy==6.1.1\"\n    - \"scipy==1.9.0\"\n    - \"transformers==4.21.1\"\npredict: \"predict.py:Predictor\"\n"
  },
  {
    "path": "image_to_image.py",
    "content": "import inspect\nfrom typing import List, Optional, Union, Tuple\n\nimport numpy as np\nimport torch\n\nfrom PIL import Image\nfrom diffusers import (\n    AutoencoderKL,\n    DDIMScheduler,\n    DiffusionPipeline,\n    PNDMScheduler,\n    LMSDiscreteScheduler,\n    UNet2DConditionModel,\n)\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer\n\n\ndef preprocess_init_image(image: Image, width: int, height: int):\n    image = image.resize((width, height), resample=Image.LANCZOS)\n    image = np.array(image).astype(np.float32) / 255.0\n    image = image[None].transpose(0, 3, 1, 2)\n    image = torch.from_numpy(image)\n    return 2.0 * image - 1.0\n\n\ndef preprocess_mask(mask: Image, width: int, height: int):\n    mask = mask.convert(\"L\")\n    mask = mask.resize((width // 8, height // 8), resample=Image.LANCZOS)\n    mask = np.array(mask).astype(np.float32) / 255.0\n    mask = np.tile(mask, (4, 1, 1))\n    mask = mask[None].transpose(0, 1, 2, 3)  # what does this step do?\n    mask = torch.from_numpy(mask)\n    return mask\n\n\nclass StableDiffusionImg2ImgPipeline(DiffusionPipeline):\n    \"\"\"\n    From https://github.com/huggingface/diffusers/pull/241\n    \"\"\"\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPFeatureExtractor,\n    ):\n        super().__init__()\n        scheduler = scheduler.set_format(\"pt\")\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n        )\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]],\n        init_image: Optional[torch.FloatTensor],\n        mask: Optional[torch.FloatTensor],\n        width: int,\n        height: int,\n        prompt_strength: float = 0.8,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        eta: float = 0.0,\n        generator: Optional[torch.Generator] = None,\n    ) -> Image:\n        if isinstance(prompt, str):\n            batch_size = 1\n        elif isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            raise ValueError(\n                f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\"\n            )\n\n        if prompt_strength < 0 or prompt_strength > 1:\n            raise ValueError(\n                f\"The value of prompt_strength should in [0.0, 1.0] but is {prompt_strength}\"\n            )\n\n        if mask is not None and init_image is None:\n            raise ValueError(\n                \"If mask is defined, then init_image also needs to be defined\"\n            )\n\n        if width % 8 != 0 or height % 8 != 0:\n            raise ValueError(\"Width and height must both be divisible by 8\")\n\n        # set timesteps\n        accepts_offset = \"offset\" in set(\n            inspect.signature(self.scheduler.set_timesteps).parameters.keys()\n        )\n        extra_set_kwargs = {}\n        offset = 0\n        if accepts_offset:\n            offset = 1\n            extra_set_kwargs[\"offset\"] = 1\n\n        self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)\n\n        if init_image is not None:\n            init_latents_orig, latents, init_timestep = self.latents_from_init_image(\n                init_image,\n                prompt_strength,\n                offset,\n                num_inference_steps,\n                batch_size,\n                generator,\n            )\n        else:\n            latents = torch.randn(\n                (batch_size, self.unet.in_channels, height // 8, width // 8),\n                generator=generator,\n                device=self.device,\n            )\n            init_timestep = num_inference_steps\n\n        do_classifier_free_guidance = guidance_scale > 1.0\n        text_embeddings = self.embed_text(\n            prompt, do_classifier_free_guidance, batch_size\n        )\n\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502\n        # and should be between [0, 1]\n        accepts_eta = \"eta\" in set(\n            inspect.signature(self.scheduler.step).parameters.keys()\n        )\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        mask_noise = torch.randn(latents.shape, generator=generator, device=self.device)\n\n        # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas\n        if isinstance(self.scheduler, LMSDiscreteScheduler):\n            latents = latents * self.scheduler.sigmas[0]\n\n        t_start = max(num_inference_steps - init_timestep + offset, 0)\n        for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):\n            # expand the latents if we are doing classifier free guidance\n            latent_model_input = (\n                torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n            )\n\n            if isinstance(self.scheduler, LMSDiscreteScheduler):\n                sigma = self.scheduler.sigmas[i]\n                latent_model_input = latent_model_input / ((sigma ** 2 + 1) ** 0.5)\n\n            # predict the noise residual\n            noise_pred = self.unet(\n                latent_model_input, t, encoder_hidden_states=text_embeddings\n            )[\"sample\"]\n\n            # perform guidance\n            if do_classifier_free_guidance:\n                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                noise_pred = noise_pred_uncond + guidance_scale * (\n                    noise_pred_text - noise_pred_uncond\n                )\n\n            # compute the previous noisy sample x_t -> x_t-1\n            if isinstance(self.scheduler, LMSDiscreteScheduler):\n                latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)[\n                    \"prev_sample\"\n                ]\n            else:\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)[\n                    \"prev_sample\"\n                ]\n\n            # replace the unmasked part with original latents, with added noise\n            if mask is not None:\n                timesteps = self.scheduler.timesteps[t_start + i]\n                timesteps = torch.tensor(\n                    [timesteps] * batch_size, dtype=torch.long, device=self.device\n                )\n                noisy_init_latents = self.scheduler.add_noise(init_latents_orig, mask_noise, timesteps)\n                latents = noisy_init_latents * mask + latents * (1 - mask)\n\n        # scale and decode the image latents with vae\n        latents = 1 / 0.18215 * latents\n        image = self.vae.decode(latents)\n\n        image = (image / 2 + 0.5).clamp(0, 1)\n        image = image.cpu().permute(0, 2, 3, 1).numpy()\n\n        # run safety checker\n        safety_cheker_input = self.feature_extractor(\n            self.numpy_to_pil(image), return_tensors=\"pt\"\n        ).to(self.device)\n        image, has_nsfw_concept = self.safety_checker(\n            images=image, clip_input=safety_cheker_input.pixel_values\n        )\n\n        image = self.numpy_to_pil(image)\n\n        return {\"sample\": image, \"nsfw_content_detected\": has_nsfw_concept}\n\n    def latents_from_init_image(\n        self,\n        init_image: torch.FloatTensor,\n        prompt_strength: float,\n        offset: int,\n        num_inference_steps: int,\n        batch_size: int,\n        generator: Optional[torch.Generator],\n    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, int]:\n        # encode the init image into latents and scale the latents\n        init_latents = self.vae.encode(init_image.to(self.device)).sample()\n        init_latents = 0.18215 * init_latents\n        init_latents_orig = init_latents\n\n        # prepare init_latents noise to latents\n        init_latents = torch.cat([init_latents] * batch_size)\n\n        # get the original timestep using init_timestep\n        init_timestep = int(num_inference_steps * prompt_strength) + offset\n        init_timestep = min(init_timestep, num_inference_steps)\n        timesteps = self.scheduler.timesteps[-init_timestep]\n        timesteps = torch.tensor(\n            [timesteps] * batch_size, dtype=torch.long, device=self.device\n        )\n\n        # add noise to latents using the timesteps\n        noise = torch.randn(init_latents.shape, generator=generator, device=self.device)\n        init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)\n\n        return init_latents_orig, init_latents, init_timestep\n\n    def embed_text(\n        self,\n        prompt: Union[str, List[str]],\n        do_classifier_free_guidance: bool,\n        batch_size: int,\n    ) -> torch.FloatTensor:\n        # get prompt text embeddings\n        text_input = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]\n\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance:\n            max_length = text_input.input_ids.shape[-1]\n            uncond_input = self.tokenizer(\n                [\"\"] * batch_size,\n                padding=\"max_length\",\n                max_length=max_length,\n                return_tensors=\"pt\",\n            )\n            uncond_embeddings = self.text_encoder(\n                uncond_input.input_ids.to(self.device)\n            )[0]\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])\n\n        return text_embeddings\n"
  },
  {
    "path": "predict.py",
    "content": "import os\nfrom typing import Optional, List\n\nimport torch\nfrom torch import autocast\nfrom diffusers import PNDMScheduler, LMSDiscreteScheduler\nfrom PIL import Image\nfrom cog import BasePredictor, Input, Path\n\nfrom image_to_image import (\n    StableDiffusionImg2ImgPipeline,\n    preprocess_init_image,\n    preprocess_mask,\n)\n\n\nMODEL_CACHE = \"diffusers-cache\"\n\n\nclass Predictor(BasePredictor):\n    def setup(self):\n        \"\"\"Load the model into memory to make running multiple predictions efficient\"\"\"\n        print(\"Loading pipeline...\")\n        scheduler = PNDMScheduler(\n            beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\"\n        )\n        self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(\n            \"CompVis/stable-diffusion-v1-4\",\n            scheduler=scheduler,\n            revision=\"fp16\",\n            torch_dtype=torch.float16,\n            cache_dir=MODEL_CACHE,\n            local_files_only=True,\n        ).to(\"cuda\")\n\n    @torch.inference_mode()\n    @torch.cuda.amp.autocast()\n    def predict(\n        self,\n        prompt: str = Input(description=\"Input prompt\", default=\"\"),\n        width: int = Input(\n            description=\"Width of output image\",\n            choices=[128, 256, 512, 768, 1024],\n            default=512,\n        ),\n        height: int = Input(\n            description=\"Height of output image\",\n            choices=[128, 256, 512, 768],\n            default=512,\n        ),\n        init_image: Path = Input(\n            description=\"Inital image to generate variations of. Will be resized to the specified width and height\", default=None\n        ),\n        mask: Path = Input(\n            description=\"Black and white image to use as mask for inpainting over init_image. Black pixels are inpainted and white pixels are preserved. Experimental feature, tends to work better with prompt strength of 0.5-0.7\",\n            default=None,\n        ),\n        prompt_strength: float = Input(\n            description=\"Prompt strength when using init image. 1.0 corresponds to full destruction of information in init image\",\n            default=0.8,\n        ),\n        num_outputs: int = Input(\n            description=\"Number of images to output\", choices=[1, 4], default=1\n        ),\n        num_inference_steps: int = Input(\n            description=\"Number of denoising steps\", ge=1, le=500, default=50\n        ),\n        guidance_scale: float = Input(\n            description=\"Scale for classifier-free guidance\", ge=1, le=20, default=7.5\n        ),\n        seed: int = Input(\n            description=\"Random seed. Leave blank to randomize the seed\", default=None\n        ),\n    ) -> List[Path]:\n        \"\"\"Run a single prediction on the model\"\"\"\n        if seed is None:\n            seed = int.from_bytes(os.urandom(2), \"big\")\n        print(f\"Using seed: {seed}\")\n\n        if init_image:\n            init_image = Image.open(init_image).convert(\"RGB\")\n            init_image = preprocess_init_image(init_image, width, height).to(\"cuda\")\n\n            # use PNDM with init images\n            scheduler = PNDMScheduler(\n                beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\"\n            )\n        else:\n            # use LMS without init images\n            scheduler = LMSDiscreteScheduler(\n                beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\"\n            )\n\n        self.pipe.scheduler = scheduler\n\n        if mask:\n            mask = Image.open(mask).convert(\"RGB\")\n            mask = preprocess_mask(mask, width, height).to(\"cuda\")\n\n        generator = torch.Generator(\"cuda\").manual_seed(seed)\n        output = self.pipe(\n            prompt=[prompt] * num_outputs if prompt is not None else None,\n            init_image=init_image,\n            mask=mask,\n            width=width,\n            height=height,\n            prompt_strength=prompt_strength,\n            guidance_scale=guidance_scale,\n            generator=generator,\n            num_inference_steps=num_inference_steps,\n        )\n        if any(output[\"nsfw_content_detected\"]):\n            raise Exception(\"NSFW content detected, please try a different prompt\")\n\n        output_paths = []\n        for i, sample in enumerate(output[\"sample\"]):\n            output_path = f\"/tmp/out-{i}.png\"\n            sample.save(output_path)\n            output_paths.append(Path(output_path))\n\n        return output_paths\n"
  },
  {
    "path": "script/download-weights",
    "content": "#!/usr/bin/env python\n\nimport os\nimport sys\n\nimport torch\nfrom diffusers import StableDiffusionPipeline\n\nos.makedirs(\"diffusers-cache\", exist_ok=True)\n\n\npipe = StableDiffusionPipeline.from_pretrained(\n    \"CompVis/stable-diffusion-v1-4\",\n    cache_dir=\"diffusers-cache\",\n    revision=\"fp16\",\n    torch_dtype=torch.float16,\n    use_auth_token=sys.argv[1],\n)\n"
  }
]