implementation of multithreading for UNION and ERROR techniques

This commit is contained in:
Miroslav Stampar 2011-05-29 23:17:50 +00:00
parent d51efa679d
commit 86455ceb9c
4 changed files with 149 additions and 27 deletions

View File

@ -1358,6 +1358,7 @@ def __setKnowledgeBaseAttributes(flushAll=True):
kb.locks = advancedDict() kb.locks = advancedDict()
kb.locks.cacheLock = threading.Lock() kb.locks.cacheLock = threading.Lock()
kb.locks.logLock = threading.Lock() kb.locks.logLock = threading.Lock()
kb.locks.ioLock = threading.Lock()
kb.matchRatio = None kb.matchRatio = None
kb.nullConnection = None kb.nullConnection = None

View File

@ -11,6 +11,11 @@ import difflib
import threading import threading
from lib.core.data import kb from lib.core.data import kb
from lib.core.data import logger
from lib.core.datatype import advancedDict
from lib.core.exception import sqlmapThreadException
shared = advancedDict()
class ThreadData(): class ThreadData():
""" """
@ -18,6 +23,8 @@ class ThreadData():
""" """
def __init__(self): def __init__(self):
global shared
self.disableStdOut = False self.disableStdOut = False
self.lastErrorPage = None self.lastErrorPage = None
self.lastHTTPError = None self.lastHTTPError = None
@ -26,6 +33,7 @@ class ThreadData():
self.lastRequestUID = 0 self.lastRequestUID = 0
self.retriesCount = 0 self.retriesCount = 0
self.seqMatcher = difflib.SequenceMatcher(None) self.seqMatcher = difflib.SequenceMatcher(None)
self.shared = shared
self.valueStack = [] self.valueStack = []
def getCurrentThreadUID(): def getCurrentThreadUID():
@ -40,3 +48,55 @@ def getCurrentThreadData():
if threadUID not in kb.threadData: if threadUID not in kb.threadData:
kb.threadData[threadUID] = ThreadData() kb.threadData[threadUID] = ThreadData()
return kb.threadData[threadUID] return kb.threadData[threadUID]
def runThreads(numThreads, threadFunction, cleanupFunction=None):
threads = []
kb.threadContinue = True
kb.threadException = False
if numThreads > 1:
infoMsg = "starting %d threads" % numThreads
logger.info(infoMsg)
else:
threadFunction()
return
# Start the threads
for numThread in range(numThreads):
thread = threading.Thread(target=threadFunction, name=str(numThread))
thread.start()
threads.append(thread)
# And wait for them to all finish
try:
alive = True
while alive:
alive = False
for thread in threads:
if thread.isAlive():
alive = True
thread.join(1)
except KeyboardInterrupt:
kb.threadContinue = False
kb.threadException = True
print
logger.debug("waiting for threads to finish")
try:
while (threading.activeCount() > 1):
pass
except KeyboardInterrupt:
raise sqlmapThreadException, "user aborted (Ctrl+C was pressed multiple times)"
finally:
kb.threadContinue = True
kb.threadException = False
if cleanupFunction:
cleanupFunction()

View File

