import gradio as gr import torch import numpy as np import threading from transformers import AutoModel, AutoTokenizer, AutoConfig, RobertaModel from modeling_dlmberta import InteractionModelATTNForRegression, InteractionModelATTNSimplePoolingForRegression, InteractionModelATTNSimplePoolingConfig, StdScaler from configuration_dlmberta import InteractionModelATTNConfig from chemberta import ChembertaTokenizer import json import os from pathlib import Path import logging # Import visualization functions from analysis import plot_crossattention_weights, plot_presum from PIL import Image, ImageDraw, ImageFont # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) BASE_DIR = Path(__file__).resolve().parent _DEFAULT_MODEL_KEY = "interpr" MODEL_REGISTRY = { "meanpool": { "name": "Mean pooling", "path": BASE_DIR / "meanpool", "model_cls": InteractionModelATTNSimplePoolingForRegression, "config_cls": InteractionModelATTNSimplePoolingConfig, "supports_presum": False, }, "interpr": { "name": "Interpretable pooling", "path": BASE_DIR / "interpr", "model_cls": InteractionModelATTNForRegression, "config_cls": InteractionModelATTNConfig, "supports_presum": True, }, } MODEL_CHOICES = [ ("Mean pooling (higher AUROC, default)", "meanpool"), ("Interpretable pooling (main manuscript model)", "interpr"), ] def create_placeholder_image(width=600, height=400, text="No visualization available", bg_color=(0, 0, 0, 0)): """ Create a transparent placeholder image with text Args: width (int): Image width height (int): Image height text (str): Text to display bg_color (tuple): Background color (R, G, B, A) - (0,0,0,0) for transparent Returns: PIL.Image: Transparent placeholder image """ img = Image.new('RGBA', (width, height), bg_color) draw = ImageDraw.Draw(img) try: font = ImageFont.truetype("arial.ttf", 16) except: try: font = ImageFont.load_default() except: font = None if font: bbox = draw.textbbox((0, 0), text, font=font) text_width = bbox[2] - bbox[0] text_height = bbox[3] - bbox[1] else: text_width = len(text) * 8 text_height = 16 x = (width - text_width) // 2 y = (height - text_height) // 2 draw.text((x, y), text, fill=(128, 128, 128, 255), font=font) return img class DrugTargetInteractionApp: def __init__(self): self.model = None self.target_tokenizer = None self.drug_tokenizer = None self.scaler = None self.current_model_name = None self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Per-instance lock so concurrent users sharing this model don't race on # INTERPR_ENABLE_MODE / INTERPR_DISABLE_MODE or inference state. self._lock = threading.Lock() def load_model(self, model_name=_DEFAULT_MODEL_KEY): """Load the pre-trained model and tokenizers""" try: if model_name == "meanpool": logger.info("Loading model from ./meanpool") elif model_name == "interpr": logger.info("Loading model from ./interpr") config_class = MODEL_REGISTRY[model_name]["config_cls"] model_class = MODEL_REGISTRY[model_name]["model_cls"] model_path = MODEL_REGISTRY[model_name]["path"] config = config_class.from_pretrained(model_path) # Load drug encoder (ChemBERTa) drug_encoder_config = AutoConfig.from_pretrained("DeepChem/ChemBERTa-77M-MTR") drug_encoder_config.pooler = None drug_encoder = RobertaModel(config=drug_encoder_config, add_pooling_layer=False) # Load target encoder target_encoder = AutoModel.from_pretrained("IlPakoZ/RNA-BERTa9700") # Load scaler if exists scaler_path = os.path.join(model_path, "scaler.config") scaler = None if os.path.exists(scaler_path): scaler = StdScaler() scaler.load(str(model_path)) self.model = model_class.from_pretrained( model_path, config=config, target_encoder=target_encoder, drug_encoder=drug_encoder, scaler=scaler ) self.model.to(self.device) self.model.eval() # Load tokenizers self.target_tokenizer = AutoTokenizer.from_pretrained( BASE_DIR / "target_tokenizer" ) # Load drug tokenizer (ChemBERTa) vocab_file = str(BASE_DIR / "drug_tokenizer" / "vocab.json") self.drug_tokenizer = ChembertaTokenizer(vocab_file) logger.info(f"Model '{model_name}' and tokenizers loaded successfully!") self.current_model_name = model_name return True except Exception as e: logger.error(f"Error loading model: {str(e)}") return False def get_target_and_smiles(self, target_sequence, drug_smiles): # Tokenize inputs target_inputs = self.target_tokenizer( target_sequence, padding="max_length", truncation=True, max_length=512, return_tensors="pt" ).to(self.device) all_smiles = [] for smiles in drug_smiles: drug_inputs = self.drug_tokenizer( smiles.strip(), padding="max_length", truncation=True, max_length=512, return_tensors="pt" ).to(self.device) all_smiles.append(drug_inputs) return target_inputs, all_smiles def predict_interaction(self, target_sequence, drug_smiles): """Predict drug-target interaction""" if self.model is None: return "Error: Model not loaded." try: target_inputs, all_drug_inputs = self.get_target_and_smiles(target_sequence, drug_smiles) to_return = [] with self._lock: self.model.INTERPR_DISABLE_MODE() for smile_name, drug_inputs in zip(drug_smiles, all_drug_inputs): with torch.no_grad(): prediction = self.model(target_inputs, drug_inputs) if self.model.scaler is not None: prediction = self.model.unscale(prediction) prediction_value = prediction.cpu().numpy()[0][0] to_return.append(f"{smile_name} predicted pKd: {prediction_value:.4f}") return "\n".join(to_return) except Exception as e: logger.error(f"Prediction error: {str(e)}") return f"Error during prediction: {str(e)}" def visualize_interaction(self, target_sequence, drug_smiles): """ Generate visualization images for drug-target interaction Args: target_sequence (str): RNA sequence drug_smiles (str): Drug SMILES notation Returns: tuple: (cross_attention_image, raw_contribution_image, normalized_contribution_image, status_message) """ if self.model is None: return None, None, None, "Error: Model not loaded." try: target_inputs, all_drug_inputs = self.get_target_and_smiles(target_sequence, drug_smiles) to_return = [] with self._lock: self.model.INTERPR_ENABLE_MODE() for smile_name, drug_inputs in zip(drug_smiles, all_drug_inputs): with torch.no_grad(): prediction = self.model(target_inputs, drug_inputs) if self.model.scaler is not None: prediction = self.model.unscale(prediction) prediction_value = prediction.cpu().numpy()[0][0] cross_attention_weights = self.model.model.crossattention_weights if isinstance(self.model, MODEL_REGISTRY["interpr"]["model_cls"]): presum_values = self.model.model.presum_layer w = self.model.model.w.squeeze(1) b = self.model.model.b scaler = self.model.model.scaler to_return.append(f"{smile_name} predicted pKd: {prediction_value:.4f}") status_msg = "\n".join(to_return) # --- Cross-attention heatmap --- cross_attention_img = None try: logger.info(f"Cross-attention weights type: {type(cross_attention_weights)}") if cross_attention_weights is not None: logger.info(f"Cross-attention weights shape: {cross_attention_weights.shape if hasattr(cross_attention_weights, 'shape') else 'No shape attr'}") try: cross_attn_matrix = cross_attention_weights[0, 0] if cross_attn_matrix is not None: logger.info(f"Extracted cross-attention matrix shape: {cross_attn_matrix.shape}") cross_attention_img = plot_crossattention_weights( target_inputs["attention_mask"][0], drug_inputs["attention_mask"][0], target_inputs, drug_inputs, cross_attn_matrix, self.target_tokenizer, self.drug_tokenizer ) else: logger.warning("Could not extract valid cross-attention matrix") except (IndexError, TypeError, AttributeError) as e: logger.warning(f"Error extracting cross-attention matrix: {str(e)}") else: logger.warning("Cross-attention weights are None") except Exception as e: logger.error(f"Cross-attention visualization error: {str(e)}") cross_attention_img = None # --- Presum visualizations (interpr model only) --- if isinstance(self.model, MODEL_REGISTRY["interpr"]["model_cls"]): try: normalized_img = None if presum_values is not None: normalized_img = plot_presum( target_inputs, presum_values.detach(), scaler, w.detach(), b.detach(), self.target_tokenizer, raw_affinities=False ) else: if prediction_value <= 0: logger.info("Skipping normalized affinities visualization as pKd <= 0") if presum_values is None: logger.warning("Cannot generate raw visualization: presum values are None") except Exception as e: logger.error(f"Normalized contribution visualization error: {str(e)}") normalized_img = None try: raw_img = None if prediction_value > 0 and presum_values is not None: raw_img = plot_presum( target_inputs, presum_values.detach(), scaler, w.detach(), b.detach(), self.target_tokenizer, raw_affinities=True ) else: logger.warning("Presum values are None") except Exception as e: logger.error(f"Raw contribution visualization error: {str(e)}") raw_img = None self.model.INTERPR_DISABLE_MODE() # --- Placeholder fallbacks --- if cross_attention_img is None: cross_attention_img = create_placeholder_image(text="Cross-Attention Heatmap\nFailed to generate") if isinstance(self.model, MODEL_REGISTRY["interpr"]["model_cls"]): if normalized_img is None: normalized_img = create_placeholder_image(text="Normalized Contribution\nFailed to generate") if raw_img is None and prediction_value > 0: raw_img = create_placeholder_image(text="Raw Contribution\nFailed to generate") elif raw_img is None: raw_img = create_placeholder_image(text="Raw Contribution\nSkipped (pKd ≤ 0)") if prediction_value <= 0: status_msg += " (Raw contribution visualization skipped due to non-positive pKd)" else: normalized_img = create_placeholder_image(text="Normalized Contribution\nNot supported for this model") raw_img = create_placeholder_image(text="Raw Contribution\nNot supported for this model") status_msg += " (Contribution visualizations not supported for this model)" if cross_attention_weights is None: status_msg += " (Cross-attention visualization failed: weights not available)" return cross_attention_img, raw_img, normalized_img, status_msg except Exception as e: logger.error(f"Visualization error: {str(e)}") try: self.model.INTERPR_DISABLE_MODE() except: pass return None, None, None, f"Error during visualization: {str(e)}" # --------------------------------------------------------------------------- # Load BOTH models globally at startup — shared across all users. # Each model gets its own lock (inside DrugTargetInteractionApp) so concurrent # requests don't race on inference/mode-toggle state. # --------------------------------------------------------------------------- logger.info("Loading all models at startup...") LOADED_MODELS: dict[str, DrugTargetInteractionApp] = {} for _key in MODEL_REGISTRY: _instance = DrugTargetInteractionApp() if _instance.load_model(_key): LOADED_MODELS[_key] = _instance logger.info(f" ✅ '{_key}' ready.") else: logger.error(f" ❌ Failed to load '{_key}'.") if not LOADED_MODELS: raise RuntimeError("No models could be loaded. Check your model files.") # Fall back to whatever loaded if the preferred default didn't make it _EFFECTIVE_DEFAULT = _DEFAULT_MODEL_KEY if _DEFAULT_MODEL_KEY in LOADED_MODELS else next(iter(LOADED_MODELS)) # --------------------------------------------------------------------------- # Helper: build the green "currently loaded" HTML label # --------------------------------------------------------------------------- def _make_label_html(model_key: str) -> str: display_name = MODEL_REGISTRY[model_key]["name"] return ( f"
" f"✅ Currently loaded model: {display_name} " f"({model_key})
" ) def _make_error_label_html() -> str: return ( "
" "❌ No model loaded
" ) # --------------------------------------------------------------------------- # Preprocessing helper # --------------------------------------------------------------------------- def smiles_preprocessing(drug_smiles, remove_dupl): drugs = drug_smiles.strip().split("\n") if remove_dupl: seen = set() sorted_drugs = [] kept = 0 for x in drugs: if x not in seen: seen.add(x) sorted_drugs.append(x) kept += 1 logger.info(f"{kept - len(drugs)} duplicate smiles removed!") drugs = sorted_drugs return drugs[:2000] # --------------------------------------------------------------------------- # Wrapper functions — look up the globally loaded model from the user's state # --------------------------------------------------------------------------- def predict_wrapper(target_seq, drug_smiles, remove_dups, model_key): if not target_seq.strip() or not drug_smiles.strip(): return "Please provide both target sequence and drug SMILES." if model_key not in LOADED_MODELS: return f"Error: model '{model_key}' is not available." target_seq = target_seq.strip() drug_smiles = smiles_preprocessing(drug_smiles, remove_dups) return LOADED_MODELS[model_key].predict_interaction(target_seq, drug_smiles) def visualize_wrapper(target_seq, drug_smiles, remove_dups, model_key): if not target_seq.strip() or not drug_smiles.strip(): return None, None, None, "Please provide both target sequence and drug SMILES." if model_key not in LOADED_MODELS: return None, None, None, f"Error: model '{model_key}' is not available." target_seq = target_seq.strip() drug_smiles = smiles_preprocessing(drug_smiles, remove_dups) return LOADED_MODELS[model_key].visualize_interaction(target_seq, drug_smiles) def load_model_wrapper(model_name, current_model_key): """ No actual loading happens here — the model is already in memory. We just update the per-user state and label. """ if model_name not in LOADED_MODELS: return f"Model '{model_name}' failed to load at startup.", _make_error_label_html(), current_model_key return "Model selected successfully!", _make_label_html(model_name), model_name # --------------------------------------------------------------------------- # Gradio UI # --------------------------------------------------------------------------- with gr.Blocks(title="Drug-Target Interaction Predictor", theme=gr.themes.Soft()) as demo: gr.HTML("""

