Combining LLMs with Random number generators for data generation#

In this notebook, we will use functions to generate random age and income data, before we provide this information to an LLM to produce corresponding customer data (number of people in household, shopping list).

Note: This notebook was AI-generated (and human curated) using an LLM as shown here.

import json
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import openai
import numpy as np
from tqdm import tqdm
# Reusing LLM utility functions from earlier exercise
import sys
sys.path.append("../20_chatbots/")
from llm_utilities import prompt_scadsai_llm, prompt_ollama, prompt_blablador, prompt_kisski
prompt = prompt_scadsai_llm

Define random data generation functions#

We’ll create functions to generate realistic age and income distributions.

def generate_random_age():
    """Generate a random age between 18 and 80 with a normal distribution centered at 40"""
    age = int(np.clip(np.random.normal(40, 12), 18, 80))
    return age

def generate_random_income():
    """Generate a random yearly income between 20000 and 150000 with a log-normal distribution"""
    income = np.random.lognormal(mean=11, sigma=0.5)
    return np.clip(income, 20000, 150000)

def generate_random_gender():
    """Generate a random gender from a list of possibilities."""
    import numpy.random
    return numpy.random.choice(["male", "female", "non-binary"], p=[0.45, 0.45, 0.1])

# Test the functions
print(f"Sample age: {generate_random_age()}")
print(f"Sample income: ${generate_random_income():.2f}")
print(f"Sample gender: {generate_random_gender()}")
Sample age: 34
Sample income: $59100.99
Sample gender: male

Define the synthetic data generation function#

We’ll create a function that uses the LLM to generate customer data in JSON format, using our random age and income values.

def generate_customer(age: int, income: float, gender:str) -> str:
    prompt_text = f"""Generate one realistic customer profile in valid JSON format with the following structure:
    {{
        'name': str,
        'gender': {gender},
        'age': {age},
        'income': {income:.2f},
        'household_size': int,
        'grocery_list': [
            {{'item': str, 'price': float}},
            ...
        ]
    }}
    Include 5-10 grocery items with realistic prices. Do not modify gender, age and income specified above.
    
    Respond with the JSON data only and no markdown fences."""
    return prompt(prompt_text)

# Test the function
sample_data = generate_customer(generate_random_age(), generate_random_income(), generate_random_gender())
print(sample_data)
{
    "name": "Emily Wilson",
    "gender": "female",
    "age": 33,
    "income": 46189.05,
    "household_size": 3,
    "grocery_list": [
        {"item": "Apple Juice", "price": 3.99},
        {"item": "Bread", "price": 2.49},
        {"item": "Milk", "price": 2.99},
        {"item": "Eggs", "price": 1.79},
        {"item": "Chicken Breasts", "price": 8.99},
        {"item": "Carrots", "price": 1.29},
        {"item": "Pasta", "price": 1.49},
        {"item": "Marinara Sauce", "price": 2.99},
        {"item": "Shredded Cheese", "price": 3.49},
        {"item": "Granola", "price": 4.99}
    ]
}

Collect and validate customer records#

We’ll generate 100 customer records using random age and income values, and validate the JSON format.

customer_records = []

for i in tqdm(range(100)):
    try:
        age = generate_random_age()
        income = generate_random_income()
        gender = generate_random_gender()
        data = generate_customer(age, income, gender)
        # Validate JSON
        customer_data = json.loads(data)
        customer_records.append(customer_data)
    except json.JSONDecodeError:
        print(f"Invalid JSON format in record {i+1}")

print(f"\nCollected {len(customer_records)} valid customer records")
100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [26:11<00:00, 15.71s/it]
Collected 100 valid customer records

Transform data into a DataFrame#

We’ll create the DataFrame with customer information and calculate total weekly grocery spending.

# Extract customer info
customer_info = []
for record in customer_records:
    try:
        # Calculate total spending
        total_spending = sum(item['price'] for item in record['grocery_list'])
        
        customer_info.append({
            'name': record['name'],
            'gender': record['gender'],
            'age': record['age'],
            'income': record['income'],
            'household_size': record['household_size'],
            'weekly_spending': total_spending
        })
    except:
        print("Error processing record")

df = pd.DataFrame(customer_info)
display(df)
name gender age income household_size weekly_spending
0 Karen Thompson female 45 150000.00 4 58.20
1 Jamie Reyes non-binary 28 31254.86 2 42.51
2 John Smith male 39 45706.93 4 30.20
3 Ethan Thompson male 27 150000.00 2 35.20
4 Ethan Thompson male 28 66737.59 2 48.80
... ... ... ... ... ... ...
95 David Thompson male 30 81494.39 3 29.01
96 Emily Wilson female 47 34766.22 3 30.20
97 John Doe male 46 32911.36 4 27.41
98 Martha Johnson female 66 51588.06 2 39.81
99 Jamie Reed non-binary 38 75559.86 3 38.31

100 rows × 6 columns

Visualize relationships in the data#

Let’s create several plots to analyze relationships between variables:

# Create a figure with three subplots
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))

# Age vs Income
sns.scatterplot(data=df, x='age', y='income', ax=ax1)
ax1.set_title('Age vs Income')

# Income vs Weekly Shopping
sns.scatterplot(data=df, x='income', y='weekly_spending', ax=ax2)
ax2.set_title('Income vs Weekly Spending')

# Household Size vs Weekly Shopping
sns.scatterplot(data=df, x='household_size', y='weekly_spending', ax=ax3)
ax3.set_title('Household Size vs Weekly Spending')

plt.tight_layout()
plt.show()
../_images/16efe826f810dd614e9f804fed9a0e5f18e74d862e9ea6729697ea30bf712069.png

Additional Analysis: Shopping patterns by gender#

# Box plot of weekly spending by gender
plt.figure(figsize=(8, 6))
sns.boxplot(data=df, x='gender', y='weekly_spending')
plt.title('Weekly Spending Distribution by Gender')
plt.show()
../_images/8411269bfbfb62a50c84cd85afca07917452b78c31fb90d5c1377e1854755704.png

Exercise#

Modify the prompt to ensure that the average money spent per week is higher, the more people live in the household. Also assuming that people have kids in certain age ranges, the number of people per houshold versus age could show a certain pattern.

Exercise#

Determine how often income is modified by the LLM even as we asked it to not modify that number.