mirror of
https://github.com/sqlmapproject/sqlmap.git
synced 2024-11-25 02:53:46 +03:00
improvement to restful API to store to IPC database partial entries, not yet functional (issue #297)
This commit is contained in:
parent
a92f1fb3b4
commit
f8bc74758c
|
@ -129,8 +129,7 @@ def main():
|
|||
|
||||
if hasattr(conf, "api"):
|
||||
try:
|
||||
conf.database_cursor.close()
|
||||
conf.database_connection.close()
|
||||
conf.database_cursor.disconnect()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
|
|
|
@ -58,6 +58,7 @@ from lib.core.dicts import DBMS_DICT
|
|||
from lib.core.dicts import DEPRECATED_OPTIONS
|
||||
from lib.core.dicts import SQL_STATEMENTS
|
||||
from lib.core.enums import ADJUST_TIME_DELAY
|
||||
from lib.core.enums import CONTENT_STATUS
|
||||
from lib.core.enums import CHARSET_TYPE
|
||||
from lib.core.enums import DBMS
|
||||
from lib.core.enums import EXPECTED
|
||||
|
@ -744,7 +745,7 @@ def setColor(message, bold=False):
|
|||
|
||||
return retVal
|
||||
|
||||
def dataToStdout(data, forceOutput=False, bold=False, content_type=None, status=None):
|
||||
def dataToStdout(data, forceOutput=False, bold=False, content_type=None, status=CONTENT_STATUS.IN_PROGRESS):
|
||||
"""
|
||||
Writes text to the stdout (console) stream
|
||||
"""
|
||||
|
@ -762,7 +763,6 @@ def dataToStdout(data, forceOutput=False, bold=False, content_type=None, status=
|
|||
message = data
|
||||
|
||||
if hasattr(conf, "api"):
|
||||
if content_type and status:
|
||||
sys.stdout.write(message, status, content_type)
|
||||
else:
|
||||
sys.stdout.write(setColor(message, bold))
|
||||
|
|
|
@ -26,7 +26,7 @@ from lib.core.data import conf
|
|||
from lib.core.data import kb
|
||||
from lib.core.data import logger
|
||||
from lib.core.dicts import DUMP_REPLACEMENTS
|
||||
from lib.core.enums import API_CONTENT_STATUS
|
||||
from lib.core.enums import CONTENT_STATUS
|
||||
from lib.core.enums import CONTENT_TYPE
|
||||
from lib.core.enums import DBMS
|
||||
from lib.core.enums import DUMP_FORMAT
|
||||
|
@ -55,7 +55,7 @@ class Dump(object):
|
|||
|
||||
def _write(self, data, newline=True, console=True, content_type=None):
|
||||
if hasattr(conf, "api"):
|
||||
dataToStdout(data, content_type=content_type, status=API_CONTENT_STATUS.COMPLETE)
|
||||
dataToStdout(data, content_type=content_type, status=CONTENT_STATUS.COMPLETE)
|
||||
return
|
||||
|
||||
text = "%s%s" % (data, "\n" if newline else " ")
|
||||
|
|
|
@ -271,6 +271,6 @@ class CONTENT_TYPE:
|
|||
OS_CMD = 23
|
||||
REG_READ = 24
|
||||
|
||||
class API_CONTENT_STATUS:
|
||||
class CONTENT_STATUS:
|
||||
IN_PROGRESS = 0
|
||||
COMPLETE = 1
|
||||
|
|
|
@ -88,8 +88,8 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
|
|||
|
||||
try:
|
||||
# Set kb.partRun in case "common prediction" feature (a.k.a. "good
|
||||
# samaritan") is used
|
||||
kb.partRun = getPartRun() if conf.predictOutput else None
|
||||
# samaritan") is used or the engine is called from the API
|
||||
kb.partRun = getPartRun() if conf.predictOutput or hasattr(conf, "api") else None
|
||||
|
||||
if partialValue:
|
||||
firstChar = len(partialValue)
|
||||
|
@ -486,7 +486,7 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
|
|||
if result:
|
||||
if showEta:
|
||||
etaProgressUpdate(time.time() - charStart, len(commonValue))
|
||||
elif conf.verbose in (1, 2):
|
||||
elif conf.verbose in (1, 2) or hasattr(conf, "api"):
|
||||
dataToStdout(filterControlChars(commonValue[index - 1:]))
|
||||
|
||||
finalValue = commonValue
|
||||
|
@ -534,7 +534,7 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
|
|||
|
||||
if showEta:
|
||||
etaProgressUpdate(time.time() - charStart, index)
|
||||
elif conf.verbose in (1, 2):
|
||||
elif conf.verbose in (1, 2) or hasattr(conf, "api"):
|
||||
dataToStdout(filterControlChars(val))
|
||||
|
||||
# some DBMSes (e.g. Firebird, DB2, etc.) have issues with trailing spaces
|
||||
|
|
|
@ -16,6 +16,7 @@ from lib.core.common import calculateDeltaSeconds
|
|||
from lib.core.common import dataToStdout
|
||||
from lib.core.common import decodeHexValue
|
||||
from lib.core.common import extractRegexResult
|
||||
from lib.core.common import getPartRun
|
||||
from lib.core.common import getUnicode
|
||||
from lib.core.common import hashDBRetrieve
|
||||
from lib.core.common import hashDBWrite
|
||||
|
@ -243,6 +244,9 @@ def errorUse(expression, dump=False):
|
|||
|
||||
_, _, _, _, _, expressionFieldsList, expressionFields, _ = agent.getFields(expression)
|
||||
|
||||
# Set kb.partRun in case the engine is called from the API
|
||||
kb.partRun = getPartRun() if hasattr(conf, "api") else None
|
||||
|
||||
# We have to check if the SQL query might return multiple entries
|
||||
# and in such case forge the SQL limiting the query output one
|
||||
# entry at a time
|
||||
|
|
|
@ -19,6 +19,7 @@ from lib.core.common import dataToStdout
|
|||
from lib.core.common import extractRegexResult
|
||||
from lib.core.common import flattenValue
|
||||
from lib.core.common import getConsoleWidth
|
||||
from lib.core.common import getPartRun
|
||||
from lib.core.common import getUnicode
|
||||
from lib.core.common import hashDBRetrieve
|
||||
from lib.core.common import hashDBWrite
|
||||
|
@ -163,6 +164,9 @@ def unionUse(expression, unpack=True, dump=False):
|
|||
|
||||
_, _, _, _, _, expressionFieldsList, expressionFields, _ = agent.getFields(origExpr)
|
||||
|
||||
# Set kb.partRun in case the engine is called from the API
|
||||
kb.partRun = getPartRun() if hasattr(conf, "api") else None
|
||||
|
||||
if expressionFieldsList and len(expressionFieldsList) > 1 and "ORDER BY" in expression.upper():
|
||||
# Removed ORDER BY clause because UNION does not play well with it
|
||||
expression = re.sub("\s*ORDER BY\s+[\w,]+", "", expression, re.I)
|
||||
|
|
|
@ -17,15 +17,16 @@ from subprocess import PIPE
|
|||
|
||||
from lib.core.common import unArrayizeValue
|
||||
from lib.core.convert import base64pickle
|
||||
from lib.core.convert import base64unpickle
|
||||
from lib.core.convert import hexencode
|
||||
from lib.core.convert import dejsonize
|
||||
from lib.core.convert import jsonize
|
||||
from lib.core.data import conf
|
||||
from lib.core.data import kb
|
||||
from lib.core.data import paths
|
||||
from lib.core.data import logger
|
||||
from lib.core.datatype import AttribDict
|
||||
from lib.core.defaults import _defaults
|
||||
from lib.core.enums import CONTENT_STATUS
|
||||
from lib.core.log import LOGGER_HANDLER
|
||||
from lib.core.optiondict import optDict
|
||||
from lib.core.subprocessng import Popen
|
||||
|
@ -47,24 +48,27 @@ RESTAPI_SERVER_PORT = 8775
|
|||
# Local global variables
|
||||
adminid = ""
|
||||
db = None
|
||||
db_filepath = tempfile.mkstemp(prefix="sqlmapipc-", text=False)[1]
|
||||
tasks = dict()
|
||||
|
||||
# API objects
|
||||
class Database(object):
|
||||
global db_filepath
|
||||
|
||||
LOGS_TABLE = "CREATE TABLE logs(id INTEGER PRIMARY KEY AUTOINCREMENT, taskid INTEGER, time TEXT, level TEXT, message TEXT)"
|
||||
DATA_TABLE = "CREATE TABLE data(id INTEGER PRIMARY KEY AUTOINCREMENT, taskid INTEGER, status INTEGER, content_type INTEGER, value TEXT)"
|
||||
ERRORS_TABLE = "CREATE TABLE errors(id INTEGER PRIMARY KEY AUTOINCREMENT, taskid INTEGER, error TEXT)"
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self, database=None):
|
||||
if database:
|
||||
self.database = database
|
||||
else:
|
||||
self.database = db_filepath
|
||||
|
||||
def create(self):
|
||||
_, self.database = tempfile.mkstemp(prefix="sqlmapipc-", text=False)
|
||||
logger.debug("IPC database: %s" % self.database)
|
||||
|
||||
def connect(self):
|
||||
def connect(self, who="server"):
|
||||
self.connection = sqlite3.connect(self.database, timeout=3, isolation_level=None)
|
||||
self.cursor = self.connection.cursor()
|
||||
logger.debug("REST-JSON API %s connected to IPC database" % who)
|
||||
|
||||
def disconnect(self):
|
||||
self.cursor.close()
|
||||
|
@ -79,18 +83,13 @@ class Database(object):
|
|||
if statement.lstrip().upper().startswith("SELECT"):
|
||||
return self.cursor.fetchall()
|
||||
|
||||
def initialize(self):
|
||||
self.create()
|
||||
self.connect()
|
||||
def init(self):
|
||||
self.execute(self.LOGS_TABLE)
|
||||
self.execute(self.DATA_TABLE)
|
||||
self.execute(self.ERRORS_TABLE)
|
||||
|
||||
def get_filepath(self):
|
||||
return self.database
|
||||
|
||||
class Task(object):
|
||||
global db
|
||||
global db_filepath
|
||||
|
||||
def __init__(self, taskid):
|
||||
self.process = None
|
||||
|
@ -109,7 +108,7 @@ class Task(object):
|
|||
# Let sqlmap engine knows it is getting called by the API, the task ID and the file path of the IPC database
|
||||
self.options.api = True
|
||||
self.options.taskid = taskid
|
||||
self.options.database = db.get_filepath()
|
||||
self.options.database = db_filepath
|
||||
|
||||
# Enforce batch mode and disable coloring
|
||||
self.options.batch = True
|
||||
|
@ -174,10 +173,23 @@ class StdDbOut(object):
|
|||
else:
|
||||
sys.stderr = self
|
||||
|
||||
def write(self, value, status=None, content_type=None):
|
||||
def write(self, value, status=CONTENT_STATUS.IN_PROGRESS, content_type=None):
|
||||
if self.messagetype == "stdout":
|
||||
#conf.database_cursor.execute("INSERT INTO data VALUES(NULL, ?, ?, ?, ?)",
|
||||
# (self.taskid, status, content_type, base64pickle(value)))
|
||||
if content_type is None:
|
||||
content_type = 99
|
||||
|
||||
if status == CONTENT_STATUS.IN_PROGRESS:
|
||||
output = conf.database_cursor.execute("SELECT id, value FROM data WHERE taskid = ? AND status = ? AND content_type = ? LIMIT 0,1",
|
||||
(self.taskid, status, content_type))
|
||||
|
||||
if len(output) == 0:
|
||||
conf.database_cursor.execute("INSERT INTO data VALUES(NULL, ?, ?, ?, ?)",
|
||||
(self.taskid, status, content_type, jsonize(value)))
|
||||
else:
|
||||
new_value = "%s%s" % (output[0][1], value)
|
||||
conf.database_cursor.execute("UPDATE data SET value = ? WHERE id = ?",
|
||||
(jsonize(new_value), output[0][0]))
|
||||
else:
|
||||
conf.database_cursor.execute("INSERT INTO data VALUES(NULL, ?, ?, ?, ?)",
|
||||
(self.taskid, status, content_type, jsonize(value)))
|
||||
else:
|
||||
|
@ -205,8 +217,11 @@ class LogRecorder(logging.StreamHandler):
|
|||
|
||||
def setRestAPILog():
|
||||
if hasattr(conf, "api"):
|
||||
conf.database_connection = sqlite3.connect(conf.database, timeout=1, isolation_level=None)
|
||||
conf.database_cursor = conf.database_connection.cursor()
|
||||
#conf.database_connection = sqlite3.connect(conf.database, timeout=1, isolation_level=None)
|
||||
#conf.database_cursor = conf.database_connection.cursor()
|
||||
|
||||
conf.database_cursor = Database(conf.database)
|
||||
conf.database_cursor.connect("client")
|
||||
|
||||
# Set a logging handler that writes log messages to a IPC database
|
||||
logger.removeHandler(LOGGER_HANDLER)
|
||||
|
@ -455,7 +470,6 @@ def scan_data(taskid):
|
|||
|
||||
# Read all data from the IPC database for the taskid
|
||||
for status, content_type, value in db.execute("SELECT status, content_type, value FROM data WHERE taskid = ? ORDER BY id ASC", (taskid,)):
|
||||
#json_data_message.append({"status": status, "type": content_type, "value": base64unpickle(value)})
|
||||
json_data_message.append({"status": status, "type": content_type, "value": dejsonize(value)})
|
||||
|
||||
# Read all error messages from the IPC database
|
||||
|
@ -536,15 +550,18 @@ def server(host="0.0.0.0", port=RESTAPI_SERVER_PORT):
|
|||
"""
|
||||
global adminid
|
||||
global db
|
||||
global db_filepath
|
||||
|
||||
adminid = hexencode(os.urandom(16))
|
||||
|
||||
logger.info("Running REST-JSON API server at '%s:%d'.." % (host, port))
|
||||
logger.info("Admin ID: %s" % adminid)
|
||||
logger.debug("IPC database: %s" % db_filepath)
|
||||
|
||||
# Initialize IPC database
|
||||
db = Database()
|
||||
db.initialize()
|
||||
db.connect()
|
||||
db.init()
|
||||
|
||||
# Run RESTful API
|
||||
run(host=host, port=port, quiet=True, debug=False)
|
||||
|
|
Loading…
Reference in New Issue
Block a user