🧬 Drug-Target Interaction Predictor

Predict binding affinity between drugs and target RNA sequences using deep learning

""") # Per-user state: just a string (model key), not a full model object. # Initialized by demo.load() below. user_model_key = gr.State(_EFFECTIVE_DEFAULT) # Shared image states between Prediction and Visualizations tabs viz_state1 = gr.State() viz_state2 = gr.State() viz_state3 = gr.State() with gr.Tab("🔮 Prediction & Analysis"): model_label = gr.HTML(value=_make_label_html(_EFFECTIVE_DEFAULT)) with gr.Row(): with gr.Column(scale=1): target_input = gr.Textbox( label="Target RNA Sequence", placeholder="Enter RNA sequence (e.g., AUGCUAGCUAGUACGUA...)", lines=5, max_lines=5 ) drug_input = gr.Textbox( label="Drug SMILES", placeholder=( "Enter SMILES notation for one or more drugs.\n" "For multiple SMILES, enter each on a new line (max 2000):\n" "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O\n" "C1CCCCC1O" ), lines=5, max_lines=5, ) remove_dups_checkbox = gr.Checkbox( label="Remove duplicate SMILES", value=False ) with gr.Row(): predict_btn = gr.Button("🚀 Predict Interaction", variant="primary", size="lg") visualize_btn = gr.Button("📊 Generate Visualizations", variant="secondary", size="lg") with gr.Column(scale=1): prediction_output = gr.Textbox( label="Prediction Result", interactive=False, lines=4 ) gr.HTML("

📚 Example Inputs:

") gr.Examples( examples=[ [ "AUGCUAGCUAGUACGUAUAUCUGCACUGC", "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O" ], [ "AUGCGAUCGACGUACGUUAGCCGUAGCGUAGCUAGUGUAGCUAGUAGCU", "C1=CC=C(C=C1)NC(=O)C2=CC=CC=N2" ] ], inputs=[target_input, drug_input, remove_dups_checkbox], outputs=prediction_output, fn=predict_wrapper, cache_examples=False ) predict_btn.click( fn=predict_wrapper, inputs=[target_input, drug_input, remove_dups_checkbox, user_model_key], outputs=prediction_output ) def visualize_and_update(target_seq, drug_smiles, remove_dups, model_key): img1, img2, img3, status = visualize_wrapper(target_seq, drug_smiles, remove_dups, model_key) combined_status = status + "\n\nVisualization analysis complete. Please navigate to the Visualizations tab to view the generated images." if len(drug_smiles) > 1: combined_status += "\nVisualizations are shown only for the last SMILES entry." return img1, img2, img3, combined_status visualize_btn.click( fn=visualize_and_update, inputs=[target_input, drug_input, remove_dups_checkbox, user_model_key], outputs=[viz_state1, viz_state2, viz_state3, prediction_output], api_name="visualize_and_update" ) with gr.Tab("📊 Visualizations"): gr.HTML("""