@ -8,6 +8,7 @@ See the file 'doc/COPYING' for copying permission
""" """
import re import re
import threading
import time import time
from lib.core.agent import agent from lib.core.agent import agent
@ -39,6 +40,7 @@ from lib.core.settings import MSSQL_ERROR_CHUNK_LENGTH
from lib.core.settings import SQL_SCALAR_REGEX from lib.core.settings import SQL_SCALAR_REGEX
from lib.core.settings import TURN_OFF_RESUME_INFO_LIMIT from lib.core.settings import TURN_OFF_RESUME_INFO_LIMIT
from lib.core.threads import getCurrentThreadData from lib.core.threads import getCurrentThreadData
from lib.core.threads import runThreads
from lib.core.unescaper import unescaper from lib.core.unescaper import unescaper
from lib.request.connect import Connect as Request from lib.request.connect import Connect as Request
from lib.utils.resume import resume from lib.utils.resume import resume
@ -159,7 +161,9 @@ def __errorFields(expression, expressionFields, expressionFieldsList, expected=N
output = __oneShotErrorUse(expressionReplaced, field) output = __oneShotErrorUse(expressionReplaced, field)
if output is not None: if output is not None:
kb.locks.ioLock.acquire()
dataToStdout("[%s] [INFO] retrieved: %s\r\n" % (time.strftime("%X"), safecharencode(output))) dataToStdout("[%s] [INFO] retrieved: %s\r\n" % (time.strftime("%X"), safecharencode(output)))
kb.locks.ioLock.release()
if isinstance(num, int): if isinstance(num, int):
expression = origExpr expression = origExpr
@ -316,13 +320,39 @@ def errorUse(expression, expected=None, resumeValue=True, dump=False):
infoMsg += "large number of rows (possible slowdown)" infoMsg += "large number of rows (possible slowdown)"
logger.info(infoMsg) logger.info(infoMsg)
for num in xrange(startLimit, stopLimit): lockNames = ('limits', 'outputs')
output = __errorFields(expression, expressionFields, expressionFieldsList, expected, num, resumeValue) for lock in lockNames:
kb.locks[lock] = threading.Lock()
if output and isinstance(output, list) and len(output) == 1: threadData = getCurrentThreadData()
output = output[0] numThreads = min(conf.threads, stopLimit-startLimit)
threadData.shared.limits = range(startLimit, stopLimit)
threadData.shared.outputs = []
outputs.append(output) def errorThread():
try:
threadData = getCurrentThreadData()
while threadData.shared.limits and kb.threadContinue:
kb.locks.limits.acquire()
num = threadData.shared.limits[-1]
del threadData.shared.limits[-1]
kb.locks.limits.release()
output = __errorFields(expression, expressionFields, expressionFieldsList, expected, num, resumeValue)
if output and isinstance(output, list) and len(output) == 1:
output = output[0]
kb.locks.outputs.acquire()
threadData.shared.outputs.append(output)
kb.locks.outputs.release()
except KeyboardInterrupt:
raise
runThreads(numThreads, errorThread)
outputs = threadData.shared.outputs
except KeyboardInterrupt: except KeyboardInterrupt:
warnMsg = "user aborted during enumeration. sqlmap " warnMsg = "user aborted during enumeration. sqlmap "

View File

@ -9,6 +9,7 @@ See the file 'doc/COPYING' for copying permission
import logging import logging
import re import re
import threading
import time import time
from lib.core.agent import agent from lib.core.agent import agent
@ -39,6 +40,8 @@ from lib.core.exception import sqlmapSyntaxException
from lib.core.settings import FROM_TABLE from lib.core.settings import FROM_TABLE
from lib.core.settings import SQL_SCALAR_REGEX from lib.core.settings import SQL_SCALAR_REGEX
from lib.core.settings import TURN_OFF_RESUME_INFO_LIMIT from lib.core.settings import TURN_OFF_RESUME_INFO_LIMIT
from lib.core.threads import getCurrentThreadData
from lib.core.threads import runThreads
from lib.core.unescaper import unescaper from lib.core.unescaper import unescaper
from lib.request.connect import Connect as Request from lib.request.connect import Connect as Request
from lib.utils.resume import resume from lib.utils.resume import resume
@ -260,32 +263,60 @@ def unionUse(expression, unpack=True, dump=False):
infoMsg += "large number of rows (possible slowdown)" infoMsg += "large number of rows (possible slowdown)"
logger.info(infoMsg) logger.info(infoMsg)
for num in xrange(startLimit, stopLimit): lockNames = ('limits', 'value')
if Backend.getIdentifiedDbms() in (DBMS.MSSQL, DBMS.SYBASE): for lock in lockNames:
field = expressionFieldsList[0] kb.locks[lock] = threading.Lock()
elif Backend.isDbms(DBMS.ORACLE):
field = expressionFieldsList
else:
field = None
limitedExpr = agent.limitQuery(num, expression, field) threadData = getCurrentThreadData()
output = resume(limitedExpr, None) numThreads = min(conf.threads, stopLimit-startLimit)
threadData.shared.limits = range(startLimit, stopLimit)
threadData.shared.value = ""
if not output: def unionThread():
output = __oneShotUnionUse(limitedExpr, unpack) threadData = getCurrentThreadData()
if output: while threadData.shared.limits and kb.threadContinue:
value += output kb.locks.limits.acquire()
num = threadData.shared.limits[-1]
del threadData.shared.limits[-1]
kb.locks.limits.release()
if conf.verbose == 1: if Backend.getIdentifiedDbms() in (DBMS.MSSQL, DBMS.SYBASE):
if all(map(lambda x: x in output, [kb.misc.start, kb.misc.stop])): field = expressionFieldsList[0]
items = extractRegexResult(r'%s(?P<result>.*?)%s' % (kb.misc.start, kb.misc.stop), output, re.DOTALL | re.IGNORECASE).split(kb.misc.delimiter) elif Backend.isDbms(DBMS.ORACLE):
else: field = expressionFieldsList
items = output.replace(kb.misc.start, "").replace(kb.misc.stop, "").split(kb.misc.delimiter) else:
status = "[%s] [INFO] retrieved: %s\r\n" % (time.strftime("%X"), safecharencode(",".join(map(lambda x: "\"%s\"" % x, items)))) field = None
if len(status) > width:
status = "%s..." % status[:width - 3] limitedExpr = agent.limitQuery(num, expression, field)
dataToStdout(status, True) output = resume(limitedExpr, None)
if not output:
output = __oneShotUnionUse(limitedExpr, unpack)
if output:
kb.locks.value.acquire()
threadData.shared.value += output
kb.locks.value.release()
if conf.verbose == 1:
if all(map(lambda x: x in output, [kb.misc.start, kb.misc.stop])):
items = extractRegexResult(r'%s(?P<result>.*?)%s' % (kb.misc.start, kb.misc.stop), output, re.DOTALL | re.IGNORECASE).split(kb.misc.delimiter)
else:
items = output.replace(kb.misc.start, "").replace(kb.misc.stop, "").split(kb.misc.delimiter)
status = "[%s] [INFO] retrieved: %s\r\n" % (time.strftime("%X"), safecharencode(",".join(map(lambda x: "\"%s\"" % x, items))))
if len(status) > width:
status = "%s..." % status[:width - 3]
kb.locks.ioLock.acquire()
dataToStdout(status, True)
kb.locks.ioLock.release()
runThreads(numThreads, unionThread)
value = threadData.shared.value
if conf.verbose == 1: if conf.verbose == 1:
clearConsoleLine(True) clearConsoleLine(True)