diff --git a/_sqlmap.py b/_sqlmap.py index 4e74d10d6..9de694a2b 100644 --- a/_sqlmap.py +++ b/_sqlmap.py @@ -36,6 +36,7 @@ from lib.core.settings import LEGAL_DISCLAIMER from lib.core.testing import smokeTest from lib.core.testing import liveTest from lib.parse.cmdline import cmdLineParser +from lib.utils.api import StdDbOut def modulePath(): """ @@ -53,16 +54,22 @@ def main(): try: paths.SQLMAP_ROOT_PATH = modulePath() setPaths() + + # Store original command line options for possible later restoration + cmdLineOptions.update(cmdLineParser().__dict__) + init(cmdLineOptions) + + if hasattr(conf, "api"): + # Overwrite system standard output and standard error to write + # to an IPC database + sys.stdout = StdDbOut(conf.taskid, messagetype="stdout") + sys.stderr = StdDbOut(conf.taskid, messagetype="stderr") + banner() dataToStdout("[!] legal disclaimer: %s\n\n" % LEGAL_DISCLAIMER, forceOutput=True) dataToStdout("[*] starting at %s\n\n" % time.strftime("%X"), forceOutput=True) - # Store original command line options for possible later restoration - cmdLineOptions.update(cmdLineParser().__dict__) - - init(cmdLineOptions) - if conf.profile: profile() elif conf.smokeTest: @@ -115,6 +122,13 @@ def main(): except KeyboardInterrupt: pass + if hasattr(conf, "api"): + try: + conf.database_cursor.close() + conf.database_connection.close() + except KeyboardInterrupt: + pass + # Reference: http://stackoverflow.com/questions/1635080/terminate-a-multi-thread-python-program if conf.get("threads", 0) > 1 or conf.get("dnsServer"): os._exit(0) diff --git a/extra/shutils/regressiontest.py b/extra/shutils/regressiontest.py index 06cc37c5a..0ec23769f 100644 --- a/extra/shutils/regressiontest.py +++ b/extra/shutils/regressiontest.py @@ -4,17 +4,19 @@ # See the file 'doc/COPYING' for copying permission import codecs +import inspect import os import re import smtplib import subprocess import sys import time +import traceback from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText -sys.path.append("../../") +sys.path.append(os.path.normpath("%s/../../" % os.path.dirname(inspect.getfile(inspect.currentframe())))) from lib.core.revision import getRevisionNumber @@ -64,7 +66,7 @@ def main(): test_counts = [] attachments = {} - command_line = "cd %s && python sqlmap.py --live-test" % SQLMAP_HOME + command_line = "python /opt/sqlmap/sqlmap.py --live-test" proc = subprocess.Popen(command_line, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) proc.wait() @@ -138,4 +140,13 @@ def main(): send_email(msg) if __name__ == "__main__": - main() + log_fd = open("/tmp/sqlmapregressiontest.log", "wb") + log_fd.write("Regression test started at %s\n" % TIME) + + try: + main() + except Exception, e: + log_fd.write("An exception has occurred:\n%s" % str(traceback.format_exc())) + + log_fd.write("Regression test finished at %s\n\n" % TIME) + log_fd.close() diff --git a/extra/shutils/regressiontest_cronjob.sh b/extra/shutils/regressiontest_cronjob.sh deleted file mode 100755 index 44661a852..000000000 --- a/extra/shutils/regressiontest_cronjob.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/usr/bin/env bash - -# Copyright (c) 2006-2013 sqlmap developers (http://sqlmap.org/) -# See the file 'doc/COPYING' for copying permission - -SQLMAP_HOME="/opt/sqlmap" -REGRESSION_SCRIPT="${SQLMAP_HOME}/extra/shutils" - -FROM="regressiontest@sqlmap.org" -TO="bernardo.damele@gmail.com, miroslav.stampar@gmail.com" -SUBJECT="Automated regression test failed on $(date)" - -cd $SQLMAP_HOME -git pull -rm -f output 2>/dev/null - -cd $REGRESSION_SCRIPT -echo "Regression test started at $(date)" 1>/tmp/regressiontest.log 2>&1 -python regressiontest.py 1>>/tmp/regressiontest.log 2>&1 - -if [ $? -ne 0 ] -then - echo "Regression test finished at $(date)" 1>>/tmp/regressiontest.log 2>&1 - cat /tmp/regressiontest.log | mailx -s "${SUBJECT}" -aFrom:${FROM} ${TO} -else - echo "Regression test finished at $(date)" 1>>/tmp/regressiontest.log 2>&1 -fi diff --git a/lib/core/common.py b/lib/core/common.py index 2e5733f6f..233648258 100644 --- a/lib/core/common.py +++ b/lib/core/common.py @@ -742,7 +742,7 @@ def setColor(message, bold=False): return retVal -def dataToStdout(data, forceOutput=False, bold=False): +def dataToStdout(data, forceOutput=False, bold=False, content_type=None, status=None): """ Writes text to the stdout (console) stream """ @@ -754,8 +754,15 @@ def dataToStdout(data, forceOutput=False, bold=False): if kb.get("multiThreadMode"): logging._acquireLock() - message = stdoutencode(data) - sys.stdout.write(setColor(message, bold)) + if isinstance(data, basestring): + message = stdoutencode(data) + else: + message = data + + if hasattr(conf, "api"): + sys.stdout.write(message, status=status, content_type=content_type) + else: + sys.stdout.write(setColor(message, bold)) try: sys.stdout.flush() diff --git a/lib/core/convert.py b/lib/core/convert.py index 890328615..28afe17bd 100644 --- a/lib/core/convert.py +++ b/lib/core/convert.py @@ -104,3 +104,6 @@ def stdoutencode(data): def jsonize(data): return json.dumps(data, sort_keys=False, indent=4) + +def dejsonize(data): + return json.loads(data) diff --git a/lib/core/dump.py b/lib/core/dump.py index 4b553f424..f2e7c9318 100644 --- a/lib/core/dump.py +++ b/lib/core/dump.py @@ -26,6 +26,8 @@ from lib.core.data import conf from lib.core.data import kb from lib.core.data import logger from lib.core.dicts import DUMP_REPLACEMENTS +from lib.core.enums import API_CONTENT_STATUS +from lib.core.enums import API_CONTENT_TYPE from lib.core.enums import DBMS from lib.core.enums import DUMP_FORMAT from lib.core.exception import SqlmapGenericException @@ -52,8 +54,13 @@ class Dump(object): self._outputFP = None self._lock = threading.Lock() - def _write(self, data, newline=True, console=True): + def _write(self, data, newline=True, console=True, content_type=None): + if hasattr(conf, "api"): + dataToStdout(data, content_type=content_type, status=API_CONTENT_STATUS.COMPLETE) + return + text = "%s%s" % (data, "\n" if newline else " ") + if console: dataToStdout(text) @@ -81,7 +88,7 @@ class Dump(object): def singleString(self, data): self._write(data) - def string(self, header, data, sort=True): + def string(self, header, data, content_type=None, sort=True): kb.stickyLevel = None if isListLike(data): @@ -92,18 +99,19 @@ class Dump(object): if _ and _[-1] == '\n': _ = _[:-1] - if "\n" in _: + if hasattr(conf, "api"): + self._write(data, content_type=content_type) + elif "\n" in _: self._write("%s:\n---\n%s\n---" % (header, _)) else: self._write("%s: %s" % (header, ("'%s'" % _) if isinstance(data, basestring) else _)) + elif hasattr(conf, "api"): + self._write(data, content_type=content_type) else: self._write("%s:\tNone" % header) - def lister(self, header, elements, sort=True): - if elements: - self._write("%s [%d]:" % (header, len(elements))) - - if sort: + def lister(self, header, elements, content_type=None, sort=True): + if elements and sort: try: elements = set(elements) elements = list(elements) @@ -111,6 +119,13 @@ class Dump(object): except: pass + if hasattr(conf, "api"): + self._write(elements, content_type=content_type) + return + + if elements: + self._write("%s [%d]:" % (header, len(elements))) + for element in elements: if isinstance(element, basestring): self._write("[*] %s" % element) @@ -121,29 +136,29 @@ class Dump(object): self._write("") def banner(self, data): - self.string("banner", data) + self.string("banner", data, content_type=API_CONTENT_TYPE.BANNER) def currentUser(self, data): - self.string("current user", data) + self.string("current user", data, content_type=API_CONTENT_TYPE.CURRENT_USER) def currentDb(self, data): if Backend.isDbms(DBMS.MAXDB): - self.string("current database (no practical usage on %s)" % Backend.getIdentifiedDbms(), data) + self.string("current database (no practical usage on %s)" % Backend.getIdentifiedDbms(), data, content_type=API_CONTENT_TYPE.CURRENT_DB) elif Backend.isDbms(DBMS.ORACLE): - self.string("current schema (equivalent to database on %s)" % Backend.getIdentifiedDbms(), data) + self.string("current schema (equivalent to database on %s)" % Backend.getIdentifiedDbms(), data, content_type=API_CONTENT_TYPE.CURRENT_DB) else: - self.string("current database", data) + self.string("current database", data, content_type=API_CONTENT_TYPE.CURRENT_DB) def hostname(self, data): - self.string("hostname", data) + self.string("hostname", data, content_type=API_CONTENT_TYPE.HOSTNAME) def dba(self, data): - self.string("current user is DBA", data) + self.string("current user is DBA", data, content_type=API_CONTENT_TYPE.IS_DBA) def users(self, users): - self.lister("database management system users", users) + self.lister("database management system users", users, content_type=API_CONTENT_TYPE.USERS) - def userSettings(self, header, userSettings, subHeader): + def userSettings(self, header, userSettings, subHeader, content_type=None): self._areAdmins = set() if userSettings: @@ -179,9 +194,9 @@ class Dump(object): self.singleString("") def dbs(self, dbs): - self.lister("available databases", dbs) + self.lister("available databases", dbs, content_type=API_CONTENT_TYPE.DBS) - def dbTables(self, dbTables): + def dbTables(self, dbTables, content_type=API_CONTENT_TYPE.TABLES): if isinstance(dbTables, dict) and len(dbTables) > 0: maxlength = 0 @@ -219,7 +234,7 @@ class Dump(object): else: self.string("tables", dbTables) - def dbTableColumns(self, tableColumns): + def dbTableColumns(self, tableColumns, content_type=API_CONTENT_TYPE.COLUMNS): if isinstance(tableColumns, dict) and len(tableColumns) > 0: for db, tables in tableColumns.items(): if not db: @@ -286,7 +301,7 @@ class Dump(object): else: self._write("+%s+\n" % lines1) - def dbTablesCount(self, dbTables): + def dbTablesCount(self, dbTables, content_type=API_CONTENT_TYPE.COUNT): if isinstance(dbTables, dict) and len(dbTables) > 0: maxlength1 = len("Table") maxlength2 = len("Entries") @@ -328,7 +343,7 @@ class Dump(object): else: logger.error("unable to retrieve the number of entries for any table") - def dbTableValues(self, tableValues): + def dbTableValues(self, tableValues, content_type=API_CONTENT_TYPE.DUMP_TABLE): replication = None rtable = None dumpFP = None @@ -534,7 +549,7 @@ class Dump(object): dumpFP.close() logger.info("table '%s.%s' dumped to %s file '%s'" % (db, table, conf.dumpFormat, dumpFileName)) - def dbColumns(self, dbColumnsDict, colConsider, dbs): + def dbColumns(self, dbColumnsDict, colConsider, dbs, content_type=API_CONTENT_TYPE.COLUMNS): for column in dbColumnsDict.keys(): if colConsider == "1": colConsiderStr = "s like '" + column + "' were" @@ -565,13 +580,13 @@ class Dump(object): self.dbTableColumns(_) def query(self, query, queryRes): - self.string(query, queryRes) + self.string(query, queryRes, content_type=API_CONTENT_TYPE.SQL_QUERY) def rFile(self, fileData): - self.lister("files saved to", fileData, sort=False) + self.lister("files saved to", fileData, sort=False, content_type=API_CONTENT_TYPE.FILE_READ) - def registerValue(self, registerData): - self.string("Registry key value data", registerData, sort=False) + def registerValue(self): + self.string("Registry key value data", registerData, registerData, content_type=API_CONTENT_TYPE.REG_READ, sort=False) # object to manage how to print the retrieved queries output to # standard output and sessions file diff --git a/lib/core/enums.py b/lib/core/enums.py index a9019c416..d05c05347 100644 --- a/lib/core/enums.py +++ b/lib/core/enums.py @@ -243,3 +243,33 @@ class WEB_API: ASP = "asp" ASPX = "aspx" JSP = "jsp" + +class API_CONTENT_TYPE: + TECHNIQUES = 0 + BANNER = 1 + CURRENT_USER = 2 + CURRENT_DB = 3 + HOSTNAME = 4 + IS_DBA = 5 + USERS = 6 + PASSWORDS = 7 + PRIVILEGES = 8 + ROLES = 9 + DBS = 10 + TABLES = 11 + COLUMNS = 12 + SCHEMA = 13 + COUNT = 14 + DUMP_TABLE = 15 + SEARCH = 16 + SQL_QUERY = 17 + COMMON_TABLES = 18 + COMMON_COLUMNS = 19 + FILE_READ = 20 + FILE_WRITE = 21 + OS_CMD = 22 + REG_READ = 23 + +class API_CONTENT_STATUS: + IN_PROGRESS = 0 + COMPLETE = 1 diff --git a/lib/core/option.py b/lib/core/option.py index f5579dc64..39476b559 100644 --- a/lib/core/option.py +++ b/lib/core/option.py @@ -87,7 +87,6 @@ from lib.core.exception import SqlmapSyntaxException from lib.core.exception import SqlmapUnsupportedDBMSException from lib.core.exception import SqlmapUserQuitException from lib.core.log import FORMATTER -from lib.core.log import LOGGER_HANDLER from lib.core.optiondict import optDict from lib.core.purge import purge from lib.core.settings import ACCESS_ALIASES @@ -137,6 +136,7 @@ from lib.request.httpshandler import HTTPSHandler from lib.request.rangehandler import HTTPRangeHandler from lib.request.redirecthandler import SmartRedirectHandler from lib.request.templates import getPageTemplate +from lib.utils.api import setRestAPILog from lib.utils.crawler import crawl from lib.utils.deps import checkDependencies from lib.utils.google import Google @@ -1795,25 +1795,6 @@ def _mergeOptions(inputOptions, overrideOptions): if hasattr(conf, key) and conf[key] is None: conf[key] = value -class LogRecorder(logging.StreamHandler): - def emit(self, record): - """ - Record emitted events to temporary database for asynchronous I/O - communication with the parent process - """ - connection = sqlite3.connect(conf.ipc, isolation_level=None) - cursor = connection.cursor() - cursor.execute("INSERT INTO logs VALUES(NULL, ?, ?, ?)", - (time.strftime("%X"), record.levelname, record.msg % record.args if record.args else record.msg)) - cursor.close() - connection.close() - -def _setRestAPILog(): - if hasattr(conf, "ipc"): - logger.removeHandler(LOGGER_HANDLER) - LOGGER_RECORDER = LogRecorder() - logger.addHandler(LOGGER_RECORDER) - def _setTrafficOutputFP(): if conf.trafficFile: infoMsg = "setting file for logging HTTP traffic" @@ -2085,7 +2066,7 @@ def init(inputOptions=AttribDict(), overrideOptions=False): _mergeOptions(inputOptions, overrideOptions) _useWizardInterface() setVerbosity() - _setRestAPILog() + setRestAPILog() _saveCmdline() _setRequestFromFile() _cleanupOptions() diff --git a/lib/core/testing.py b/lib/core/testing.py index ab502ba7f..0b72fd39f 100644 --- a/lib/core/testing.py +++ b/lib/core/testing.py @@ -265,7 +265,7 @@ def runCase(switches=None, parse=None): try: result = start() except KeyboardInterrupt: - raise + pass except SqlmapBaseException, e: handled_exception = e except Exception, e: diff --git a/lib/utils/api.py b/lib/utils/api.py index 446c89c0e..926673536 100644 --- a/lib/utils/api.py +++ b/lib/utils/api.py @@ -5,10 +5,13 @@ Copyright (c) 2006-2013 sqlmap developers (http://sqlmap.org/) See the file 'doc/COPYING' for copying permission """ +import logging import os import shutil import sqlite3 +import sys import tempfile +import time from subprocess import PIPE @@ -16,13 +19,16 @@ from lib.core.common import unArrayizeValue from lib.core.convert import base64pickle from lib.core.convert import base64unpickle from lib.core.convert import hexencode +from lib.core.convert import dejsonize from lib.core.convert import jsonize +from lib.core.data import conf from lib.core.data import paths from lib.core.data import logger from lib.core.datatype import AttribDict from lib.core.defaults import _defaults +from lib.core.log import LOGGER_HANDLER from lib.core.optiondict import optDict -from lib.core.subprocessng import Popen as execute +from lib.core.subprocessng import Popen from lib.core.subprocessng import send_all from lib.core.subprocessng import recv_some from thirdparty.bottle.bottle import abort @@ -40,8 +46,154 @@ RESTAPI_SERVER_PORT = 8775 # Local global variables adminid = "" -procs = dict() -tasks = AttribDict() +db = None +tasks = dict() + +# API objects +class Database(object): + LOGS_TABLE = "CREATE TABLE logs(id INTEGER PRIMARY KEY AUTOINCREMENT, taskid INTEGER, time TEXT, level TEXT, message TEXT)" + DATA_TABLE = "CREATE TABLE data(id INTEGER PRIMARY KEY AUTOINCREMENT, taskid INTEGER, status INTEGER, content_type INTEGER, value TEXT)" + ERRORS_TABLE = "CREATE TABLE errors(id INTEGER PRIMARY KEY AUTOINCREMENT, taskid INTEGER, error TEXT)" + + def __init__(self): + pass + + def create(self): + _, self.database = tempfile.mkstemp(prefix="sqlmapipc-", text=False) + logger.info("IPC database is %s" % self.database) + + def connect(self): + self.connection = sqlite3.connect(self.database, timeout=1, isolation_level=None) + self.cursor = self.connection.cursor() + + def disconnect(self): + self.cursor.close() + self.connection.close() + + def execute(self, statement, arguments=None): + if arguments: + self.cursor.execute(statement, arguments) + else: + self.cursor.execute(statement) + + if statement.lstrip().upper().startswith("SELECT"): + return self.cursor.fetchall() + + def initialize(self): + self.create() + self.connect() + self.execute(self.LOGS_TABLE) + self.execute(self.DATA_TABLE) + self.execute(self.ERRORS_TABLE) + + def get_filepath(self): + return self.database + +class Task(object): + global db + + def __init__(self, taskid): + self.process = None + self.output_directory = None + self.initialize_options(taskid) + + def initialize_options(self, taskid): + dataype = {"boolean": False, "string": None, "integer": None, "float": None} + self.options = AttribDict() + + for _ in optDict: + for name, type_ in optDict[_].items(): + type_ = unArrayizeValue(type_) + self.options[name] = _defaults.get(name, dataype[type_]) + + # Let sqlmap engine knows it is getting called by the API, the task ID and the file path of the IPC database + self.options.api = True + self.options.taskid = taskid + self.options.database = db.get_filepath() + + # Enforce batch mode and disable coloring + self.options.batch = True + self.options.disableColoring = True + + def set_option(self, option, value): + self.options[option] = value + + def get_option(self, option): + return self.options[option] + + def get_options(self): + return self.options + + def set_output_directory(self): + self.output_directory = tempfile.mkdtemp(prefix="sqlmapoutput-") + self.set_option("oDir", self.output_directory) + + def clean_filesystem(self): + shutil.rmtree(self.output_directory) + + def engine_start(self): + self.process = Popen("python sqlmap.py --pickled-options %s" % base64pickle(self.options), shell=True, stdin=PIPE) + + def engine_stop(self): + if self.process: + self.process.terminate() + + def engine_kill(self): + if self.process: + self.process.kill() + + def engine_get_pid(self): + return self.processid.pid + +# Wrapper functions for sqlmap engine +class StdDbOut(object): + encoding = "UTF-8" + + def __init__(self, taskid, messagetype="stdout"): + # Overwrite system standard output and standard error to write + # to an IPC database + self.messagetype = messagetype + self.taskid = taskid + + if self.messagetype == "stdout": + sys.stdout = self + else: + sys.stderr = self + + def write(self, value, status=None, content_type=None): + if self.messagetype == "stdout": + conf.database_cursor.execute("INSERT INTO data VALUES(NULL, ?, ?, ?, ?)", (self.taskid, status, content_type, jsonize(value))) + else: + conf.database_cursor.execute("INSERT INTO errors VALUES(NULL, ?, ?)", (self.taskid, value)) + + def flush(self): + pass + + def close(self): + pass + + def seek(self): + pass + +class LogRecorder(logging.StreamHandler): + def emit(self, record): + """ + Record emitted events to IPC database for asynchronous I/O + communication with the parent process + """ + conf.database_cursor.execute("INSERT INTO logs VALUES(NULL, ?, ?, ?, ?)", + (conf.taskid, time.strftime("%X"), record.levelname, + record.msg % record.args if record.args else record.msg)) + +def setRestAPILog(): + if hasattr(conf, "api"): + conf.database_connection = sqlite3.connect(conf.database, timeout=1, isolation_level=None) + conf.database_cursor = conf.database_connection.cursor() + + # Set a logging handler that writes log messages to a IPC database + logger.removeHandler(LOGGER_HANDLER) + LOGGER_RECORDER = LogRecorder() + logger.addHandler(LOGGER_RECORDER) # Generic functions def is_admin(taskid): @@ -51,23 +203,8 @@ def is_admin(taskid): else: return True -def init_options(): - dataype = {"boolean": False, "string": None, "integer": None, "float": None} - options = AttribDict() - - for _ in optDict: - for name, type_ in optDict[_].items(): - type_ = unArrayizeValue(type_) - options[name] = _defaults.get(name, dataype[type_]) - - # Enforce batch mode and disable coloring - options.batch = True - options.disableColoring = True - - return options - @hook("after_request") -def security_headers(): +def security_headers(json_header=True): """ Set some headers across all HTTP responses """ @@ -78,7 +215,8 @@ def security_headers(): response.headers["Pragma"] = "no-cache" response.headers["Cache-Control"] = "no-cache" response.headers["Expires"] = "0" - response.content_type = "application/json; charset=UTF-8" + if json_header: + response.content_type = "application/json; charset=UTF-8" ############################## # HTTP Status Code functions # @@ -86,18 +224,22 @@ def security_headers(): @error(401) # Access Denied def error401(error=None): + security_headers(False) return "Access denied" @error(404) # Not Found def error404(error=None): + security_headers(False) return "Nothing here" @error(405) # Method Not Allowed (e.g. when requesting a POST method via GET) def error405(error=None): + security_headers(False) return "Method not allowed" @error(500) # Internal Server Error def error500(error=None): + security_headers(False) return "Internal server error" ############################# @@ -112,21 +254,8 @@ def task_new(): """ global tasks - taskid = hexencode(os.urandom(16)) - tasks[taskid] = init_options() - - # Initiate the temporary database for asynchronous I/O with the - # sqlmap engine (children processes) - _, ipc_filepath = tempfile.mkstemp(prefix="sqlmapipc-", suffix=".db", text=False) - connection = sqlite3.connect(ipc_filepath, isolation_level=None) - cursor = connection.cursor() - cursor.execute("DROP TABLE IF EXISTS logs") - cursor.execute("CREATE TABLE logs(id INTEGER PRIMARY KEY AUTOINCREMENT, time TEXT, level TEXT, message TEXT)") - cursor.close() - connection.close() - - # Set the temporary database to use for asynchronous I/O communication - tasks[taskid].ipc = ipc_filepath + taskid = hexencode(os.urandom(8)) + tasks[taskid] = Task(taskid) return jsonize({"taskid": taskid}) @@ -135,7 +264,8 @@ def task_destroy(taskid): """ Destroy own task ID """ - if taskid in tasks and not is_admin(taskid): + if taskid in tasks: + tasks[taskid].clean_filesystem() tasks.pop(taskid) return jsonize({"success": True}) else: @@ -155,16 +285,15 @@ def task_list(taskid): @get("/task//flush") def task_flush(taskid): """ - Flush task spool (destroy all tasks except admin) + Flush task spool (destroy all tasks) """ - global adminid global tasks if is_admin(taskid): - admin_task = tasks[adminid] - tasks = AttribDict() - tasks[adminid] = admin_task + for task in tasks: + tasks[task].clean_filesystem() + tasks = dict() return jsonize({"success": True}) else: abort(401) @@ -186,26 +315,6 @@ def status(taskid): else: abort(401) -@get("/cleanup/") -def cleanup(taskid): - """ - Destroy all sessions except admin ID and all output directories - """ - global tasks - - if is_admin(taskid): - for task, options in tasks.items(): - if "oDir" in options and options.oDir is not None: - shutil.rmtree(options.oDir) - - admin_task = tasks[adminid] - tasks = AttribDict() - tasks[adminid] = admin_task - - return jsonize({"success": True}) - else: - abort(401) - # Functions to handle options @get("/option//list") def option_list(taskid): @@ -215,7 +324,7 @@ def option_list(taskid): if taskid not in tasks: abort(500, "Invalid task ID") - return jsonize(tasks[taskid]) + return jsonize(tasks[taskid].get_options()) @post("/option//get") def option_get(taskid): @@ -228,7 +337,7 @@ def option_get(taskid): option = request.json.get("option", "") if option in tasks[taskid]: - return jsonize({option: tasks[taskid][option]}) + return jsonize({option: tasks[taskid].get_option(option)}) else: return jsonize({option: None}) @@ -242,8 +351,8 @@ def option_set(taskid): if taskid not in tasks: abort(500, "Invalid task ID") - for key, value in request.json.items(): - tasks[taskid][key] = value + for option, value in request.json.items(): + tasks[taskid].set_option(option, value) return jsonize({"success": True}) @@ -254,83 +363,109 @@ def scan_start(taskid): Launch a scan """ global tasks - global procs if taskid not in tasks: abort(500, "Invalid task ID") - # Initialize sqlmap engine's options with user's provided options - # within the JSON request - for key, value in request.json.items(): - tasks[taskid][key] = value + # Initialize sqlmap engine's options with user's provided options, if any + for option, value in request.json.items(): + tasks[taskid].set_option(option, value) - # Overwrite output directory (oDir) value to a temporary directory - tasks[taskid].oDir = tempfile.mkdtemp(prefix="sqlmaptask-") + # Overwrite output directory value to a temporary directory + tasks[taskid].set_output_directory() # Launch sqlmap engine in a separate thread logger.debug("starting a scan for task ID %s" % taskid) # Launch sqlmap engine - procs[taskid] = execute("python sqlmap.py --pickled-options %s" % base64pickle(tasks[taskid]), shell=True, stdin=PIPE, stdout=PIPE, stderr=PIPE, close_fds=False) + tasks[taskid].engine_start() return jsonize({"success": True}) -@get("/scan//output") -def scan_output(taskid): +@get("/scan//stop") +def scan_stop(taskid): """ - Read the standard output of sqlmap core execution + Stop a scan """ global tasks if taskid not in tasks: abort(500, "Invalid task ID") - stdout = recv_some(procs[taskid], t=1, e=0, stderr=0) - stderr = recv_some(procs[taskid], t=1, e=0, stderr=1) + return jsonize({"success": tasks[taskid].engine_stop()}) - return jsonize({"stdout": stdout, "stderr": stderr}) +@get("/scan//kill") +def scan_kill(taskid): + """ + Kill a scan + """ + global tasks + + if taskid not in tasks: + abort(500, "Invalid task ID") + + return jsonize({"success": tasks[taskid].engine_kill()}) @get("/scan//delete") def scan_delete(taskid): """ - Delete a scan and corresponding temporary output directory + Delete a scan and corresponding temporary output directory and IPC database """ global tasks if taskid not in tasks: abort(500, "Invalid task ID") - if "oDir" in tasks[taskid] and tasks[taskid].oDir is not None: - shutil.rmtree(tasks[taskid].oDir) + scan_stop(taskid) + tasks[taskid].clean_filesystem() return jsonize({"success": True}) +@get("/scan//data") +def scan_data(taskid): + """ + Retrieve the data of a scan + """ + global db + global tasks + json_data_message = list() + json_errors_message = list() + + if taskid not in tasks: + abort(500, "Invalid task ID") + + # Read all data from the IPC database for the taskid + for status, content_type, value in db.execute("SELECT status, content_type, value FROM data WHERE taskid = ? ORDER BY id ASC", (taskid,)): + json_data_message.append([status, content_type, dejsonize(value)]) + + # Read all error messages from the IPC database + for error in db.execute("SELECT error FROM errors WHERE taskid = ? ORDER BY id ASC", (taskid,)): + json_errors_message.append(error) + + return jsonize({"data": json_data_message, "error": json_errors_message}) + # Functions to handle scans' logs @get("/scan//log//") def scan_log_limited(taskid, start, end): """ Retrieve a subset of log messages """ - json_log_messages = {} + global db + global tasks + json_log_messages = list() if taskid not in tasks: abort(500, "Invalid task ID") - # Temporary "protection" against SQL injection FTW ;) - if not start.isdigit() or not end.isdigit() or end <= start: + if not start.isdigit() or not end.isdigit() or end < start: abort(500, "Invalid start or end value, must be digits") start = max(1, int(start)) end = max(1, int(end)) - # Read a subset of log messages from the temporary I/O database - connection = sqlite3.connect(tasks[taskid].ipc, isolation_level=None) - cursor = connection.cursor() - cursor.execute("SELECT id, time, level, message FROM logs WHERE id >= %d AND id <= %d" % (start, end)) - db_log_messages = cursor.fetchall() - - for (id_, time_, level, message) in db_log_messages: - json_log_messages[id_] = {"time": time_, "level": level, "message": message} + # Read a subset of log messages from the IPC database + for time_, level, message in db.execute("SELECT time, level, message FROM logs WHERE taskid = ? AND id >= ? AND id <= ? ORDER BY id ASC", (taskid, start, end)): + json_log_messages.append({"time": time_, "level": level, "message": message}) return jsonize({"log": json_log_messages}) @@ -339,19 +474,16 @@ def scan_log(taskid): """ Retrieve the log messages """ - json_log_messages = {} + global db + global tasks + json_log_messages = list() if taskid not in tasks: abort(500, "Invalid task ID") - # Read all log messages from the temporary I/O database - connection = sqlite3.connect(tasks[taskid].ipc, isolation_level=None) - cursor = connection.cursor() - cursor.execute("SELECT id, time, level, message FROM logs") - db_log_messages = cursor.fetchall() - - for (id_, time_, level, message) in db_log_messages: - json_log_messages[id_] = {"time": time_, "level": level, "message": message} + # Read all log messages from the IPC database + for time_, level, message in db.execute("SELECT time, level, message FROM logs WHERE taskid = ? ORDER BY id ASC", (taskid,)): + json_log_messages.append({"time": time_, "level": level, "message": message}) return jsonize({"log": json_log_messages}) @@ -369,6 +501,7 @@ def download(taskid, target, filename): abort(500) path = os.path.join(paths.SQLMAP_OUTPUT_PATH, target) + if os.path.exists(path): return static_file(filename, root=path) else: @@ -379,10 +512,11 @@ def server(host="0.0.0.0", port=RESTAPI_SERVER_PORT): REST-JSON API server """ global adminid - global tasks + global db adminid = hexencode(os.urandom(16)) - tasks[adminid] = init_options() + db = Database() + db.initialize() logger.info("running REST-JSON API server at '%s:%d'.." % (host, port)) logger.info("the admin task ID is: %s" % adminid) diff --git a/xml/livetests.xml b/xml/livetests.xml index 455e1da55..5a27d28e3 100644 --- a/xml/livetests.xml +++ b/xml/livetests.xml @@ -3133,7 +3133,7 @@ - + @@ -3167,7 +3167,7 @@ - +