Enhance SQLMapScanner with real-time progress updates and temporary output handling; improve UI display for batch scan results

This commit is contained in:
Wilbert Chandra 2026-01-07 12:55:57 +00:00
parent 2270c8981b
commit 93a204ef01
3 changed files with 298 additions and 108 deletions

View File

@ -2,6 +2,9 @@ from rich.table import Table
import sys import sys
import subprocess import subprocess
import json import json
import os
import tempfile
import shutil
from datetime import datetime from datetime import datetime
from typing import List, Dict, Tuple, Optional, Any from typing import List, Dict, Tuple, Optional, Any
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
@ -22,15 +25,16 @@ from .ui import display_summary, display_batch_results
console = Console() console = Console()
class SQLMapScanner: class SQLMapScanner:
def __init__(self, enable_logging: bool = True): def __init__(self, enable_logging: bool = True):
self.enable_logging = enable_logging self.enable_logging = enable_logging
self.results: ScanResult = { self.results: ScanResult = {
'total_tests': 0, "total_tests": 0,
'vulnerabilities': [], "vulnerabilities": [],
'start_time': None, "start_time": None,
'end_time': None, "end_time": None,
'target': None "target": None,
} }
def run_sqlmap_test( def run_sqlmap_test(
@ -44,8 +48,10 @@ class SQLMapScanner:
headers: Optional[str] = None, headers: Optional[str] = None,
verbose: int = 1, verbose: int = 1,
extra_args: Optional[List[str]] = None, extra_args: Optional[List[str]] = None,
progress: Optional[Progress] = None,
task_id: Any = None,
) -> Tuple[bool, str]: ) -> Tuple[bool, str]:
"""Run sqlmap with specified parameters""" """Run sqlmap with specified parameters and optional real-time progress"""
cmd = [ cmd = [
sys.executable, sys.executable,
str(SQLMAP_PATH), str(SQLMAP_PATH),
@ -70,16 +76,80 @@ class SQLMapScanner:
if extra_args: if extra_args:
cmd.extend(extra_args) cmd.extend(extra_args)
# Create a unique temporary directory for this run to avoid session database locks
# which are the primary cause of concurrency bottlenecks in sqlmap
tmp_output_dir = tempfile.mkdtemp(prefix="sqlmap_scan_")
cmd.extend(["--output-dir", tmp_output_dir])
try: try:
result = subprocess.run( if progress and task_id:
cmd, # Run with real-time output parsing
capture_output=True, process = subprocess.Popen(
text=True cmd,
) stdout=subprocess.PIPE,
return result.returncode == 0, result.stdout + result.stderr stderr=subprocess.STDOUT,
text=True,
bufsize=1,
)
output_lines = []
if process.stdout is None:
return False, "Failed to capture sqlmap output"
for line in process.stdout:
output_lines.append(line)
# Update progress description based on sqlmap status
clean_line = line.strip()
if "[INFO]" in clean_line:
status = clean_line.split("[INFO]", 1)[1].strip()
# Clean up status message
if "testing" in status.lower():
progress.update(
task_id, description=f"[cyan]{status[:50]}[/cyan]"
)
elif "detecting" in status.lower():
progress.update(
task_id, description=f"[yellow]{status[:50]}[/yellow]"
)
elif "identified" in status.lower():
progress.update(
task_id, description=f"[green]{status[:50]}[/green]"
)
process.wait()
full_output = "".join(output_lines)
# Cleanup temporary output directory
try:
shutil.rmtree(tmp_output_dir)
except:
pass
return process.returncode == 0, full_output
else:
# Standard blocking run
result = subprocess.run(cmd, capture_output=True, text=True)
# Cleanup temporary output directory
try:
shutil.rmtree(tmp_output_dir)
except:
pass
return result.returncode == 0, result.stdout + result.stderr
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
# Cleanup on timeout
try:
shutil.rmtree(tmp_output_dir)
except:
pass
return False, "Test timed out after 10 minutes" return False, "Test timed out after 10 minutes"
except Exception as e: except Exception as e:
# Cleanup on error
try:
shutil.rmtree(tmp_output_dir)
except:
pass
return False, str(e) return False, str(e)
def parse_results(self, output: str) -> Dict[str, Any]: def parse_results(self, output: str) -> Dict[str, Any]:
@ -157,11 +227,19 @@ class SQLMapScanner:
for risk in range(1, max_risk + 1): for risk in range(1, max_risk + 1):
progress.update( progress.update(
overall_task, overall_task,
description=f"[cyan]Testing Level {level}, Risk {risk}...", description=f"[cyan]Testing Level {level}, Risk {risk}...[/cyan]",
) )
success, output = self.run_sqlmap_test( success, output = self.run_sqlmap_test(
url, level, risk, techniques, data=data, headers=headers, verbose=verbose url,
level,
risk,
techniques,
data=data,
headers=headers,
verbose=max(verbose, 3),
progress=progress,
task_id=overall_task,
) )
parsed = self.parse_results(output) parsed = self.parse_results(output)
@ -226,7 +304,7 @@ class SQLMapScanner:
url, level, risk, data=data, headers=headers, verbose=verbose url, level, risk, data=data, headers=headers, verbose=verbose
) )
console.print(output) console.print(output)
if self.enable_logging: if self.enable_logging:
log_file = get_log_filename(url) log_file = get_log_filename(url)
save_log(log_file, output) save_log(log_file, output)
@ -238,18 +316,25 @@ class SQLMapScanner:
TimeElapsedColumn(), TimeElapsedColumn(),
console=console, console=console,
) as progress: ) as progress:
task = progress.add_task( task = progress.add_task(f"[cyan]Scanning {url[:40]}...", total=None)
"[cyan]Scanning for vulnerabilities...", total=None
)
success, output = self.run_sqlmap_test( success, output = self.run_sqlmap_test(
url, level, risk, data=data, headers=headers, verbose=verbose url,
level,
risk,
data=data,
headers=headers,
verbose=max(verbose, 3),
progress=progress,
task_id=task,
) )
progress.update(task, completed=True) progress.update(
task, completed=True, description="[green]✓ Scan Complete[/green]"
)
if self.enable_logging: if self.enable_logging:
log_file = get_log_filename(url) log_file = get_log_filename(url)
save_log(log_file, output) save_log(log_file, output)
parsed = self.parse_results(output) parsed = self.parse_results(output)
self.results["vulnerabilities"] = parsed["vulnerabilities"] self.results["vulnerabilities"] = parsed["vulnerabilities"]
self.results["total_tests"] = 1 self.results["total_tests"] = 1
@ -257,88 +342,148 @@ class SQLMapScanner:
display_summary(self.results) display_summary(self.results)
def process_single_endpoint(self, endpoint: Dict, level: int, risk: int, verbose: int) -> Dict: def process_single_endpoint(
self,
endpoint: Dict,
level: int,
risk: int,
verbose: int,
progress: Optional[Progress] = None,
task_id: Any = None,
) -> Dict:
"""Process a single endpoint for batch mode""" """Process a single endpoint for batch mode"""
url = str(endpoint.get('url')) if endpoint.get('url') else '' url = str(endpoint.get("url")) if endpoint.get("url") else ""
data = endpoint.get('data') data = endpoint.get("data")
if data is not None and not isinstance(data, str): if data is not None and not isinstance(data, str):
data = json.dumps(data) data = json.dumps(data)
headers = endpoint.get('headers') headers = endpoint.get("headers")
if headers is not None and isinstance(headers, list): if headers is not None and isinstance(headers, list):
headers = "\\n".join(headers) headers = "\\n".join(headers)
try: try:
success, output = self.run_sqlmap_test(url, level, risk, data=data, headers=headers, verbose=verbose) # Force verbosity 3 for better progress tracking if in batch mode
# unless a specific high verbosity is already requested
exec_verbose = max(verbose, 3) if progress else verbose
success, output = self.run_sqlmap_test(
url,
level,
risk,
data=data,
headers=headers,
verbose=exec_verbose,
progress=progress,
task_id=task_id,
)
if self.enable_logging: if self.enable_logging:
log_file = get_log_filename(url) log_file = get_log_filename(url)
save_log(log_file, output) save_log(log_file, output)
parsed = self.parse_results(output) parsed = self.parse_results(output)
if progress and task_id:
status_color = "red" if parsed["is_vulnerable"] else "green"
status_text = "Vulnerable" if parsed["is_vulnerable"] else "Clean"
progress.update(
task_id,
description=f"[{status_color}]✓ {status_text}[/{status_color}] - {url[:30]}...",
completed=100,
)
return { return {
'url': url, "url": url,
'data': data, "data": data,
'success': success, "success": success,
'vulnerabilities': parsed['vulnerabilities'], "vulnerabilities": parsed["vulnerabilities"],
'is_vulnerable': parsed['is_vulnerable'] "is_vulnerable": parsed["is_vulnerable"],
} }
except Exception as e: except Exception as e:
if progress and task_id:
progress.update(
task_id,
description=f"[bold red]✗ Error: {str(e)[:30]}[/bold red]",
completed=100,
)
return { return {
'url': url, "url": url,
'data': data, "data": data,
'success': False, "success": False,
'error': str(e), "error": str(e),
'vulnerabilities': [], "vulnerabilities": [],
'is_vulnerable': False "is_vulnerable": False,
} }
def batch_scan(self, endpoints: List[Dict], level: int = 1, risk: int = 1, def batch_scan(
concurrency: int = 5, verbose: int = 1): self,
endpoints: List[Dict],
level: int = 1,
risk: int = 1,
concurrency: int = 5,
verbose: int = 1,
):
"""Run batch scan on multiple endpoints with concurrency""" """Run batch scan on multiple endpoints with concurrency"""
# Determine actual concurrency
if concurrency <= 0:
# For I/O bound tasks like scanning, we can use 2x CPU count
concurrency = (os.cpu_count() or 2) * 2
console.print( console.print(
Panel( Panel(
f"[cyan]Batch Scan Mode[/cyan]\n" f"[cyan]Batch Scan Mode[/cyan]\n"
f"[dim]Testing {len(endpoints)} endpoint(s) with concurrency={concurrency}[/dim]\n" f"[dim]Testing {len(endpoints)} endpoint(s) with concurrency={concurrency}[/dim]\n"
f"[dim]Level: {level}, Risk: {risk}[/dim]", f"[dim]Level: {level}, Risk: {risk}[/dim]",
border_style="cyan", border_style="cyan",
box=box.ROUNDED box=box.ROUNDED,
) )
) )
results = [] results = []
with Progress( with Progress(
SpinnerColumn(), SpinnerColumn(),
TextColumn("[progress.description]{task.description}"), TextColumn("[progress.description]{task.description}"),
BarColumn(), BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TextColumn("({task.completed}/{task.total})"),
TimeElapsedColumn(), TimeElapsedColumn(),
console=console console=console,
expand=True,
) as progress: ) as progress:
task = progress.add_task("[cyan]Processing endpoints...", total=len(endpoints))
with ThreadPoolExecutor(max_workers=concurrency) as executor: with ThreadPoolExecutor(max_workers=concurrency) as executor:
future_to_endpoint = { future_to_endpoint = {}
executor.submit(self.process_single_endpoint, endpoint, level, risk, verbose): endpoint
for endpoint in endpoints for endpoint in endpoints:
} url = endpoint.get("url", "Unknown")
task_id = progress.add_task(
f"[dim]Waiting...[/dim] {url[:40]}...", total=100
)
future = executor.submit(
self.process_single_endpoint,
endpoint,
level,
risk,
verbose,
progress,
task_id,
)
future_to_endpoint[future] = endpoint
for future in as_completed(future_to_endpoint): for future in as_completed(future_to_endpoint):
endpoint = future_to_endpoint[future] endpoint = future_to_endpoint[future]
try: try:
results.append(future.result()) results.append(future.result())
except Exception as e: except Exception as e:
results.append({ results.append(
'url': endpoint.get('url'), {
'data': endpoint.get('data'), "url": endpoint.get("url"),
'success': False, "data": endpoint.get("data"),
'error': str(e), "success": False,
'vulnerabilities': [], "error": str(e),
'is_vulnerable': False "vulnerabilities": [],
}) "is_vulnerable": False,
progress.update(task, advance=1) }
)
display_batch_results(results) display_batch_results(results)
return results return results