🔬 Interaction Analysis & Visualizations

Generated visualizations will appear here after clicking "Generate Visualizations" in the Prediction tab

""") viz_image1 = gr.Image( label="Cross-Attention Heatmap", type="pil", interactive=False, container=True, height=500, value=create_placeholder_image(text="Cross-Attention Heatmap\n(Generate visualizations in the Prediction tab)") ) viz_image2 = gr.Image( label="Raw pKd Contribution Visualization", type="pil", interactive=False, container=True, height=500, value=create_placeholder_image(text="Raw pKd Contribution\n(Generate visualizations in the Prediction tab)") ) viz_image3 = gr.Image( label="Normalized pKd Contribution Visualization", type="pil", interactive=False, container=True, height=500, value=create_placeholder_image(text="Normalized pKd Contribution\n(Generate visualizations in the Prediction tab)") ) viz_state1.change(fn=lambda x: x, inputs=viz_state1, outputs=viz_image1) viz_state2.change(fn=lambda x: x, inputs=viz_state2, outputs=viz_image2) viz_state3.change(fn=lambda x: x, inputs=viz_state3, outputs=viz_image3) with gr.Tab("⚙️ Model Settings"): gr.HTML("

Model Configuration

") model_path_input = gr.Dropdown( label="Model Path", choices=MODEL_CHOICES, value=_EFFECTIVE_DEFAULT, ) load_model_btn = gr.Button("📥 Select Model", variant="secondary") model_status = gr.Textbox( label="Status", interactive=False, value="Both models are pre-loaded and ready." ) load_model_btn.click( fn=load_model_wrapper, inputs=[model_path_input, user_model_key], outputs=[model_status, model_label, user_model_key] ) with gr.Tab("📊 Dataset"): gr.Markdown(""" ## Training and Test Datasets ### Fine-tuning Dataset (Training) The model was trained on a dataset comprising **1,439 RNA–drug interaction pairs**, including: - **759 unique compounds** (SMILES representations) - **294 unique RNA sequences** - Dissociation constants (pKd values) for binding affinity prediction **RNA Sequence Distribution by Type:** | RNA Sequence Type | Number of Interactions | |-------------------|------------------------| | Aptamers | 520 | | Ribosomal | 295 | | Viral RNAs | 281 | | miRNAs | 146 | | Riboswitches | 100 | | Repeats | 97 | | **Total** | **1,439** | ### External Evaluation Dataset (Test) Model validation was performed using external ROBIN classification datasets containing **5,534 RNA–drug pairs**: - **2,991 positive interactions** - **2,538 negative interactions** **Test Dataset Composition:** - **1,617 aptamer pairs** (5 unique RNA sequences) - **1,828 viral RNA pairs** (6 unique RNA sequences) - **1,459 riboswitch pairs** (5 unique RNA sequences) - **630 miRNA pairs** (3 unique RNA sequences) ### Dataset Downloads - [Training Dataset Download](https://huggingface.co/spaces/IlPakoZ/DLRNA-BERTa/resolve/main/datasets/training_data.csv?download=true) - [Test Dataset Download](https://huggingface.co/spaces/IlPakoZ/DLRNA-BERTa/resolve/main/datasets/test_data.csv?download=true) ### Citation Original datasets published by: **Krishnan et al.** - Available on the RSAPred website in PDF format. *Reference:* ```bibtex @article{krishnan2024reliable, title={Reliable method for predicting the binding affinity of RNA-small molecule interactions using machine learning}, author={Krishnan, Sowmya R and Roy, Arijit and Gromiha, M Michael}, journal={Briefings in Bioinformatics}, volume={25}, number={2}, pages={bbae002}, year={2024}, publisher={Oxford University Press} } ``` """) with gr.Tab("ℹ️ About"): gr.Markdown(""" ## About this application This application implements DLRNA-BERTa, a Dual Language RoBERTa Transformer model for predicting drug-to-RNA target interactions. The architecture combines: - **Target encoder**: RNA-BERTa for processing RNA sequences - **Drug encoder**: ChemBERTa for SMILES representation - **Cross-attention mechanism**: Captures interactions between drug and target - **Regression head**: Predicts binding affinity (pKd) with either mean pooling or interpretable pooling ### Available models - **Mean pooling (`meanpool`)**: loaded by default because it has the higher AUROC on most external ROBIN classes. Use it for regular prediction and cross-attention visualization. - **Interpretable pooling (`interpr`)**: use it when you need both cross-attention and token-level prediction contribution plots. Both models are loaded into memory at startup. Switching between them is instant. ### Input requirements - **Target sequence**: RNA sequence (A, U, G, C) - **Drug SMILES**: One or more SMILES strings - For batch mode, enter each SMILES on a new line (up to 2000 entries) - A checkbox option allows automatic removal of duplicate SMILES before prediction ### Model features - Cross-attention for drug-target interaction modeling - Regularization via dropout - Layer normalization for stable training - Dedicated interpretability mode for visualization - Batch prediction with optional de-duplication ### Usage tips 1. Select a model in the Model Settings tab if you want to change from the default 2. Enter an RNA sequence and one or more SMILES strings 3. Use the **"Remove duplicate SMILES"** checkbox if you want duplicates filtered automatically 4. Click *Predict Interaction* for affinity scores 5. Click *Generate Visualizations* for interpretability plots 6. Visualizations are produced only for the final SMILES entry in batch mode For best results, ensure your input sequences are properly formatted and within reasonable length limits (max 512 tokens). ### Visualization features: - **Cross-attention heatmap**: Shows cross-attention weights between drug and target tokens - **Unnormalized pKd contribution**: Shows unnormalized signed contributions from each target token with the `interpr` model only - **Normalized pKd contribution**: Shows normalized non-negative contributions from each target token with the `interpr` model only ### Performance metrics: - Training on diverse drug-target interaction datasets - Evaluated using RMSE, Pearson correlation, and Concordance Index - Optimized for both predictive accuracy and interpretability ### GitHub repository: - The full model GitHub repository can be found here: https://github.com/IlPakoZ/dlrnaberta-dti-prediction ### Contribution: - Special thanks to Umut Onur Özcan for help in developing this space:) ### Contact: - Ziaurrehman Tanoli (ziaurrehman.tanoli@helsinki.fi) Principal investigator at Institute for Molecular Medicine Finland HiLIFE, University of Helsinki, Finland. """) # Sync each user's label to their own state on every page load / refresh. # user_model_key is already seeded to _EFFECTIVE_DEFAULT via gr.State(...), # so this just makes the visible label consistent with it. demo.load( fn=lambda key: _make_label_html(key), inputs=[user_model_key], outputs=[model_label] ) # --------------------------------------------------------------------------- # Launch # --------------------------------------------------------------------------- if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True )