implementation of multithreading for UNION and ERROR techniques

This commit is contained in:
Miroslav Stampar 2011-05-29 23:17:50 +00:00
parent d51efa679d
commit 86455ceb9c
4 changed files with 149 additions and 27 deletions

View File

@ -1358,6 +1358,7 @@ def __setKnowledgeBaseAttributes(flushAll=True):
kb.locks = advancedDict()
kb.locks.cacheLock = threading.Lock()
kb.locks.logLock = threading.Lock()
kb.locks.ioLock = threading.Lock()
kb.matchRatio = None
kb.nullConnection = None

View File

@ -11,6 +11,11 @@ import difflib
import threading
from lib.core.data import kb
from lib.core.data import logger
from lib.core.datatype import advancedDict
from lib.core.exception import sqlmapThreadException
shared = advancedDict()
class ThreadData():
"""
@ -18,6 +23,8 @@ class ThreadData():
"""
def __init__(self):
global shared
self.disableStdOut = False
self.lastErrorPage = None
self.lastHTTPError = None
@ -26,6 +33,7 @@ class ThreadData():
self.lastRequestUID = 0
self.retriesCount = 0
self.seqMatcher = difflib.SequenceMatcher(None)
self.shared = shared
self.valueStack = []
def getCurrentThreadUID():
@ -40,3 +48,55 @@ def getCurrentThreadData():
if threadUID not in kb.threadData:
kb.threadData[threadUID] = ThreadData()
return kb.threadData[threadUID]
def runThreads(numThreads, threadFunction, cleanupFunction=None):
threads = []
kb.threadContinue = True
kb.threadException = False
if numThreads > 1:
infoMsg = "starting %d threads" % numThreads
logger.info(infoMsg)
else:
threadFunction()
return
# Start the threads
for numThread in range(numThreads):
thread = threading.Thread(target=threadFunction, name=str(numThread))
thread.start()
threads.append(thread)
# And wait for them to all finish
try:
alive = True
while alive:
alive = False
for thread in threads:
if thread.isAlive():
alive = True
thread.join(1)
except KeyboardInterrupt:
kb.threadContinue = False
kb.threadException = True
print
logger.debug("waiting for threads to finish")
try:
while (threading.activeCount() > 1):
pass
except KeyboardInterrupt:
raise sqlmapThreadException, "user aborted (Ctrl+C was pressed multiple times)"
finally:
kb.threadContinue = True
kb.threadException = False
if cleanupFunction:
cleanupFunction()

View File

@ -8,6 +8,7 @@ See the file 'doc/COPYING' for copying permission
"""
import re
import threading
import time
from lib.core.agent import agent
@ -39,6 +40,7 @@ from lib.core.settings import MSSQL_ERROR_CHUNK_LENGTH
from lib.core.settings import SQL_SCALAR_REGEX
from lib.core.settings import TURN_OFF_RESUME_INFO_LIMIT
from lib.core.threads import getCurrentThreadData
from lib.core.threads import runThreads
from lib.core.unescaper import unescaper
from lib.request.connect import Connect as Request
from lib.utils.resume import resume
@ -159,7 +161,9 @@ def __errorFields(expression, expressionFields, expressionFieldsList, expected=N
output = __oneShotErrorUse(expressionReplaced, field)
if output is not None:
kb.locks.ioLock.acquire()
dataToStdout("[%s] [INFO] retrieved: %s\r\n" % (time.strftime("%X"), safecharencode(output)))
kb.locks.ioLock.release()
if isinstance(num, int):
expression = origExpr
@ -316,13 +320,39 @@ def errorUse(expression, expected=None, resumeValue=True, dump=False):
infoMsg += "large number of rows (possible slowdown)"
logger.info(infoMsg)
for num in xrange(startLimit, stopLimit):
lockNames = ('limits', 'outputs')
for lock in lockNames:
kb.locks[lock] = threading.Lock()
threadData = getCurrentThreadData()
numThreads = min(conf.threads, stopLimit-startLimit)
threadData.shared.limits = range(startLimit, stopLimit)
threadData.shared.outputs = []
def errorThread():
try:
threadData = getCurrentThreadData()
while threadData.shared.limits and kb.threadContinue:
kb.locks.limits.acquire()
num = threadData.shared.limits[-1]
del threadData.shared.limits[-1]
kb.locks.limits.release()
output = __errorFields(expression, expressionFields, expressionFieldsList, expected, num, resumeValue)
if output and isinstance(output, list) and len(output) == 1:
output = output[0]
outputs.append(output)
kb.locks.outputs.acquire()
threadData.shared.outputs.append(output)
kb.locks.outputs.release()
except KeyboardInterrupt:
raise
runThreads(numThreads, errorThread)
outputs = threadData.shared.outputs
except KeyboardInterrupt:
warnMsg = "user aborted during enumeration. sqlmap "

View File

@ -9,6 +9,7 @@ See the file 'doc/COPYING' for copying permission
import logging
import re
import threading
import time
from lib.core.agent import agent
@ -39,6 +40,8 @@ from lib.core.exception import sqlmapSyntaxException
from lib.core.settings import FROM_TABLE
from lib.core.settings import SQL_SCALAR_REGEX
from lib.core.settings import TURN_OFF_RESUME_INFO_LIMIT
from lib.core.threads import getCurrentThreadData
from lib.core.threads import runThreads
from lib.core.unescaper import unescaper
from lib.request.connect import Connect as Request
from lib.utils.resume import resume
@ -260,7 +263,24 @@ def unionUse(expression, unpack=True, dump=False):
infoMsg += "large number of rows (possible slowdown)"
logger.info(infoMsg)
for num in xrange(startLimit, stopLimit):
lockNames = ('limits', 'value')
for lock in lockNames:
kb.locks[lock] = threading.Lock()
threadData = getCurrentThreadData()
numThreads = min(conf.threads, stopLimit-startLimit)
threadData.shared.limits = range(startLimit, stopLimit)
threadData.shared.value = ""
def unionThread():
threadData = getCurrentThreadData()
while threadData.shared.limits and kb.threadContinue:
kb.locks.limits.acquire()
num = threadData.shared.limits[-1]
del threadData.shared.limits[-1]
kb.locks.limits.release()
if Backend.getIdentifiedDbms() in (DBMS.MSSQL, DBMS.SYBASE):
field = expressionFieldsList[0]
elif Backend.isDbms(DBMS.ORACLE):
@ -275,17 +295,28 @@ def unionUse(expression, unpack=True, dump=False):
output = __oneShotUnionUse(limitedExpr, unpack)
if output:
value += output
kb.locks.value.acquire()
threadData.shared.value += output
kb.locks.value.release()
if conf.verbose == 1:
if all(map(lambda x: x in output, [kb.misc.start, kb.misc.stop])):
items = extractRegexResult(r'%s(?P<result>.*?)%s' % (kb.misc.start, kb.misc.stop), output, re.DOTALL | re.IGNORECASE).split(kb.misc.delimiter)
else:
items = output.replace(kb.misc.start, "").replace(kb.misc.stop, "").split(kb.misc.delimiter)
status = "[%s] [INFO] retrieved: %s\r\n" % (time.strftime("%X"), safecharencode(",".join(map(lambda x: "\"%s\"" % x, items))))
if len(status) > width:
status = "%s..." % status[:width - 3]
kb.locks.ioLock.acquire()
dataToStdout(status, True)
kb.locks.ioLock.release()
runThreads(numThreads, unionThread)
value = threadData.shared.value
if conf.verbose == 1:
clearConsoleLine(True)