View File

@ -7,6 +7,7 @@ from .models import ScanResult
console = Console() console = Console()
def print_banner(): def print_banner():
"""Display a beautiful banner""" """Display a beautiful banner"""
banner = """ banner = """
@ -33,6 +34,7 @@ def print_banner():
) )
console.print() console.print()
def display_summary(results: ScanResult): def display_summary(results: ScanResult):
"""Display a comprehensive summary of results""" """Display a comprehensive summary of results"""
console.print() console.print()
@ -87,28 +89,29 @@ def display_summary(results: ScanResult):
console.print() console.print()
def display_batch_results(results: List[Dict]): def display_batch_results(results: List[Dict]):
"""Display batch scan results in a table""" """Display batch scan results in a table"""
console.print() console.print()
# Create results table # Create results table
results_table = Table(title="Batch Scan Results", box=box.ROUNDED) results_table = Table(title="Batch Scan Results", box=box.ROUNDED)
results_table.add_column("URL", style="cyan", no_wrap=False) results_table.add_column("URL", style="cyan", no_wrap=False)
results_table.add_column("Status", justify="center") results_table.add_column("Status", justify="center")
results_table.add_column("Vulnerabilities", style="magenta") results_table.add_column("Vulnerabilities", style="magenta")
vulnerable_count = 0 vulnerable_count = 0
successful_count = 0 successful_count = 0
for result in results: for result in results:
url = result['url'][:60] + '...' if len(result['url']) > 60 else result['url'] url = result["url"][:60] + "..." if len(result["url"]) > 60 else result["url"]
if result.get('error'): if result.get("error"):
status = "[red]✗ Error[/red]" status = "[red]✗ Error[/red]"
vulns = f"[red]{result['error'][:40]}[/red]" vulns = f"[red]{result['error'][:40]}[/red]"
elif result['success']: elif result["success"]:
successful_count += 1 successful_count += 1
if result['is_vulnerable']: if result["is_vulnerable"]:
vulnerable_count += 1 vulnerable_count += 1
status = "[red]✓ Vulnerable[/red]" status = "[red]✓ Vulnerable[/red]"
vulns = f"[red]{len(result['vulnerabilities'])} found[/red]" vulns = f"[red]{len(result['vulnerabilities'])} found[/red]"
@ -118,11 +121,11 @@ def display_batch_results(results: List[Dict]):
else: else:
status = "[yellow]✗ Failed[/yellow]" status = "[yellow]✗ Failed[/yellow]"
vulns = "[yellow]N/A[/yellow]" vulns = "[yellow]N/A[/yellow]"
results_table.add_row(url, status, vulns) results_table.add_row(url, status, vulns)
console.print(results_table) console.print(results_table)
# Summary # Summary
console.print() console.print()
summary = f""" summary = f"""
@ -132,14 +135,14 @@ def display_batch_results(results: List[Dict]):
Vulnerable: [red]{vulnerable_count}[/red] Vulnerable: [red]{vulnerable_count}[/red]
Clean: [green]{successful_count - vulnerable_count}[/green] Clean: [green]{successful_count - vulnerable_count}[/green]
""" """
border_color = "red" if vulnerable_count > 0 else "green" border_color = "red" if vulnerable_count > 0 else "green"
console.print( console.print(
Panel( Panel(
summary.strip(), summary.strip(),
title="[bold]Summary[/bold]", title="[bold]Summary[/bold]",
border_style=border_color, border_style=border_color,
box=box.DOUBLE box=box.DOUBLE,
) )
) )
console.print() console.print()

