refactoring and stabilization of multithreading

This commit is contained in:
Miroslav Stampar 2011-06-07 09:50:00 +00:00
parent 5f7858455d
commit 7a3cc38e3c
6 changed files with 129 additions and 142 deletions

View File

@ -20,6 +20,9 @@ import threading
import urllib2 import urllib2
import urlparse import urlparse
import lib.core.common
import lib.core.threads
from extra.clientform.clientform import ParseResponse from extra.clientform.clientform import ParseResponse
from extra.clientform.clientform import ParseError from extra.clientform.clientform import ParseError
from extra.keepalive import keepalive from extra.keepalive import keepalive
@ -109,6 +112,7 @@ from lib.request.basicauthhandler import SmartHTTPBasicAuthHandler
from lib.request.certhandler import HTTPSCertAuthHandler from lib.request.certhandler import HTTPSCertAuthHandler
from lib.request.rangehandler import HTTPRangeHandler from lib.request.rangehandler import HTTPRangeHandler
from lib.request.redirecthandler import SmartRedirectHandler from lib.request.redirecthandler import SmartRedirectHandler
from lib.request.templates import getPageTemplate
from lib.utils.google import Google from lib.utils.google import Google
authHandler = urllib2.BaseHandler() authHandler = urllib2.BaseHandler()
@ -1360,8 +1364,10 @@ def __setKnowledgeBaseAttributes(flushAll=True):
kb.locks.cacheLock = threading.Lock() kb.locks.cacheLock = threading.Lock()
kb.locks.logLock = threading.Lock() kb.locks.logLock = threading.Lock()
kb.locks.ioLock = threading.Lock() kb.locks.ioLock = threading.Lock()
kb.locks.countLock = threading.Lock()
kb.matchRatio = None kb.matchRatio = None
kb.multiThreadMode = False
kb.nullConnection = None kb.nullConnection = None
kb.pageTemplate = None kb.pageTemplate = None
kb.pageTemplates = dict() kb.pageTemplates = dict()
@ -1701,6 +1707,10 @@ def __basicOptionValidation():
errMsg += "to get the full list of supported charsets" errMsg += "to get the full list of supported charsets"
raise sqlmapSyntaxException, errMsg raise sqlmapSyntaxException, errMsg
def __resolveCrossReferences():
lib.core.threads.readInput = readInput
lib.core.common.getPageTemplate = getPageTemplate
def init(inputOptions=advancedDict(), overrideOptions=False): def init(inputOptions=advancedDict(), overrideOptions=False):
""" """
Set attributes into both configuration and knowledge base singletons Set attributes into both configuration and knowledge base singletons
@ -1720,6 +1730,7 @@ def init(inputOptions=advancedDict(), overrideOptions=False):
__setMultipleTargets() __setMultipleTargets()
__setTamperingFunctions() __setTamperingFunctions()
__setTrafficOutputFP() __setTrafficOutputFP()
__resolveCrossReferences()
parseTargetUrl() parseTargetUrl()
parseTargetDirect() parseTargetDirect()

View File

