Optimizing image generation prompting using CLIP scores#
In this notebook we will compare different prompts and measure a text-to-image similarity metric: The CLIP score.
import torch
import numpy as np
from skimage.io import imread
import stackview
from diffusers import DiffusionPipeline, AutoencoderTiny
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
This is how the metric is initialized
from torchmetrics.multimodal.clip_score import CLIPScore
metric = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16")
We start with this example image.
image = imread("data/real_cat.png")
stackview.insight(image)
|
score = metric(torch.as_tensor(image), "cat")
score.detach()
tensor(25.3473)
score = metric(torch.as_tensor(image), "microscope")
float(score.detach())
30.786285400390625
Recap: Generating images#
We will now use a prompt to generate an image and will measure if it shows a cat.
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16
)
pipe = pipe.to("cuda")
prompt = "Draw a realistic photo of a cat."
cat = pipe(prompt).images[0]
cat
score = metric(torch.as_tensor(np.array(cat)), "cat")
np.asarray(score.detach())
array(26.65498, dtype=float32)
Benchmarking prompts#
To compare different prompts quantitatively, we run image generation in a loop and measure their quality. As a control, we also generate images that show dogs and no cats.
Using this strategy, we can do prompt engineering / prompt optimization in an informed way.
num_attempts = 10
prompts = ["Draw a realistic photo of a cat.",
"Draw a cat",
"cat",
"Draw a realistic photo of a dog."]
data = {"prompt":[],
"score":[]}
for prompt in prompts:
for i in range(num_attempts):
image = pipe(prompt, disable_tqdm=True).images[0]
score = metric(torch.as_tensor(np.array(image)), "cat")
data["score"].append(float(score.detach()))
data["prompt"].append(prompt)
data = pd.DataFrame(data)
data
prompt | score | |
---|---|---|
0 | Draw a realistic photo of a cat. | 27.656635 |
1 | Draw a realistic photo of a cat. | 26.781824 |
2 | Draw a realistic photo of a cat. | 27.483803 |
3 | Draw a realistic photo of a cat. | 26.999044 |
4 | Draw a realistic photo of a cat. | 27.809370 |
5 | Draw a realistic photo of a cat. | 27.048084 |
6 | Draw a realistic photo of a cat. | 28.171909 |
7 | Draw a realistic photo of a cat. | 25.953665 |
8 | Draw a realistic photo of a cat. | 26.899948 |
9 | Draw a realistic photo of a cat. | 30.001472 |
10 | Draw a cat | 28.417461 |
11 | Draw a cat | 27.374229 |
12 | Draw a cat | 27.655571 |
13 | Draw a cat | 27.965340 |
14 | Draw a cat | 27.172005 |
15 | Draw a cat | 27.536922 |
16 | Draw a cat | 29.083773 |
17 | Draw a cat | 29.016024 |
18 | Draw a cat | 25.260868 |
19 | Draw a cat | 25.417572 |
20 | cat | 27.310360 |
21 | cat | 25.902298 |
22 | cat | 26.449892 |
23 | cat | 27.699142 |
24 | cat | 19.328476 |
25 | cat | 27.631119 |
26 | cat | 21.009279 |
27 | cat | 22.222973 |
28 | cat | 17.562325 |
29 | cat | 27.389830 |
30 | Draw a realistic photo of a dog. | 20.009043 |
31 | Draw a realistic photo of a dog. | 21.995111 |
32 | Draw a realistic photo of a dog. | 19.825274 |
33 | Draw a realistic photo of a dog. | 21.263514 |
34 | Draw a realistic photo of a dog. | 21.207270 |
35 | Draw a realistic photo of a dog. | 19.247950 |
36 | Draw a realistic photo of a dog. | 21.005089 |
37 | Draw a realistic photo of a dog. | 22.206501 |
38 | Draw a realistic photo of a dog. | 20.775682 |
39 | Draw a realistic photo of a dog. | 18.967247 |
# Plotting the boxplot
plt.figure(figsize=(10, 6))
sns.boxplot(x='prompt', y='score', data=data)
plt.title('Boxplot of Scores by Prompt')
plt.xlabel('Prompt')
plt.ylabel('Score')
plt.show()
Exercise#
Append more prompts to the list above. Attempt to improve the score for the image showing a cat.
Measure which of the prompts produces the most photorealistic images.