refactoring and stabilization of multithreading

This commit is contained in:
Miroslav Stampar 2011-06-07 09:50:00 +00:00
parent 5f7858455d
commit 7a3cc38e3c
6 changed files with 129 additions and 142 deletions

View File

@ -20,6 +20,9 @@ import threading
import urllib2
import urlparse
import lib.core.common
import lib.core.threads
from extra.clientform.clientform import ParseResponse
from extra.clientform.clientform import ParseError
from extra.keepalive import keepalive
@ -109,6 +112,7 @@ from lib.request.basicauthhandler import SmartHTTPBasicAuthHandler
from lib.request.certhandler import HTTPSCertAuthHandler
from lib.request.rangehandler import HTTPRangeHandler
from lib.request.redirecthandler import SmartRedirectHandler
from lib.request.templates import getPageTemplate
from lib.utils.google import Google
authHandler = urllib2.BaseHandler()
@ -1360,8 +1364,10 @@ def __setKnowledgeBaseAttributes(flushAll=True):
kb.locks.cacheLock = threading.Lock()
kb.locks.logLock = threading.Lock()
kb.locks.ioLock = threading.Lock()
kb.locks.countLock = threading.Lock()
kb.matchRatio = None
kb.multiThreadMode = False
kb.nullConnection = None
kb.pageTemplate = None
kb.pageTemplates = dict()
@ -1701,6 +1707,10 @@ def __basicOptionValidation():
errMsg += "to get the full list of supported charsets"
raise sqlmapSyntaxException, errMsg
def __resolveCrossReferences():
lib.core.threads.readInput = readInput
lib.core.common.getPageTemplate = getPageTemplate
def init(inputOptions=advancedDict(), overrideOptions=False):
"""
Set attributes into both configuration and knowledge base singletons
@ -1720,6 +1730,7 @@ def init(inputOptions=advancedDict(), overrideOptions=False):
__setMultipleTargets()
__setTamperingFunctions()
__setTrafficOutputFP()
__resolveCrossReferences()
parseTargetUrl()
parseTargetDirect()

View File

@ -14,6 +14,7 @@ from lib.core.data import kb
from lib.core.data import logger
from lib.core.datatype import advancedDict
from lib.core.exception import sqlmapThreadException
from lib.core.settings import MAX_NUMBER_OF_THREADS
shared = advancedDict()
@ -39,6 +40,9 @@ class ThreadData():
def getCurrentThreadUID():
return hash(threading.currentThread())
def readInput(message, default=None):
pass
def getCurrentThreadData():
"""
Returns current thread's dependent data
@ -49,12 +53,40 @@ def getCurrentThreadData():
kb.threadData[threadUID] = ThreadData()
return kb.threadData[threadUID]
def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardException=True):
def exceptionHandledFunction(threadFunction):
try:
threadFunction()
except KeyboardInterrupt:
kb.threadContinue = False
kb.threadException = True
raise
except:
kb.threadContinue = False
kb.threadException = True
def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardException=True, threadChoice=False):
threads = []
kb.multiThreadMode = True
kb.threadContinue = True
kb.threadException = False
if threadChoice and numThreads == 1:
while True:
message = "please enter number of threads? [Enter for %d (current)] " % numThreads
choice = readInput(message, default=str(numThreads))
if choice and choice.isdigit():
if int(choice) > MAX_NUMBER_OF_THREADS:
errMsg = "maximum number of used threads is %d avoiding possible connection issues" % MAX_NUMBER_OF_THREADS
logger.critical(errMsg)
else:
numThreads = int(choice)
break
if numThreads == 1:
warnMsg = "running in a single-thread mode. This could take a while."
logger.warn(warnMsg)
if numThreads > 1:
infoMsg = "starting %d threads" % numThreads
logger.info(infoMsg)
@ -64,7 +96,7 @@ def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardExceptio
# Start the threads
for numThread in range(numThreads):
thread = threading.Thread(target=threadFunction, name=str(numThread))
thread = threading.Thread(target=exceptionHandledFunction, name=str(numThread), args=[threadFunction])
thread.start()
threads.append(thread)
@ -98,6 +130,8 @@ def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardExceptio
raise
finally:
kb.multiThreadMode = False
kb.bruteMode = False
kb.threadContinue = True
kb.threadException = False

View File

