Refactor SQLMapCLI class for improved type hinting and code clarity

This commit is contained in:
Wilbert Chandra 2026-01-07 12:33:26 +00:00
parent 3a975b79c1
commit 656a0dcdf7

View File

@ -7,23 +7,23 @@ Automates comprehensive SQL injection testing with a single command
import subprocess import subprocess
import sys import sys
import argparse import argparse
import time
import re
from pathlib import Path from pathlib import Path
from typing import List, Dict, Tuple from typing import List, Dict, Tuple, Optional, TypedDict, Any
from datetime import datetime from datetime import datetime
try: try:
from rich.console import Console from rich.console import Console
from rich.panel import Panel from rich.panel import Panel
from rich.table import Table from rich.table import Table
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeElapsedColumn from rich.progress import (
from rich.live import Live Progress,
from rich.layout import Layout SpinnerColumn,
from rich.text import Text TextColumn,
BarColumn,
TimeElapsedColumn,
)
from rich.prompt import Prompt, Confirm from rich.prompt import Prompt, Confirm
from rich import box from rich import box
from rich.style import Style
except ImportError: except ImportError:
print("Error: 'rich' library is required. Install it with: pip install rich") print("Error: 'rich' library is required. Install it with: pip install rich")
sys.exit(1) sys.exit(1)
@ -32,25 +32,33 @@ console = Console()
SQLMAP_PATH = Path(__file__).parent / "sqlmap.py" SQLMAP_PATH = Path(__file__).parent / "sqlmap.py"
# SQL injection techniques SQL_TECHNIQUES = {
TECHNIQUES = { "B": "Boolean-based blind",
'B': 'Boolean-based blind', "E": "Error-based",
'E': 'Error-based', "U": "Union query-based",
'U': 'Union query-based', "S": "Stacked queries",
'S': 'Stacked queries', "T": "Time-based blind",
'T': 'Time-based blind', "Q": "Inline queries",
'Q': 'Inline queries'
} }
class ScanResult(TypedDict):
total_tests: int
vulnerabilities: List[Dict[str, str]]
start_time: Optional[datetime]
end_time: Optional[datetime]
target: Optional[str]
class SQLMapCLI: class SQLMapCLI:
def __init__(self): def __init__(self):
self.console = Console() self.console = Console()
self.results = { self.results: ScanResult = {
'total_tests': 0, "total_tests": 0,
'vulnerabilities': [], "vulnerabilities": [],
'start_time': None, "start_time": None,
'end_time': None, "end_time": None,
'target': None "target": None,
} }
def print_banner(self): def print_banner(self):
@ -58,14 +66,14 @@ class SQLMapCLI:
banner = """ banner = """
CLI - Automated SQL Injection Testing CLI - Automated SQL Injection Testing
""" """
@ -74,23 +82,33 @@ class SQLMapCLI:
Panel( Panel(
"[yellow]⚠️ Legal Disclaimer: Only use on targets you have permission to test[/yellow]", "[yellow]⚠️ Legal Disclaimer: Only use on targets you have permission to test[/yellow]",
border_style="yellow", border_style="yellow",
box=box.ROUNDED box=box.ROUNDED,
) )
) )
self.console.print() self.console.print()
def run_sqlmap_test(self, url: str, level: int, risk: int, technique: str = "BEUSTQ", def run_sqlmap_test(
batch: bool = True, data: str = None, verbose: int = 1, self,
extra_args: List[str] = None) -> Tuple[bool, str]: url: str,
level: int,
risk: int,
technique: str = "BEUSTQ",
batch: bool = True,
data: Optional[str] = None,
verbose: int = 1,
extra_args: Optional[List[str]] = None,
) -> Tuple[bool, str]:
"""Run sqlmap with specified parameters""" """Run sqlmap with specified parameters"""
cmd = [ cmd = [
sys.executable, sys.executable,
str(SQLMAP_PATH), str(SQLMAP_PATH),
"-u", url, "-u",
url,
f"--level={level}", f"--level={level}",
f"--risk={risk}", f"--risk={risk}",
f"--technique={technique}", f"--technique={technique}",
"-v", str(verbose) "-v",
str(verbose),
] ]
if batch: if batch:
@ -107,7 +125,7 @@ class SQLMapCLI:
cmd, cmd,
capture_output=True, capture_output=True,
text=True, text=True,
timeout=600 # 10 minute timeout per test timeout=600, # 10 minute timeout per test
) )
return result.returncode == 0, result.stdout + result.stderr return result.returncode == 0, result.stdout + result.stderr
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
@ -115,15 +133,15 @@ class SQLMapCLI:
except Exception as e: except Exception as e:
return False, str(e) return False, str(e)
def parse_results(self, output: str) -> Dict: def parse_results(self, output: str) -> Dict[str, Any]:
"""Parse sqlmap output for vulnerabilities""" """Parse sqlmap output for vulnerabilities"""
vulns = [] vulns = []
# Look for vulnerability indicators # Look for vulnerability indicators
if "sqlmap identified the following injection point" in output: if "sqlmap identified the following injection point" in output:
# Extract injection details # Extract injection details
lines = output.split('\n') lines = output.split("\n")
current_param = 'Unknown' # Default parameter name current_param = "Unknown" # Default parameter name
for i, line in enumerate(lines): for i, line in enumerate(lines):
if "Parameter:" in line: if "Parameter:" in line:
@ -133,31 +151,40 @@ class SQLMapCLI:
# Check if next line contains the title # Check if next line contains the title
if i + 1 < len(lines) and "Title:" in lines[i + 1]: if i + 1 < len(lines) and "Title:" in lines[i + 1]:
title = lines[i + 1].split("Title:")[1].strip() title = lines[i + 1].split("Title:")[1].strip()
vulns.append({ vulns.append(
'parameter': current_param, {
'type': vuln_type, "parameter": current_param,
'title': title "type": vuln_type,
}) "title": title,
}
)
# Check for backend DBMS detection # Check for backend DBMS detection
backend_dbms = None backend_dbms = None
if "back-end DBMS:" in output.lower(): if "back-end DBMS:" in output.lower():
for line in output.split('\n'): for line in output.split("\n"):
if "back-end DBMS:" in line.lower(): if "back-end DBMS:" in line.lower():
backend_dbms = line.split(":", 1)[1].strip() backend_dbms = line.split(":", 1)[1].strip()
break break
return { return {
'vulnerabilities': vulns, "vulnerabilities": vulns,
'backend_dbms': backend_dbms, "backend_dbms": backend_dbms,
'is_vulnerable': len(vulns) > 0 or "vulnerable" in output.lower() "is_vulnerable": len(vulns) > 0 or "vulnerable" in output.lower(),
} }
def comprehensive_scan(self, url: str, max_level: int = 5, max_risk: int = 3, def comprehensive_scan(
techniques: str = "BEUSTQ", data: str = None, verbose: int = 1): self,
url: str,
max_level: int = 5,
max_risk: int = 3,
techniques: str = "BEUSTQ",
data: Optional[str] = None,
verbose: int = 1,
):
"""Run comprehensive scan with all levels and risks""" """Run comprehensive scan with all levels and risks"""
self.results['target'] = url self.results["target"] = url
self.results['start_time'] = datetime.now() self.results["start_time"] = datetime.now()
# Create results table # Create results table
results_table = Table(title="Scan Results", box=box.ROUNDED) results_table = Table(title="Scan Results", box=box.ROUNDED)
@ -167,7 +194,6 @@ class SQLMapCLI:
results_table.add_column("Findings", style="magenta") results_table.add_column("Findings", style="magenta")
total_tests = max_level * max_risk total_tests = max_level * max_risk
test_count = 0
with Progress( with Progress(
SpinnerColumn(), SpinnerColumn(),
@ -175,75 +201,84 @@ class SQLMapCLI:
BarColumn(), BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeElapsedColumn(), TimeElapsedColumn(),
console=self.console console=self.console,
) as progress: ) as progress:
overall_task = progress.add_task( overall_task = progress.add_task(
f"[cyan]Scanning {url}...", f"[cyan]Scanning {url}...", total=total_tests
total=total_tests
) )
for level in range(1, max_level + 1): for level in range(1, max_level + 1):
for risk in range(1, max_risk + 1): for risk in range(1, max_risk + 1):
test_count += 1
progress.update( progress.update(
overall_task, overall_task,
description=f"[cyan]Testing Level {level}, Risk {risk}..." description=f"[cyan]Testing Level {level}, Risk {risk}...",
) )
success, output = self.run_sqlmap_test(url, level, risk, techniques, data=data, verbose=verbose) success, output = self.run_sqlmap_test(
url, level, risk, techniques, data=data, verbose=verbose
)
parsed = self.parse_results(output) parsed = self.parse_results(output)
status = "" if success else "" status = "" if success else ""
status_style = "green" if success else "red" status_style = "green" if success else "red"
findings = "No vulnerabilities" if not parsed['is_vulnerable'] else f"{len(parsed['vulnerabilities'])} found!" findings = (
findings_style = "green" if not parsed['is_vulnerable'] else "bold red" "No vulnerabilities"
if not parsed["is_vulnerable"]
else f"{len(parsed['vulnerabilities'])} found!"
)
findings_style = (
"green" if not parsed["is_vulnerable"] else "bold red"
)
if parsed['is_vulnerable']: if parsed["is_vulnerable"]:
self.results['vulnerabilities'].extend(parsed['vulnerabilities']) self.results["vulnerabilities"].extend(
parsed["vulnerabilities"]
)
results_table.add_row( results_table.add_row(
str(level), str(level),
str(risk), str(risk),
f"[{status_style}]{status}[/{status_style}]", f"[{status_style}]{status}[/{status_style}]",
f"[{findings_style}]{findings}[/{findings_style}]" f"[{findings_style}]{findings}[/{findings_style}]",
) )
progress.update(overall_task, advance=1) progress.update(overall_task, advance=1)
self.results['total_tests'] += 1 self.results["total_tests"] += 1
self.results['end_time'] = datetime.now() self.results["end_time"] = datetime.now()
# Display results # Display results
self.console.print() self.console.print()
self.console.print(results_table) self.console.print(results_table)
self.display_summary() self.display_summary()
def quick_scan(self, url: str, level: int = 1, risk: int = 1, data: str = None, def quick_scan(
raw: bool = False, verbose: int = 1): self,
url: str,
level: int = 1,
risk: int = 1,
data: Optional[str] = None,
raw: bool = False,
verbose: int = 1,
):
"""Run a quick scan with default settings""" """Run a quick scan with default settings"""
self.results['target'] = url self.results["target"] = url
self.results['start_time'] = datetime.now() self.results["start_time"] = datetime.now()
if not raw: if not raw:
scan_info = f"[cyan]Running quick scan on:[/cyan]\n[yellow]{url}[/yellow]\n[dim]Level: {level}, Risk: {risk}[/dim]" scan_info = f"[cyan]Running quick scan on:[/cyan]\n[yellow]{url}[/yellow]\n[dim]Level: {level}, Risk: {risk}[/dim]"
if data: if data:
scan_info += f"\n[dim]POST Data: {data}[/dim]" scan_info += f"\n[dim]POST Data: {data}[/dim]"
self.console.print( self.console.print(Panel(scan_info, border_style="cyan", box=box.ROUNDED))
Panel(
scan_info,
border_style="cyan",
box=box.ROUNDED
)
)
if raw: if raw:
# Raw mode - just show sqlmap output directly # Raw mode - just show sqlmap output directly
self.console.print("[cyan]Running sqlmap...[/cyan]\n") self.console.print("[cyan]Running sqlmap...[/cyan]\n")
success, output = self.run_sqlmap_test(url, level, risk, data=data, verbose=verbose) success, output = self.run_sqlmap_test(
url, level, risk, data=data, verbose=verbose
)
self.console.print(output) self.console.print(output)
return return
@ -251,17 +286,20 @@ class SQLMapCLI:
SpinnerColumn(), SpinnerColumn(),
TextColumn("[progress.description]{task.description}"), TextColumn("[progress.description]{task.description}"),
TimeElapsedColumn(), TimeElapsedColumn(),
console=self.console console=self.console,
) as progress: ) as progress:
task = progress.add_task(
task = progress.add_task("[cyan]Scanning for vulnerabilities...", total=None) "[cyan]Scanning for vulnerabilities...", total=None
success, output = self.run_sqlmap_test(url, level, risk, data=data, verbose=verbose) )
success, output = self.run_sqlmap_test(
url, level, risk, data=data, verbose=verbose
)
progress.update(task, completed=True) progress.update(task, completed=True)
parsed = self.parse_results(output) parsed = self.parse_results(output)
self.results['vulnerabilities'] = parsed['vulnerabilities'] self.results["vulnerabilities"] = parsed["vulnerabilities"]
self.results['total_tests'] = 1 self.results["total_tests"] = 1
self.results['end_time'] = datetime.now() self.results["end_time"] = datetime.now()
self.display_summary() self.display_summary()
@ -270,38 +308,44 @@ class SQLMapCLI:
self.console.print() self.console.print()
# Calculate duration # Calculate duration
duration = (self.results['end_time'] - self.results['start_time']).total_seconds() duration = 0.0
if self.results["end_time"] and self.results["start_time"]:
duration = (
self.results["end_time"] - self.results["start_time"]
).total_seconds()
# Create summary panel # Create summary panel
summary_text = f""" summary_text = f"""
[cyan]Target:[/cyan] {self.results['target']} [cyan]Target:[/cyan] {self.results["target"] or "N/A"}
[cyan]Total Tests:[/cyan] {self.results['total_tests']} [cyan]Total Tests:[/cyan] {self.results["total_tests"]}
[cyan]Duration:[/cyan] {duration:.2f} seconds [cyan]Duration:[/cyan] {duration:.2f} seconds
[cyan]Vulnerabilities Found:[/cyan] {len(self.results['vulnerabilities'])} [cyan]Vulnerabilities Found:[/cyan] {len(self.results["vulnerabilities"])}
""" """
self.console.print( self.console.print(
Panel( Panel(
summary_text.strip(), summary_text.strip(),
title="[bold]Scan Summary[/bold]", title="[bold]Scan Summary[/bold]",
border_style="green" if len(self.results['vulnerabilities']) == 0 else "red", border_style="green"
box=box.DOUBLE if len(self.results["vulnerabilities"]) == 0
else "red",
box=box.DOUBLE,
) )
) )
# Display vulnerabilities if found # Display vulnerabilities if found
if self.results['vulnerabilities']: if self.results["vulnerabilities"]:
self.console.print() self.console.print()
vuln_table = Table(title="⚠️ Vulnerabilities Detected", box=box.HEAVY) vuln_table = Table(title="⚠️ Vulnerabilities Detected", box=box.HEAVY)
vuln_table.add_column("Parameter", style="cyan") vuln_table.add_column("Parameter", style="cyan")
vuln_table.add_column("Type", style="yellow") vuln_table.add_column("Type", style="yellow")
vuln_table.add_column("Title", style="red") vuln_table.add_column("Title", style="red")
for vuln in self.results['vulnerabilities']: for vuln in self.results["vulnerabilities"]:
vuln_table.add_row( vuln_table.add_row(
vuln.get('parameter', 'N/A'), vuln.get("parameter", "N/A"),
vuln.get('type', 'N/A'), vuln.get("type", "N/A"),
vuln.get('title', 'N/A') vuln.get("title", "N/A"),
) )
self.console.print(vuln_table) self.console.print(vuln_table)
@ -323,26 +367,30 @@ class SQLMapCLI:
self.console.print( self.console.print(
Panel( Panel(
"[cyan]Interactive Mode[/cyan]\n[dim]Enter target details for SQL injection testing[/dim]", "[cyan]Interactive Mode[/cyan]\n[dim]Enter target details for SQL injection testing[/dim]",
border_style="cyan" border_style="cyan",
) )
) )
url = Prompt.ask("\n[cyan]Enter target URL[/cyan]") url = Prompt.ask("\n[cyan]Enter target URL[/cyan]")
# Ask if this is a POST request # Ask if this is a POST request
has_data = Confirm.ask("[cyan]Does this request require POST data/body?[/cyan]", default=False) has_data = Confirm.ask(
"[cyan]Does this request require POST data/body?[/cyan]", default=False
)
data = None data = None
if has_data: if has_data:
self.console.print("\n[dim]Examples:[/dim]") self.console.print("\n[dim]Examples:[/dim]")
self.console.print("[dim] JSON: {\"email\":\"test@example.com\",\"password\":\"pass123\"}[/dim]") self.console.print(
'[dim] JSON: {"email":"test@example.com","password":"pass123"}[/dim]'
)
self.console.print("[dim] Form: username=admin&password=secret[/dim]") self.console.print("[dim] Form: username=admin&password=secret[/dim]")
data = Prompt.ask("\n[cyan]Enter POST data/body[/cyan]") data = Prompt.ask("\n[cyan]Enter POST data/body[/cyan]")
scan_type = Prompt.ask( scan_type = Prompt.ask(
"\n[cyan]Select scan type[/cyan]", "\n[cyan]Select scan type[/cyan]",
choices=["quick", "comprehensive"], choices=["quick", "comprehensive"],
default="quick" default="quick",
) )
if scan_type == "quick": if scan_type == "quick":
@ -350,8 +398,12 @@ class SQLMapCLI:
risk = int(Prompt.ask("[cyan]Test risk (1-3)[/cyan]", default="1")) risk = int(Prompt.ask("[cyan]Test risk (1-3)[/cyan]", default="1"))
self.quick_scan(url, level, risk, data=data) self.quick_scan(url, level, risk, data=data)
else: else:
max_level = int(Prompt.ask("[cyan]Maximum test level (1-5)[/cyan]", default="5")) max_level = int(
max_risk = int(Prompt.ask("[cyan]Maximum test risk (1-3)[/cyan]", default="3")) Prompt.ask("[cyan]Maximum test level (1-5)[/cyan]", default="5")
)
max_risk = int(
Prompt.ask("[cyan]Maximum test risk (1-3)[/cyan]", default="3")
)
self.comprehensive_scan(url, max_level, max_risk, data=data) self.comprehensive_scan(url, max_level, max_risk, data=data)
@ -375,82 +427,77 @@ Examples:
# Interactive mode # Interactive mode
python sqlmapcli.py --interactive python sqlmapcli.py --interactive
""" """,
) )
parser.add_argument( parser.add_argument(
'-u', '--url', "-u", "--url", help='Target URL (e.g., "http://example.com/page?id=1")'
help='Target URL (e.g., "http://example.com/page?id=1")'
) )
parser.add_argument( parser.add_argument(
'--comprehensive', "--comprehensive",
action='store_true', action="store_true",
help='Run comprehensive scan with all risk/level combinations' help="Run comprehensive scan with all risk/level combinations",
) )
parser.add_argument( parser.add_argument(
'--level', "--level",
type=int, type=int,
default=1, default=1,
choices=[1, 2, 3, 4, 5], choices=[1, 2, 3, 4, 5],
help='Level of tests to perform (1-5, default: 1)' help="Level of tests to perform (1-5, default: 1)",
) )
parser.add_argument( parser.add_argument(
'--risk', "--risk",
type=int, type=int,
default=1, default=1,
choices=[1, 2, 3], choices=[1, 2, 3],
help='Risk of tests to perform (1-3, default: 1)' help="Risk of tests to perform (1-3, default: 1)",
) )
parser.add_argument( parser.add_argument(
'--max-level', "--max-level",
type=int, type=int,
default=5, default=5,
choices=[1, 2, 3, 4, 5], choices=[1, 2, 3, 4, 5],
help='Maximum level for comprehensive scan (default: 5)' help="Maximum level for comprehensive scan (default: 5)",
) )
parser.add_argument( parser.add_argument(
'--max-risk', "--max-risk",
type=int, type=int,
default=3, default=3,
choices=[1, 2, 3], choices=[1, 2, 3],
help='Maximum risk for comprehensive scan (default: 3)' help="Maximum risk for comprehensive scan (default: 3)",
) )
parser.add_argument( parser.add_argument(
'--technique', "--technique",
type=str, type=str,
default='BEUSTQ', default="BEUSTQ",
help='SQL injection techniques to use (default: BEUSTQ)' help="SQL injection techniques to use (default: BEUSTQ)",
) )
parser.add_argument( parser.add_argument(
'--data', "--data",
type=str, type=str,
help='Data string to be sent through POST (e.g., "username=test&password=test")' help='Data string to be sent through POST (e.g., "username=test&password=test")',
) )
parser.add_argument( parser.add_argument(
'--raw', "--raw", action="store_true", help="Show raw sqlmap output without formatting"
action='store_true',
help='Show raw sqlmap output without formatting'
) )
parser.add_argument( parser.add_argument(
'--verbose', "--verbose",
type=int, type=int,
choices=[0, 1, 2, 3, 4, 5, 6], choices=[0, 1, 2, 3, 4, 5, 6],
help='Sqlmap verbosity level (0-6, default: 1)' help="Sqlmap verbosity level (0-6, default: 1)",
) )
parser.add_argument( parser.add_argument(
'-i', '--interactive', "-i", "--interactive", action="store_true", help="Run in interactive mode"
action='store_true',
help='Run in interactive mode'
) )
args = parser.parse_args() args = parser.parse_args()
@ -462,9 +509,11 @@ Examples:
if not SQLMAP_PATH.exists(): if not SQLMAP_PATH.exists():
console.print( console.print(
f"[bold red]Error: sqlmap.py not found at {SQLMAP_PATH}[/bold red]", f"[bold red]Error: sqlmap.py not found at {SQLMAP_PATH}[/bold red]",
style="bold red" style="bold red",
)
console.print(
"[yellow]Make sure you're running this script from the sqlmap directory[/yellow]"
) )
console.print("[yellow]Make sure you're running this script from the sqlmap directory[/yellow]")
sys.exit(1) sys.exit(1)
# Interactive mode # Interactive mode
@ -474,7 +523,9 @@ Examples:
# Check if URL is provided # Check if URL is provided
if not args.url: if not args.url:
console.print("[bold red]Error: URL is required (use -u or --interactive)[/bold red]") console.print(
"[bold red]Error: URL is required (use -u or --interactive)[/bold red]"
)
parser.print_help() parser.print_help()
sys.exit(1) sys.exit(1)
@ -488,7 +539,7 @@ Examples:
max_risk=args.max_risk, max_risk=args.max_risk,
techniques=args.technique, techniques=args.technique,
data=args.data, data=args.data,
verbose=verbose_level verbose=verbose_level,
) )
else: else:
cli.quick_scan( cli.quick_scan(
@ -497,7 +548,7 @@ Examples:
risk=args.risk, risk=args.risk,
data=args.data, data=args.data,
raw=args.raw, raw=args.raw,
verbose=verbose_level verbose=verbose_level,
) )