From 93a204ef01be0d850353878b1bbd21d5e0a14f53 Mon Sep 17 00:00:00 2001 From: Wilbert Chandra <90319182+GilbertKrantz@users.noreply.github.com> Date: Wed, 7 Jan 2026 12:55:57 +0000 Subject: [PATCH] Enhance SQLMapScanner with real-time progress updates and temporary output handling; improve UI display for batch scan results --- sql_cli/scanner.py | 279 ++++++++++++++++++++++++++++++++++----------- sql_cli/ui.py | 29 ++--- sqlmapcli.py | 98 +++++++++++----- 3 files changed, 298 insertions(+), 108 deletions(-) diff --git a/sql_cli/scanner.py b/sql_cli/scanner.py index 8a21a76a6..3ccfc8b93 100644 --- a/sql_cli/scanner.py +++ b/sql_cli/scanner.py @@ -2,6 +2,9 @@ from rich.table import Table import sys import subprocess import json +import os +import tempfile +import shutil from datetime import datetime from typing import List, Dict, Tuple, Optional, Any from concurrent.futures import ThreadPoolExecutor, as_completed @@ -22,15 +25,16 @@ from .ui import display_summary, display_batch_results console = Console() + class SQLMapScanner: def __init__(self, enable_logging: bool = True): self.enable_logging = enable_logging self.results: ScanResult = { - 'total_tests': 0, - 'vulnerabilities': [], - 'start_time': None, - 'end_time': None, - 'target': None + "total_tests": 0, + "vulnerabilities": [], + "start_time": None, + "end_time": None, + "target": None, } def run_sqlmap_test( @@ -44,8 +48,10 @@ class SQLMapScanner: headers: Optional[str] = None, verbose: int = 1, extra_args: Optional[List[str]] = None, + progress: Optional[Progress] = None, + task_id: Any = None, ) -> Tuple[bool, str]: - """Run sqlmap with specified parameters""" + """Run sqlmap with specified parameters and optional real-time progress""" cmd = [ sys.executable, str(SQLMAP_PATH), @@ -70,16 +76,80 @@ class SQLMapScanner: if extra_args: cmd.extend(extra_args) + # Create a unique temporary directory for this run to avoid session database locks + # which are the primary cause of concurrency bottlenecks in sqlmap + tmp_output_dir = tempfile.mkdtemp(prefix="sqlmap_scan_") + cmd.extend(["--output-dir", tmp_output_dir]) + try: - result = subprocess.run( - cmd, - capture_output=True, - text=True - ) - return result.returncode == 0, result.stdout + result.stderr + if progress and task_id: + # Run with real-time output parsing + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + ) + + output_lines = [] + if process.stdout is None: + return False, "Failed to capture sqlmap output" + for line in process.stdout: + output_lines.append(line) + + # Update progress description based on sqlmap status + clean_line = line.strip() + if "[INFO]" in clean_line: + status = clean_line.split("[INFO]", 1)[1].strip() + # Clean up status message + if "testing" in status.lower(): + progress.update( + task_id, description=f"[cyan]{status[:50]}[/cyan]" + ) + elif "detecting" in status.lower(): + progress.update( + task_id, description=f"[yellow]{status[:50]}[/yellow]" + ) + elif "identified" in status.lower(): + progress.update( + task_id, description=f"[green]{status[:50]}[/green]" + ) + + process.wait() + full_output = "".join(output_lines) + + # Cleanup temporary output directory + try: + shutil.rmtree(tmp_output_dir) + except: + pass + + return process.returncode == 0, full_output + else: + # Standard blocking run + result = subprocess.run(cmd, capture_output=True, text=True) + + # Cleanup temporary output directory + try: + shutil.rmtree(tmp_output_dir) + except: + pass + + return result.returncode == 0, result.stdout + result.stderr except subprocess.TimeoutExpired: + # Cleanup on timeout + try: + shutil.rmtree(tmp_output_dir) + except: + pass return False, "Test timed out after 10 minutes" except Exception as e: + # Cleanup on error + try: + shutil.rmtree(tmp_output_dir) + except: + pass return False, str(e) def parse_results(self, output: str) -> Dict[str, Any]: @@ -157,11 +227,19 @@ class SQLMapScanner: for risk in range(1, max_risk + 1): progress.update( overall_task, - description=f"[cyan]Testing Level {level}, Risk {risk}...", + description=f"[cyan]Testing Level {level}, Risk {risk}...[/cyan]", ) success, output = self.run_sqlmap_test( - url, level, risk, techniques, data=data, headers=headers, verbose=verbose + url, + level, + risk, + techniques, + data=data, + headers=headers, + verbose=max(verbose, 3), + progress=progress, + task_id=overall_task, ) parsed = self.parse_results(output) @@ -226,7 +304,7 @@ class SQLMapScanner: url, level, risk, data=data, headers=headers, verbose=verbose ) console.print(output) - + if self.enable_logging: log_file = get_log_filename(url) save_log(log_file, output) @@ -238,18 +316,25 @@ class SQLMapScanner: TimeElapsedColumn(), console=console, ) as progress: - task = progress.add_task( - "[cyan]Scanning for vulnerabilities...", total=None - ) + task = progress.add_task(f"[cyan]Scanning {url[:40]}...", total=None) success, output = self.run_sqlmap_test( - url, level, risk, data=data, headers=headers, verbose=verbose + url, + level, + risk, + data=data, + headers=headers, + verbose=max(verbose, 3), + progress=progress, + task_id=task, ) - progress.update(task, completed=True) - + progress.update( + task, completed=True, description="[green]✓ Scan Complete[/green]" + ) + if self.enable_logging: log_file = get_log_filename(url) save_log(log_file, output) - + parsed = self.parse_results(output) self.results["vulnerabilities"] = parsed["vulnerabilities"] self.results["total_tests"] = 1 @@ -257,88 +342,148 @@ class SQLMapScanner: display_summary(self.results) - def process_single_endpoint(self, endpoint: Dict, level: int, risk: int, verbose: int) -> Dict: + def process_single_endpoint( + self, + endpoint: Dict, + level: int, + risk: int, + verbose: int, + progress: Optional[Progress] = None, + task_id: Any = None, + ) -> Dict: """Process a single endpoint for batch mode""" - url = str(endpoint.get('url')) if endpoint.get('url') else '' - - data = endpoint.get('data') + url = str(endpoint.get("url")) if endpoint.get("url") else "" + + data = endpoint.get("data") if data is not None and not isinstance(data, str): data = json.dumps(data) - - headers = endpoint.get('headers') + + headers = endpoint.get("headers") if headers is not None and isinstance(headers, list): headers = "\\n".join(headers) - + try: - success, output = self.run_sqlmap_test(url, level, risk, data=data, headers=headers, verbose=verbose) - + # Force verbosity 3 for better progress tracking if in batch mode + # unless a specific high verbosity is already requested + exec_verbose = max(verbose, 3) if progress else verbose + + success, output = self.run_sqlmap_test( + url, + level, + risk, + data=data, + headers=headers, + verbose=exec_verbose, + progress=progress, + task_id=task_id, + ) + if self.enable_logging: log_file = get_log_filename(url) save_log(log_file, output) - + parsed = self.parse_results(output) + + if progress and task_id: + status_color = "red" if parsed["is_vulnerable"] else "green" + status_text = "Vulnerable" if parsed["is_vulnerable"] else "Clean" + progress.update( + task_id, + description=f"[{status_color}]✓ {status_text}[/{status_color}] - {url[:30]}...", + completed=100, + ) + return { - 'url': url, - 'data': data, - 'success': success, - 'vulnerabilities': parsed['vulnerabilities'], - 'is_vulnerable': parsed['is_vulnerable'] + "url": url, + "data": data, + "success": success, + "vulnerabilities": parsed["vulnerabilities"], + "is_vulnerable": parsed["is_vulnerable"], } except Exception as e: + if progress and task_id: + progress.update( + task_id, + description=f"[bold red]✗ Error: {str(e)[:30]}[/bold red]", + completed=100, + ) return { - 'url': url, - 'data': data, - 'success': False, - 'error': str(e), - 'vulnerabilities': [], - 'is_vulnerable': False + "url": url, + "data": data, + "success": False, + "error": str(e), + "vulnerabilities": [], + "is_vulnerable": False, } - def batch_scan(self, endpoints: List[Dict], level: int = 1, risk: int = 1, - concurrency: int = 5, verbose: int = 1): + def batch_scan( + self, + endpoints: List[Dict], + level: int = 1, + risk: int = 1, + concurrency: int = 5, + verbose: int = 1, + ): """Run batch scan on multiple endpoints with concurrency""" + + # Determine actual concurrency + if concurrency <= 0: + # For I/O bound tasks like scanning, we can use 2x CPU count + concurrency = (os.cpu_count() or 2) * 2 + console.print( Panel( f"[cyan]Batch Scan Mode[/cyan]\n" f"[dim]Testing {len(endpoints)} endpoint(s) with concurrency={concurrency}[/dim]\n" f"[dim]Level: {level}, Risk: {risk}[/dim]", border_style="cyan", - box=box.ROUNDED + box=box.ROUNDED, ) ) - + results = [] with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), - TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), - TextColumn("({task.completed}/{task.total})"), TimeElapsedColumn(), - console=console + console=console, + expand=True, ) as progress: - task = progress.add_task("[cyan]Processing endpoints...", total=len(endpoints)) - with ThreadPoolExecutor(max_workers=concurrency) as executor: - future_to_endpoint = { - executor.submit(self.process_single_endpoint, endpoint, level, risk, verbose): endpoint - for endpoint in endpoints - } - + future_to_endpoint = {} + + for endpoint in endpoints: + url = endpoint.get("url", "Unknown") + task_id = progress.add_task( + f"[dim]Waiting...[/dim] {url[:40]}...", total=100 + ) + future = executor.submit( + self.process_single_endpoint, + endpoint, + level, + risk, + verbose, + progress, + task_id, + ) + future_to_endpoint[future] = endpoint + for future in as_completed(future_to_endpoint): endpoint = future_to_endpoint[future] try: results.append(future.result()) except Exception as e: - results.append({ - 'url': endpoint.get('url'), - 'data': endpoint.get('data'), - 'success': False, - 'error': str(e), - 'vulnerabilities': [], - 'is_vulnerable': False - }) - progress.update(task, advance=1) - + results.append( + { + "url": endpoint.get("url"), + "data": endpoint.get("data"), + "success": False, + "error": str(e), + "vulnerabilities": [], + "is_vulnerable": False, + } + ) + display_batch_results(results) return results diff --git a/sql_cli/ui.py b/sql_cli/ui.py index 7f0daaac2..171977474 100644 --- a/sql_cli/ui.py +++ b/sql_cli/ui.py @@ -7,6 +7,7 @@ from .models import ScanResult console = Console() + def print_banner(): """Display a beautiful banner""" banner = """ @@ -33,6 +34,7 @@ def print_banner(): ) console.print() + def display_summary(results: ScanResult): """Display a comprehensive summary of results""" console.print() @@ -87,28 +89,29 @@ def display_summary(results: ScanResult): console.print() + def display_batch_results(results: List[Dict]): """Display batch scan results in a table""" console.print() - + # Create results table results_table = Table(title="Batch Scan Results", box=box.ROUNDED) results_table.add_column("URL", style="cyan", no_wrap=False) results_table.add_column("Status", justify="center") results_table.add_column("Vulnerabilities", style="magenta") - + vulnerable_count = 0 successful_count = 0 - + for result in results: - url = result['url'][:60] + '...' if len(result['url']) > 60 else result['url'] - - if result.get('error'): + url = result["url"][:60] + "..." if len(result["url"]) > 60 else result["url"] + + if result.get("error"): status = "[red]✗ Error[/red]" vulns = f"[red]{result['error'][:40]}[/red]" - elif result['success']: + elif result["success"]: successful_count += 1 - if result['is_vulnerable']: + if result["is_vulnerable"]: vulnerable_count += 1 status = "[red]✓ Vulnerable[/red]" vulns = f"[red]{len(result['vulnerabilities'])} found[/red]" @@ -118,11 +121,11 @@ def display_batch_results(results: List[Dict]): else: status = "[yellow]✗ Failed[/yellow]" vulns = "[yellow]N/A[/yellow]" - + results_table.add_row(url, status, vulns) - + console.print(results_table) - + # Summary console.print() summary = f""" @@ -132,14 +135,14 @@ def display_batch_results(results: List[Dict]): Vulnerable: [red]{vulnerable_count}[/red] Clean: [green]{successful_count - vulnerable_count}[/green] """ - + border_color = "red" if vulnerable_count > 0 else "green" console.print( Panel( summary.strip(), title="[bold]Summary[/bold]", border_style=border_color, - box=box.DOUBLE + box=box.DOUBLE, ) ) console.print() diff --git a/sqlmapcli.py b/sqlmapcli.py index 7ffaf9fbe..6262aea4d 100755 --- a/sqlmapcli.py +++ b/sqlmapcli.py @@ -26,6 +26,7 @@ from sql_cli.ui import print_banner console = Console() + def interactive_mode(scanner: SQLMapScanner): """Interactive mode for user input""" console.print() @@ -79,9 +80,7 @@ def interactive_mode(scanner: SQLMapScanner): max_level = int( Prompt.ask("[cyan]Maximum test level (1-5)[/cyan]", default="5") ) - max_risk = int( - Prompt.ask("[cyan]Maximum test risk (1-3)[/cyan]", default="3") - ) + max_risk = int(Prompt.ask("[cyan]Maximum test risk (1-3)[/cyan]", default="3")) scanner.comprehensive_scan(url, max_level, max_risk, data=data, headers=headers) @@ -101,64 +100,107 @@ Examples: python sqlmapcli.py -u "https://demo.owasp-juice.shop/rest/products/search?q=test" --comprehensive # Batch mode - test multiple endpoints from JSON file - python sqlmapcli.py -b endpoints.json --level 2 --risk 2 --concurrency 10 + python sqlmapcli.py -b endpoints.json --level 2 --risk 2 # Interactive mode python sqlmapcli.py --interactive """, ) - parser.add_argument("-u", "--url", help='Target URL (e.g., "http://example.com/page?id=1")') - parser.add_argument("--comprehensive", action="store_true", help="Run comprehensive scan") - parser.add_argument("--level", type=int, default=1, choices=[1, 2, 3, 4, 5], help="Level (1-5, default: 1)") - parser.add_argument("--risk", type=int, default=1, choices=[1, 2, 3], help="Risk (1-3, default: 1)") - parser.add_argument("--max-level", type=int, default=5, choices=[1, 2, 3, 4, 5], help="Max level for comprehensive") - parser.add_argument("--max-risk", type=int, default=3, choices=[1, 2, 3], help="Max risk for comprehensive") - parser.add_argument("--technique", type=str, default="BEUSTQ", help="SQL techniques (default: BEUSTQ)") - parser.add_argument("--data", type=str, help='POST data') - parser.add_argument("--headers", type=str, help='Extra headers') + parser.add_argument( + "-u", "--url", help='Target URL (e.g., "http://example.com/page?id=1")' + ) + parser.add_argument( + "--comprehensive", action="store_true", help="Run comprehensive scan" + ) + parser.add_argument( + "--level", + type=int, + default=1, + choices=[1, 2, 3, 4, 5], + help="Level (1-5, default: 1)", + ) + parser.add_argument( + "--risk", type=int, default=1, choices=[1, 2, 3], help="Risk (1-3, default: 1)" + ) + parser.add_argument( + "--max-level", + type=int, + default=5, + choices=[1, 2, 3, 4, 5], + help="Max level for comprehensive", + ) + parser.add_argument( + "--max-risk", + type=int, + default=3, + choices=[1, 2, 3], + help="Max risk for comprehensive", + ) + parser.add_argument( + "--technique", + type=str, + default="BEUSTQ", + help="SQL techniques (default: BEUSTQ)", + ) + parser.add_argument("--data", type=str, help="POST data") + parser.add_argument("--headers", type=str, help="Extra headers") parser.add_argument("--raw", action="store_true", help="Show raw sqlmap output") - parser.add_argument("--verbose", type=int, choices=[0, 1, 2, 3, 4, 5, 6], help="Verbosity (0-6)") - parser.add_argument("-i", "--interactive", action="store_true", help="Interactive mode") - parser.add_argument('-b', '--batch-file', type=str, help='Path to batch JSON') - parser.add_argument('-c', '--concurrency', type=int, default=5, help='Concurrency (default: 5)') - parser.add_argument('--no-logs', action='store_true', help='Disable logs') - + parser.add_argument( + "--verbose", type=int, choices=[0, 1, 2, 3, 4, 5, 6], help="Verbosity (0-6)" + ) + parser.add_argument( + "-i", "--interactive", action="store_true", help="Interactive mode" + ) + parser.add_argument("-b", "--batch-file", type=str, help="Path to batch JSON") + parser.add_argument( + "-c", + "--concurrency", + type=int, + default=0, + help="Number of concurrent scans (default: 0 for auto-scale)", + ) + parser.add_argument("--no-logs", action="store_true", help="Disable logs") + args = parser.parse_args() - + scanner = SQLMapScanner(enable_logging=not args.no_logs) print_banner() if not SQLMAP_PATH.exists(): - console.print(f"[bold red]Error: sqlmap.py not found at {SQLMAP_PATH}[/bold red]") + console.print( + f"[bold red]Error: sqlmap.py not found at {SQLMAP_PATH}[/bold red]" + ) sys.exit(1) if args.interactive: interactive_mode(scanner) return - + if args.batch_file: try: - with open(args.batch_file, 'r') as f: + with open(args.batch_file, "r") as f: endpoints = json.load(f) - + if not isinstance(endpoints, list): - console.print("[bold red]Error: Batch file must contain a JSON array[/bold red]") + console.print( + "[bold red]Error: Batch file must contain a JSON array[/bold red]" + ) sys.exit(1) - + verbose_level = args.verbose if args.verbose is not None else 1 scanner.batch_scan( endpoints, level=args.level, risk=args.risk, concurrency=args.concurrency, - verbose=verbose_level + verbose=verbose_level, ) return except Exception as e: console.print(f"[bold red]Error loading batch file: {e}[/bold red]") sys.exit(1) - + if not args.url: console.print("[bold red]Error: URL is required[/bold red]") parser.print_help()