@ -14,6 +14,7 @@ from lib.core.data import kb
from lib.core.data import logger from lib.core.data import logger
from lib.core.datatype import advancedDict from lib.core.datatype import advancedDict
from lib.core.exception import sqlmapThreadException from lib.core.exception import sqlmapThreadException
from lib.core.settings import MAX_NUMBER_OF_THREADS
shared = advancedDict() shared = advancedDict()
@ -39,6 +40,9 @@ class ThreadData():
def getCurrentThreadUID(): def getCurrentThreadUID():
return hash(threading.currentThread()) return hash(threading.currentThread())
def readInput(message, default=None):
pass
def getCurrentThreadData(): def getCurrentThreadData():
""" """
Returns current thread's dependent data Returns current thread's dependent data
@ -49,12 +53,40 @@ def getCurrentThreadData():
kb.threadData[threadUID] = ThreadData() kb.threadData[threadUID] = ThreadData()
return kb.threadData[threadUID] return kb.threadData[threadUID]
def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardException=True): def exceptionHandledFunction(threadFunction):
try:
threadFunction()
except KeyboardInterrupt:
kb.threadContinue = False
kb.threadException = True
raise
except:
kb.threadContinue = False
kb.threadException = True
def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardException=True, threadChoice=False):
threads = [] threads = []
kb.multiThreadMode = True
kb.threadContinue = True kb.threadContinue = True
kb.threadException = False kb.threadException = False
if threadChoice and numThreads == 1:
while True:
message = "please enter number of threads? [Enter for %d (current)] " % numThreads
choice = readInput(message, default=str(numThreads))
if choice and choice.isdigit():
if int(choice) > MAX_NUMBER_OF_THREADS:
errMsg = "maximum number of used threads is %d avoiding possible connection issues" % MAX_NUMBER_OF_THREADS
logger.critical(errMsg)
else:
numThreads = int(choice)
break
if numThreads == 1:
warnMsg = "running in a single-thread mode. This could take a while."
logger.warn(warnMsg)
if numThreads > 1: if numThreads > 1:
infoMsg = "starting %d threads" % numThreads infoMsg = "starting %d threads" % numThreads
logger.info(infoMsg) logger.info(infoMsg)
@ -64,7 +96,7 @@ def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardExceptio
# Start the threads # Start the threads
for numThread in range(numThreads): for numThread in range(numThreads):
thread = threading.Thread(target=threadFunction, name=str(numThread)) thread = threading.Thread(target=exceptionHandledFunction, name=str(numThread), args=[threadFunction])
thread.start() thread.start()
threads.append(thread) threads.append(thread)
@ -98,6 +130,8 @@ def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardExceptio
raise raise
finally: finally:
kb.multiThreadMode = False
kb.bruteMode = False
kb.threadContinue = True kb.threadContinue = True
kb.threadException = False kb.threadException = False

View File

