anshumanatrey's picture
KB-driven dynamic architecture: 26 vuln types, procedural generation, parameter-level testing
d3c5d92 verified
"""
Security testing tool handlers.
Handles: test_injection, test_xss, test_auth, test_config, test_crypto,
check_secrets.
Each handler:
1. Accepts an optional ``parameter`` argument (backward-compatible).
2. Looks up vulnerabilities in the scenario that match host + endpoint + tool.
3. Filters by chain prerequisites (requires_found).
4. Generates output using KB response templates via formatters.
"""
from typing import Any, Dict, List, Optional, Set, Tuple
from .formatters import (
format_tool_output,
format_safe_output,
_map_vuln_to_type,
_get_sample_payload,
)
# ---------------------------------------------------------------------------
# Shared helpers
# ---------------------------------------------------------------------------
def _filter_by_chain(vulns: List[Dict], discovered_vulns: Optional[Set[str]]) -> List[Dict]:
"""Filter vulnerabilities by chain prerequisites."""
if discovered_vulns is None:
discovered_vulns = set()
return [
v for v in vulns
if not v.get("requires_found") or all(r in discovered_vulns for r in v["requires_found"])
]
def _normalize_difficulty(scenario: Dict[str, Any]) -> str:
"""Extract difficulty tier from scenario ID."""
sid = scenario.get("id", "easy")
if "easy" in sid:
return "easy"
elif "medium" in sid:
return "medium"
elif "hard" in sid:
return "hard"
return "easy"
def _build_context(
vuln: Dict[str, Any],
host: str,
endpoint: str,
parameter: Optional[str],
) -> Dict[str, str]:
"""Build a template context dict from a vulnerability instance."""
vuln_type_id = _map_vuln_to_type(vuln)
return {
"host": host,
"endpoint": endpoint or vuln.get("endpoint", ""),
"parameter": parameter or "multiple",
"payload": _get_sample_payload(vuln_type_id),
"cvss": str(vuln.get("cvss", 0)),
"cwe": vuln.get("cwe", ""),
"severity": vuln.get("severity", ""),
"owasp": vuln.get("owasp", ""),
"evidence_detail": vuln.get("evidence", ""),
"remediation": vuln.get("remediation", ""),
"service_version": "",
"port": "",
}
# ---------------------------------------------------------------------------
# test_injection
# ---------------------------------------------------------------------------
def handle_test_injection(
args: Dict[str, Any],
scenario: Dict[str, Any],
discovered_hosts: List[str],
discovered_ports: Dict[str, List[int]],
discovered_vulns: Optional[Set[str]] = None,
) -> Tuple[str, List[str], Dict[str, List[int]], float]:
host = args.get("host", "")
endpoint = args.get("endpoint", "")
parameter = args.get("parameter")
if host not in discovered_hosts:
return ("Error: Host not discovered yet. Run network_scan first.", [], {}, -0.02)
difficulty = _normalize_difficulty(scenario)
vulns = [
v for v in scenario.get("vulnerabilities", [])
if v["host"] == host
and v.get("endpoint") == endpoint
and "test_injection" in v.get("discoverable_by", [])
]
vulns = _filter_by_chain(vulns, discovered_vulns)
if not vulns:
return (
f"Injection testing on {host}{endpoint}:\n"
f" Tested: SQLi (error-based, blind, time-based), command injection, SSTI, SSRF\n"
f" Payloads: 47 injection patterns tested\n"
f" Result: No injection vulnerabilities detected on this endpoint.",
[], {}, 0.01,
)
output_parts = [f"Injection testing on {host}{endpoint}:", ""]
for v in vulns:
context = _build_context(v, host, endpoint, parameter)
output_parts.append(format_tool_output(v, difficulty, context))
output_parts.append("")
return "\n".join(output_parts), [], {}, 0.08
# ---------------------------------------------------------------------------
# test_xss
# ---------------------------------------------------------------------------
def handle_test_xss(
args: Dict[str, Any],
scenario: Dict[str, Any],
discovered_hosts: List[str],
discovered_ports: Dict[str, List[int]],
discovered_vulns: Optional[Set[str]] = None,
) -> Tuple[str, List[str], Dict[str, List[int]], float]:
host = args.get("host", "")
endpoint = args.get("endpoint", "")
parameter = args.get("parameter")
if host not in discovered_hosts:
return ("Error: Host not discovered yet. Run network_scan first.", [], {}, -0.02)
difficulty = _normalize_difficulty(scenario)
vulns = [
v for v in scenario.get("vulnerabilities", [])
if v["host"] == host
and v.get("endpoint") == endpoint
and "test_xss" in v.get("discoverable_by", [])
]
vulns = _filter_by_chain(vulns, discovered_vulns)
if not vulns:
return (
f"XSS testing on {host}{endpoint}:\n"
f" Tested: reflected XSS, stored XSS, DOM-based XSS\n"
f" Payloads: 23 XSS vectors tested across all parameters\n"
f" Result: No XSS vulnerabilities detected.",
[], {}, 0.01,
)
output_parts = [f"XSS testing on {host}{endpoint}:", ""]
for v in vulns:
context = _build_context(v, host, endpoint, parameter)
output_parts.append(format_tool_output(v, difficulty, context))
output_parts.append("")
return "\n".join(output_parts), [], {}, 0.08
# ---------------------------------------------------------------------------
# test_auth
# ---------------------------------------------------------------------------
def handle_test_auth(
args: Dict[str, Any],
scenario: Dict[str, Any],
discovered_hosts: List[str],
discovered_ports: Dict[str, List[int]],
discovered_vulns: Optional[Set[str]] = None,
) -> Tuple[str, List[str], Dict[str, List[int]], float]:
host = args.get("host", "")
endpoint = args.get("endpoint")
parameter = args.get("parameter")
if host not in discovered_hosts:
return ("Error: Host not discovered yet. Run network_scan first.", [], {}, -0.02)
difficulty = _normalize_difficulty(scenario)
# test_auth handles both endpoint-level (IDOR) and host-level (default creds)
vulns = [
v for v in scenario.get("vulnerabilities", [])
if v["host"] == host
and "test_auth" in v.get("discoverable_by", [])
and (endpoint is None or v.get("endpoint") is None or v.get("endpoint") == endpoint)
]
vulns = _filter_by_chain(vulns, discovered_vulns)
if not vulns:
target = f"{host}{endpoint}" if endpoint else host
return (
f"Auth testing on {target}:\n"
f" Default credentials: 15 common sets tested -- none accepted\n"
f" Session management: tokens properly rotated\n"
f" Access controls: authorization checks present\n"
f" Brute force: rate limiting detected after 5 attempts\n"
f" Result: PASS -- no authentication weaknesses found.",
[], {}, 0.01,
)
ep_display = endpoint or ""
output_parts = [f"Auth & access control testing on {host}{ep_display}:", ""]
for v in vulns:
context = _build_context(v, host, endpoint or "", parameter)
output_parts.append(format_tool_output(v, difficulty, context))
output_parts.append("")
return "\n".join(output_parts), [], {}, 0.08
# ---------------------------------------------------------------------------
# test_config
# ---------------------------------------------------------------------------
def handle_test_config(
args: Dict[str, Any],
scenario: Dict[str, Any],
discovered_hosts: List[str],
discovered_ports: Dict[str, List[int]],
discovered_vulns: Optional[Set[str]] = None,
) -> Tuple[str, List[str], Dict[str, List[int]], float]:
host = args.get("host", "")
if host not in discovered_hosts:
return ("Error: Host not discovered yet. Run network_scan first.", [], {}, -0.02)
difficulty = _normalize_difficulty(scenario)
vulns = [
v for v in scenario.get("vulnerabilities", [])
if v["host"] == host
and "test_config" in v.get("discoverable_by", [])
]
vulns = _filter_by_chain(vulns, discovered_vulns)
if not vulns:
return (
f"Configuration audit for {host}:\n"
f" Directory listing: disabled\n"
f" Debug mode: off\n"
f" Server headers: version info suppressed\n"
f" Admin panels: not exposed publicly\n"
f" Result: PASS -- no significant misconfigurations.",
[], {}, 0.01,
)
output_parts = [f"Configuration audit for {host}:", ""]
for v in vulns:
context = _build_context(v, host, "", None)
# Enrich with port/service if available from scenario
host_ports = scenario.get("ports", {}).get(host, [])
if host_ports:
context["port"] = str(host_ports[0].get("port", ""))
context["service_version"] = host_ports[0].get("version", "")
output_parts.append(format_tool_output(v, difficulty, context))
output_parts.append("")
return "\n".join(output_parts), [], {}, 0.08
# ---------------------------------------------------------------------------
# test_crypto
# ---------------------------------------------------------------------------
def handle_test_crypto(
args: Dict[str, Any],
scenario: Dict[str, Any],
discovered_hosts: List[str],
discovered_ports: Dict[str, List[int]],
discovered_vulns: Optional[Set[str]] = None,
) -> Tuple[str, List[str], Dict[str, List[int]], float]:
host = args.get("host", "")
if host not in discovered_hosts:
return ("Error: Host not discovered yet. Run network_scan first.", [], {}, -0.02)
difficulty = _normalize_difficulty(scenario)
vulns = [
v for v in scenario.get("vulnerabilities", [])
if v["host"] == host
and "test_crypto" in v.get("discoverable_by", [])
]
vulns = _filter_by_chain(vulns, discovered_vulns)
if not vulns:
return (
f"Cryptographic analysis for {host}:\n"
f" TLS version: 1.2+ only (1.0/1.1 disabled)\n"
f" Cipher suites: strong (AES-256-GCM preferred)\n"
f" Certificate: valid, not expired\n"
f" HSTS: enabled\n"
f" Result: PASS -- TLS/SSL configuration is secure.",
[], {}, 0.01,
)
output_parts = [f"Cryptographic analysis for {host}:", ""]
for v in vulns:
context = _build_context(v, host, "", None)
# Enrich with port info
host_ports = scenario.get("ports", {}).get(host, [])
for p in host_ports:
if p.get("service") in ("https", "ssl", "tls") or p.get("port") in (443, 8443):
context["port"] = str(p["port"])
context["service_version"] = p.get("version", "")
break
else:
if host_ports:
context["port"] = str(host_ports[0].get("port", "443"))
context["service_version"] = host_ports[0].get("version", "")
output_parts.append(format_tool_output(v, difficulty, context))
output_parts.append("")
return "\n".join(output_parts), [], {}, 0.06
# ---------------------------------------------------------------------------
# check_secrets
# ---------------------------------------------------------------------------
def handle_check_secrets(
args: Dict[str, Any],
scenario: Dict[str, Any],
discovered_hosts: List[str],
discovered_ports: Dict[str, List[int]],
discovered_vulns: Optional[Set[str]] = None,
) -> Tuple[str, List[str], Dict[str, List[int]], float]:
host = args.get("host", "")
endpoint = args.get("endpoint")
parameter = args.get("parameter")
if host not in discovered_hosts:
return ("Error: Host not discovered yet. Run network_scan first.", [], {}, -0.02)
difficulty = _normalize_difficulty(scenario)
vulns = [
v for v in scenario.get("vulnerabilities", [])
if v["host"] == host
and "check_secrets" in v.get("discoverable_by", [])
and (endpoint is None or v.get("endpoint") is None or v.get("endpoint") == endpoint)
]
vulns = _filter_by_chain(vulns, discovered_vulns)
if not vulns:
target = f"{host}{endpoint}" if endpoint else host
return (
f"Secret scanning on {target}:\n"
f" Scanned: source files, config files, environment variables, HTTP responses\n"
f" Patterns: 34 secret patterns checked (AWS, Stripe, JWT, private keys, etc.)\n"
f" Entropy analysis: no high-entropy strings detected\n"
f" Result: PASS -- no exposed secrets found.",
[], {}, 0.01,
)
ep_display = endpoint or ""
output_parts = [f"Secret scanning on {host}{ep_display}:", ""]
for v in vulns:
context = _build_context(v, host, endpoint or "", parameter)
output_parts.append(format_tool_output(v, difficulty, context))
output_parts.append("")
return "\n".join(output_parts), [], {}, 0.08