| import gradio as gr |
| import joblib |
| import json |
| import numpy as np |
| import re |
| from urllib.parse import urlparse |
| import os |
| from huggingface_hub import hf_hub_download |
|
|
| |
| MODEL_NAME = "XGBoost" |
| HF_USERNAME = "Devishetty100" |
| CUSTOM_MODEL_NAME = "NeoGuardianAI" |
| REPO_ID = f"{HF_USERNAME}/{CUSTOM_MODEL_NAME.lower()}" |
|
|
| |
| TRUSTED_DOMAINS = [ |
| 'huggingface.co', |
| 'github.com', |
| 'google.com', |
| 'microsoft.com', |
| 'apple.com', |
| 'amazon.com', |
| 'facebook.com', |
| 'twitter.com', |
| 'linkedin.com', |
| 'youtube.com', |
| 'wikipedia.org' |
| ] |
|
|
| |
| def load_model_files(): |
| try: |
| print(f"Attempting to download model from Hugging Face Hub: {REPO_ID}") |
| model_path = hf_hub_download(repo_id=REPO_ID, filename=f"{MODEL_NAME.lower()}_model.joblib") |
| scaler_path = hf_hub_download(repo_id=REPO_ID, filename="scaler.joblib") |
| feature_names_path = hf_hub_download(repo_id=REPO_ID, filename="feature_names.json") |
|
|
| |
| model = joblib.load(model_path) |
| scaler = joblib.load(scaler_path) |
|
|
| |
| with open(feature_names_path, 'r') as f: |
| feature_names = json.load(f) |
|
|
| print("Successfully downloaded model from Hugging Face Hub.") |
| return model, scaler, feature_names |
| except Exception as hub_error: |
| print(f"Error downloading from Hugging Face Hub: {hub_error}") |
|
|
| |
| try: |
| print("Attempting to load model from local files...") |
| model = joblib.load(f"{MODEL_NAME.lower()}_model.joblib") |
| scaler = joblib.load("scaler.joblib") |
|
|
| with open("feature_names.json", 'r') as f: |
| feature_names = json.load(f) |
|
|
| print("Successfully loaded model from local files.") |
| return model, scaler, feature_names |
| except Exception as local_error: |
| print(f"Error loading from local files: {local_error}") |
| raise RuntimeError("Failed to load model from both Hugging Face Hub and local files.") |
|
|
| |
| def extract_features(url): |
| """Extract features from a URL for model prediction.""" |
| features = {} |
|
|
| |
| features['length_url'] = len(url) |
|
|
| |
| parsed_url = urlparse(url) |
| hostname = parsed_url.netloc |
| path = parsed_url.path |
|
|
| |
| features['length_hostname'] = len(hostname) |
| features['ip'] = 1 if re.match(r'\d+\.\d+\.\d+\.\d+', hostname) else 0 |
|
|
| |
| features['nb_dots'] = url.count('.') |
| features['nb_hyphens'] = url.count('-') |
| features['nb_at'] = url.count('@') |
| features['nb_qm'] = url.count('?') |
| features['nb_and'] = url.count('&') |
| features['nb_or'] = url.count('|') |
| features['nb_eq'] = url.count('=') |
| features['nb_underscore'] = url.count('_') |
| features['nb_tilde'] = url.count('~') |
| features['nb_percent'] = url.count('%') |
| features['nb_slash'] = url.count('/') |
| features['nb_star'] = url.count('*') |
| features['nb_colon'] = url.count(':') |
| features['nb_comma'] = url.count(',') |
| features['nb_semicolumn'] = url.count(';') |
| features['nb_dollar'] = url.count('$') |
| features['nb_space'] = url.count(' ') |
|
|
| |
| features['nb_www'] = 1 if 'www' in hostname else 0 |
| features['nb_com'] = 1 if '.com' in hostname else 0 |
| features['nb_dslash'] = url.count('//') |
| features['http_in_path'] = 1 if 'http' in path else 0 |
| features['https_token'] = 1 if 'https' in url and 'http://' not in url else 0 |
|
|
| |
| digits_count = sum(c.isdigit() for c in url) |
| features['ratio_digits_url'] = digits_count / len(url) if len(url) > 0 else 0 |
| features['ratio_digits_host'] = sum(c.isdigit() for c in hostname) / len(hostname) if len(hostname) > 0 else 0 |
|
|
| |
| features['punycode'] = 1 if 'xn--' in hostname else 0 |
|
|
| |
| features['port'] = 1 if ':' in hostname and any(c.isdigit() for c in hostname.split(':')[1]) else 0 |
|
|
| |
| tlds = ['.com', '.org', '.net', '.edu', '.gov', '.mil', '.int'] |
| features['tld_in_path'] = 1 if any(tld in path for tld in tlds) else 0 |
| features['tld_in_subdomain'] = 1 if hostname.count('.') > 1 and any(tld in hostname.split('.')[0] for tld in tlds) else 0 |
|
|
| |
| features['abnormal_subdomain'] = 1 if hostname.count('.') > 2 else 0 |
| features['nb_subdomains'] = hostname.count('.') |
|
|
| |
| features['prefix_suffix'] = 1 if '-' in hostname else 0 |
| features['random_domain'] = 1 if len(hostname) > 12 and sum(c.isdigit() for c in hostname) > 4 else 0 |
|
|
| |
| shortening_services = ['bit.ly', 'goo.gl', 'tinyurl.com', 't.co', 'tr.im', 'is.gd', 'cli.gs', 'ow.ly', 'yfrog.com', 'migre.me', 'ff.im', 'tiny.cc', 'url4.eu', 'twit.ac', 'su.pr', 'twurl.nl', 'snipurl.com', 'short.to', 'budurl.com', 'ping.fm', 'post.ly', 'just.as', 'bkite.com', 'snipr.com', 'fic.kr', 'loopt.us', 'doiop.com', 'twitthis.com', 'htxt.it', 'ak.im', 'shar.es', 'kl.am', 'wp.me', 'rubyurl.com', 'om.ly', 'to.ly', 'bit.do', 't.co', 'lnkd.in', 'db.tt', 'qr.ae', 'adf.ly', 'goo.gl', 'bitly.com', 'cur.lv', 'tinyurl.com', 'ow.ly', 'bit.ly', 'ity.im', 'q.gs', 'is.gd', 'po.st', 'bc.vc', 'twitthis.com', 'u.to', 'j.mp', 'buzurl.com', 'cutt.us', 'u.bb', 'yourls.org', 'x.co', 'prettylinkpro.com', 'scrnch.me', 'filoops.info', 'vzturl.com', 'qr.net', '1url.com', 'tweez.me', 'v.gd', 'tr.im', 'link.zip.net'] |
| features['shortening_service'] = 1 if any(service in hostname for service in shortening_services) else 0 |
|
|
| |
| features['path_extension'] = 1 if '.' in path.split('/')[-1] else 0 |
|
|
| |
| |
| for feature in ['nb_redirection', 'nb_external_redirection', 'length_words_raw', |
| 'char_repeat', 'shortest_words_raw', 'shortest_word_host', |
| 'shortest_word_path', 'longest_words_raw', 'longest_word_host', |
| 'longest_word_path', 'avg_words_raw', 'avg_word_host', |
| 'avg_word_path', 'phish_hints', 'domain_in_brand', |
| 'brand_in_subdomain', 'brand_in_path', 'suspecious_tld', |
| 'statistical_report', 'nb_hyperlinks', 'ratio_intHyperlinks', |
| 'ratio_extHyperlinks', 'ratio_nullHyperlinks', 'nb_extCSS', |
| 'ratio_intRedirection', 'ratio_extRedirection', 'ratio_intErrors', |
| 'ratio_extErrors', 'login_form', 'external_favicon', |
| 'links_in_tags', 'submit_email', 'ratio_intMedia', |
| 'ratio_extMedia', 'sfh', 'iframe', 'popup_window', |
| 'safe_anchor', 'onmouseover', 'right_clic', 'empty_title', |
| 'domain_in_title', 'domain_with_copyright', 'whois_registered_domain', |
| 'domain_registration_length', 'domain_age', 'web_traffic', |
| 'dns_record', 'google_index', 'page_rank']: |
| if feature not in features: |
| features[feature] = 0 |
|
|
| return features |
|
|
| |
| try: |
| model, scaler, feature_names = load_model_files() |
| print("Model loaded successfully!") |
| except Exception as e: |
| print(f"Error loading model: {e}") |
| |
| print("Using dummy model for demonstration purposes.") |
| import numpy as np |
| from sklearn.ensemble import RandomForestClassifier |
|
|
| |
| model = RandomForestClassifier(n_estimators=10) |
| model.fit(np.array([[0, 0]]), np.array([0])) |
| model.predict_proba = lambda x: np.array([[0.5, 0.5]]) |
|
|
| |
| scaler = lambda x: x |
| scaler.transform = lambda x: x |
| feature_names = ['length_url', 'length_hostname'] |
|
|
| def predict_url(url): |
| """Predict if a URL is phishing or legitimate.""" |
| if not url or not url.strip(): |
| return "Please enter a URL", 0.0, "N/A" |
|
|
| try: |
| |
| parsed_url = urlparse(url) |
| domain = parsed_url.netloc |
|
|
| |
| if domain.startswith('www.'): |
| domain = domain[4:] |
|
|
| |
| is_trusted = False |
| domain_parts = domain.split('.') |
| for i in range(len(domain_parts) - 1): |
| check_domain = '.'.join(domain_parts[i:]) |
| if check_domain in TRUSTED_DOMAINS: |
| is_trusted = True |
| break |
|
|
| if is_trusted: |
| return "Legitimate (Trusted Domain)", 1.0, "✅ SAFE" |
|
|
| |
| url_features = extract_features(url) |
|
|
| |
| features_array = [] |
| for feature in feature_names: |
| if feature in url_features: |
| features_array.append(url_features[feature]) |
| else: |
| features_array.append(0) |
|
|
| |
| scaled_features = scaler.transform([features_array]) |
|
|
| |
| prediction = model.predict(scaled_features)[0] |
| probability = model.predict_proba(scaled_features)[0][1] |
|
|
| |
| prediction_text = "Phishing" if prediction == 1 else "Legitimate" |
| confidence = float(probability) if prediction == 1 else float(1 - probability) |
| status = "⚠️ UNSAFE" if prediction == 1 else "✅ SAFE" |
|
|
| |
| return prediction_text, confidence, status |
| except Exception as e: |
| error_msg = f"Error: {str(e)}" |
| return error_msg, 0.0, "Error" |
|
|
| |
| def create_interface(): |
| with gr.Blocks(title="NeoGuardianAI - URL Phishing Detection", theme=gr.themes.Soft()) as demo: |
| gr.Markdown( |
| """ |
| # NeoGuardianAI - URL Phishing Detection |
| |
| This app uses a machine learning model to detect if a URL is legitimate or phishing. |
| |
| Enter a URL below to check if it's safe or potentially malicious. |
| """ |
| ) |
|
|
| with gr.Row(): |
| url_input = gr.Textbox(label="Enter URL", placeholder="https://example.com") |
| submit_btn = gr.Button("Check URL", variant="primary") |
|
|
| with gr.Row(): |
| status_output = gr.Textbox(label="Status") |
| prediction_output = gr.Textbox(label="Prediction") |
| confidence_output = gr.Textbox(label="Confidence") |
|
|
| submit_btn.click( |
| fn=predict_url, |
| inputs=url_input, |
| outputs=[ |
| prediction_output, |
| confidence_output, |
| status_output |
| ] |
| ) |
|
|
| gr.Markdown( |
| """ |
| ## How it works |
| |
| This model was trained on the [pirocheto/phishing-url](https://huggingface.co/datasets/pirocheto/phishing-url) dataset from Hugging Face. |
| |
| The model extracts various features from the URL and uses a machine learning algorithm to classify it as legitimate or phishing. |
| |
| **Note**: While this model is highly accurate, it's not perfect. Always exercise caution when visiting unfamiliar websites. |
| |
| ## API Usage |
| |
| You can also use this model via the Hugging Face Inference API: |
| |
| ```python |
| import requests |
| |
| API_URL = "https://api-inference.huggingface.co/models/Devishetty100/neoguardianai" |
| headers = {"Authorization": "Bearer YOUR_API_TOKEN"} |
| |
| def query(url): |
| payload = {"inputs": url} |
| response = requests.post(API_URL, headers=headers, json=payload) |
| return response.json() |
| |
| # Example |
| result = query("https://example.com") |
| print(result) |
| ``` |
| """ |
| ) |
|
|
| return demo |
|
|
| |
| if __name__ == "__main__": |
| demo = create_interface() |
| demo.launch() |
|
|