Optimizing image generation prompting using CLIP scores#

In this notebook we will compare different prompts and measure a text-to-image similarity metric: The CLIP score.

import torch
import numpy as np
from skimage.io import imread
import stackview
from diffusers import DiffusionPipeline, AutoencoderTiny
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

This is how the metric is initialized

from torchmetrics.multimodal.clip_score import CLIPScore
metric = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16")

We start with this example image.

image = imread("data/real_cat.png")
stackview.insight(image)
shape(512, 512, 3)
dtypeuint8
size768.0 kB
min0
max255
score = metric(torch.as_tensor(image), "cat")
score.detach()
tensor(25.3473)
score = metric(torch.as_tensor(image), "microscope")
float(score.detach())
30.786285400390625

Recap: Generating images#

We will now use a prompt to generate an image and will measure if it shows a cat.

pipe = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16
)
pipe = pipe.to("cuda")
prompt = "Draw a realistic photo of a cat."

cat = pipe(prompt).images[0]
cat
../_images/d7a9358f6e5ca841b2f9a60cfd66cf171622b375766560fa9324d38cbbe2f53e.png
score = metric(torch.as_tensor(np.array(cat)), "cat")
np.asarray(score.detach())
array(26.65498, dtype=float32)

Benchmarking prompts#

To compare different prompts quantitatively, we run image generation in a loop and measure their quality. As a control, we also generate images that show dogs and no cats.

Using this strategy, we can do prompt engineering / prompt optimization in an informed way.

num_attempts = 10
prompts = ["Draw a realistic photo of a cat.", 
           "Draw a cat",
           "cat", 
           "Draw a realistic photo of a dog."]

data = {"prompt":[],
        "score":[]}
for prompt in prompts:
    for i in range(num_attempts):
        image = pipe(prompt, disable_tqdm=True).images[0]

        score = metric(torch.as_tensor(np.array(image)), "cat")
        data["score"].append(float(score.detach()))
        data["prompt"].append(prompt)
data = pd.DataFrame(data)
data
prompt score
0 Draw a realistic photo of a cat. 27.656635
1 Draw a realistic photo of a cat. 26.781824
2 Draw a realistic photo of a cat. 27.483803
3 Draw a realistic photo of a cat. 26.999044
4 Draw a realistic photo of a cat. 27.809370
5 Draw a realistic photo of a cat. 27.048084
6 Draw a realistic photo of a cat. 28.171909
7 Draw a realistic photo of a cat. 25.953665
8 Draw a realistic photo of a cat. 26.899948
9 Draw a realistic photo of a cat. 30.001472
10 Draw a cat 28.417461
11 Draw a cat 27.374229
12 Draw a cat 27.655571
13 Draw a cat 27.965340
14 Draw a cat 27.172005
15 Draw a cat 27.536922
16 Draw a cat 29.083773
17 Draw a cat 29.016024
18 Draw a cat 25.260868
19 Draw a cat 25.417572
20 cat 27.310360
21 cat 25.902298
22 cat 26.449892
23 cat 27.699142
24 cat 19.328476
25 cat 27.631119
26 cat 21.009279
27 cat 22.222973
28 cat 17.562325
29 cat 27.389830
30 Draw a realistic photo of a dog. 20.009043
31 Draw a realistic photo of a dog. 21.995111
32 Draw a realistic photo of a dog. 19.825274
33 Draw a realistic photo of a dog. 21.263514
34 Draw a realistic photo of a dog. 21.207270
35 Draw a realistic photo of a dog. 19.247950
36 Draw a realistic photo of a dog. 21.005089
37 Draw a realistic photo of a dog. 22.206501
38 Draw a realistic photo of a dog. 20.775682
39 Draw a realistic photo of a dog. 18.967247
# Plotting the boxplot
plt.figure(figsize=(10, 6))
sns.boxplot(x='prompt', y='score', data=data)
plt.title('Boxplot of Scores by Prompt')
plt.xlabel('Prompt')
plt.ylabel('Score')
plt.show()
../_images/78e9bf8a9842a35644c01f90c3fc1767961b20809891e90685058fd62abfc3f1.png

Exercise#

Append more prompts to the list above. Attempt to improve the score for the image showing a cat.

Measure which of the prompts produces the most photorealistic images.