import gradio as gr
import torch
import numpy as np
import threading
from transformers import AutoModel, AutoTokenizer, AutoConfig, RobertaModel
from modeling_dlmberta import InteractionModelATTNForRegression, InteractionModelATTNSimplePoolingForRegression, InteractionModelATTNSimplePoolingConfig, StdScaler
from configuration_dlmberta import InteractionModelATTNConfig
from chemberta import ChembertaTokenizer
import json
import os
from pathlib import Path
import logging
# Import visualization functions
from analysis import plot_crossattention_weights, plot_presum
from PIL import Image, ImageDraw, ImageFont
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
BASE_DIR = Path(__file__).resolve().parent
_DEFAULT_MODEL_KEY = "interpr"
MODEL_REGISTRY = {
"meanpool": {
"name": "Mean pooling",
"path": BASE_DIR / "meanpool",
"model_cls": InteractionModelATTNSimplePoolingForRegression,
"config_cls": InteractionModelATTNSimplePoolingConfig,
"supports_presum": False,
},
"interpr": {
"name": "Interpretable pooling",
"path": BASE_DIR / "interpr",
"model_cls": InteractionModelATTNForRegression,
"config_cls": InteractionModelATTNConfig,
"supports_presum": True,
},
}
MODEL_CHOICES = [
("Mean pooling (higher AUROC, default)", "meanpool"),
("Interpretable pooling (main manuscript model)", "interpr"),
]
def create_placeholder_image(width=600, height=400, text="No visualization available", bg_color=(0, 0, 0, 0)):
"""
Create a transparent placeholder image with text
Args:
width (int): Image width
height (int): Image height
text (str): Text to display
bg_color (tuple): Background color (R, G, B, A) - (0,0,0,0) for transparent
Returns:
PIL.Image: Transparent placeholder image
"""
img = Image.new('RGBA', (width, height), bg_color)
draw = ImageDraw.Draw(img)
try:
font = ImageFont.truetype("arial.ttf", 16)
except:
try:
font = ImageFont.load_default()
except:
font = None
if font:
bbox = draw.textbbox((0, 0), text, font=font)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
else:
text_width = len(text) * 8
text_height = 16
x = (width - text_width) // 2
y = (height - text_height) // 2
draw.text((x, y), text, fill=(128, 128, 128, 255), font=font)
return img
class DrugTargetInteractionApp:
def __init__(self):
self.model = None
self.target_tokenizer = None
self.drug_tokenizer = None
self.scaler = None
self.current_model_name = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Per-instance lock so concurrent users sharing this model don't race on
# INTERPR_ENABLE_MODE / INTERPR_DISABLE_MODE or inference state.
self._lock = threading.Lock()
def load_model(self, model_name=_DEFAULT_MODEL_KEY):
"""Load the pre-trained model and tokenizers"""
try:
if model_name == "meanpool":
logger.info("Loading model from ./meanpool")
elif model_name == "interpr":
logger.info("Loading model from ./interpr")
config_class = MODEL_REGISTRY[model_name]["config_cls"]
model_class = MODEL_REGISTRY[model_name]["model_cls"]
model_path = MODEL_REGISTRY[model_name]["path"]
config = config_class.from_pretrained(model_path)
# Load drug encoder (ChemBERTa)
drug_encoder_config = AutoConfig.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
drug_encoder_config.pooler = None
drug_encoder = RobertaModel(config=drug_encoder_config, add_pooling_layer=False)
# Load target encoder
target_encoder = AutoModel.from_pretrained("IlPakoZ/RNA-BERTa9700")
# Load scaler if exists
scaler_path = os.path.join(model_path, "scaler.config")
scaler = None
if os.path.exists(scaler_path):
scaler = StdScaler()
scaler.load(str(model_path))
self.model = model_class.from_pretrained(
model_path,
config=config,
target_encoder=target_encoder,
drug_encoder=drug_encoder,
scaler=scaler
)
self.model.to(self.device)
self.model.eval()
# Load tokenizers
self.target_tokenizer = AutoTokenizer.from_pretrained(
BASE_DIR / "target_tokenizer"
)
# Load drug tokenizer (ChemBERTa)
vocab_file = str(BASE_DIR / "drug_tokenizer" / "vocab.json")
self.drug_tokenizer = ChembertaTokenizer(vocab_file)
logger.info(f"Model '{model_name}' and tokenizers loaded successfully!")
self.current_model_name = model_name
return True
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
return False
def get_target_and_smiles(self, target_sequence, drug_smiles):
# Tokenize inputs
target_inputs = self.target_tokenizer(
target_sequence,
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt"
).to(self.device)
all_smiles = []
for smiles in drug_smiles:
drug_inputs = self.drug_tokenizer(
smiles.strip(),
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt"
).to(self.device)
all_smiles.append(drug_inputs)
return target_inputs, all_smiles
def predict_interaction(self, target_sequence, drug_smiles):
"""Predict drug-target interaction"""
if self.model is None:
return "Error: Model not loaded."
try:
target_inputs, all_drug_inputs = self.get_target_and_smiles(target_sequence, drug_smiles)
to_return = []
with self._lock:
self.model.INTERPR_DISABLE_MODE()
for smile_name, drug_inputs in zip(drug_smiles, all_drug_inputs):
with torch.no_grad():
prediction = self.model(target_inputs, drug_inputs)
if self.model.scaler is not None:
prediction = self.model.unscale(prediction)
prediction_value = prediction.cpu().numpy()[0][0]
to_return.append(f"{smile_name} predicted pKd: {prediction_value:.4f}")
return "\n".join(to_return)
except Exception as e:
logger.error(f"Prediction error: {str(e)}")
return f"Error during prediction: {str(e)}"
def visualize_interaction(self, target_sequence, drug_smiles):
"""
Generate visualization images for drug-target interaction
Args:
target_sequence (str): RNA sequence
drug_smiles (str): Drug SMILES notation
Returns:
tuple: (cross_attention_image, raw_contribution_image, normalized_contribution_image, status_message)
"""
if self.model is None:
return None, None, None, "Error: Model not loaded."
try:
target_inputs, all_drug_inputs = self.get_target_and_smiles(target_sequence, drug_smiles)
to_return = []
with self._lock:
self.model.INTERPR_ENABLE_MODE()
for smile_name, drug_inputs in zip(drug_smiles, all_drug_inputs):
with torch.no_grad():
prediction = self.model(target_inputs, drug_inputs)
if self.model.scaler is not None:
prediction = self.model.unscale(prediction)
prediction_value = prediction.cpu().numpy()[0][0]
cross_attention_weights = self.model.model.crossattention_weights
if isinstance(self.model, MODEL_REGISTRY["interpr"]["model_cls"]):
presum_values = self.model.model.presum_layer
w = self.model.model.w.squeeze(1)
b = self.model.model.b
scaler = self.model.model.scaler
to_return.append(f"{smile_name} predicted pKd: {prediction_value:.4f}")
status_msg = "\n".join(to_return)
# --- Cross-attention heatmap ---
cross_attention_img = None
try:
logger.info(f"Cross-attention weights type: {type(cross_attention_weights)}")
if cross_attention_weights is not None:
logger.info(f"Cross-attention weights shape: {cross_attention_weights.shape if hasattr(cross_attention_weights, 'shape') else 'No shape attr'}")
try:
cross_attn_matrix = cross_attention_weights[0, 0]
if cross_attn_matrix is not None:
logger.info(f"Extracted cross-attention matrix shape: {cross_attn_matrix.shape}")
cross_attention_img = plot_crossattention_weights(
target_inputs["attention_mask"][0],
drug_inputs["attention_mask"][0],
target_inputs,
drug_inputs,
cross_attn_matrix,
self.target_tokenizer,
self.drug_tokenizer
)
else:
logger.warning("Could not extract valid cross-attention matrix")
except (IndexError, TypeError, AttributeError) as e:
logger.warning(f"Error extracting cross-attention matrix: {str(e)}")
else:
logger.warning("Cross-attention weights are None")
except Exception as e:
logger.error(f"Cross-attention visualization error: {str(e)}")
cross_attention_img = None
# --- Presum visualizations (interpr model only) ---
if isinstance(self.model, MODEL_REGISTRY["interpr"]["model_cls"]):
try:
normalized_img = None
if presum_values is not None:
normalized_img = plot_presum(
target_inputs,
presum_values.detach(),
scaler,
w.detach(),
b.detach(),
self.target_tokenizer,
raw_affinities=False
)
else:
if prediction_value <= 0:
logger.info("Skipping normalized affinities visualization as pKd <= 0")
if presum_values is None:
logger.warning("Cannot generate raw visualization: presum values are None")
except Exception as e:
logger.error(f"Normalized contribution visualization error: {str(e)}")
normalized_img = None
try:
raw_img = None
if prediction_value > 0 and presum_values is not None:
raw_img = plot_presum(
target_inputs,
presum_values.detach(),
scaler,
w.detach(),
b.detach(),
self.target_tokenizer,
raw_affinities=True
)
else:
logger.warning("Presum values are None")
except Exception as e:
logger.error(f"Raw contribution visualization error: {str(e)}")
raw_img = None
self.model.INTERPR_DISABLE_MODE()
# --- Placeholder fallbacks ---
if cross_attention_img is None:
cross_attention_img = create_placeholder_image(text="Cross-Attention Heatmap\nFailed to generate")
if isinstance(self.model, MODEL_REGISTRY["interpr"]["model_cls"]):
if normalized_img is None:
normalized_img = create_placeholder_image(text="Normalized Contribution\nFailed to generate")
if raw_img is None and prediction_value > 0:
raw_img = create_placeholder_image(text="Raw Contribution\nFailed to generate")
elif raw_img is None:
raw_img = create_placeholder_image(text="Raw Contribution\nSkipped (pKd ≤ 0)")
if prediction_value <= 0:
status_msg += " (Raw contribution visualization skipped due to non-positive pKd)"
else:
normalized_img = create_placeholder_image(text="Normalized Contribution\nNot supported for this model")
raw_img = create_placeholder_image(text="Raw Contribution\nNot supported for this model")
status_msg += " (Contribution visualizations not supported for this model)"
if cross_attention_weights is None:
status_msg += " (Cross-attention visualization failed: weights not available)"
return cross_attention_img, raw_img, normalized_img, status_msg
except Exception as e:
logger.error(f"Visualization error: {str(e)}")
try:
self.model.INTERPR_DISABLE_MODE()
except:
pass
return None, None, None, f"Error during visualization: {str(e)}"
# ---------------------------------------------------------------------------
# Load BOTH models globally at startup — shared across all users.
# Each model gets its own lock (inside DrugTargetInteractionApp) so concurrent
# requests don't race on inference/mode-toggle state.
# ---------------------------------------------------------------------------
logger.info("Loading all models at startup...")
LOADED_MODELS: dict[str, DrugTargetInteractionApp] = {}
for _key in MODEL_REGISTRY:
_instance = DrugTargetInteractionApp()
if _instance.load_model(_key):
LOADED_MODELS[_key] = _instance
logger.info(f" ✅ '{_key}' ready.")
else:
logger.error(f" ❌ Failed to load '{_key}'.")
if not LOADED_MODELS:
raise RuntimeError("No models could be loaded. Check your model files.")
# Fall back to whatever loaded if the preferred default didn't make it
_EFFECTIVE_DEFAULT = _DEFAULT_MODEL_KEY if _DEFAULT_MODEL_KEY in LOADED_MODELS else next(iter(LOADED_MODELS))
# ---------------------------------------------------------------------------
# Helper: build the green "currently loaded" HTML label
# ---------------------------------------------------------------------------
def _make_label_html(model_key: str) -> str:
display_name = MODEL_REGISTRY[model_key]["name"]
return (
f"
"
f"✅ Currently loaded model: {display_name} "
f"({model_key})
"
)
def _make_error_label_html() -> str:
return (
""
"❌ No model loaded
"
)
# ---------------------------------------------------------------------------
# Preprocessing helper
# ---------------------------------------------------------------------------
def smiles_preprocessing(drug_smiles, remove_dupl):
drugs = drug_smiles.strip().split("\n")
if remove_dupl:
seen = set()
sorted_drugs = []
kept = 0
for x in drugs:
if x not in seen:
seen.add(x)
sorted_drugs.append(x)
kept += 1
logger.info(f"{kept - len(drugs)} duplicate smiles removed!")
drugs = sorted_drugs
return drugs[:2000]
# ---------------------------------------------------------------------------
# Wrapper functions — look up the globally loaded model from the user's state
# ---------------------------------------------------------------------------
def predict_wrapper(target_seq, drug_smiles, remove_dups, model_key):
if not target_seq.strip() or not drug_smiles.strip():
return "Please provide both target sequence and drug SMILES."
if model_key not in LOADED_MODELS:
return f"Error: model '{model_key}' is not available."
target_seq = target_seq.strip()
drug_smiles = smiles_preprocessing(drug_smiles, remove_dups)
return LOADED_MODELS[model_key].predict_interaction(target_seq, drug_smiles)
def visualize_wrapper(target_seq, drug_smiles, remove_dups, model_key):
if not target_seq.strip() or not drug_smiles.strip():
return None, None, None, "Please provide both target sequence and drug SMILES."
if model_key not in LOADED_MODELS:
return None, None, None, f"Error: model '{model_key}' is not available."
target_seq = target_seq.strip()
drug_smiles = smiles_preprocessing(drug_smiles, remove_dups)
return LOADED_MODELS[model_key].visualize_interaction(target_seq, drug_smiles)
def load_model_wrapper(model_name, current_model_key):
"""
No actual loading happens here — the model is already in memory.
We just update the per-user state and label.
"""
if model_name not in LOADED_MODELS:
return f"Model '{model_name}' failed to load at startup.", _make_error_label_html(), current_model_key
return "Model selected successfully!", _make_label_html(model_name), model_name
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
with gr.Blocks(title="Drug-Target Interaction Predictor", theme=gr.themes.Soft()) as demo:
gr.HTML("""
🧬 Drug-Target Interaction Predictor
Predict binding affinity between drugs and target RNA sequences using deep learning
""")
# Per-user state: just a string (model key), not a full model object.
# Initialized by demo.load() below.
user_model_key = gr.State(_EFFECTIVE_DEFAULT)
# Shared image states between Prediction and Visualizations tabs
viz_state1 = gr.State()
viz_state2 = gr.State()
viz_state3 = gr.State()
with gr.Tab("🔮 Prediction & Analysis"):
model_label = gr.HTML(value=_make_label_html(_EFFECTIVE_DEFAULT))
with gr.Row():
with gr.Column(scale=1):
target_input = gr.Textbox(
label="Target RNA Sequence",
placeholder="Enter RNA sequence (e.g., AUGCUAGCUAGUACGUA...)",
lines=5,
max_lines=5
)
drug_input = gr.Textbox(
label="Drug SMILES",
placeholder=(
"Enter SMILES notation for one or more drugs.\n"
"For multiple SMILES, enter each on a new line (max 2000):\n"
"CC(C)CC1=CC=C(C=C1)C(C)C(=O)O\n"
"C1CCCCC1O"
),
lines=5,
max_lines=5,
)
remove_dups_checkbox = gr.Checkbox(
label="Remove duplicate SMILES",
value=False
)
with gr.Row():
predict_btn = gr.Button("🚀 Predict Interaction", variant="primary", size="lg")
visualize_btn = gr.Button("📊 Generate Visualizations", variant="secondary", size="lg")
with gr.Column(scale=1):
prediction_output = gr.Textbox(
label="Prediction Result",
interactive=False,
lines=4
)
gr.HTML("📚 Example Inputs:
")
gr.Examples(
examples=[
[
"AUGCUAGCUAGUACGUAUAUCUGCACUGC",
"CC(C)CC1=CC=C(C=C1)C(C)C(=O)O"
],
[
"AUGCGAUCGACGUACGUUAGCCGUAGCGUAGCUAGUGUAGCUAGUAGCU",
"C1=CC=C(C=C1)NC(=O)C2=CC=CC=N2"
]
],
inputs=[target_input, drug_input, remove_dups_checkbox],
outputs=prediction_output,
fn=predict_wrapper,
cache_examples=False
)
predict_btn.click(
fn=predict_wrapper,
inputs=[target_input, drug_input, remove_dups_checkbox, user_model_key],
outputs=prediction_output
)
def visualize_and_update(target_seq, drug_smiles, remove_dups, model_key):
img1, img2, img3, status = visualize_wrapper(target_seq, drug_smiles, remove_dups, model_key)
combined_status = status + "\n\nVisualization analysis complete. Please navigate to the Visualizations tab to view the generated images."
if len(drug_smiles) > 1:
combined_status += "\nVisualizations are shown only for the last SMILES entry."
return img1, img2, img3, combined_status
visualize_btn.click(
fn=visualize_and_update,
inputs=[target_input, drug_input, remove_dups_checkbox, user_model_key],
outputs=[viz_state1, viz_state2, viz_state3, prediction_output],
api_name="visualize_and_update"
)
with gr.Tab("📊 Visualizations"):
gr.HTML("""
🔬 Interaction Analysis & Visualizations
Generated visualizations will appear here after clicking "Generate Visualizations" in the Prediction tab
""")
viz_image1 = gr.Image(
label="Cross-Attention Heatmap",
type="pil",
interactive=False,
container=True,
height=500,
value=create_placeholder_image(text="Cross-Attention Heatmap\n(Generate visualizations in the Prediction tab)")
)
viz_image2 = gr.Image(
label="Raw pKd Contribution Visualization",
type="pil",
interactive=False,
container=True,
height=500,
value=create_placeholder_image(text="Raw pKd Contribution\n(Generate visualizations in the Prediction tab)")
)
viz_image3 = gr.Image(
label="Normalized pKd Contribution Visualization",
type="pil",
interactive=False,
container=True,
height=500,
value=create_placeholder_image(text="Normalized pKd Contribution\n(Generate visualizations in the Prediction tab)")
)
viz_state1.change(fn=lambda x: x, inputs=viz_state1, outputs=viz_image1)
viz_state2.change(fn=lambda x: x, inputs=viz_state2, outputs=viz_image2)
viz_state3.change(fn=lambda x: x, inputs=viz_state3, outputs=viz_image3)
with gr.Tab("⚙️ Model Settings"):
gr.HTML("Model Configuration
")
model_path_input = gr.Dropdown(
label="Model Path",
choices=MODEL_CHOICES,
value=_EFFECTIVE_DEFAULT,
)
load_model_btn = gr.Button("📥 Select Model", variant="secondary")
model_status = gr.Textbox(
label="Status",
interactive=False,
value="Both models are pre-loaded and ready."
)
load_model_btn.click(
fn=load_model_wrapper,
inputs=[model_path_input, user_model_key],
outputs=[model_status, model_label, user_model_key]
)
with gr.Tab("📊 Dataset"):
gr.Markdown("""
## Training and Test Datasets
### Fine-tuning Dataset (Training)
The model was trained on a dataset comprising **1,439 RNA–drug interaction pairs**, including:
- **759 unique compounds** (SMILES representations)
- **294 unique RNA sequences**
- Dissociation constants (pKd values) for binding affinity prediction
**RNA Sequence Distribution by Type:**
| RNA Sequence Type | Number of Interactions |
|-------------------|------------------------|
| Aptamers | 520 |
| Ribosomal | 295 |
| Viral RNAs | 281 |
| miRNAs | 146 |
| Riboswitches | 100 |
| Repeats | 97 |
| **Total** | **1,439** |
### External Evaluation Dataset (Test)
Model validation was performed using external ROBIN classification datasets containing **5,534 RNA–drug pairs**:
- **2,991 positive interactions**
- **2,538 negative interactions**
**Test Dataset Composition:**
- **1,617 aptamer pairs** (5 unique RNA sequences)
- **1,828 viral RNA pairs** (6 unique RNA sequences)
- **1,459 riboswitch pairs** (5 unique RNA sequences)
- **630 miRNA pairs** (3 unique RNA sequences)
### Dataset Downloads
- [Training Dataset Download](https://huggingface.co/spaces/IlPakoZ/DLRNA-BERTa/resolve/main/datasets/training_data.csv?download=true)
- [Test Dataset Download](https://huggingface.co/spaces/IlPakoZ/DLRNA-BERTa/resolve/main/datasets/test_data.csv?download=true)
### Citation
Original datasets published by:
**Krishnan et al.** - Available on the RSAPred website in PDF format.
*Reference:*
```bibtex
@article{krishnan2024reliable,
title={Reliable method for predicting the binding affinity of RNA-small molecule interactions using machine learning},
author={Krishnan, Sowmya R and Roy, Arijit and Gromiha, M Michael},
journal={Briefings in Bioinformatics},
volume={25},
number={2},
pages={bbae002},
year={2024},
publisher={Oxford University Press}
}
```
""")
with gr.Tab("ℹ️ About"):
gr.Markdown("""
## About this application
This application implements DLRNA-BERTa, a Dual Language RoBERTa Transformer model for predicting drug-to-RNA target interactions. The architecture combines:
- **Target encoder**: RNA-BERTa for processing RNA sequences
- **Drug encoder**: ChemBERTa for SMILES representation
- **Cross-attention mechanism**: Captures interactions between drug and target
- **Regression head**: Predicts binding affinity (pKd) with either mean pooling or interpretable pooling
### Available models
- **Mean pooling (`meanpool`)**: loaded by default because it has the higher AUROC on most external ROBIN classes. Use it for regular prediction and cross-attention visualization.
- **Interpretable pooling (`interpr`)**: use it when you need both cross-attention and token-level prediction contribution plots.
Both models are loaded into memory at startup. Switching between them is instant.
### Input requirements
- **Target sequence**: RNA sequence (A, U, G, C)
- **Drug SMILES**: One or more SMILES strings
- For batch mode, enter each SMILES on a new line (up to 2000 entries)
- A checkbox option allows automatic removal of duplicate SMILES before prediction
### Model features
- Cross-attention for drug-target interaction modeling
- Regularization via dropout
- Layer normalization for stable training
- Dedicated interpretability mode for visualization
- Batch prediction with optional de-duplication
### Usage tips
1. Select a model in the Model Settings tab if you want to change from the default
2. Enter an RNA sequence and one or more SMILES strings
3. Use the **"Remove duplicate SMILES"** checkbox if you want duplicates filtered automatically
4. Click *Predict Interaction* for affinity scores
5. Click *Generate Visualizations* for interpretability plots
6. Visualizations are produced only for the final SMILES entry in batch mode
For best results, ensure your input sequences are properly formatted and within reasonable length limits (max 512 tokens).
### Visualization features:
- **Cross-attention heatmap**: Shows cross-attention weights between drug and target tokens
- **Unnormalized pKd contribution**: Shows unnormalized signed contributions from each target token with the `interpr` model only
- **Normalized pKd contribution**: Shows normalized non-negative contributions from each target token with the `interpr` model only
### Performance metrics:
- Training on diverse drug-target interaction datasets
- Evaluated using RMSE, Pearson correlation, and Concordance Index
- Optimized for both predictive accuracy and interpretability
### GitHub repository:
- The full model GitHub repository can be found here: https://github.com/IlPakoZ/dlrnaberta-dti-prediction
### Contribution:
- Special thanks to Umut Onur Özcan for help in developing this space:)
### Contact:
- Ziaurrehman Tanoli (ziaurrehman.tanoli@helsinki.fi)
Principal investigator at Institute for Molecular Medicine Finland
HiLIFE, University of Helsinki, Finland.
""")
# Sync each user's label to their own state on every page load / refresh.
# user_model_key is already seeded to _EFFECTIVE_DEFAULT via gr.State(...),
# so this just makes the visible label consistent with it.
demo.load(
fn=lambda key: _make_label_html(key),
inputs=[user_model_key],
outputs=[model_label]
)
# ---------------------------------------------------------------------------
# Launch
# ---------------------------------------------------------------------------
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
)