diff --git a/_sqlmap.py b/_sqlmap.py index e77f7b4e1..935c54f52 100644 --- a/_sqlmap.py +++ b/_sqlmap.py @@ -129,8 +129,7 @@ def main(): if hasattr(conf, "api"): try: - conf.database_cursor.close() - conf.database_connection.close() + conf.database_cursor.disconnect() except KeyboardInterrupt: pass diff --git a/lib/core/common.py b/lib/core/common.py index 5ccfaaff0..57040d4d6 100644 --- a/lib/core/common.py +++ b/lib/core/common.py @@ -58,6 +58,7 @@ from lib.core.dicts import DBMS_DICT from lib.core.dicts import DEPRECATED_OPTIONS from lib.core.dicts import SQL_STATEMENTS from lib.core.enums import ADJUST_TIME_DELAY +from lib.core.enums import CONTENT_STATUS from lib.core.enums import CHARSET_TYPE from lib.core.enums import DBMS from lib.core.enums import EXPECTED @@ -744,7 +745,7 @@ def setColor(message, bold=False): return retVal -def dataToStdout(data, forceOutput=False, bold=False, content_type=None, status=None): +def dataToStdout(data, forceOutput=False, bold=False, content_type=None, status=CONTENT_STATUS.IN_PROGRESS): """ Writes text to the stdout (console) stream """ @@ -762,8 +763,7 @@ def dataToStdout(data, forceOutput=False, bold=False, content_type=None, status= message = data if hasattr(conf, "api"): - if content_type and status: - sys.stdout.write(message, status, content_type) + sys.stdout.write(message, status, content_type) else: sys.stdout.write(setColor(message, bold)) diff --git a/lib/core/dump.py b/lib/core/dump.py index f140adfb7..86a6c0a7c 100644 --- a/lib/core/dump.py +++ b/lib/core/dump.py @@ -26,7 +26,7 @@ from lib.core.data import conf from lib.core.data import kb from lib.core.data import logger from lib.core.dicts import DUMP_REPLACEMENTS -from lib.core.enums import API_CONTENT_STATUS +from lib.core.enums import CONTENT_STATUS from lib.core.enums import CONTENT_TYPE from lib.core.enums import DBMS from lib.core.enums import DUMP_FORMAT @@ -55,7 +55,7 @@ class Dump(object): def _write(self, data, newline=True, console=True, content_type=None): if hasattr(conf, "api"): - dataToStdout(data, content_type=content_type, status=API_CONTENT_STATUS.COMPLETE) + dataToStdout(data, content_type=content_type, status=CONTENT_STATUS.COMPLETE) return text = "%s%s" % (data, "\n" if newline else " ") diff --git a/lib/core/enums.py b/lib/core/enums.py index f4ae5f3f3..e1bd0eb3a 100644 --- a/lib/core/enums.py +++ b/lib/core/enums.py @@ -271,6 +271,6 @@ class CONTENT_TYPE: OS_CMD = 23 REG_READ = 24 -class API_CONTENT_STATUS: +class CONTENT_STATUS: IN_PROGRESS = 0 COMPLETE = 1 diff --git a/lib/techniques/blind/inference.py b/lib/techniques/blind/inference.py index 0fd592662..43e4e486b 100644 --- a/lib/techniques/blind/inference.py +++ b/lib/techniques/blind/inference.py @@ -88,8 +88,8 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None try: # Set kb.partRun in case "common prediction" feature (a.k.a. "good - # samaritan") is used - kb.partRun = getPartRun() if conf.predictOutput else None + # samaritan") is used or the engine is called from the API + kb.partRun = getPartRun() if conf.predictOutput or hasattr(conf, "api") else None if partialValue: firstChar = len(partialValue) @@ -486,7 +486,7 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None if result: if showEta: etaProgressUpdate(time.time() - charStart, len(commonValue)) - elif conf.verbose in (1, 2): + elif conf.verbose in (1, 2) or hasattr(conf, "api"): dataToStdout(filterControlChars(commonValue[index - 1:])) finalValue = commonValue @@ -534,7 +534,7 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None if showEta: etaProgressUpdate(time.time() - charStart, index) - elif conf.verbose in (1, 2): + elif conf.verbose in (1, 2) or hasattr(conf, "api"): dataToStdout(filterControlChars(val)) # some DBMSes (e.g. Firebird, DB2, etc.) have issues with trailing spaces diff --git a/lib/techniques/error/use.py b/lib/techniques/error/use.py index 54b509b73..d041dca5c 100644 --- a/lib/techniques/error/use.py +++ b/lib/techniques/error/use.py @@ -16,6 +16,7 @@ from lib.core.common import calculateDeltaSeconds from lib.core.common import dataToStdout from lib.core.common import decodeHexValue from lib.core.common import extractRegexResult +from lib.core.common import getPartRun from lib.core.common import getUnicode from lib.core.common import hashDBRetrieve from lib.core.common import hashDBWrite @@ -243,6 +244,9 @@ def errorUse(expression, dump=False): _, _, _, _, _, expressionFieldsList, expressionFields, _ = agent.getFields(expression) + # Set kb.partRun in case the engine is called from the API + kb.partRun = getPartRun() if hasattr(conf, "api") else None + # We have to check if the SQL query might return multiple entries # and in such case forge the SQL limiting the query output one # entry at a time diff --git a/lib/techniques/union/use.py b/lib/techniques/union/use.py index 1132fc3ea..2062cf410 100644 --- a/lib/techniques/union/use.py +++ b/lib/techniques/union/use.py @@ -19,6 +19,7 @@ from lib.core.common import dataToStdout from lib.core.common import extractRegexResult from lib.core.common import flattenValue from lib.core.common import getConsoleWidth +from lib.core.common import getPartRun from lib.core.common import getUnicode from lib.core.common import hashDBRetrieve from lib.core.common import hashDBWrite @@ -163,6 +164,9 @@ def unionUse(expression, unpack=True, dump=False): _, _, _, _, _, expressionFieldsList, expressionFields, _ = agent.getFields(origExpr) + # Set kb.partRun in case the engine is called from the API + kb.partRun = getPartRun() if hasattr(conf, "api") else None + if expressionFieldsList and len(expressionFieldsList) > 1 and "ORDER BY" in expression.upper(): # Removed ORDER BY clause because UNION does not play well with it expression = re.sub("\s*ORDER BY\s+[\w,]+", "", expression, re.I) diff --git a/lib/utils/api.py b/lib/utils/api.py index 3d5143c99..6dd8aa06d 100644 --- a/lib/utils/api.py +++ b/lib/utils/api.py @@ -17,15 +17,16 @@ from subprocess import PIPE from lib.core.common import unArrayizeValue from lib.core.convert import base64pickle -from lib.core.convert import base64unpickle from lib.core.convert import hexencode from lib.core.convert import dejsonize from lib.core.convert import jsonize from lib.core.data import conf +from lib.core.data import kb from lib.core.data import paths from lib.core.data import logger from lib.core.datatype import AttribDict from lib.core.defaults import _defaults +from lib.core.enums import CONTENT_STATUS from lib.core.log import LOGGER_HANDLER from lib.core.optiondict import optDict from lib.core.subprocessng import Popen @@ -47,24 +48,27 @@ RESTAPI_SERVER_PORT = 8775 # Local global variables adminid = "" db = None +db_filepath = tempfile.mkstemp(prefix="sqlmapipc-", text=False)[1] tasks = dict() # API objects class Database(object): + global db_filepath + LOGS_TABLE = "CREATE TABLE logs(id INTEGER PRIMARY KEY AUTOINCREMENT, taskid INTEGER, time TEXT, level TEXT, message TEXT)" DATA_TABLE = "CREATE TABLE data(id INTEGER PRIMARY KEY AUTOINCREMENT, taskid INTEGER, status INTEGER, content_type INTEGER, value TEXT)" ERRORS_TABLE = "CREATE TABLE errors(id INTEGER PRIMARY KEY AUTOINCREMENT, taskid INTEGER, error TEXT)" - def __init__(self): - pass + def __init__(self, database=None): + if database: + self.database = database + else: + self.database = db_filepath - def create(self): - _, self.database = tempfile.mkstemp(prefix="sqlmapipc-", text=False) - logger.debug("IPC database: %s" % self.database) - - def connect(self): + def connect(self, who="server"): self.connection = sqlite3.connect(self.database, timeout=3, isolation_level=None) self.cursor = self.connection.cursor() + logger.debug("REST-JSON API %s connected to IPC database" % who) def disconnect(self): self.cursor.close() @@ -79,18 +83,13 @@ class Database(object): if statement.lstrip().upper().startswith("SELECT"): return self.cursor.fetchall() - def initialize(self): - self.create() - self.connect() + def init(self): self.execute(self.LOGS_TABLE) self.execute(self.DATA_TABLE) self.execute(self.ERRORS_TABLE) - def get_filepath(self): - return self.database - class Task(object): - global db + global db_filepath def __init__(self, taskid): self.process = None @@ -109,7 +108,7 @@ class Task(object): # Let sqlmap engine knows it is getting called by the API, the task ID and the file path of the IPC database self.options.api = True self.options.taskid = taskid - self.options.database = db.get_filepath() + self.options.database = db_filepath # Enforce batch mode and disable coloring self.options.batch = True @@ -174,12 +173,25 @@ class StdDbOut(object): else: sys.stderr = self - def write(self, value, status=None, content_type=None): + def write(self, value, status=CONTENT_STATUS.IN_PROGRESS, content_type=None): if self.messagetype == "stdout": - #conf.database_cursor.execute("INSERT INTO data VALUES(NULL, ?, ?, ?, ?)", - # (self.taskid, status, content_type, base64pickle(value))) - conf.database_cursor.execute("INSERT INTO data VALUES(NULL, ?, ?, ?, ?)", - (self.taskid, status, content_type, jsonize(value))) + if content_type is None: + content_type = 99 + + if status == CONTENT_STATUS.IN_PROGRESS: + output = conf.database_cursor.execute("SELECT id, value FROM data WHERE taskid = ? AND status = ? AND content_type = ? LIMIT 0,1", + (self.taskid, status, content_type)) + + if len(output) == 0: + conf.database_cursor.execute("INSERT INTO data VALUES(NULL, ?, ?, ?, ?)", + (self.taskid, status, content_type, jsonize(value))) + else: + new_value = "%s%s" % (output[0][1], value) + conf.database_cursor.execute("UPDATE data SET value = ? WHERE id = ?", + (jsonize(new_value), output[0][0])) + else: + conf.database_cursor.execute("INSERT INTO data VALUES(NULL, ?, ?, ?, ?)", + (self.taskid, status, content_type, jsonize(value))) else: conf.database_cursor.execute("INSERT INTO errors VALUES(NULL, ?, ?)", (self.taskid, str(value) if value else "")) @@ -205,8 +217,11 @@ class LogRecorder(logging.StreamHandler): def setRestAPILog(): if hasattr(conf, "api"): - conf.database_connection = sqlite3.connect(conf.database, timeout=1, isolation_level=None) - conf.database_cursor = conf.database_connection.cursor() + #conf.database_connection = sqlite3.connect(conf.database, timeout=1, isolation_level=None) + #conf.database_cursor = conf.database_connection.cursor() + + conf.database_cursor = Database(conf.database) + conf.database_cursor.connect("client") # Set a logging handler that writes log messages to a IPC database logger.removeHandler(LOGGER_HANDLER) @@ -455,7 +470,6 @@ def scan_data(taskid): # Read all data from the IPC database for the taskid for status, content_type, value in db.execute("SELECT status, content_type, value FROM data WHERE taskid = ? ORDER BY id ASC", (taskid,)): - #json_data_message.append({"status": status, "type": content_type, "value": base64unpickle(value)}) json_data_message.append({"status": status, "type": content_type, "value": dejsonize(value)}) # Read all error messages from the IPC database @@ -536,15 +550,18 @@ def server(host="0.0.0.0", port=RESTAPI_SERVER_PORT): """ global adminid global db + global db_filepath adminid = hexencode(os.urandom(16)) logger.info("Running REST-JSON API server at '%s:%d'.." % (host, port)) logger.info("Admin ID: %s" % adminid) + logger.debug("IPC database: %s" % db_filepath) # Initialize IPC database db = Database() - db.initialize() + db.connect() + db.init() # Run RESTful API run(host=host, port=port, quiet=True, debug=False)