PeptideAI / StreamlitApp /StreamlitApp.py
m0ksh's picture
Sync from GitHub (preserve manual model files)
cccf8bd verified
Raw
History Blame
33.5 kB
import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import torch
import plotly.express as px
import html as _html
from sklearn.manifold import TSNE
# modular imports
from utils.predict import load_model, predict_amp, encode_sequence
from utils.analyze import aa_composition, compute_properties
from utils.optimize import optimize_sequence
from utils.ui_helpers import (
choose_top_candidate,
format_conf_percent,
mutation_heatmap_html,
mutation_diff_table,
optimization_summary,
sequence_length_warning,
sequence_health_label,
build_analysis_summary_text,
)
from utils.peptide_extras import (
KNOWN_AMPS,
MAX_3D_SEQUENCE_LENGTH,
COMPACT_3D_LEGEND,
COMPACT_MAP_LEGEND,
COMPACT_WHEEL_LEGEND,
find_most_similar,
build_importance_map_html,
plot_helical_wheel,
render_3d_structure,
)
try:
import pyperclip # Optional; may not exist in all environments.
except Exception:
pyperclip = None
def _tooltip_label(label: str, tooltip_text: str) -> None:
"""
Render a label with a hover tooltip using HTML 'title' attribute.
"""
safe = _html.escape(tooltip_text, quote=True)
st.markdown(f"{label} <span title='{safe}' style='cursor:help;color:#666'>(i)</span>", unsafe_allow_html=True)
def _try_copy_to_clipboard(text: str) -> None:
"""
Best-effort clipboard copy (server-side only).
Avoids streamlit.components.html — iframe/JS can fail on Hugging Face Spaces
(TypeError: Failed to fetch dynamically imported module for static/js chunks).
"""
if pyperclip is not None:
try:
pyperclip.copy(text)
except Exception:
pass
# APP CONFIG
st.set_page_config(page_title="AMP Predictor", layout="wide")
# App title
st.title("PeptideAI")
st.write("Antimicrobial Peptide Predictor and Optimizer")
st.divider()
# SESSION STATE KEYS (one-time init)
if "predictions" not in st.session_state:
st.session_state.predictions = [] # list of dicts
if "predict_ran" not in st.session_state:
st.session_state.predict_ran = False
if "predict_input_widget" not in st.session_state:
st.session_state.predict_input_widget = ""
if "analyze_input" not in st.session_state:
st.session_state.analyze_input = "" # last analyze input
if "analyze_output" not in st.session_state:
st.session_state.analyze_output = None # (label, conf_display, comp, props, analysis)
if "optimize_input" not in st.session_state:
st.session_state.optimize_input = "" # last optimize input
if "optimize_output" not in st.session_state:
st.session_state.optimize_output = None # (orig_seq, orig_conf, improved_seq, improved_conf, history)
if "visualize_sequences" not in st.session_state:
st.session_state.visualize_sequences = None
if "visualize_df" not in st.session_state:
st.session_state.visualize_df = None
if "visualize_peptide_input" not in st.session_state:
st.session_state.visualize_peptide_input = ""
# SIDEBAR: navigation + global clear
st.sidebar.header("Navigation")
page = st.sidebar.radio(
"Go to",
[
"Predict",
"Analyze",
"Optimize",
"Visualize Peptide",
"Visualize t-SNE",
"About",
],
)
if st.sidebar.button("Clear All Fields"):
# clear only our known keys
keys = [
"predictions",
"predict_ran",
"predict_input_widget",
"analyze_input",
"analyze_output",
"optimize_input",
"optimize_output",
"visualize_sequences",
"visualize_df",
"visualize_peptide_input",
]
for k in keys:
if k in st.session_state:
del st.session_state[k]
st.sidebar.success("Cleared app state.")
# Streamlit renamed `experimental_rerun()` -> `rerun()` in newer versions.
# Use a version-safe call so Spaces don't fail with AttributeError.
rerun_fn = getattr(st, "rerun", None) or getattr(st, "experimental_rerun", None)
if rerun_fn is not None:
rerun_fn()
else:
st.stop()
# Load model once
model = load_model()
# PREDICT PAGE
if page == "Predict":
st.header("AMP Predictor")
preset_cols = st.columns(2)
with preset_cols[0]:
if st.button("Use strong AMP example"):
st.session_state.predict_input_widget = "RGGRLCYCRGWICFCVGR"
st.rerun()
with preset_cols[1]:
if st.button("Use weak sequence example"):
st.session_state.predict_input_widget = "KAEEEVEKNKEEAEEKAEKKIAE"
st.rerun()
seq_input = st.text_area(
"Enter peptide sequences (one per line):",
height=150,
key="predict_input_widget",
)
uploaded_file = st.file_uploader("Or upload a FASTA/text file", type=["txt", "fasta"])
# Sequence length warnings (preview only; does not run model).
preview_sequences = [s.strip() for s in (seq_input or "").splitlines() if s.strip()]
if preview_sequences:
short_cnt = sum(1 for s in preview_sequences if len(s) < 8)
long_cnt = sum(1 for s in preview_sequences if len(s) > 50)
if short_cnt:
st.caption(f"Warning: {short_cnt} sequence(s) too short for typical AMP (< 8 aa).")
if long_cnt:
st.caption(f"Warning: {long_cnt} sequence(s) unusually long (> 50 aa).")
run = st.button("Run Prediction")
if run:
# Gather sequences
sequences = []
if seq_input:
sequences += [s.strip() for s in seq_input.splitlines() if s.strip()]
if uploaded_file:
text = uploaded_file.read().decode("utf-8")
sequences += [l.strip() for l in text.splitlines() if not l.startswith(">") and l.strip()]
if not sequences:
st.warning("Please input or upload sequences first.")
else:
progress = st.progress(0.0)
with st.spinner("Running prediction..."):
results = []
for i, seq in enumerate(sequences):
label, conf = predict_amp(seq, model)
conf_display = round(conf * 100, 1) if label == "AMP" else round((1 - conf) * 100, 1)
results.append({
"Sequence": seq,
"Prediction": label,
"Confidence": conf,
"Description": f"{label} with {conf_display}% confidence"
})
progress.progress((i + 1) / max(1, len(sequences)), text=f"Predicted {i + 1}/{len(sequences)}")
progress.progress(1.0)
# Persist new predictions and mark that we ran
st.session_state.predictions = results
st.session_state.predict_ran = True
st.success("Prediction complete.")
# If user hasn't just run predictions, show the last saved results (if any)
if st.session_state.predictions and not (run and st.session_state.predict_ran is False):
st.divider()
top_candidate = choose_top_candidate(st.session_state.predictions)
if top_candidate:
with st.container():
st.write("**Top AMP Predicted Candidate**")
seq = top_candidate.get("Sequence", "")
cc = st.columns([9, 1])
with cc[0]:
st.code(seq, language="text")
with cc[1]:
if st.button("Copy", key="copy_top_candidate"):
_try_copy_to_clipboard(seq)
toast_fn = getattr(st, "toast", None)
if toast_fn is not None:
toast_fn("Copied — or select the sequence above (Ctrl+C)")
else:
st.success("Copied — or select the sequence above (Ctrl+C)")
label = top_candidate.get("Prediction", "")
conf_str = format_conf_percent(top_candidate["predicted_confidence"], digits=1)
st.write(f"**{label} with {conf_str} confidence**")
st.write(f"Reason: {top_candidate['Reason']}")
st.divider()
# Keep the original dataframe for full overview/download compatibility.
st.dataframe(pd.DataFrame(st.session_state.predictions), use_container_width=True)
csv = pd.DataFrame(st.session_state.predictions).to_csv(index=False)
st.download_button("Download predictions as CSV", csv, "predictions.csv", "text/csv")
# ANALYZE PAGE
elif page == "Analyze":
st.header("Peptide Analyzer")
# show the last saved analyze output if user navigated back
last_seq = st.session_state.analyze_input
seq = st.text_input(
"Enter a peptide sequence to analyze:",
value=last_seq,
)
warn = sequence_length_warning(seq)
if warn:
st.caption(f"Warning: {warn}")
# only run analysis when input changed from last saved input
if seq and seq != st.session_state.get("analyze_input", ""):
with st.spinner("Running analysis..."):
label, conf = predict_amp(seq, model)
conf_pct = round(conf * 100, 1)
conf_display = conf_pct if label == "AMP" else 100 - conf_pct
comp = aa_composition(seq)
props = compute_properties(seq)
# normalize property key names if necessary
net_charge = props.get("Net Charge (approx.)",
props.get("Net charge", props.get("NetCharge", 0)))
# build analysis summary (same rules as before)
length = props.get("Length", len(seq))
hydro = props.get("Hydrophobic Fraction", props.get("Hydrophobic", 0))
charge = net_charge
mw = props.get("Molecular Weight (Da)", props.get("MolecularWeight", 0))
analysis = []
if (conf_pct if label == "AMP" else (100 - conf_pct)) >= 80:
analysis.append(f"Highly likely to be {label}.")
elif (conf_pct if label == "AMP" else (100 - conf_pct)) >= 60:
analysis.append(f"Moderately likely to be {label}.")
else:
analysis.append(f"Low likelihood to be {label}.")
if hydro < 0.4:
analysis.append("Low hydrophobicity may reduce membrane interaction.")
elif hydro > 0.6:
analysis.append("High hydrophobicity may reduce solubility.")
if charge <= 0:
analysis.append("Low or negative charge may limit antimicrobial activity.")
if length < 10:
analysis.append("Short sequence may reduce efficacy.")
elif length > 50:
analysis.append("Long sequence may affect stability.")
if comp.get("K", 0) + comp.get("R", 0) + comp.get("H", 0) >= 3:
analysis.append("High basic residue content enhances membrane binding.")
if comp.get("C", 0) + comp.get("W", 0) >= 2:
analysis.append("Multiple cysteine/tryptophan residues may improve activity.")
# Save to session state
st.session_state.analyze_input = seq
st.session_state.analyze_output = (label, conf, conf_display, comp, props, analysis)
# If we have stored output, display it
if st.session_state.analyze_output:
label, conf, conf_display, comp, props, analysis = st.session_state.analyze_output
st.subheader("AMP Prediction")
display_conf = round(conf * 100, 1) if label == "AMP" else round((1 - conf) * 100, 1)
st.write(f"Prediction: **{label}** with **{display_conf}%** confidence")
# Sequence health check badge
hydro = props.get("Hydrophobic Fraction", 0)
charge = props.get("Net Charge (approx.)", props.get("Net charge", 0))
health_label, color = sequence_health_label(float(conf), float(charge), float(hydro))
st.markdown(
f"<span style='color:{color}; font-weight:800;'>{health_label}</span>",
unsafe_allow_html=True,
)
st.subheader("Amino Acid Composition")
comp_df = pd.DataFrame(list(comp.items()), columns=["Amino Acid", "Frequency"]).set_index("Amino Acid")
st.bar_chart(comp_df)
st.subheader("Physicochemical Properties and Favorability")
# pull properties safely
length = props.get("Length", len(st.session_state.analyze_input))
hydro = props.get("Hydrophobic Fraction", 0)
charge = props.get("Net Charge (approx.)", props.get("Net charge", 0))
mw = props.get("Molecular Weight (Da)", 0)
favorability = {
"Length": "Good" if 10 <= length <= 50 else "Too short" if length < 10 else "Too long",
"Hydrophobic Fraction": "Good" if 0.4 <= hydro <= 0.6 else "Low" if hydro < 0.4 else "High",
"Net Charge": "Favorable" if charge > 0 else "Neutral" if charge == 0 else "Unfavorable",
"Molecular Weight": "Acceptable" if 500 <= mw <= 5000 else "Extreme"
}
def _info_icon(tooltip_text: str) -> str:
safe = _html.escape(tooltip_text, quote=False)
return (
"<span "
"class='amp-i' "
f"data-tooltip='{safe}' "
"style=\"display:inline-flex; align-items:center; justify-content:center; "
"margin-left:6px; width:16px; height:16px; border-radius:50%; "
"background:#f2f2f2; border:1px solid #d9d9d9; color:#333; "
"font-size:12px; font-weight:700; cursor:help;\">(i)</span>"
)
# Render the favorability table with working inline tooltips.
hydro_label = f"Hydrophobic Fraction{_info_icon('Fraction of residues that prefer non-aqueous environments')}"
charge_label = f"Net Charge{_info_icon('Positive charge helps peptides bind bacterial membranes')}"
table_html = (
"<style>"
".amp-i{position:relative; display:inline-flex;}"
".amp-i::after{"
"content:attr(data-tooltip);"
"position:absolute;"
"left:50%;"
"top:125%;"
"transform:translateX(-50%);"
"max-width:860px;"
"white-space:normal;"
"padding:8px 10px;"
"background:rgba(30,30,30,0.95);"
"color:#fff;"
"border-radius:8px;"
"font-size:12px;"
"line-height:1.25;"
"box-shadow:0 8px 30px rgba(0,0,0,0.25);"
"opacity:0;"
"pointer-events:none;"
"z-index:9999;"
"}"
".amp-i:hover::after{opacity:1;}"
"</style>"
"<table style='width:100%; border-collapse:collapse;'>"
"<thead>"
"<tr>"
"<th style='text-align:left; padding:8px; border-bottom:1px solid #e6e6e6;'>Property</th>"
"<th style='text-align:right; padding:8px; border-bottom:1px solid #e6e6e6;'>Value</th>"
"<th style='text-align:left; padding:8px; border-bottom:1px solid #e6e6e6;'>Favorability</th>"
"</tr>"
"</thead>"
"<tbody>"
f"<tr><td style='padding:8px;'>{_html.escape('Length')}{_info_icon('Peptides with ~10–50 aa often balance membrane insertion and solubility.')}</td><td style='padding:8px; text-align:right;'>{_html.escape(str(length))}</td><td style='padding:8px;'>{_html.escape(favorability['Length'])}</td></tr>"
f"<tr><td style='padding:8px;'>{hydro_label}</td><td style='padding:8px; text-align:right;'>{_html.escape(str(hydro))}</td><td style='padding:8px;'>{_html.escape(favorability['Hydrophobic Fraction'])}</td></tr>"
f"<tr><td style='padding:8px;'>{charge_label}</td><td style='padding:8px; text-align:right;'>{_html.escape(str(charge))}</td><td style='padding:8px;'>{_html.escape(favorability['Net Charge'])}</td></tr>"
f"<tr><td style='padding:8px;'>{_html.escape('Molecular Weight')}{_info_icon('Moderate molecular weight can help stability and binding; extremes may hurt performance.')}</td><td style='padding:8px; text-align:right;'>{_html.escape(str(mw))}</td><td style='padding:8px;'>{_html.escape(favorability['Molecular Weight'])}</td></tr>"
"</tbody>"
"</table>"
)
st.markdown(table_html, unsafe_allow_html=True)
st.divider()
st.subheader("Property Radar Chart")
categories = ["Length", "Hydrophobic Fraction", "Net Charge", "Molecular Weight"]
values = [min(length / 50, 1), min(hydro, 1), 1 if charge > 0 else 0, min(mw / 5000, 1)]
values += values[:1]
ideal_min = [10/50, 0.4, 1/6, 500/5000] + [10/50]
ideal_max = [50/50, 0.6, 6/6, 5000/5000] + [50/50]
angles = np.linspace(0, 2 * np.pi, len(categories), endpoint=False).tolist()
angles += angles[:1]
# Adjusted figsize for better vertical space
fig, ax = plt.subplots(figsize=(2.8, 3.2), subplot_kw=dict(polar=True))
fig.patch.set_facecolor("white")
ax.fill_between(angles, ideal_min, ideal_max, color='#457a00', alpha=0.15, label="Ideal AMP range")
ax.plot(angles, values, 'o-', color='#457a00', linewidth=2, label="Sequence")
ax.fill(angles, values, color='#457a00', alpha=0.25)
ax.set_thetagrids(np.degrees(angles[:-1]), categories, fontsize=8)
ax.set_ylim(0, 1)
ax.tick_params(axis='y', labelsize=7)
ax.legend(loc='lower center', bbox_to_anchor=(0.85, 1.15), ncol=2, fontsize=7)
st.pyplot(fig, use_container_width=False)
st.divider()
st.subheader("Most similar known AMP")
st.caption(
f"Compared to **{len(KNOWN_AMPS)}** unique AMP sequences (label = 1 in `Data/ampData.csv`)."
)
seq_sim = str(st.session_state.analyze_input or "").strip()
seq_clean_sim = "".join(c for c in seq_sim.upper() if not c.isspace())
if seq_clean_sim:
match_seq, sim_score = find_most_similar(seq_clean_sim)
if match_seq is not None:
st.write(f"**Best match:** `{match_seq}`")
st.write(f"**Similarity score:** **{sim_score:.3f}** (position match / max length)")
if sim_score > 0.6:
st.success("High similarity to a known AMP in the reference set.")
elif sim_score > 0.3:
st.warning("Moderate similarity — interpret with care.")
else:
st.error("Low similarity — sequence is distant from reference AMPs.")
else:
st.warning("Could not compute similarity.")
else:
st.caption("Run analysis with a sequence to compare against known AMPs.")
st.divider()
# Analysis Summary
st.subheader("Analysis Summary")
for line in analysis:
st.write(f"- {line}")
# Export analysis report
st.divider()
st.subheader("Export Analysis Report")
export_format = st.radio("Format", ["CSV", "TXT"], horizontal=True)
confidence_display_str = f"{round(conf_display, 1)}%"
summary_text = build_analysis_summary_text(
sequence=st.session_state.analyze_input,
prediction=label,
confidence_display=confidence_display_str,
props=props,
analysis_lines=analysis,
)
csv_df = pd.DataFrame(
[
{
"Sequence": st.session_state.analyze_input,
"Prediction": label,
"Confidence": confidence_display_str,
"Length": props.get("Length", len(st.session_state.analyze_input)),
"Charge": charge,
"Hydrophobic fraction": hydro,
"Summary": "\n".join(analysis or []),
}
]
)
if export_format == "CSV":
csv_bytes = csv_df.to_csv(index=False).encode("utf-8")
st.download_button(
"Download CSV report",
csv_bytes,
file_name="analysis_report.csv",
mime="text/csv",
)
else:
st.download_button(
"Download TXT report",
summary_text.encode("utf-8"),
file_name="analysis_report.txt",
mime="text/plain",
)
# OPTIMIZE PAGE
elif page == "Optimize":
st.header("Peptide Optimizer")
# Form: Enter in the text field submits the form (same as clicking Run Optimization).
with st.form("optimize_form", clear_on_submit=False):
seq = st.text_input(
"Enter a peptide sequence to optimize:",
value=st.session_state.get("optimize_input", ""),
)
submitted = st.form_submit_button("Run Optimization")
warn_opt = sequence_length_warning(seq) if seq else None
if warn_opt:
st.caption(f"Warning: {warn_opt}")
if submitted and seq and str(seq).strip():
seq = str(seq).strip()
st.session_state.optimize_input = seq
progress = st.progress(0.0, text="Optimizing...")
with st.spinner("Optimizing sequence..."):
improved_seq, improved_conf, history = optimize_sequence(seq, model)
_ol, orig_conf = predict_amp(seq, model)
st.session_state.optimize_output = (seq, orig_conf, improved_seq, improved_conf, history)
progress.progress(1.0, text="Optimization complete")
st.success("Optimization finished.")
# If there is saved output show it
if st.session_state.optimize_output:
orig_seq, orig_conf, improved_seq, improved_conf, history = st.session_state.optimize_output
summary = optimization_summary(orig_seq, orig_conf, improved_seq, improved_conf)
delta_str = f"{summary['delta_conf_pct']:+.2f}%"
col_results, col_opt_summary = st.columns(2)
with col_results:
st.subheader("Results")
st.write(f"**Original Sequence:** {orig_seq} — Confidence: {round(orig_conf*100,1)}%")
st.write(
f"**Optimized Sequence:** {improved_seq} — Confidence: {round(improved_conf*100,1)}%"
)
with col_opt_summary:
st.subheader("Optimization Summary")
st.write(f"Confidence: **{delta_str}** (final - original)")
st.write(
f"Charge: **{summary['charge_change']}** (orig {summary['charge_orig']}, final {summary['charge_final']})"
)
st.write(
f"Hydrophobicity: **{summary['hydro_change']}** (orig {summary['hydro_orig']}, final {summary['hydro_final']})"
)
st.divider()
# Mutation Heatmap
st.subheader("Mutation Heatmap (Changed Residues Highlighted)")
st.markdown(mutation_heatmap_html(orig_seq, improved_seq), unsafe_allow_html=True)
with st.expander("Mutation Details (table)"):
diff_rows = mutation_diff_table(orig_seq, improved_seq)
st.dataframe(pd.DataFrame(diff_rows), use_container_width=True)
if len(history) > 1:
df_steps = pd.DataFrame([{
"Step": i,
"Change": change,
"Old Type": old_type,
"New Type": new_type,
"Reason for Improvement": reason,
"New Confidence (%)": round(conf * 100, 2)
} for i, (seq_after, conf, change, old_type, new_type, reason) in enumerate(history[1:], start=1)])
st.subheader("Mutation Steps")
st.dataframe(df_steps, use_container_width=True)
# Confidence improvement plot
step_nums = df_steps["Step"].tolist()
conf_values = df_steps["New Confidence (%)"].tolist()
df_graph = pd.DataFrame({"Step": step_nums, "Confidence (%)": conf_values})
fig = px.line(df_graph, x="Step", y="Confidence (%)", markers=True, color_discrete_sequence=["#457a00"])
fig.update_layout(yaxis=dict(range=[0, 100]), title="Confidence Improvement Over Steps")
st.plotly_chart(fig, use_container_width=True)
# VISUALIZE PEPTIDE PAGE
elif page == "Visualize Peptide":
st.header("Peptide Visualizer")
# Tighter legend expanders (summary row + scrollable body)
st.markdown(
"""
<style>
div[data-testid="stExpander"] details > summary {
padding-top: 0.3rem !important;
padding-bottom: 0.3rem !important;
min-height: 2rem !important;
}
div[data-testid="stExpander"] details div[data-testid="stMarkdownContainer"] {
max-height: 6.5rem;
overflow-y: auto;
}
</style>
""",
unsafe_allow_html=True,
)
st.text_input(
"Enter a peptide sequence to visualize:",
key="visualize_peptide_input",
placeholder="Paste or type a one-letter amino-acid sequence",
)
seq_viz = (st.session_state.get("visualize_peptide_input") or "").strip()
clean_viz = "".join(c for c in seq_viz.upper() if not c.isspace())
if clean_viz:
with st.spinner("Building 3D view and helical wheel..."):
warn_len = sequence_length_warning(clean_viz)
if warn_len:
st.warning(warn_len)
if len(clean_viz) > MAX_3D_SEQUENCE_LENGTH:
st.warning(
f"Sequence longer than **{MAX_3D_SEQUENCE_LENGTH}** aa: **3D model is disabled**; "
"helical wheel and functional map still render."
)
col_l, col_r = st.columns(2)
with col_l:
st.subheader("3D structural approximation (py3Dmol)")
st.caption("Smoothed helical CA trace; colored spheres follow the same scheme as the wheel.")
if len(clean_viz) <= MAX_3D_SEQUENCE_LENGTH:
if render_3d_structure(clean_viz, enhanced=True, spin=False):
with st.expander("3D · legend", expanded=False):
st.markdown(COMPACT_3D_LEGEND)
else:
st.info("3D view unavailable (install **py3dmol** in your environment).")
else:
st.info("3D visualization is limited to **60 residues** for performance.")
with col_r:
st.subheader("Helical wheel")
st.caption(
"Radial spokes per residue, black connectors along the sequence, colored disks (same scheme as 3D)."
)
fig_wheel = plot_helical_wheel(clean_viz)
st.pyplot(fig_wheel, use_container_width=True)
plt.close(fig_wheel)
with st.expander("Wheel · legend", expanded=False):
st.markdown(COMPACT_WHEEL_LEGEND)
st.divider()
st.subheader("Functional region map")
st.caption("Residue-level chemistry; colors align with the 3D view and wheel.")
st.markdown(build_importance_map_html(clean_viz), unsafe_allow_html=True)
with st.expander("Map · legend", expanded=False):
st.markdown(COMPACT_MAP_LEGEND)
# VISUALIZE t-SNE PAGE
elif page == "Visualize t-SNE":
st.header("t-SNE Visualizer")
st.write("Upload peptide sequences (FASTA or plain list) to embed sequences and explore clusters with t-SNE.")
uploaded_file = st.file_uploader("Upload FASTA or text file", type=["txt", "fasta"])
# If file uploaded, set session sequences (replacing previous)
if uploaded_file:
text = uploaded_file.read().decode("utf-8")
sequences = [l.strip() for l in text.splitlines() if not l.startswith(">") and l.strip()]
st.session_state.visualize_sequences = sequences
# Clear any previous df so we recompute
st.session_state.visualize_df = None
# If we have sequences stored, compute embeddings and t-SNE if no df present
if st.session_state.visualize_sequences and st.session_state.visualize_df is None:
sequences = st.session_state.visualize_sequences
if len(sequences) < 2:
st.warning("Need at least 2 sequences for t-SNE visualization.")
else:
progress = st.progress(0.0, text="Generating embedding...")
with st.spinner("Generating embedding..."):
embeddings_list, labels, confs, lengths, hydros, charges = [], [], [], [], [], []
# Use model internals for embeddings; keep same approach as your module
embedding_extractor = torch.nn.Sequential(*list(model.layers)[:-1])
for i, s in enumerate(sequences):
x = torch.tensor(encode_sequence(s), dtype=torch.float32).unsqueeze(0)
with torch.no_grad():
emb = embedding_extractor(x).squeeze().numpy()
embeddings_list.append(emb)
label, conf = predict_amp(s, model)
labels.append(label)
confs.append(conf)
props = compute_properties(s)
lengths.append(props.get("Length", len(s)))
hydros.append(props.get("Hydrophobic Fraction", 0))
charges.append(props.get("Net Charge (approx.)", props.get("Net charge", 0)))
progress.progress((i + 1) / max(1, len(sequences)), text=f"Encoding {i + 1}/{len(sequences)}")
embeddings_array = np.stack(embeddings_list)
perplexity = min(30, max(2, len(sequences) - 1))
tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity)
reduced = tsne.fit_transform(embeddings_array)
df = pd.DataFrame(reduced, columns=["x", "y"])
df["Sequence"] = sequences
df["Label"] = labels
df["Confidence"] = confs
df["Length"] = lengths
df["Hydrophobic Fraction"] = hydros
df["Net Charge"] = charges
st.session_state.visualize_df = df
progress.progress(1.0, text="Embedding ready")
# If we have a t-SNE dataframe, show plot and sidebar filters
if st.session_state.visualize_df is not None:
df = st.session_state.visualize_df
st.subheader("t-SNE plot")
st.sidebar.subheader("Filter Sequences")
min_len, max_len = int(df["Length"].min()), int(df["Length"].max())
if min_len == max_len:
st.sidebar.write(f"All sequences have length {min_len}")
length_range = (min_len, max_len)
else:
length_range = st.sidebar.slider("Sequence length", min_len, max_len, (min_len, max_len))
label_options = st.sidebar.multiselect("Label", ["AMP", "Non-AMP"], default=["AMP", "Non-AMP"])
filtered_df = df[(df["Length"].between(length_range[0], length_range[1])) & (df["Label"].isin(label_options))]
color_by = st.sidebar.selectbox("Color points by", ["Label", "Confidence", "Hydrophobic Fraction", "Net Charge", "Length"])
color_map = {"AMP": "#2ca02c", "Non-AMP": "#d62728"}
fig = px.scatter(
filtered_df,
x="x", y="y",
color=color_by if color_by != "Label" else "Label",
color_discrete_map=color_map if color_by == "Label" else None,
hover_data={"Sequence": True, "Label": True, "Confidence": True, "Length": True, "Hydrophobic Fraction": True, "Net Charge": True},
title="t-SNE Visualization of Model Embeddings"
)
st.plotly_chart(fig, use_container_width=True)
st.subheader("t-SNE Analysis")
st.markdown("""
• Each point represents a peptide sequence.
• Sequences close together have similar internal representations in the model.
• AMP and Non-AMP clusters indicate strong model separation.
• Coloring by properties reveals biochemical trends.
""")
# ABOUT PAGE
elif page == "About":
st.header("About the Project")
st.markdown("""
PeptideAI is a lightweight Streamlit app for exploring antimicrobial peptide (AMP) sequences.
It uses a trained neural network to estimate whether a peptide is likely to be antimicrobial, then helps you interpret and improve candidates:
- **AMP Predictor**: batch predictions from multi-line or FASTA input, length warnings, persisted results, top-candidate highlight, and CSV export.
- **Peptide Analyzer**: single-sequence numerical and textual analysis — AMP prediction, composition, physicochemical table + radar, similarity to known AMPs, and report export.
- **Peptide Optimizer**: guided sequence optimization with Enter-to-run input, mutation heatmap, step table, and confidence-vs-step trend.
- **Peptide Visualizer**: single-sequence 3D approximation + detailed helical wheel + functional region map with consistent residue coloring and concise legend dropdowns.
- **t-SNE Visualizer**: upload many sequences, embed with the model, run t-SNE, and explore clusters with filters and hover metadata.
- **About**: this overview and disclaimer.
**Disclaimer:** Predictions are model-based heuristics and are **not** a substitute for wet-lab validation or regulatory use.
""")