Improve image quality with deterministic generation

Improve image quality with deterministic generation

A common way to improve the quality of generated images is with deterministic batch generation, generate a batch of images and select one image to improve with a more detailed prompt in a second round of inference. The key is to pass a list of torch.Generatorโ€™s to the pipeline for batched image generation, and tie each Generator to a seed so you can reuse it for an image.

Letโ€™s use runwayml/stable-diffusion-v1-5 for example, and generate several versions of the following prompt:

Copied

prompt = "Labrador in the style of Vermeer"

Instantiate a pipeline with DiffusionPipeline.from_pretrained() and place it on a GPU (if available):

Copied

>>> from diffusers import DiffusionPipeline

>>> pipe = DiffusionPipeline.from_pretrained(
...     "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
... )
>>> pipe = pipe.to("cuda")

Now, define four different Generatorโ€™s and assign each Generator a seed (0 to 3) so you can reuse a Generator later for a specific image:

Copied

>>> import torch

>>> generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(4)]

Generate the images and have a look:

Copied

>>> images = pipe(prompt, generator=generator, num_images_per_prompt=4).images
>>> images
img

In this example, youโ€™ll improve upon the first image - but in reality, you can use any image you want (even the image with double sets of eyes!). The first image used the Generator with seed 0, so youโ€™ll reuse that Generator for the second round of inference. To improve the quality of the image, add some additional text to the prompt:

Copied

prompt = [prompt + t for t in [", highly realistic", ", artsy", ", trending", ", colorful"]]
generator = [torch.Generator(device="cuda").manual_seed(0) for i in range(4)]

Create four generators with seed 0, and generate another batch of images, all of which should look like the first image from the previous round!

Copied

>>> images = pipe(prompt, generator=generator).images
>>> images
img

Last updated