@ -7,8 +7,6 @@ Copyright (c) 2006-2011 sqlmap developers (http://sqlmap.sourceforge.net/)
See the file 'doc/COPYING' for copying permission See the file 'doc/COPYING' for copying permission
""" """
import lib.core.common
from lib.core.data import kb from lib.core.data import kb
from lib.request.connect import Connect as Request from lib.request.connect import Connect as Request
@ -24,4 +22,3 @@ def getPageTemplate(payload, place):
return retVal return retVal
lib.core.common.getPageTemplate = getPageTemplate

View File

@ -32,6 +32,8 @@ from lib.core.exception import sqlmapThreadException
from lib.core.settings import MAX_NUMBER_OF_THREADS from lib.core.settings import MAX_NUMBER_OF_THREADS
from lib.core.settings import METADB_SUFFIX from lib.core.settings import METADB_SUFFIX
from lib.core.session import safeFormatString from lib.core.session import safeFormatString
from lib.core.threads import getCurrentThreadData
from lib.core.threads import runThreads
from lib.request import inject from lib.request import inject
def tableExists(tableFile, regex=None): def tableExists(tableFile, regex=None):
@ -184,31 +186,36 @@ def columnExists(columnFile, regex=None):
table = conf.tbl table = conf.tbl
table = safeSQLIdentificatorNaming(table) table = safeSQLIdentificatorNaming(table)
retVal = []
infoMsg = "checking column existence using items from '%s'" % columnFile infoMsg = "checking column existence using items from '%s'" % columnFile
logger.info(infoMsg) logger.info(infoMsg)
count = [0]
length = len(columns)
threads = []
collock = threading.Lock()
iolock = threading.Lock()
kb.threadContinue = True kb.threadContinue = True
kb.bruteMode = True kb.bruteMode = True
threadData = getCurrentThreadData()
threadData.shared.count = 0
threadData.shared.limit = len(columns)
threadData.shared.outputs = []
def columnExistsThread(): def columnExistsThread():
while count[0] < length and kb.threadContinue: threadData = getCurrentThreadData()
collock.acquire()
column = safeSQLIdentificatorNaming(columns[count[0]]) while kb.threadContinue:
count[0] += 1 kb.locks.countLock.acquire()
collock.release() if threadData.shared.count < threadData.shared.limit:
column = safeSQLIdentificatorNaming(columns[threadData.shared.count])
threadData.shared.count += 1
kb.locks.countLock.release()
else:
kb.locks.countLock.release()
break
result = inject.checkBooleanExpression("%s" % safeStringFormat("EXISTS(SELECT %s FROM %s)", (column, table))) result = inject.checkBooleanExpression("%s" % safeStringFormat("EXISTS(SELECT %s FROM %s)", (column, table)))
iolock.acquire() kb.locks.ioLock.acquire()
if result: if result:
retVal.append(column) threadData.shared.outputs.append(column)
if conf.verbose in (1, 2): if conf.verbose in (1, 2):
clearConsoleLine(True) clearConsoleLine(True)
@ -216,79 +223,29 @@ def columnExists(columnFile, regex=None):
dataToStdout(infoMsg, True) dataToStdout(infoMsg, True)
if conf.verbose in (1, 2): if conf.verbose in (1, 2):
status = '%d/%d items (%d%s)' % (count[0], length, round(100.0*count[0]/length), '%') status = '%d/%d items (%d%s)' % (threadData.shared.count, threadData.shared.limit, round(100.0*threadData.shared.count/threadData.shared.limit), '%')
dataToStdout("\r[%s] [INFO] tried %s" % (time.strftime("%X"), status), True) dataToStdout("\r[%s] [INFO] tried %s" % (time.strftime("%X"), status), True)
iolock.release() kb.locks.ioLock.release()
if conf.threads > 1:
infoMsg = "starting %d threads" % conf.threads
logger.info(infoMsg)
else:
while True:
message = "please enter number of threads? [Enter for %d (current)] " % conf.threads
choice = readInput(message, default=str(conf.threads))
if choice and choice.isdigit():
if int(choice) > MAX_NUMBER_OF_THREADS:
errMsg = "maximum number of used threads is %d avoiding possible connection issues" % MAX_NUMBER_OF_THREADS
logger.critical(errMsg)
else:
conf.threads = int(choice)
break
if conf.threads == 1:
warnMsg = "running in a single-thread mode. This could take a while."
logger.warn(warnMsg)
# Start the threads
for numThread in range(conf.threads):
thread = threading.Thread(target=columnExistsThread, name=str(numThread))
thread.start()
threads.append(thread)
# And wait for them to all finish
try: try:
alive = True runThreads(conf.threads, columnExistsThread, threadChoice=True)
while alive:
alive = False
for thread in threads:
if thread.isAlive():
alive = True
thread.join(5)
except KeyboardInterrupt: except KeyboardInterrupt:
kb.threadContinue = False warnMsg = "user aborted during column existence "
kb.threadException = True warnMsg += "check. sqlmap will display partial output"
print
logger.debug("waiting for threads to finish")
warnMsg = "user aborted during common column existence check. "
warnMsg += "sqlmap will display some columns only"
logger.warn(warnMsg) logger.warn(warnMsg)
try:
while (threading.activeCount() > 1):
pass
except KeyboardInterrupt:
raise sqlmapThreadException, "user aborted"
finally:
kb.bruteMode = False
kb.threadContinue = True
kb.threadException = False
clearConsoleLine(True) clearConsoleLine(True)
dataToStdout("\n") dataToStdout("\n")
if not retVal: if not threadData.shared.outputs:
warnMsg = "no column(s) found" warnMsg = "no column(s) found"
logger.warn(warnMsg) logger.warn(warnMsg)
else: else:
columns = {} columns = {}
for column in retVal: for column in threadData.shared.outputs:
result = inject.checkBooleanExpression("%s" % safeStringFormat("EXISTS(SELECT %s FROM %s WHERE ROUND(%s)=ROUND(%s))", (column, table, column, column))) result = inject.checkBooleanExpression("%s" % safeStringFormat("EXISTS(SELECT %s FROM %s WHERE ROUND(%s)=ROUND(%s))", (column, table, column, column)))
if result: if result:

View File

@ -335,35 +335,29 @@ def errorUse(expression, expected=None, resumeValue=True, dump=False):
threadData.shared.outputs = [] threadData.shared.outputs = []
def errorThread(): def errorThread():
try: threadData = getCurrentThreadData()
threadData = getCurrentThreadData()
while kb.threadContinue: while kb.threadContinue:
kb.locks.limits.acquire() kb.locks.limits.acquire()
if threadData.shared.limits: if threadData.shared.limits:
num = threadData.shared.limits[-1] num = threadData.shared.limits[-1]
del threadData.shared.limits[-1] del threadData.shared.limits[-1]
kb.locks.limits.release() kb.locks.limits.release()
else: else:
kb.locks.limits.release() kb.locks.limits.release()
break break
output = __errorFields(expression, expressionFields, expressionFieldsList, expected, num, resumeValue) output = __errorFields(expression, expressionFields, expressionFieldsList, expected, num, resumeValue)
if not kb.threadContinue: if not kb.threadContinue:
break break
if output and isinstance(output, list) and len(output) == 1: if output and isinstance(output, list) and len(output) == 1:
output = output[0] output = output[0]
kb.locks.outputs.acquire() kb.locks.outputs.acquire()
threadData.shared.outputs.append(output) threadData.shared.outputs.append(output)
kb.locks.outputs.release() kb.locks.outputs.release()
except KeyboardInterrupt:
kb.threadContinue = False
kb.threadException = True
raise
runThreads(numThreads, errorThread) runThreads(numThreads, errorThread)

View File

@ -275,59 +275,53 @@ def unionUse(expression, unpack=True, dump=False):
threadData.shared.value = "" threadData.shared.value = ""
def unionThread(): def unionThread():
try: threadData = getCurrentThreadData()
threadData = getCurrentThreadData()
while kb.threadContinue: while kb.threadContinue:
kb.locks.limits.acquire() kb.locks.limits.acquire()
if threadData.shared.limits: if threadData.shared.limits:
num = threadData.shared.limits[-1] num = threadData.shared.limits[-1]
del threadData.shared.limits[-1] del threadData.shared.limits[-1]
kb.locks.limits.release() kb.locks.limits.release()
else: else:
kb.locks.limits.release() kb.locks.limits.release()
break break
if Backend.getIdentifiedDbms() in (DBMS.MSSQL, DBMS.SYBASE): if Backend.getIdentifiedDbms() in (DBMS.MSSQL, DBMS.SYBASE):
field = expressionFieldsList[0] field = expressionFieldsList[0]
elif Backend.isDbms(DBMS.ORACLE): elif Backend.isDbms(DBMS.ORACLE):
field = expressionFieldsList field = expressionFieldsList
else: else:
field = None field = None
limitedExpr = agent.limitQuery(num, expression, field) limitedExpr = agent.limitQuery(num, expression, field)
output = resume(limitedExpr, None) output = resume(limitedExpr, None)
if not output: if not output:
output = __oneShotUnionUse(limitedExpr, unpack) output = __oneShotUnionUse(limitedExpr, unpack)
if not kb.threadContinue: if not kb.threadContinue:
break break
if output: if output:
kb.locks.value.acquire() kb.locks.value.acquire()
threadData.shared.value += output threadData.shared.value += output
kb.locks.value.release() kb.locks.value.release()
if conf.verbose == 1: if conf.verbose == 1:
if all(map(lambda x: x in output, [kb.misc.start, kb.misc.stop])): if all(map(lambda x: x in output, [kb.misc.start, kb.misc.stop])):
items = extractRegexResult(r'%s(?P<result>.*?)%s' % (kb.misc.start, kb.misc.stop), output, re.DOTALL | re.IGNORECASE).split(kb.misc.delimiter) items = extractRegexResult(r'%s(?P<result>.*?)%s' % (kb.misc.start, kb.misc.stop), output, re.DOTALL | re.IGNORECASE).split(kb.misc.delimiter)
else: else:
items = output.replace(kb.misc.start, "").replace(kb.misc.stop, "").split(kb.misc.delimiter) items = output.replace(kb.misc.start, "").replace(kb.misc.stop, "").split(kb.misc.delimiter)
status = "[%s] [INFO] retrieved: %s\r\n" % (time.strftime("%X"), safecharencode(",".join(map(lambda x: "\"%s\"" % x, items)))) status = "[%s] [INFO] retrieved: %s\r\n" % (time.strftime("%X"), safecharencode(",".join(map(lambda x: "\"%s\"" % x, items))))
if len(status) > width: if len(status) > width:
status = "%s..." % status[:width - 3] status = "%s..." % status[:width - 3]
kb.locks.ioLock.acquire() kb.locks.ioLock.acquire()
dataToStdout(status, True) dataToStdout(status, True)
kb.locks.ioLock.release() kb.locks.ioLock.release()
except KeyboardInterrupt:
kb.threadContinue = False
kb.threadException = True
raise
runThreads(numThreads, unionThread) runThreads(numThreads, unionThread)