From 86455ceb9c56bbe2ad48f36c12ebe89b8cbfb333 Mon Sep 17 00:00:00 2001 From: Miroslav Stampar Date: Sun, 29 May 2011 23:17:50 +0000 Subject: [PATCH] implementation of multithreading for UNION and ERROR techniques --- lib/core/option.py | 1 + lib/core/threads.py | 60 ++++++++++++++++++++++++ lib/techniques/error/use.py | 40 ++++++++++++++-- lib/techniques/inband/union/use.py | 75 +++++++++++++++++++++--------- 4 files changed, 149 insertions(+), 27 deletions(-) diff --git a/lib/core/option.py b/lib/core/option.py index 13cd5353d..ddc293fa1 100644 --- a/lib/core/option.py +++ b/lib/core/option.py @@ -1358,6 +1358,7 @@ def __setKnowledgeBaseAttributes(flushAll=True): kb.locks = advancedDict() kb.locks.cacheLock = threading.Lock() kb.locks.logLock = threading.Lock() + kb.locks.ioLock = threading.Lock() kb.matchRatio = None kb.nullConnection = None diff --git a/lib/core/threads.py b/lib/core/threads.py index d54a05406..20f207a2d 100644 --- a/lib/core/threads.py +++ b/lib/core/threads.py @@ -11,6 +11,11 @@ import difflib import threading from lib.core.data import kb +from lib.core.data import logger +from lib.core.datatype import advancedDict +from lib.core.exception import sqlmapThreadException + +shared = advancedDict() class ThreadData(): """ @@ -18,6 +23,8 @@ class ThreadData(): """ def __init__(self): + global shared + self.disableStdOut = False self.lastErrorPage = None self.lastHTTPError = None @@ -26,6 +33,7 @@ class ThreadData(): self.lastRequestUID = 0 self.retriesCount = 0 self.seqMatcher = difflib.SequenceMatcher(None) + self.shared = shared self.valueStack = [] def getCurrentThreadUID(): @@ -40,3 +48,55 @@ def getCurrentThreadData(): if threadUID not in kb.threadData: kb.threadData[threadUID] = ThreadData() return kb.threadData[threadUID] + +def runThreads(numThreads, threadFunction, cleanupFunction=None): + threads = [] + + kb.threadContinue = True + kb.threadException = False + + if numThreads > 1: + infoMsg = "starting %d threads" % numThreads + logger.info(infoMsg) + else: + threadFunction() + return + + # Start the threads + for numThread in range(numThreads): + thread = threading.Thread(target=threadFunction, name=str(numThread)) + thread.start() + threads.append(thread) + + # And wait for them to all finish + try: + alive = True + + while alive: + alive = False + + for thread in threads: + if thread.isAlive(): + alive = True + thread.join(1) + + except KeyboardInterrupt: + kb.threadContinue = False + kb.threadException = True + + print + logger.debug("waiting for threads to finish") + + try: + while (threading.activeCount() > 1): + pass + + except KeyboardInterrupt: + raise sqlmapThreadException, "user aborted (Ctrl+C was pressed multiple times)" + + finally: + kb.threadContinue = True + kb.threadException = False + + if cleanupFunction: + cleanupFunction() diff --git a/lib/techniques/error/use.py b/lib/techniques/error/use.py index b5ad1c3a0..7f37bb506 100644 --- a/lib/techniques/error/use.py +++ b/lib/techniques/error/use.py @@ -8,6 +8,7 @@ See the file 'doc/COPYING' for copying permission """ import re +import threading import time from lib.core.agent import agent @@ -39,6 +40,7 @@ from lib.core.settings import MSSQL_ERROR_CHUNK_LENGTH from lib.core.settings import SQL_SCALAR_REGEX from lib.core.settings import TURN_OFF_RESUME_INFO_LIMIT from lib.core.threads import getCurrentThreadData +from lib.core.threads import runThreads from lib.core.unescaper import unescaper from lib.request.connect import Connect as Request from lib.utils.resume import resume @@ -159,7 +161,9 @@ def __errorFields(expression, expressionFields, expressionFieldsList, expected=N output = __oneShotErrorUse(expressionReplaced, field) if output is not None: + kb.locks.ioLock.acquire() dataToStdout("[%s] [INFO] retrieved: %s\r\n" % (time.strftime("%X"), safecharencode(output))) + kb.locks.ioLock.release() if isinstance(num, int): expression = origExpr @@ -316,13 +320,39 @@ def errorUse(expression, expected=None, resumeValue=True, dump=False): infoMsg += "large number of rows (possible slowdown)" logger.info(infoMsg) - for num in xrange(startLimit, stopLimit): - output = __errorFields(expression, expressionFields, expressionFieldsList, expected, num, resumeValue) + lockNames = ('limits', 'outputs') + for lock in lockNames: + kb.locks[lock] = threading.Lock() - if output and isinstance(output, list) and len(output) == 1: - output = output[0] + threadData = getCurrentThreadData() + numThreads = min(conf.threads, stopLimit-startLimit) + threadData.shared.limits = range(startLimit, stopLimit) + threadData.shared.outputs = [] - outputs.append(output) + def errorThread(): + try: + threadData = getCurrentThreadData() + + while threadData.shared.limits and kb.threadContinue: + kb.locks.limits.acquire() + num = threadData.shared.limits[-1] + del threadData.shared.limits[-1] + kb.locks.limits.release() + + output = __errorFields(expression, expressionFields, expressionFieldsList, expected, num, resumeValue) + + if output and isinstance(output, list) and len(output) == 1: + output = output[0] + + kb.locks.outputs.acquire() + threadData.shared.outputs.append(output) + kb.locks.outputs.release() + except KeyboardInterrupt: + raise + + runThreads(numThreads, errorThread) + + outputs = threadData.shared.outputs except KeyboardInterrupt: warnMsg = "user aborted during enumeration. sqlmap " diff --git a/lib/techniques/inband/union/use.py b/lib/techniques/inband/union/use.py index ffcb1f586..497d76a26 100644 --- a/lib/techniques/inband/union/use.py +++ b/lib/techniques/inband/union/use.py @@ -9,6 +9,7 @@ See the file 'doc/COPYING' for copying permission import logging import re +import threading import time from lib.core.agent import agent @@ -39,6 +40,8 @@ from lib.core.exception import sqlmapSyntaxException from lib.core.settings import FROM_TABLE from lib.core.settings import SQL_SCALAR_REGEX from lib.core.settings import TURN_OFF_RESUME_INFO_LIMIT +from lib.core.threads import getCurrentThreadData +from lib.core.threads import runThreads from lib.core.unescaper import unescaper from lib.request.connect import Connect as Request from lib.utils.resume import resume @@ -260,32 +263,60 @@ def unionUse(expression, unpack=True, dump=False): infoMsg += "large number of rows (possible slowdown)" logger.info(infoMsg) - for num in xrange(startLimit, stopLimit): - if Backend.getIdentifiedDbms() in (DBMS.MSSQL, DBMS.SYBASE): - field = expressionFieldsList[0] - elif Backend.isDbms(DBMS.ORACLE): - field = expressionFieldsList - else: - field = None + lockNames = ('limits', 'value') + for lock in lockNames: + kb.locks[lock] = threading.Lock() - limitedExpr = agent.limitQuery(num, expression, field) - output = resume(limitedExpr, None) + threadData = getCurrentThreadData() + numThreads = min(conf.threads, stopLimit-startLimit) + threadData.shared.limits = range(startLimit, stopLimit) + threadData.shared.value = "" - if not output: - output = __oneShotUnionUse(limitedExpr, unpack) + def unionThread(): + threadData = getCurrentThreadData() - if output: - value += output + while threadData.shared.limits and kb.threadContinue: + kb.locks.limits.acquire() + num = threadData.shared.limits[-1] + del threadData.shared.limits[-1] + kb.locks.limits.release() - if conf.verbose == 1: - if all(map(lambda x: x in output, [kb.misc.start, kb.misc.stop])): - items = extractRegexResult(r'%s(?P.*?)%s' % (kb.misc.start, kb.misc.stop), output, re.DOTALL | re.IGNORECASE).split(kb.misc.delimiter) - else: - items = output.replace(kb.misc.start, "").replace(kb.misc.stop, "").split(kb.misc.delimiter) - status = "[%s] [INFO] retrieved: %s\r\n" % (time.strftime("%X"), safecharencode(",".join(map(lambda x: "\"%s\"" % x, items)))) - if len(status) > width: - status = "%s..." % status[:width - 3] - dataToStdout(status, True) + if Backend.getIdentifiedDbms() in (DBMS.MSSQL, DBMS.SYBASE): + field = expressionFieldsList[0] + elif Backend.isDbms(DBMS.ORACLE): + field = expressionFieldsList + else: + field = None + + limitedExpr = agent.limitQuery(num, expression, field) + output = resume(limitedExpr, None) + + if not output: + output = __oneShotUnionUse(limitedExpr, unpack) + + if output: + kb.locks.value.acquire() + threadData.shared.value += output + kb.locks.value.release() + + if conf.verbose == 1: + if all(map(lambda x: x in output, [kb.misc.start, kb.misc.stop])): + items = extractRegexResult(r'%s(?P.*?)%s' % (kb.misc.start, kb.misc.stop), output, re.DOTALL | re.IGNORECASE).split(kb.misc.delimiter) + else: + items = output.replace(kb.misc.start, "").replace(kb.misc.stop, "").split(kb.misc.delimiter) + + status = "[%s] [INFO] retrieved: %s\r\n" % (time.strftime("%X"), safecharencode(",".join(map(lambda x: "\"%s\"" % x, items)))) + + if len(status) > width: + status = "%s..." % status[:width - 3] + + kb.locks.ioLock.acquire() + dataToStdout(status, True) + kb.locks.ioLock.release() + + runThreads(numThreads, unionThread) + + value = threadData.shared.value if conf.verbose == 1: clearConsoleLine(True)