improvement to restful API to store to IPC database partial entries, not yet functional (issue #297)

This commit is contained in:
Bernardo Damele 2013-02-03 11:31:05 +00:00
parent a92f1fb3b4
commit f8bc74758c
8 changed files with 61 additions and 37 deletions

View File

@ -129,8 +129,7 @@ def main():
if hasattr(conf, "api"):
try:
conf.database_cursor.close()
conf.database_connection.close()
conf.database_cursor.disconnect()
except KeyboardInterrupt:
pass

View File

@ -58,6 +58,7 @@ from lib.core.dicts import DBMS_DICT
from lib.core.dicts import DEPRECATED_OPTIONS
from lib.core.dicts import SQL_STATEMENTS
from lib.core.enums import ADJUST_TIME_DELAY
from lib.core.enums import CONTENT_STATUS
from lib.core.enums import CHARSET_TYPE
from lib.core.enums import DBMS
from lib.core.enums import EXPECTED
@ -744,7 +745,7 @@ def setColor(message, bold=False):
return retVal
def dataToStdout(data, forceOutput=False, bold=False, content_type=None, status=None):
def dataToStdout(data, forceOutput=False, bold=False, content_type=None, status=CONTENT_STATUS.IN_PROGRESS):
"""
Writes text to the stdout (console) stream
"""
@ -762,7 +763,6 @@ def dataToStdout(data, forceOutput=False, bold=False, content_type=None, status=
message = data
if hasattr(conf, "api"):
if content_type and status:
sys.stdout.write(message, status, content_type)
else:
sys.stdout.write(setColor(message, bold))

View File

@ -26,7 +26,7 @@ from lib.core.data import conf
from lib.core.data import kb
from lib.core.data import logger
from lib.core.dicts import DUMP_REPLACEMENTS
from lib.core.enums import API_CONTENT_STATUS
from lib.core.enums import CONTENT_STATUS
from lib.core.enums import CONTENT_TYPE
from lib.core.enums import DBMS
from lib.core.enums import DUMP_FORMAT
@ -55,7 +55,7 @@ class Dump(object):
def _write(self, data, newline=True, console=True, content_type=None):
if hasattr(conf, "api"):
dataToStdout(data, content_type=content_type, status=API_CONTENT_STATUS.COMPLETE)
dataToStdout(data, content_type=content_type, status=CONTENT_STATUS.COMPLETE)
return
text = "%s%s" % (data, "\n" if newline else " ")

View File

@ -271,6 +271,6 @@ class CONTENT_TYPE:
OS_CMD = 23
REG_READ = 24
class API_CONTENT_STATUS:
class CONTENT_STATUS:
IN_PROGRESS = 0
COMPLETE = 1

View File

@ -88,8 +88,8 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
try:
# Set kb.partRun in case "common prediction" feature (a.k.a. "good
# samaritan") is used
kb.partRun = getPartRun() if conf.predictOutput else None
# samaritan") is used or the engine is called from the API
kb.partRun = getPartRun() if conf.predictOutput or hasattr(conf, "api") else None
if partialValue:
firstChar = len(partialValue)
@ -486,7 +486,7 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
if result:
if showEta:
etaProgressUpdate(time.time() - charStart, len(commonValue))
elif conf.verbose in (1, 2):
elif conf.verbose in (1, 2) or hasattr(conf, "api"):
dataToStdout(filterControlChars(commonValue[index - 1:]))
finalValue = commonValue
@ -534,7 +534,7 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
if showEta:
etaProgressUpdate(time.time() - charStart, index)
elif conf.verbose in (1, 2):
elif conf.verbose in (1, 2) or hasattr(conf, "api"):
dataToStdout(filterControlChars(val))
# some DBMSes (e.g. Firebird, DB2, etc.) have issues with trailing spaces

View File

@ -16,6 +16,7 @@ from lib.core.common import calculateDeltaSeconds
from lib.core.common import dataToStdout
from lib.core.common import decodeHexValue
from lib.core.common import extractRegexResult
from lib.core.common import getPartRun
from lib.core.common import getUnicode
from lib.core.common import hashDBRetrieve
from lib.core.common import hashDBWrite
@ -243,6 +244,9 @@ def errorUse(expression, dump=False):
_, _, _, _, _, expressionFieldsList, expressionFields, _ = agent.getFields(expression)
# Set kb.partRun in case the engine is called from the API
kb.partRun = getPartRun() if hasattr(conf, "api") else None
# We have to check if the SQL query might return multiple entries
# and in such case forge the SQL limiting the query output one
# entry at a time

View File

@ -19,6 +19,7 @@ from lib.core.common import dataToStdout
from lib.core.common import extractRegexResult
from lib.core.common import flattenValue
from lib.core.common import getConsoleWidth
from lib.core.common import getPartRun
from lib.core.common import getUnicode
from lib.core.common import hashDBRetrieve
from lib.core.common import hashDBWrite
@ -163,6 +164,9 @@ def unionUse(expression, unpack=True, dump=False):
_, _, _, _, _, expressionFieldsList, expressionFields, _ = agent.getFields(origExpr)
# Set kb.partRun in case the engine is called from the API
kb.partRun = getPartRun() if hasattr(conf, "api") else None
if expressionFieldsList and len(expressionFieldsList) > 1 and "ORDER BY" in expression.upper():
# Removed ORDER BY clause because UNION does not play well with it
expression = re.sub("\s*ORDER BY\s+[\w,]+", "", expression, re.I)

View File

@ -17,15 +17,16 @@ from subprocess import PIPE
from lib.core.common import unArrayizeValue
from lib.core.convert import base64pickle
from lib.core.convert import base64unpickle
from lib.core.convert import hexencode
from lib.core.convert import dejsonize
from lib.core.convert import jsonize
from lib.core.data import conf
from lib.core.data import kb
from lib.core.data import paths
from lib.core.data import logger
from lib.core.datatype import AttribDict
from lib.core.defaults import _defaults
from lib.core.enums import CONTENT_STATUS
from lib.core.log import LOGGER_HANDLER
from lib.core.optiondict import optDict
from lib.core.subprocessng import Popen
@ -47,24 +48,27 @@ RESTAPI_SERVER_PORT = 8775
# Local global variables
adminid = ""
db = None
db_filepath = tempfile.mkstemp(prefix="sqlmapipc-", text=False)[1]
tasks = dict()
# API objects
class Database(object):
global db_filepath
LOGS_TABLE = "CREATE TABLE logs(id INTEGER PRIMARY KEY AUTOINCREMENT, taskid INTEGER, time TEXT, level TEXT, message TEXT)"
DATA_TABLE = "CREATE TABLE data(id INTEGER PRIMARY KEY AUTOINCREMENT, taskid INTEGER, status INTEGER, content_type INTEGER, value TEXT)"
ERRORS_TABLE = "CREATE TABLE errors(id INTEGER PRIMARY KEY AUTOINCREMENT, taskid INTEGER, error TEXT)"
def __init__(self):
pass
def __init__(self, database=None):
if database:
self.database = database
else:
self.database = db_filepath
def create(self):
_, self.database = tempfile.mkstemp(prefix="sqlmapipc-", text=False)
logger.debug("IPC database: %s" % self.database)
def connect(self):
def connect(self, who="server"):
self.connection = sqlite3.connect(self.database, timeout=3, isolation_level=None)
self.cursor = self.connection.cursor()
logger.debug("REST-JSON API %s connected to IPC database" % who)
def disconnect(self):
self.cursor.close()
@ -79,18 +83,13 @@ class Database(object):
if statement.lstrip().upper().startswith("SELECT"):
return self.cursor.fetchall()
def initialize(self):
self.create()
self.connect()
def init(self):
self.execute(self.LOGS_TABLE)
self.execute(self.DATA_TABLE)
self.execute(self.ERRORS_TABLE)
def get_filepath(self):
return self.database
class Task(object):
global db
global db_filepath
def __init__(self, taskid):
self.process = None
@ -109,7 +108,7 @@ class Task(object):
# Let sqlmap engine knows it is getting called by the API, the task ID and the file path of the IPC database
self.options.api = True
self.options.taskid = taskid
self.options.database = db.get_filepath()
self.options.database = db_filepath
# Enforce batch mode and disable coloring
self.options.batch = True
@ -174,10 +173,23 @@ class StdDbOut(object):
else:
sys.stderr = self
def write(self, value, status=None, content_type=None):
def write(self, value, status=CONTENT_STATUS.IN_PROGRESS, content_type=None):
if self.messagetype == "stdout":
#conf.database_cursor.execute("INSERT INTO data VALUES(NULL, ?, ?, ?, ?)",
# (self.taskid, status, content_type, base64pickle(value)))
if content_type is None:
content_type = 99
if status == CONTENT_STATUS.IN_PROGRESS:
output = conf.database_cursor.execute("SELECT id, value FROM data WHERE taskid = ? AND status = ? AND content_type = ? LIMIT 0,1",
(self.taskid, status, content_type))
if len(output) == 0:
conf.database_cursor.execute("INSERT INTO data VALUES(NULL, ?, ?, ?, ?)",
(self.taskid, status, content_type, jsonize(value)))
else:
new_value = "%s%s" % (output[0][1], value)
conf.database_cursor.execute("UPDATE data SET value = ? WHERE id = ?",
(jsonize(new_value), output[0][0]))
else:
conf.database_cursor.execute("INSERT INTO data VALUES(NULL, ?, ?, ?, ?)",
(self.taskid, status, content_type, jsonize(value)))
else:
@ -205,8 +217,11 @@ class LogRecorder(logging.StreamHandler):
def setRestAPILog():
if hasattr(conf, "api"):
conf.database_connection = sqlite3.connect(conf.database, timeout=1, isolation_level=None)
conf.database_cursor = conf.database_connection.cursor()
#conf.database_connection = sqlite3.connect(conf.database, timeout=1, isolation_level=None)
#conf.database_cursor = conf.database_connection.cursor()
conf.database_cursor = Database(conf.database)
conf.database_cursor.connect("client")
# Set a logging handler that writes log messages to a IPC database
logger.removeHandler(LOGGER_HANDLER)
@ -455,7 +470,6 @@ def scan_data(taskid):
# Read all data from the IPC database for the taskid
for status, content_type, value in db.execute("SELECT status, content_type, value FROM data WHERE taskid = ? ORDER BY id ASC", (taskid,)):
#json_data_message.append({"status": status, "type": content_type, "value": base64unpickle(value)})
json_data_message.append({"status": status, "type": content_type, "value": dejsonize(value)})
# Read all error messages from the IPC database
@ -536,15 +550,18 @@ def server(host="0.0.0.0", port=RESTAPI_SERVER_PORT):
"""
global adminid
global db
global db_filepath
adminid = hexencode(os.urandom(16))
logger.info("Running REST-JSON API server at '%s:%d'.." % (host, port))
logger.info("Admin ID: %s" % adminid)
logger.debug("IPC database: %s" % db_filepath)
# Initialize IPC database
db = Database()
db.initialize()
db.connect()
db.init()
# Run RESTful API
run(host=host, port=port, quiet=True, debug=False)