@ -7,8 +7,6 @@ Copyright (c) 2006-2011 sqlmap developers (http://sqlmap.sourceforge.net/)
See the file 'doc/COPYING' for copying permission
"""
import lib.core.common
from lib.core.data import kb
from lib.request.connect import Connect as Request
@ -24,4 +22,3 @@ def getPageTemplate(payload, place):
return retVal
lib.core.common.getPageTemplate = getPageTemplate

View File

@ -32,6 +32,8 @@ from lib.core.exception import sqlmapThreadException
from lib.core.settings import MAX_NUMBER_OF_THREADS
from lib.core.settings import METADB_SUFFIX
from lib.core.session import safeFormatString
from lib.core.threads import getCurrentThreadData
from lib.core.threads import runThreads
from lib.request import inject
def tableExists(tableFile, regex=None):
@ -184,31 +186,36 @@ def columnExists(columnFile, regex=None):
table = conf.tbl
table = safeSQLIdentificatorNaming(table)
retVal = []
infoMsg = "checking column existence using items from '%s'" % columnFile
logger.info(infoMsg)
count = [0]
length = len(columns)
threads = []
collock = threading.Lock()
iolock = threading.Lock()
kb.threadContinue = True
kb.bruteMode = True
threadData = getCurrentThreadData()
threadData.shared.count = 0
threadData.shared.limit = len(columns)
threadData.shared.outputs = []
def columnExistsThread():
while count[0] < length and kb.threadContinue:
collock.acquire()
column = safeSQLIdentificatorNaming(columns[count[0]])
count[0] += 1
collock.release()
threadData = getCurrentThreadData()
while kb.threadContinue:
kb.locks.countLock.acquire()
if threadData.shared.count < threadData.shared.limit:
column = safeSQLIdentificatorNaming(columns[threadData.shared.count])
threadData.shared.count += 1
kb.locks.countLock.release()
else:
kb.locks.countLock.release()
break
result = inject.checkBooleanExpression("%s" % safeStringFormat("EXISTS(SELECT %s FROM %s)", (column, table)))
iolock.acquire()
kb.locks.ioLock.acquire()
if result:
retVal.append(column)
threadData.shared.outputs.append(column)
if conf.verbose in (1, 2):
clearConsoleLine(True)
@ -216,79 +223,29 @@ def columnExists(columnFile, regex=None):
dataToStdout(infoMsg, True)
if conf.verbose in (1, 2):
status = '%d/%d items (%d%s)' % (count[0], length, round(100.0*count[0]/length), '%')
status = '%d/%d items (%d%s)' % (threadData.shared.count, threadData.shared.limit, round(100.0*threadData.shared.count/threadData.shared.limit), '%')
dataToStdout("\r[%s] [INFO] tried %s" % (time.strftime("%X"), status), True)
iolock.release()
kb.locks.ioLock.release()
if conf.threads > 1:
infoMsg = "starting %d threads" % conf.threads
logger.info(infoMsg)
else:
while True:
message = "please enter number of threads? [Enter for %d (current)] " % conf.threads
choice = readInput(message, default=str(conf.threads))
if choice and choice.isdigit():
if int(choice) > MAX_NUMBER_OF_THREADS:
errMsg = "maximum number of used threads is %d avoiding possible connection issues" % MAX_NUMBER_OF_THREADS
logger.critical(errMsg)
else:
conf.threads = int(choice)
break
if conf.threads == 1:
warnMsg = "running in a single-thread mode. This could take a while."
logger.warn(warnMsg)
# Start the threads
for numThread in range(conf.threads):
thread = threading.Thread(target=columnExistsThread, name=str(numThread))
thread.start()
threads.append(thread)
# And wait for them to all finish
try:
alive = True
runThreads(conf.threads, columnExistsThread, threadChoice=True)
while alive:
alive = False
for thread in threads:
if thread.isAlive():
alive = True
thread.join(5)
except KeyboardInterrupt:
kb.threadContinue = False
kb.threadException = True
print
logger.debug("waiting for threads to finish")
warnMsg = "user aborted during common column existence check. "
warnMsg += "sqlmap will display some columns only"
warnMsg = "user aborted during column existence "
warnMsg += "check. sqlmap will display partial output"
logger.warn(warnMsg)
try:
while (threading.activeCount() > 1):
pass
except KeyboardInterrupt:
raise sqlmapThreadException, "user aborted"
finally:
kb.bruteMode = False
kb.threadContinue = True
kb.threadException = False
clearConsoleLine(True)
dataToStdout("\n")
if not retVal:
if not threadData.shared.outputs:
warnMsg = "no column(s) found"
logger.warn(warnMsg)
else:
columns = {}
for column in retVal:
for column in threadData.shared.outputs:
result = inject.checkBooleanExpression("%s" % safeStringFormat("EXISTS(SELECT %s FROM %s WHERE ROUND(%s)=ROUND(%s))", (column, table, column, column)))
if result:

View File

@ -335,35 +335,29 @@ def errorUse(expression, expected=None, resumeValue=True, dump=False):
threadData.shared.outputs = []
def errorThread():
try:
threadData = getCurrentThreadData()
threadData = getCurrentThreadData()
while kb.threadContinue:
kb.locks.limits.acquire()
if threadData.shared.limits:
num = threadData.shared.limits[-1]
del threadData.shared.limits[-1]
kb.locks.limits.release()
else:
kb.locks.limits.release()
break
while kb.threadContinue:
kb.locks.limits.acquire()
if threadData.shared.limits:
num = threadData.shared.limits[-1]
del threadData.shared.limits[-1]
kb.locks.limits.release()
else:
kb.locks.limits.release()
break
output = __errorFields(expression, expressionFields, expressionFieldsList, expected, num, resumeValue)
output = __errorFields(expression, expressionFields, expressionFieldsList, expected, num, resumeValue)
if not kb.threadContinue:
break
if not kb.threadContinue:
break
if output and isinstance(output, list) and len(output) == 1:
output = output[0]
if output and isinstance(output, list) and len(output) == 1:
output = output[0]
kb.locks.outputs.acquire()
threadData.shared.outputs.append(output)
kb.locks.outputs.release()
except KeyboardInterrupt:
kb.threadContinue = False
kb.threadException = True
raise
kb.locks.outputs.acquire()
threadData.shared.outputs.append(output)
kb.locks.outputs.release()
runThreads(numThreads, errorThread)

View File

@ -275,59 +275,53 @@ def unionUse(expression, unpack=True, dump=False):
threadData.shared.value = ""
def unionThread():
try:
threadData = getCurrentThreadData()
threadData = getCurrentThreadData()
while kb.threadContinue:
kb.locks.limits.acquire()
if threadData.shared.limits:
num = threadData.shared.limits[-1]
del threadData.shared.limits[-1]
kb.locks.limits.release()
else:
kb.locks.limits.release()
break
while kb.threadContinue:
kb.locks.limits.acquire()
if threadData.shared.limits:
num = threadData.shared.limits[-1]
del threadData.shared.limits[-1]
kb.locks.limits.release()
else:
kb.locks.limits.release()
break
if Backend.getIdentifiedDbms() in (DBMS.MSSQL, DBMS.SYBASE):
field = expressionFieldsList[0]
elif Backend.isDbms(DBMS.ORACLE):
field = expressionFieldsList
else:
field = None
if Backend.getIdentifiedDbms() in (DBMS.MSSQL, DBMS.SYBASE):
field = expressionFieldsList[0]
elif Backend.isDbms(DBMS.ORACLE):
field = expressionFieldsList
else:
field = None
limitedExpr = agent.limitQuery(num, expression, field)
output = resume(limitedExpr, None)
limitedExpr = agent.limitQuery(num, expression, field)
output = resume(limitedExpr, None)
if not output:
output = __oneShotUnionUse(limitedExpr, unpack)
if not output:
output = __oneShotUnionUse(limitedExpr, unpack)
if not kb.threadContinue:
break
if not kb.threadContinue:
break
if output:
kb.locks.value.acquire()
threadData.shared.value += output
kb.locks.value.release()
if output:
kb.locks.value.acquire()
threadData.shared.value += output
kb.locks.value.release()
if conf.verbose == 1:
if all(map(lambda x: x in output, [kb.misc.start, kb.misc.stop])):
items = extractRegexResult(r'%s(?P<result>.*?)%s' % (kb.misc.start, kb.misc.stop), output, re.DOTALL | re.IGNORECASE).split(kb.misc.delimiter)
else:
items = output.replace(kb.misc.start, "").replace(kb.misc.stop, "").split(kb.misc.delimiter)
if conf.verbose == 1:
if all(map(lambda x: x in output, [kb.misc.start, kb.misc.stop])):
items = extractRegexResult(r'%s(?P<result>.*?)%s' % (kb.misc.start, kb.misc.stop), output, re.DOTALL | re.IGNORECASE).split(kb.misc.delimiter)
else:
items = output.replace(kb.misc.start, "").replace(kb.misc.stop, "").split(kb.misc.delimiter)
status = "[%s] [INFO] retrieved: %s\r\n" % (time.strftime("%X"), safecharencode(",".join(map(lambda x: "\"%s\"" % x, items))))
status = "[%s] [INFO] retrieved: %s\r\n" % (time.strftime("%X"), safecharencode(",".join(map(lambda x: "\"%s\"" % x, items))))
if len(status) > width:
status = "%s..." % status[:width - 3]
if len(status) > width:
status = "%s..." % status[:width - 3]
kb.locks.ioLock.acquire()
dataToStdout(status, True)
kb.locks.ioLock.release()
except KeyboardInterrupt:
kb.threadContinue = False
kb.threadException = True
raise
kb.locks.ioLock.acquire()
dataToStdout(status, True)
kb.locks.ioLock.release()
runThreads(numThreads, unionThread)