Optimizing image generation prompting using CLIP scores#

In this notebook we will compare different prompts and measure a text-to-image similarity metric: The CLIP score.

import torch
import numpy as np
from skimage.io import imread
import stackview
from diffusers import DiffusionPipeline, AutoencoderTiny
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

This is how the metric is initialized

from torchmetrics.multimodal.clip_score import CLIPScore
metric = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16")

We start with this example image.

image = imread("data/real_cat.png")
stackview.insight(image)
shape(512, 512, 3)
dtypeuint8
size768.0 kB
min0
max255
score = metric(torch.as_tensor(image), "cat")
score.detach()
tensor(25.3473)
score = metric(torch.as_tensor(image), "microscope")
float(score.detach())
30.786287307739258

Recap: Generating images#

We will now use a prompt to generate an image and will measure if it shows a cat.

pipe = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16
)
pipe = pipe.to("cuda")
prompt = "Draw a realistic photo of a cat."

cat = pipe(prompt).images[0]
cat
score = metric(torch.as_tensor(np.array(cat)), "cat")
np.asarray(score.detach())

Benchmarking prompts#

To compare different prompts quantitatively, we run image generation in a loop and measure their quality. As a control, we also generate images that show dogs and no cats.

Using this strategy, we can do prompt engineering / prompt optimization in an informed way.

num_attempts = 2
prompts = ["Draw a realistic photo of a cat.", 
           "Draw a cat",
           "cat", 
           "Draw a realistic photo of a dog."]

data = {"prompt":[],
        "score":[]}
for prompt in prompts:
    for i in range(num_attempts):
        image = pipe(prompt).images[0]

        score = metric(torch.as_tensor(np.array(image)), "cat")
        data["score"].append(float(score.detach()))
        data["prompt"].append(prompt)

        print(f"{prompt}: {score}")
data = pd.DataFrame(data)
data
# Plotting the boxplot
plt.figure(figsize=(10, 6))
sns.boxplot(x='prompt', y='score', data=data)
plt.title('Boxplot of Scores by Prompt')
plt.xlabel('Prompt')
plt.ylabel('Score')
plt.show()

Exercise#

Append more prompts to the list above. Attempt to improve the score for the image showing a cat.

Measure which of the prompts produces the most photorealistic images.