improvement to restful API to store to IPC database partial entries, not yet functional (issue #297)

This commit is contained in:
Bernardo Damele 2013-02-03 11:31:05 +00:00
parent a92f1fb3b4
commit f8bc74758c
8 changed files with 61 additions and 37 deletions

View File

@ -129,8 +129,7 @@ def main():
if hasattr(conf, "api"): if hasattr(conf, "api"):
try: try:
conf.database_cursor.close() conf.database_cursor.disconnect()
conf.database_connection.close()
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass

View File

@ -58,6 +58,7 @@ from lib.core.dicts import DBMS_DICT
from lib.core.dicts import DEPRECATED_OPTIONS from lib.core.dicts import DEPRECATED_OPTIONS
from lib.core.dicts import SQL_STATEMENTS from lib.core.dicts import SQL_STATEMENTS
from lib.core.enums import ADJUST_TIME_DELAY from lib.core.enums import ADJUST_TIME_DELAY
from lib.core.enums import CONTENT_STATUS
from lib.core.enums import CHARSET_TYPE from lib.core.enums import CHARSET_TYPE
from lib.core.enums import DBMS from lib.core.enums import DBMS
from lib.core.enums import EXPECTED from lib.core.enums import EXPECTED
@ -744,7 +745,7 @@ def setColor(message, bold=False):
return retVal return retVal
def dataToStdout(data, forceOutput=False, bold=False, content_type=None, status=None): def dataToStdout(data, forceOutput=False, bold=False, content_type=None, status=CONTENT_STATUS.IN_PROGRESS):
""" """
Writes text to the stdout (console) stream Writes text to the stdout (console) stream
""" """
@ -762,7 +763,6 @@ def dataToStdout(data, forceOutput=False, bold=False, content_type=None, status=
message = data message = data
if hasattr(conf, "api"): if hasattr(conf, "api"):
if content_type and status:
sys.stdout.write(message, status, content_type) sys.stdout.write(message, status, content_type)
else: else:
sys.stdout.write(setColor(message, bold)) sys.stdout.write(setColor(message, bold))

View File

@ -26,7 +26,7 @@ from lib.core.data import conf
from lib.core.data import kb from lib.core.data import kb
from lib.core.data import logger from lib.core.data import logger
from lib.core.dicts import DUMP_REPLACEMENTS from lib.core.dicts import DUMP_REPLACEMENTS
from lib.core.enums import API_CONTENT_STATUS from lib.core.enums import CONTENT_STATUS
from lib.core.enums import CONTENT_TYPE from lib.core.enums import CONTENT_TYPE
from lib.core.enums import DBMS from lib.core.enums import DBMS
from lib.core.enums import DUMP_FORMAT from lib.core.enums import DUMP_FORMAT
@ -55,7 +55,7 @@ class Dump(object):
def _write(self, data, newline=True, console=True, content_type=None): def _write(self, data, newline=True, console=True, content_type=None):
if hasattr(conf, "api"): if hasattr(conf, "api"):
dataToStdout(data, content_type=content_type, status=API_CONTENT_STATUS.COMPLETE) dataToStdout(data, content_type=content_type, status=CONTENT_STATUS.COMPLETE)
return return
text = "%s%s" % (data, "\n" if newline else " ") text = "%s%s" % (data, "\n" if newline else " ")

View File

@ -271,6 +271,6 @@ class CONTENT_TYPE:
OS_CMD = 23 OS_CMD = 23
REG_READ = 24 REG_READ = 24
class API_CONTENT_STATUS: class CONTENT_STATUS:
IN_PROGRESS = 0 IN_PROGRESS = 0
COMPLETE = 1 COMPLETE = 1

View File

@ -88,8 +88,8 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
try: try:
# Set kb.partRun in case "common prediction" feature (a.k.a. "good # Set kb.partRun in case "common prediction" feature (a.k.a. "good
# samaritan") is used # samaritan") is used or the engine is called from the API
kb.partRun = getPartRun() if conf.predictOutput else None kb.partRun = getPartRun() if conf.predictOutput or hasattr(conf, "api") else None
if partialValue: if partialValue:
firstChar = len(partialValue) firstChar = len(partialValue)
@ -486,7 +486,7 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
if result: if result:
if showEta: if showEta:
etaProgressUpdate(time.time() - charStart, len(commonValue)) etaProgressUpdate(time.time() - charStart, len(commonValue))
elif conf.verbose in (1, 2): elif conf.verbose in (1, 2) or hasattr(conf, "api"):
dataToStdout(filterControlChars(commonValue[index - 1:])) dataToStdout(filterControlChars(commonValue[index - 1:]))
finalValue = commonValue finalValue = commonValue
@ -534,7 +534,7 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
if showEta: if showEta:
etaProgressUpdate(time.time() - charStart, index) etaProgressUpdate(time.time() - charStart, index)
elif conf.verbose in (1, 2): elif conf.verbose in (1, 2) or hasattr(conf, "api"):
dataToStdout(filterControlChars(val)) dataToStdout(filterControlChars(val))
# some DBMSes (e.g. Firebird, DB2, etc.) have issues with trailing spaces # some DBMSes (e.g. Firebird, DB2, etc.) have issues with trailing spaces

View File

@ -16,6 +16,7 @@ from lib.core.common import calculateDeltaSeconds
from lib.core.common import dataToStdout from lib.core.common import dataToStdout
from lib.core.common import decodeHexValue from lib.core.common import decodeHexValue
from lib.core.common import extractRegexResult from lib.core.common import extractRegexResult
from lib.core.common import getPartRun
from lib.core.common import getUnicode from lib.core.common import getUnicode
from lib.core.common import hashDBRetrieve from lib.core.common import hashDBRetrieve
from lib.core.common import hashDBWrite from lib.core.common import hashDBWrite
@ -243,6 +244,9 @@ def errorUse(expression, dump=False):
_, _, _, _, _, expressionFieldsList, expressionFields, _ = agent.getFields(expression) _, _, _, _, _, expressionFieldsList, expressionFields, _ = agent.getFields(expression)
# Set kb.partRun in case the engine is called from the API
kb.partRun = getPartRun() if hasattr(conf, "api") else None
# We have to check if the SQL query might return multiple entries # We have to check if the SQL query might return multiple entries
# and in such case forge the SQL limiting the query output one # and in such case forge the SQL limiting the query output one
# entry at a time # entry at a time

View File

@ -19,6 +19,7 @@ from lib.core.common import dataToStdout
from lib.core.common import extractRegexResult from lib.core.common import extractRegexResult
from lib.core.common import flattenValue from lib.core.common import flattenValue
from lib.core.common import getConsoleWidth from lib.core.common import getConsoleWidth
from lib.core.common import getPartRun
from lib.core.common import getUnicode from lib.core.common import getUnicode
from lib.core.common import hashDBRetrieve from lib.core.common import hashDBRetrieve
from lib.core.common import hashDBWrite from lib.core.common import hashDBWrite
@ -163,6 +164,9 @@ def unionUse(expression, unpack=True, dump=False):
_, _, _, _, _, expressionFieldsList, expressionFields, _ = agent.getFields(origExpr) _, _, _, _, _, expressionFieldsList, expressionFields, _ = agent.getFields(origExpr)
# Set kb.partRun in case the engine is called from the API
kb.partRun = getPartRun() if hasattr(conf, "api") else None
if expressionFieldsList and len(expressionFieldsList) > 1 and "ORDER BY" in expression.upper(): if expressionFieldsList and len(expressionFieldsList) > 1 and "ORDER BY" in expression.upper():
# Removed ORDER BY clause because UNION does not play well with it # Removed ORDER BY clause because UNION does not play well with it
expression = re.sub("\s*ORDER BY\s+[\w,]+", "", expression, re.I) expression = re.sub("\s*ORDER BY\s+[\w,]+", "", expression, re.I)

View File

@ -17,15 +17,16 @@ from subprocess import PIPE
from lib.core.common import unArrayizeValue from lib.core.common import unArrayizeValue
from lib.core.convert import base64pickle from lib.core.convert import base64pickle
from lib.core.convert import base64unpickle
from lib.core.convert import hexencode from lib.core.convert import hexencode
from lib.core.convert import dejsonize from lib.core.convert import dejsonize
from lib.core.convert import jsonize from lib.core.convert import jsonize
from lib.core.data import conf from lib.core.data import conf
from lib.core.data import kb
from lib.core.data import paths from lib.core.data import paths
from lib.core.data import logger from lib.core.data import logger
from lib.core.datatype import AttribDict from lib.core.datatype import AttribDict
from lib.core.defaults import _defaults from lib.core.defaults import _defaults
from lib.core.enums import CONTENT_STATUS
from lib.core.log import LOGGER_HANDLER from lib.core.log import LOGGER_HANDLER
from lib.core.optiondict import optDict from lib.core.optiondict import optDict
from lib.core.subprocessng import Popen from lib.core.subprocessng import Popen
@ -47,24 +48,27 @@ RESTAPI_SERVER_PORT = 8775
# Local global variables # Local global variables
adminid = "" adminid = ""
db = None db = None
db_filepath = tempfile.mkstemp(prefix="sqlmapipc-", text=False)[1]
tasks = dict() tasks = dict()
# API objects # API objects
class Database(object): class Database(object):
global db_filepath
LOGS_TABLE = "CREATE TABLE logs(id INTEGER PRIMARY KEY AUTOINCREMENT, taskid INTEGER, time TEXT, level TEXT, message TEXT)" LOGS_TABLE = "CREATE TABLE logs(id INTEGER PRIMARY KEY AUTOINCREMENT, taskid INTEGER, time TEXT, level TEXT, message TEXT)"
DATA_TABLE = "CREATE TABLE data(id INTEGER PRIMARY KEY AUTOINCREMENT, taskid INTEGER, status INTEGER, content_type INTEGER, value TEXT)" DATA_TABLE = "CREATE TABLE data(id INTEGER PRIMARY KEY AUTOINCREMENT, taskid INTEGER, status INTEGER, content_type INTEGER, value TEXT)"
ERRORS_TABLE = "CREATE TABLE errors(id INTEGER PRIMARY KEY AUTOINCREMENT, taskid INTEGER, error TEXT)" ERRORS_TABLE = "CREATE TABLE errors(id INTEGER PRIMARY KEY AUTOINCREMENT, taskid INTEGER, error TEXT)"
def __init__(self): def __init__(self, database=None):
pass if database:
self.database = database
else:
self.database = db_filepath
def create(self): def connect(self, who="server"):
_, self.database = tempfile.mkstemp(prefix="sqlmapipc-", text=False)
logger.debug("IPC database: %s" % self.database)
def connect(self):
self.connection = sqlite3.connect(self.database, timeout=3, isolation_level=None) self.connection = sqlite3.connect(self.database, timeout=3, isolation_level=None)
self.cursor = self.connection.cursor() self.cursor = self.connection.cursor()
logger.debug("REST-JSON API %s connected to IPC database" % who)
def disconnect(self): def disconnect(self):
self.cursor.close() self.cursor.close()
@ -79,18 +83,13 @@ class Database(object):
if statement.lstrip().upper().startswith("SELECT"): if statement.lstrip().upper().startswith("SELECT"):
return self.cursor.fetchall() return self.cursor.fetchall()
def initialize(self): def init(self):
self.create()
self.connect()
self.execute(self.LOGS_TABLE) self.execute(self.LOGS_TABLE)
self.execute(self.DATA_TABLE) self.execute(self.DATA_TABLE)
self.execute(self.ERRORS_TABLE) self.execute(self.ERRORS_TABLE)
def get_filepath(self):
return self.database
class Task(object): class Task(object):
global db global db_filepath
def __init__(self, taskid): def __init__(self, taskid):
self.process = None self.process = None
@ -109,7 +108,7 @@ class Task(object):
# Let sqlmap engine knows it is getting called by the API, the task ID and the file path of the IPC database # Let sqlmap engine knows it is getting called by the API, the task ID and the file path of the IPC database
self.options.api = True self.options.api = True
self.options.taskid = taskid self.options.taskid = taskid
self.options.database = db.get_filepath() self.options.database = db_filepath
# Enforce batch mode and disable coloring # Enforce batch mode and disable coloring
self.options.batch = True self.options.batch = True
@ -174,10 +173,23 @@ class StdDbOut(object):
else: else:
sys.stderr = self sys.stderr = self
def write(self, value, status=None, content_type=None): def write(self, value, status=CONTENT_STATUS.IN_PROGRESS, content_type=None):
if self.messagetype == "stdout": if self.messagetype == "stdout":
#conf.database_cursor.execute("INSERT INTO data VALUES(NULL, ?, ?, ?, ?)", if content_type is None:
# (self.taskid, status, content_type, base64pickle(value))) content_type = 99
if status == CONTENT_STATUS.IN_PROGRESS:
output = conf.database_cursor.execute("SELECT id, value FROM data WHERE taskid = ? AND status = ? AND content_type = ? LIMIT 0,1",
(self.taskid, status, content_type))
if len(output) == 0:
conf.database_cursor.execute("INSERT INTO data VALUES(NULL, ?, ?, ?, ?)",
(self.taskid, status, content_type, jsonize(value)))
else:
new_value = "%s%s" % (output[0][1], value)
conf.database_cursor.execute("UPDATE data SET value = ? WHERE id = ?",
(jsonize(new_value), output[0][0]))
else:
conf.database_cursor.execute("INSERT INTO data VALUES(NULL, ?, ?, ?, ?)", conf.database_cursor.execute("INSERT INTO data VALUES(NULL, ?, ?, ?, ?)",
(self.taskid, status, content_type, jsonize(value))) (self.taskid, status, content_type, jsonize(value)))
else: else:
@ -205,8 +217,11 @@ class LogRecorder(logging.StreamHandler):
def setRestAPILog(): def setRestAPILog():
if hasattr(conf, "api"): if hasattr(conf, "api"):
conf.database_connection = sqlite3.connect(conf.database, timeout=1, isolation_level=None) #conf.database_connection = sqlite3.connect(conf.database, timeout=1, isolation_level=None)
conf.database_cursor = conf.database_connection.cursor() #conf.database_cursor = conf.database_connection.cursor()
conf.database_cursor = Database(conf.database)
conf.database_cursor.connect("client")
# Set a logging handler that writes log messages to a IPC database # Set a logging handler that writes log messages to a IPC database
logger.removeHandler(LOGGER_HANDLER) logger.removeHandler(LOGGER_HANDLER)
@ -455,7 +470,6 @@ def scan_data(taskid):
# Read all data from the IPC database for the taskid # Read all data from the IPC database for the taskid
for status, content_type, value in db.execute("SELECT status, content_type, value FROM data WHERE taskid = ? ORDER BY id ASC", (taskid,)): for status, content_type, value in db.execute("SELECT status, content_type, value FROM data WHERE taskid = ? ORDER BY id ASC", (taskid,)):
#json_data_message.append({"status": status, "type": content_type, "value": base64unpickle(value)})
json_data_message.append({"status": status, "type": content_type, "value": dejsonize(value)}) json_data_message.append({"status": status, "type": content_type, "value": dejsonize(value)})
# Read all error messages from the IPC database # Read all error messages from the IPC database
@ -536,15 +550,18 @@ def server(host="0.0.0.0", port=RESTAPI_SERVER_PORT):
""" """
global adminid global adminid
global db global db
global db_filepath
adminid = hexencode(os.urandom(16)) adminid = hexencode(os.urandom(16))
logger.info("Running REST-JSON API server at '%s:%d'.." % (host, port)) logger.info("Running REST-JSON API server at '%s:%d'.." % (host, port))
logger.info("Admin ID: %s" % adminid) logger.info("Admin ID: %s" % adminid)
logger.debug("IPC database: %s" % db_filepath)
# Initialize IPC database # Initialize IPC database
db = Database() db = Database()
db.initialize() db.connect()
db.init()
# Run RESTful API # Run RESTful API
run(host=host, port=port, quiet=True, debug=False) run(host=host, port=port, quiet=True, debug=False)