View File

@ -26,6 +26,7 @@ from sql_cli.ui import print_banner
console = Console() console = Console()
def interactive_mode(scanner: SQLMapScanner): def interactive_mode(scanner: SQLMapScanner):
"""Interactive mode for user input""" """Interactive mode for user input"""
console.print() console.print()
@ -79,9 +80,7 @@ def interactive_mode(scanner: SQLMapScanner):
max_level = int( max_level = int(
Prompt.ask("[cyan]Maximum test level (1-5)[/cyan]", default="5") Prompt.ask("[cyan]Maximum test level (1-5)[/cyan]", default="5")
) )
max_risk = int( max_risk = int(Prompt.ask("[cyan]Maximum test risk (1-3)[/cyan]", default="3"))
Prompt.ask("[cyan]Maximum test risk (1-3)[/cyan]", default="3")
)
scanner.comprehensive_scan(url, max_level, max_risk, data=data, headers=headers) scanner.comprehensive_scan(url, max_level, max_risk, data=data, headers=headers)
@ -101,64 +100,107 @@ Examples:
python sqlmapcli.py -u "https://demo.owasp-juice.shop/rest/products/search?q=test" --comprehensive python sqlmapcli.py -u "https://demo.owasp-juice.shop/rest/products/search?q=test" --comprehensive
# Batch mode - test multiple endpoints from JSON file # Batch mode - test multiple endpoints from JSON file
python sqlmapcli.py -b endpoints.json --level 2 --risk 2 --concurrency 10 python sqlmapcli.py -b endpoints.json --level 2 --risk 2
# Interactive mode # Interactive mode
python sqlmapcli.py --interactive python sqlmapcli.py --interactive
""", """,
) )
parser.add_argument("-u", "--url", help='Target URL (e.g., "http://example.com/page?id=1")') parser.add_argument(
parser.add_argument("--comprehensive", action="store_true", help="Run comprehensive scan") "-u", "--url", help='Target URL (e.g., "http://example.com/page?id=1")'
parser.add_argument("--level", type=int, default=1, choices=[1, 2, 3, 4, 5], help="Level (1-5, default: 1)") )
parser.add_argument("--risk", type=int, default=1, choices=[1, 2, 3], help="Risk (1-3, default: 1)") parser.add_argument(
parser.add_argument("--max-level", type=int, default=5, choices=[1, 2, 3, 4, 5], help="Max level for comprehensive") "--comprehensive", action="store_true", help="Run comprehensive scan"
parser.add_argument("--max-risk", type=int, default=3, choices=[1, 2, 3], help="Max risk for comprehensive") )
parser.add_argument("--technique", type=str, default="BEUSTQ", help="SQL techniques (default: BEUSTQ)") parser.add_argument(
parser.add_argument("--data", type=str, help='POST data') "--level",
parser.add_argument("--headers", type=str, help='Extra headers') type=int,
default=1,
choices=[1, 2, 3, 4, 5],
help="Level (1-5, default: 1)",
)
parser.add_argument(
"--risk", type=int, default=1, choices=[1, 2, 3], help="Risk (1-3, default: 1)"
)
parser.add_argument(
"--max-level",
type=int,
default=5,
choices=[1, 2, 3, 4, 5],
help="Max level for comprehensive",
)
parser.add_argument(
"--max-risk",
type=int,
default=3,
choices=[1, 2, 3],
help="Max risk for comprehensive",
)
parser.add_argument(
"--technique",
type=str,
default="BEUSTQ",
help="SQL techniques (default: BEUSTQ)",
)
parser.add_argument("--data", type=str, help="POST data")
parser.add_argument("--headers", type=str, help="Extra headers")
parser.add_argument("--raw", action="store_true", help="Show raw sqlmap output") parser.add_argument("--raw", action="store_true", help="Show raw sqlmap output")
parser.add_argument("--verbose", type=int, choices=[0, 1, 2, 3, 4, 5, 6], help="Verbosity (0-6)") parser.add_argument(
parser.add_argument("-i", "--interactive", action="store_true", help="Interactive mode") "--verbose", type=int, choices=[0, 1, 2, 3, 4, 5, 6], help="Verbosity (0-6)"
parser.add_argument('-b', '--batch-file', type=str, help='Path to batch JSON') )
parser.add_argument('-c', '--concurrency', type=int, default=5, help='Concurrency (default: 5)') parser.add_argument(
parser.add_argument('--no-logs', action='store_true', help='Disable logs') "-i", "--interactive", action="store_true", help="Interactive mode"
)
parser.add_argument("-b", "--batch-file", type=str, help="Path to batch JSON")
parser.add_argument(
"-c",
"--concurrency",
type=int,
default=0,
help="Number of concurrent scans (default: 0 for auto-scale)",
)
parser.add_argument("--no-logs", action="store_true", help="Disable logs")
args = parser.parse_args() args = parser.parse_args()
scanner = SQLMapScanner(enable_logging=not args.no_logs) scanner = SQLMapScanner(enable_logging=not args.no_logs)
print_banner() print_banner()
if not SQLMAP_PATH.exists(): if not SQLMAP_PATH.exists():
console.print(f"[bold red]Error: sqlmap.py not found at {SQLMAP_PATH}[/bold red]") console.print(
f"[bold red]Error: sqlmap.py not found at {SQLMAP_PATH}[/bold red]"
)
sys.exit(1) sys.exit(1)
if args.interactive: if args.interactive:
interactive_mode(scanner) interactive_mode(scanner)
return return
if args.batch_file: if args.batch_file:
try: try:
with open(args.batch_file, 'r') as f: with open(args.batch_file, "r") as f:
endpoints = json.load(f) endpoints = json.load(f)
if not isinstance(endpoints, list): if not isinstance(endpoints, list):
console.print("[bold red]Error: Batch file must contain a JSON array[/bold red]") console.print(
"[bold red]Error: Batch file must contain a JSON array[/bold red]"
)
sys.exit(1) sys.exit(1)
verbose_level = args.verbose if args.verbose is not None else 1 verbose_level = args.verbose if args.verbose is not None else 1
scanner.batch_scan( scanner.batch_scan(
endpoints, endpoints,
level=args.level, level=args.level,
risk=args.risk, risk=args.risk,
concurrency=args.concurrency, concurrency=args.concurrency,
verbose=verbose_level verbose=verbose_level,
) )
return return
except Exception as e: except Exception as e:
console.print(f"[bold red]Error loading batch file: {e}[/bold red]") console.print(f"[bold red]Error loading batch file: {e}[/bold red]")
sys.exit(1) sys.exit(1)
if not args.url: if not args.url:
console.print("[bold red]Error: URL is required[/bold red]") console.print("[bold red]Error: URL is required[/bold red]")
parser.print_help() parser.print_help()