# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gradio app to show the results"""
import functools
import json
import logging
import os
import re
import tempfile
from io import BytesIO
from typing import Any
import gradio as gr
import plotly.express as px
import plotly.graph_objects as go
from datasets import load_dataset
from huggingface_hub import HfFileSystem
from PIL import Image
from processing import (
filter_data,
get_model_ids,
get_metric_preferences,
get_task_columns,
_get_metric_explanation,
_TASK_PARETO_DEFAULTS,
compute_pareto_frontier,
format_df,
load_task_results,
)
from sanitizer import parse_and_filter
logger = logging.getLogger(__name__)
def generate_pareto_plot(df, metric_x, metric_y, metric_preferences):
if df.empty:
return {}
# Compute Pareto frontier and non-frontier points.
pareto_df = compute_pareto_frontier(df, metric_x, metric_y, metric_preferences)
non_pareto_df = df.drop(pareto_df.index)
# Create an empty figure.
fig = go.Figure()
# Draw the line connecting Pareto frontier points.
if not pareto_df.empty:
# Sort the Pareto frontier points by metric_x for a meaningful connection.
pareto_sorted = pareto_df.sort_values(by=metric_x)
line_trace = go.Scatter(
x=pareto_sorted[metric_x],
y=pareto_sorted[metric_y],
mode="lines",
line={"color": "rgba(0,0,255,0.3)", "width": 4},
name="Pareto Frontier",
)
fig.add_trace(line_trace)
# Add non-frontier points in gray with semi-transparency.
if not non_pareto_df.empty:
non_frontier_trace = go.Scatter(
x=non_pareto_df[metric_x],
y=non_pareto_df[metric_y],
mode="markers",
marker={"color": "rgba(128,128,128,0.5)", "size": 12},
hoverinfo="text",
text=non_pareto_df.apply(
lambda row: (
f"experiment_name: {row['experiment_name']}
"
f"peft_type: {row['peft_type']}
"
f"{metric_x}: {row[metric_x]}
"
f"{metric_y}: {row[metric_y]}"
),
axis=1,
),
showlegend=False,
)
fig.add_trace(non_frontier_trace)
# Add Pareto frontier points with legend
if not pareto_df.empty:
pareto_scatter = px.scatter(
pareto_df,
x=metric_x,
y=metric_y,
color="experiment_name",
hover_data={"experiment_name": True, "peft_type": True, metric_x: True, metric_y: True},
)
for trace in pareto_scatter.data:
trace.marker = {"size": 12}
fig.add_trace(trace)
# Update layout with axes labels.
fig.update_layout(
title=f"Pareto Frontier for {metric_x} vs {metric_y}",
template="seaborn",
height=700,
autosize=True,
xaxis_title=metric_x,
yaxis_title=metric_y,
)
return fig
def compute_pareto_summary(filtered, pareto_df, metric_x, metric_y):
if filtered.empty:
return "No data available."
stats = filtered[[metric_x, metric_y]].agg(["min", "max", "mean"]).to_string()
total_points = len(filtered)
pareto_points = len(pareto_df)
excluded_points = total_points - pareto_points
summary_text = (
f"{stats}\n\n"
f"Total points: {total_points}\n"
f"Pareto frontier points: {pareto_points}\n"
f"Excluded points: {excluded_points}"
)
return summary_text
def export_csv(df):
if df.empty:
return None
csv_data = df.to_csv(index=False)
with tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w", encoding="utf-8") as tmp:
tmp.write(csv_data)
tmp_path = tmp.name
return tmp_path
IMAGE_GEN_TASK = "image-gen"
SAMPLE_IMAGE_BUCKET = "peft-internal-testing/image-gen-benchmark"
SAMPLE_IMAGE_BUCKET_DIR = f"hf://buckets/{SAMPLE_IMAGE_BUCKET}/sample-images/results"
SAMPLE_IMAGE_BUCKET_URL = f"https://huggingface.co/buckets/{SAMPLE_IMAGE_BUCKET}"
GENERATED_VIEW = "Generated samples"
DATASET_VIEW = "Training dataset"
def _load_default_train_config_image_gen() -> dict[str, Any]:
# The default training params define the prompts and dataset used by the benchmark; load them once
# to caption generated images and to show the dataset images before an experiment is selected.
path = os.path.join(os.path.dirname(__file__), "image-gen", "default_training_params.json")
try:
with open(path) as f:
return json.load(f)
except (OSError, ValueError) as exc:
logger.warning("Could not load default training params from %r: %s", path, exc)
return {}
DEFAULT_TRAIN_CONFIG_IMAGE_GEN = _load_default_train_config_image_gen()
SAMPLE_IMAGE_PROMPTS = DEFAULT_TRAIN_CONFIG_IMAGE_GEN.get("sample_image_prompts", [])
@functools.lru_cache(maxsize=1)
def _get_bucket_fs() -> HfFileSystem:
# Anonymous read access to the public bucket. The listing cache is disabled so that newly uploaded sample images
# show up on a page refresh without having to redeploy the app.
return HfFileSystem(use_listings_cache=False)
def get_sample_images(experiment_name: str) -> list[tuple[Image.Image, str]]:
"""Fetch the sample images of an experiment from the storage bucket.
Returns a list of (PIL image, caption) tuples suitable for a gr.Gallery, or an empty list if no images are
found. Each image is captioned with the prompt that was used to generate it.
"""
stem = experiment_name.replace("/", "--")
fs = _get_bucket_fs()
try:
paths = sorted(fs.glob(f"{SAMPLE_IMAGE_BUCKET_DIR}/{stem}_*.png"))
except Exception as exc:
logger.warning("Could not list sample images for %r: %s", experiment_name, exc)
return []
gallery = []
for path in paths:
try:
with fs.open(path, "rb") as f:
image = Image.open(BytesIO(f.read()))
image.load()
except Exception as exc:
logger.warning("Could not load sample image %r: %s", path, exc)
continue
match = re.search(r"_(\d+)\.png$", path)
prompt_idx = int(match.group(1)) - 1 if match else len(gallery)
caption = (
SAMPLE_IMAGE_PROMPTS[prompt_idx]
if 0 <= prompt_idx < len(SAMPLE_IMAGE_PROMPTS)
else os.path.basename(path)
)
gallery.append((image, caption))
return gallery
@functools.lru_cache(maxsize=1)
def _load_dataset_images(dataset_id: str, split: str, image_column: str) -> list[Image.Image]:
ds = load_dataset(dataset_id, split=split)
images = []
for image in ds[image_column]:
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
images.append(image.convert("RGB"))
return images
def get_dataset_images(config: dict[str, Any]) -> list[tuple[Image.Image, str]]:
"""Fetch the training dataset images for a training configuration from the Hugging Face Hub."""
dataset_id = config.get("dataset_id") if config else None
if not dataset_id:
return []
split = config.get("dataset_split", "train")
image_column = config.get("image_column", "image")
try:
images = _load_dataset_images(dataset_id, split, image_column)
except Exception as exc:
logger.warning("Could not load dataset images for %r: %s", dataset_id, exc)
return []
prompts = config.get("instance_prompts", [])
if isinstance(prompts, str):
prompts = [prompts] * len(images)
gallery = []
for idx, image in enumerate(images):
gallery.append((image, prompts[idx]))
return gallery
def render_image_gallery(image_view, selected):
"""Return a gallery update with the contents for the selected experiment and image source view.
The dataset view falls back to the default dataset when no experiment is selected, so its images can be shown before
the user clicks a row. When generated samples are shown, the gallery label names the selected experiment.
"""
if image_view == DATASET_VIEW:
if selected:
try:
config = json.loads(selected["train_config"])
except (TypeError, ValueError):
config = {}
else:
config = DEFAULT_TRAIN_CONFIG_IMAGE_GEN
return gr.update(value=get_dataset_images(config), label="Images")
if not selected:
return gr.update(value=None, label="Images")
return gr.update(
value=get_sample_images(selected["experiment_name"]),
label=f"Generated samples for {selected['experiment_name']}",
)
def load_gallery_deferred(task_name, image_view, selected):
"""Populate the image gallery in a chained event.
Fetching the images can take a while, so the event handlers that update multiple components only clear the
gallery and the images are loaded here in a follow-up event. Otherwise, the other components (e.g. the results
table) would not be updated until the images are loaded.
"""
if task_name != IMAGE_GEN_TASK:
return gr.update()
return render_image_gallery(image_view, selected)
def build_app(df):
task_names = sorted(df["task_name"].unique())
initial_task = "MetaMathQA" if "MetaMathQA" in task_names else task_names[0]
initial_prefs = get_metric_preferences(initial_task)
initial_x, initial_y = _TASK_PARETO_DEFAULTS.get(initial_task, (list(initial_prefs)[0], list(initial_prefs)[1]))
with gr.Blocks() as demo:
gr.Markdown("# PEFT method comparison")
gr.Markdown(
"Find more information [on the PEFT GitHub repo](https://github.com/huggingface/peft/tree/main/method_comparison)"
)
# Hidden state to store the current filter query.
filter_state = gr.State("")
# Hidden state to store the experiment selected for the image gallery.
selected_state = gr.State(None)
gr.Markdown("## Choose the task and base model")
with gr.Row():
task_dropdown = gr.Dropdown(
label="Select Task",
choices=task_names,
value=initial_task,
)
model_dropdown = gr.Dropdown(label="Select Model ID", choices=get_model_ids(initial_task, df))
task_info = gr.Markdown(_get_task_info(initial_task))
# Make dataframe columns all equal in width so that they are good enough for numbers but don't get hugely
# extended by columns like `train_config`. Tasks can have different column counts, so size the widths to the
# widest task; experiment_name is always the first column (see _TASK_IMPORTANT_COLUMNS) and holds long names, so
# it gets extra width.
initial_filtered = filter_data(initial_task, get_model_ids(initial_task, df)[0], df)
num_columns = max(len(get_task_columns(task)) for task in task_names)
column_widths = ["150px"] * num_columns
column_widths[0] = "300px"
data_table = gr.DataFrame(
label="Results",
value=format_df(initial_filtered),
interactive=False,
max_chars=100,
wrap=False,
column_widths=column_widths,
)
with gr.Row():
filter_textbox = gr.Textbox(
label="Filter DataFrame",
placeholder="Enter filter (e.g.: peft_type=='LORA')",
interactive=True,
)
apply_filter_button = gr.Button("Apply Filter")
reset_filter_button = gr.Button("Reset Filter")
metric_explanation = gr.Markdown(
_get_metric_explanation(initial_task),
)
with gr.Group(visible=initial_task == IMAGE_GEN_TASK) as sample_images_group:
gr.Markdown("## Images")
gr.Markdown(
"The training dataset images are shown by default. Click a row in the results table above to see the "
"sample images generated by that experiment, and use the selector to switch between the generated "
"samples and the training dataset. Each image is captioned with its prompt. The generated images are "
f"stored in [this bucket]({SAMPLE_IMAGE_BUCKET_URL})."
)
image_view_radio = gr.Radio(
choices=[GENERATED_VIEW, DATASET_VIEW],
value=DATASET_VIEW,
label="Image source",
)
# The gallery starts empty and is populated by load_gallery_deferred on page load so that fetching the
# images doesn't block the app startup.
sample_gallery = gr.Gallery(
label="Images",
value=None,
columns=3,
object_fit="contain",
)
gr.Markdown("## Pareto plot")
gr.Markdown(
"Select 2 criteria to plot the Pareto frontier. This will show the best PEFT methods along this axis and "
"the trade-offs with the other axis. The PEFT methods that Pareto-dominate are shown in colors. All other "
"methods are inferior with regard to these two metrics. Hover over a point to show details."
)
with gr.Row():
metric_x_dropdown = gr.Dropdown(
label="1st metric for Pareto plot",
choices=list(initial_prefs.keys()),
value=initial_x,
)
metric_y_dropdown = gr.Dropdown(
label="2nd metric for Pareto plot",
choices=list(initial_prefs.keys()),
value=initial_y,
)
pareto_plot = gr.Plot(label="Pareto Frontier Plot")
summary_box = gr.Textbox(label="Summary Statistics", lines=6)
csv_output = gr.File(label="Export Filtered Data as CSV")
def update_on_task(task_name, current_filter):
new_models = get_model_ids(task_name, df)
filtered = filter_data(task_name, new_models[0] if new_models else "", df)
if current_filter.strip():
try:
mask = parse_and_filter(filtered, current_filter)
df_queried = filtered[mask]
if not df_queried.empty:
filtered = df_queried
except Exception as exc:
# invalid filter query
logger.debug("Ignoring invalid filter query: %s", exc)
prefs = get_metric_preferences(task_name)
x_default, y_default = _TASK_PARETO_DEFAULTS.get(task_name, (list(prefs)[0], list(prefs)[1]))
metric_choices = list(prefs.keys())
explanation = _get_metric_explanation(task_name)
is_image_gen = task_name == IMAGE_GEN_TASK
return (
gr.update(choices=new_models, value=new_models[0] if new_models else None),
_get_task_info(task_name),
format_df(filtered),
gr.update(choices=metric_choices, value=x_default),
gr.update(choices=metric_choices, value=y_default),
explanation,
gr.update(visible=is_image_gen),
gr.update(value=DATASET_VIEW),
gr.update(value=None, label="Images"),
None,
)
task_dropdown.change(
fn=update_on_task,
inputs=[task_dropdown, filter_state],
outputs=[
model_dropdown,
task_info,
data_table,
metric_x_dropdown,
metric_y_dropdown,
metric_explanation,
sample_images_group,
image_view_radio,
sample_gallery,
selected_state,
],
).then(
fn=load_gallery_deferred,
inputs=[task_dropdown, image_view_radio, selected_state],
outputs=sample_gallery,
)
def update_on_model(task_name, model_id, current_filter):
filtered = filter_data(task_name, model_id, df)
if current_filter.strip():
try:
mask = parse_and_filter(filtered, current_filter)
filtered = filtered[mask]
except Exception as exc:
logger.debug("Ignoring invalid filter query: %s", exc)
return format_df(filtered), gr.update(value=DATASET_VIEW), gr.update(value=None, label="Images"), None
model_dropdown.change(
fn=update_on_model,
inputs=[task_dropdown, model_dropdown, filter_state],
outputs=[data_table, image_view_radio, sample_gallery, selected_state],
).then(
fn=load_gallery_deferred,
inputs=[task_dropdown, image_view_radio, selected_state],
outputs=sample_gallery,
)
def show_sample_images(task_name, model_id, evt: gr.SelectData):
if task_name != IMAGE_GEN_TASK or evt.index is None:
return None, gr.update(), gr.update()
# Look up the clicked row by its experiment name (always the first column) instead of by the row index:
# sorting the table happens client-side only, so the row index refers to the displayed order, not the order
# of the dataframe on the server.
experiment_name = evt.row_value[0]
rows = filter_data(task_name, model_id, df)
rows = rows[rows["experiment_name"] == experiment_name]
if rows.empty:
return None, gr.update(), gr.update()
row = rows.iloc[0]
selected = {"experiment_name": row["experiment_name"], "train_config": row["train_config"]}
# Clicking a row switches the view to the experiment's generated samples.
return selected, gr.update(value=GENERATED_VIEW), render_image_gallery(GENERATED_VIEW, selected)
data_table.select(
fn=show_sample_images,
inputs=[task_dropdown, model_dropdown],
outputs=[selected_state, image_view_radio, sample_gallery],
)
def update_image_view(image_view, selected):
return render_image_gallery(image_view, selected)
# Use the input event (user-only) so the programmatic radio updates above don't re-trigger this.
image_view_radio.input(
fn=update_image_view,
inputs=[image_view_radio, selected_state],
outputs=sample_gallery,
)
def update_pareto_plot_and_summary(task_name, model_id, metric_x, metric_y, current_filter):
prefs = get_metric_preferences(task_name)
filtered = filter_data(task_name, model_id, df)
if current_filter.strip():
try:
mask = parse_and_filter(filtered, current_filter)
filtered = filtered[mask]
except Exception as e:
return generate_pareto_plot(filtered, metric_x, metric_y, prefs), f"Filter error: {e}"
pareto_df = compute_pareto_frontier(filtered, metric_x, metric_y, prefs)
fig = generate_pareto_plot(filtered, metric_x, metric_y, prefs)
summary = compute_pareto_summary(filtered, pareto_df, metric_x, metric_y)
return fig, summary
for comp in [model_dropdown, metric_x_dropdown, metric_y_dropdown]:
comp.change(
fn=update_pareto_plot_and_summary,
inputs=[task_dropdown, model_dropdown, metric_x_dropdown, metric_y_dropdown, filter_state],
outputs=[pareto_plot, summary_box],
)
def apply_filter(filter_query, task_name, model_id, metric_x, metric_y):
prefs = get_metric_preferences(task_name)
filtered = filter_data(task_name, model_id, df)
if filter_query.strip():
try:
mask = parse_and_filter(filtered, filter_query)
filtered = filtered[mask]
except Exception as e:
# Update the table, plot, and summary even if there is a filter error.
return (
filter_query,
filtered,
generate_pareto_plot(filtered, metric_x, metric_y, prefs),
f"Filter error: {e}",
)
pareto_df = compute_pareto_frontier(filtered, metric_x, metric_y, prefs)
fig = generate_pareto_plot(filtered, metric_x, metric_y, prefs)
summary = compute_pareto_summary(filtered, pareto_df, metric_x, metric_y)
return filter_query, format_df(filtered), fig, summary
apply_filter_button.click(
fn=apply_filter,
inputs=[filter_textbox, task_dropdown, model_dropdown, metric_x_dropdown, metric_y_dropdown],
outputs=[filter_state, data_table, pareto_plot, summary_box],
)
def reset_filter(task_name, model_id, metric_x, metric_y):
prefs = get_metric_preferences(task_name)
filtered = filter_data(task_name, model_id, df)
pareto_df = compute_pareto_frontier(filtered, metric_x, metric_y, prefs)
fig = generate_pareto_plot(filtered, metric_x, metric_y, prefs)
summary = compute_pareto_summary(filtered, pareto_df, metric_x, metric_y)
# Return empty strings to clear the filter state and textbox.
return "", "", format_df(filtered), fig, summary
reset_filter_button.click(
fn=reset_filter,
inputs=[task_dropdown, model_dropdown, metric_x_dropdown, metric_y_dropdown],
outputs=[filter_state, filter_textbox, data_table, pareto_plot, summary_box],
)
gr.Markdown("## Export data")
# Export button for CSV download.
export_button = gr.Button("Export Filtered Data")
export_button.click(
fn=lambda task, model: export_csv(filter_data(task, model, df)),
inputs=[task_dropdown, model_dropdown],
outputs=csv_output,
)
demo.load(
fn=update_pareto_plot_and_summary,
inputs=[task_dropdown, model_dropdown, metric_x_dropdown, metric_y_dropdown, filter_state],
outputs=[pareto_plot, summary_box],
)
demo.load(
fn=load_gallery_deferred,
inputs=[task_dropdown, image_view_radio, selected_state],
outputs=sample_gallery,
)
return demo
_TASK_DESCRIPTIONS = {
"MetaMathQA": (
"Trains on the MetaMathQA dataset and validates/tests on GSM8K, comparing how well PEFT methods teach "
"mathematical chain-of-thought reasoning."
),
"image-gen": (
"DreamBooth-style fine-tuning on a "
"[cat plushy dataset](https://huggingface.co/datasets/peft-internal-testing/cat-image-dataset) image dataset."
),
}
_TASK_CHECKPOINT_URLS = {
"MetaMathQA": "https://huggingface.co/buckets/peft-internal-testing/metamathqa-checkpoints",
"image-gen": "https://huggingface.co/buckets/peft-internal-testing/image-gen-benchmark/tree/checkpoints",
}
def _get_task_info(task_name):
description = _TASK_DESCRIPTIONS.get(task_name, "")
url = _TASK_CHECKPOINT_URLS.get(task_name)
if url:
description = f"{description} The trained PEFT checkpoints are available in [this bucket]({url})."
return description
base_dir = os.path.dirname(__file__)
_TASK_CONFIGS = {
"MetaMathQA": os.path.join(base_dir, "MetaMathQA", "results"),
"image-gen": os.path.join(base_dir, "image-gen", "results"),
}
df = load_task_results(_TASK_CONFIGS)
demo = build_app(df)
demo.launch(theme=gr.themes.Soft())