mirror of
https://github.com/sqlmapproject/sqlmap.git
synced 2025-02-03 05:04:11 +03:00
refactoring and stabilization of multithreading
This commit is contained in:
parent
5f7858455d
commit
7a3cc38e3c
|
@ -20,6 +20,9 @@ import threading
|
||||||
import urllib2
|
import urllib2
|
||||||
import urlparse
|
import urlparse
|
||||||
|
|
||||||
|
import lib.core.common
|
||||||
|
import lib.core.threads
|
||||||
|
|
||||||
from extra.clientform.clientform import ParseResponse
|
from extra.clientform.clientform import ParseResponse
|
||||||
from extra.clientform.clientform import ParseError
|
from extra.clientform.clientform import ParseError
|
||||||
from extra.keepalive import keepalive
|
from extra.keepalive import keepalive
|
||||||
|
@ -109,6 +112,7 @@ from lib.request.basicauthhandler import SmartHTTPBasicAuthHandler
|
||||||
from lib.request.certhandler import HTTPSCertAuthHandler
|
from lib.request.certhandler import HTTPSCertAuthHandler
|
||||||
from lib.request.rangehandler import HTTPRangeHandler
|
from lib.request.rangehandler import HTTPRangeHandler
|
||||||
from lib.request.redirecthandler import SmartRedirectHandler
|
from lib.request.redirecthandler import SmartRedirectHandler
|
||||||
|
from lib.request.templates import getPageTemplate
|
||||||
from lib.utils.google import Google
|
from lib.utils.google import Google
|
||||||
|
|
||||||
authHandler = urllib2.BaseHandler()
|
authHandler = urllib2.BaseHandler()
|
||||||
|
@ -1360,8 +1364,10 @@ def __setKnowledgeBaseAttributes(flushAll=True):
|
||||||
kb.locks.cacheLock = threading.Lock()
|
kb.locks.cacheLock = threading.Lock()
|
||||||
kb.locks.logLock = threading.Lock()
|
kb.locks.logLock = threading.Lock()
|
||||||
kb.locks.ioLock = threading.Lock()
|
kb.locks.ioLock = threading.Lock()
|
||||||
|
kb.locks.countLock = threading.Lock()
|
||||||
|
|
||||||
kb.matchRatio = None
|
kb.matchRatio = None
|
||||||
|
kb.multiThreadMode = False
|
||||||
kb.nullConnection = None
|
kb.nullConnection = None
|
||||||
kb.pageTemplate = None
|
kb.pageTemplate = None
|
||||||
kb.pageTemplates = dict()
|
kb.pageTemplates = dict()
|
||||||
|
@ -1701,6 +1707,10 @@ def __basicOptionValidation():
|
||||||
errMsg += "to get the full list of supported charsets"
|
errMsg += "to get the full list of supported charsets"
|
||||||
raise sqlmapSyntaxException, errMsg
|
raise sqlmapSyntaxException, errMsg
|
||||||
|
|
||||||
|
def __resolveCrossReferences():
|
||||||
|
lib.core.threads.readInput = readInput
|
||||||
|
lib.core.common.getPageTemplate = getPageTemplate
|
||||||
|
|
||||||
def init(inputOptions=advancedDict(), overrideOptions=False):
|
def init(inputOptions=advancedDict(), overrideOptions=False):
|
||||||
"""
|
"""
|
||||||
Set attributes into both configuration and knowledge base singletons
|
Set attributes into both configuration and knowledge base singletons
|
||||||
|
@ -1720,6 +1730,7 @@ def init(inputOptions=advancedDict(), overrideOptions=False):
|
||||||
__setMultipleTargets()
|
__setMultipleTargets()
|
||||||
__setTamperingFunctions()
|
__setTamperingFunctions()
|
||||||
__setTrafficOutputFP()
|
__setTrafficOutputFP()
|
||||||
|
__resolveCrossReferences()
|
||||||
|
|
||||||
parseTargetUrl()
|
parseTargetUrl()
|
||||||
parseTargetDirect()
|
parseTargetDirect()
|
||||||
|
|
|
@ -14,6 +14,7 @@ from lib.core.data import kb
|
||||||
from lib.core.data import logger
|
from lib.core.data import logger
|
||||||
from lib.core.datatype import advancedDict
|
from lib.core.datatype import advancedDict
|
||||||
from lib.core.exception import sqlmapThreadException
|
from lib.core.exception import sqlmapThreadException
|
||||||
|
from lib.core.settings import MAX_NUMBER_OF_THREADS
|
||||||
|
|
||||||
shared = advancedDict()
|
shared = advancedDict()
|
||||||
|
|
||||||
|
@ -39,6 +40,9 @@ class ThreadData():
|
||||||
def getCurrentThreadUID():
|
def getCurrentThreadUID():
|
||||||
return hash(threading.currentThread())
|
return hash(threading.currentThread())
|
||||||
|
|
||||||
|
def readInput(message, default=None):
|
||||||
|
pass
|
||||||
|
|
||||||
def getCurrentThreadData():
|
def getCurrentThreadData():
|
||||||
"""
|
"""
|
||||||
Returns current thread's dependent data
|
Returns current thread's dependent data
|
||||||
|
@ -49,12 +53,40 @@ def getCurrentThreadData():
|
||||||
kb.threadData[threadUID] = ThreadData()
|
kb.threadData[threadUID] = ThreadData()
|
||||||
return kb.threadData[threadUID]
|
return kb.threadData[threadUID]
|
||||||
|
|
||||||
def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardException=True):
|
def exceptionHandledFunction(threadFunction):
|
||||||
|
try:
|
||||||
|
threadFunction()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
kb.threadContinue = False
|
||||||
|
kb.threadException = True
|
||||||
|
raise
|
||||||
|
except:
|
||||||
|
kb.threadContinue = False
|
||||||
|
kb.threadException = True
|
||||||
|
|
||||||
|
def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardException=True, threadChoice=False):
|
||||||
threads = []
|
threads = []
|
||||||
|
|
||||||
|
kb.multiThreadMode = True
|
||||||
kb.threadContinue = True
|
kb.threadContinue = True
|
||||||
kb.threadException = False
|
kb.threadException = False
|
||||||
|
|
||||||
|
if threadChoice and numThreads == 1:
|
||||||
|
while True:
|
||||||
|
message = "please enter number of threads? [Enter for %d (current)] " % numThreads
|
||||||
|
choice = readInput(message, default=str(numThreads))
|
||||||
|
if choice and choice.isdigit():
|
||||||
|
if int(choice) > MAX_NUMBER_OF_THREADS:
|
||||||
|
errMsg = "maximum number of used threads is %d avoiding possible connection issues" % MAX_NUMBER_OF_THREADS
|
||||||
|
logger.critical(errMsg)
|
||||||
|
else:
|
||||||
|
numThreads = int(choice)
|
||||||
|
break
|
||||||
|
|
||||||
|
if numThreads == 1:
|
||||||
|
warnMsg = "running in a single-thread mode. This could take a while."
|
||||||
|
logger.warn(warnMsg)
|
||||||
|
|
||||||
if numThreads > 1:
|
if numThreads > 1:
|
||||||
infoMsg = "starting %d threads" % numThreads
|
infoMsg = "starting %d threads" % numThreads
|
||||||
logger.info(infoMsg)
|
logger.info(infoMsg)
|
||||||
|
@ -64,7 +96,7 @@ def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardExceptio
|
||||||
|
|
||||||
# Start the threads
|
# Start the threads
|
||||||
for numThread in range(numThreads):
|
for numThread in range(numThreads):
|
||||||
thread = threading.Thread(target=threadFunction, name=str(numThread))
|
thread = threading.Thread(target=exceptionHandledFunction, name=str(numThread), args=[threadFunction])
|
||||||
thread.start()
|
thread.start()
|
||||||
threads.append(thread)
|
threads.append(thread)
|
||||||
|
|
||||||
|
@ -98,6 +130,8 @@ def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardExceptio
|
||||||
raise
|
raise
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
kb.multiThreadMode = False
|
||||||
|
kb.bruteMode = False
|
||||||
kb.threadContinue = True
|
kb.threadContinue = True
|
||||||
kb.threadException = False
|
kb.threadException = False
|
||||||
|
|
||||||
|
|
|
@ -7,8 +7,6 @@ Copyright (c) 2006-2011 sqlmap developers (http://sqlmap.sourceforge.net/)
|
||||||
See the file 'doc/COPYING' for copying permission
|
See the file 'doc/COPYING' for copying permission
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import lib.core.common
|
|
||||||
|
|
||||||
from lib.core.data import kb
|
from lib.core.data import kb
|
||||||
from lib.request.connect import Connect as Request
|
from lib.request.connect import Connect as Request
|
||||||
|
|
||||||
|
@ -24,4 +22,3 @@ def getPageTemplate(payload, place):
|
||||||
|
|
||||||
return retVal
|
return retVal
|
||||||
|
|
||||||
lib.core.common.getPageTemplate = getPageTemplate
|
|
||||||
|
|
|
@ -32,6 +32,8 @@ from lib.core.exception import sqlmapThreadException
|
||||||
from lib.core.settings import MAX_NUMBER_OF_THREADS
|
from lib.core.settings import MAX_NUMBER_OF_THREADS
|
||||||
from lib.core.settings import METADB_SUFFIX
|
from lib.core.settings import METADB_SUFFIX
|
||||||
from lib.core.session import safeFormatString
|
from lib.core.session import safeFormatString
|
||||||
|
from lib.core.threads import getCurrentThreadData
|
||||||
|
from lib.core.threads import runThreads
|
||||||
from lib.request import inject
|
from lib.request import inject
|
||||||
|
|
||||||
def tableExists(tableFile, regex=None):
|
def tableExists(tableFile, regex=None):
|
||||||
|
@ -184,31 +186,36 @@ def columnExists(columnFile, regex=None):
|
||||||
table = conf.tbl
|
table = conf.tbl
|
||||||
table = safeSQLIdentificatorNaming(table)
|
table = safeSQLIdentificatorNaming(table)
|
||||||
|
|
||||||
retVal = []
|
|
||||||
infoMsg = "checking column existence using items from '%s'" % columnFile
|
infoMsg = "checking column existence using items from '%s'" % columnFile
|
||||||
logger.info(infoMsg)
|
logger.info(infoMsg)
|
||||||
|
|
||||||
count = [0]
|
|
||||||
length = len(columns)
|
|
||||||
threads = []
|
|
||||||
collock = threading.Lock()
|
|
||||||
iolock = threading.Lock()
|
|
||||||
kb.threadContinue = True
|
kb.threadContinue = True
|
||||||
kb.bruteMode = True
|
kb.bruteMode = True
|
||||||
|
|
||||||
|
threadData = getCurrentThreadData()
|
||||||
|
threadData.shared.count = 0
|
||||||
|
threadData.shared.limit = len(columns)
|
||||||
|
threadData.shared.outputs = []
|
||||||
|
|
||||||
def columnExistsThread():
|
def columnExistsThread():
|
||||||
while count[0] < length and kb.threadContinue:
|
threadData = getCurrentThreadData()
|
||||||
collock.acquire()
|
|
||||||
column = safeSQLIdentificatorNaming(columns[count[0]])
|
while kb.threadContinue:
|
||||||
count[0] += 1
|
kb.locks.countLock.acquire()
|
||||||
collock.release()
|
if threadData.shared.count < threadData.shared.limit:
|
||||||
|
column = safeSQLIdentificatorNaming(columns[threadData.shared.count])
|
||||||
|
threadData.shared.count += 1
|
||||||
|
kb.locks.countLock.release()
|
||||||
|
else:
|
||||||
|
kb.locks.countLock.release()
|
||||||
|
break
|
||||||
|
|
||||||
result = inject.checkBooleanExpression("%s" % safeStringFormat("EXISTS(SELECT %s FROM %s)", (column, table)))
|
result = inject.checkBooleanExpression("%s" % safeStringFormat("EXISTS(SELECT %s FROM %s)", (column, table)))
|
||||||
|
|
||||||
iolock.acquire()
|
kb.locks.ioLock.acquire()
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
retVal.append(column)
|
threadData.shared.outputs.append(column)
|
||||||
|
|
||||||
if conf.verbose in (1, 2):
|
if conf.verbose in (1, 2):
|
||||||
clearConsoleLine(True)
|
clearConsoleLine(True)
|
||||||
|
@ -216,79 +223,29 @@ def columnExists(columnFile, regex=None):
|
||||||
dataToStdout(infoMsg, True)
|
dataToStdout(infoMsg, True)
|
||||||
|
|
||||||
if conf.verbose in (1, 2):
|
if conf.verbose in (1, 2):
|
||||||
status = '%d/%d items (%d%s)' % (count[0], length, round(100.0*count[0]/length), '%')
|
status = '%d/%d items (%d%s)' % (threadData.shared.count, threadData.shared.limit, round(100.0*threadData.shared.count/threadData.shared.limit), '%')
|
||||||
dataToStdout("\r[%s] [INFO] tried %s" % (time.strftime("%X"), status), True)
|
dataToStdout("\r[%s] [INFO] tried %s" % (time.strftime("%X"), status), True)
|
||||||
|
|
||||||
iolock.release()
|
kb.locks.ioLock.release()
|
||||||
|
|
||||||
if conf.threads > 1:
|
|
||||||
infoMsg = "starting %d threads" % conf.threads
|
|
||||||
logger.info(infoMsg)
|
|
||||||
else:
|
|
||||||
while True:
|
|
||||||
message = "please enter number of threads? [Enter for %d (current)] " % conf.threads
|
|
||||||
choice = readInput(message, default=str(conf.threads))
|
|
||||||
if choice and choice.isdigit():
|
|
||||||
if int(choice) > MAX_NUMBER_OF_THREADS:
|
|
||||||
errMsg = "maximum number of used threads is %d avoiding possible connection issues" % MAX_NUMBER_OF_THREADS
|
|
||||||
logger.critical(errMsg)
|
|
||||||
else:
|
|
||||||
conf.threads = int(choice)
|
|
||||||
break
|
|
||||||
|
|
||||||
if conf.threads == 1:
|
|
||||||
warnMsg = "running in a single-thread mode. This could take a while."
|
|
||||||
logger.warn(warnMsg)
|
|
||||||
|
|
||||||
# Start the threads
|
|
||||||
for numThread in range(conf.threads):
|
|
||||||
thread = threading.Thread(target=columnExistsThread, name=str(numThread))
|
|
||||||
thread.start()
|
|
||||||
threads.append(thread)
|
|
||||||
|
|
||||||
# And wait for them to all finish
|
|
||||||
try:
|
try:
|
||||||
alive = True
|
runThreads(conf.threads, columnExistsThread, threadChoice=True)
|
||||||
|
|
||||||
while alive:
|
|
||||||
alive = False
|
|
||||||
|
|
||||||
for thread in threads:
|
|
||||||
if thread.isAlive():
|
|
||||||
alive = True
|
|
||||||
thread.join(5)
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
kb.threadContinue = False
|
warnMsg = "user aborted during column existence "
|
||||||
kb.threadException = True
|
warnMsg += "check. sqlmap will display partial output"
|
||||||
|
|
||||||
print
|
|
||||||
logger.debug("waiting for threads to finish")
|
|
||||||
|
|
||||||
warnMsg = "user aborted during common column existence check. "
|
|
||||||
warnMsg += "sqlmap will display some columns only"
|
|
||||||
logger.warn(warnMsg)
|
logger.warn(warnMsg)
|
||||||
|
|
||||||
try:
|
|
||||||
while (threading.activeCount() > 1):
|
|
||||||
pass
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
raise sqlmapThreadException, "user aborted"
|
|
||||||
finally:
|
|
||||||
kb.bruteMode = False
|
|
||||||
kb.threadContinue = True
|
|
||||||
kb.threadException = False
|
|
||||||
|
|
||||||
clearConsoleLine(True)
|
clearConsoleLine(True)
|
||||||
dataToStdout("\n")
|
dataToStdout("\n")
|
||||||
|
|
||||||
if not retVal:
|
if not threadData.shared.outputs:
|
||||||
warnMsg = "no column(s) found"
|
warnMsg = "no column(s) found"
|
||||||
logger.warn(warnMsg)
|
logger.warn(warnMsg)
|
||||||
else:
|
else:
|
||||||
columns = {}
|
columns = {}
|
||||||
|
|
||||||
for column in retVal:
|
for column in threadData.shared.outputs:
|
||||||
result = inject.checkBooleanExpression("%s" % safeStringFormat("EXISTS(SELECT %s FROM %s WHERE ROUND(%s)=ROUND(%s))", (column, table, column, column)))
|
result = inject.checkBooleanExpression("%s" % safeStringFormat("EXISTS(SELECT %s FROM %s WHERE ROUND(%s)=ROUND(%s))", (column, table, column, column)))
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
|
|
|
@ -335,35 +335,29 @@ def errorUse(expression, expected=None, resumeValue=True, dump=False):
|
||||||
threadData.shared.outputs = []
|
threadData.shared.outputs = []
|
||||||
|
|
||||||
def errorThread():
|
def errorThread():
|
||||||
try:
|
threadData = getCurrentThreadData()
|
||||||
threadData = getCurrentThreadData()
|
|
||||||
|
|
||||||
while kb.threadContinue:
|
while kb.threadContinue:
|
||||||
kb.locks.limits.acquire()
|
kb.locks.limits.acquire()
|
||||||
if threadData.shared.limits:
|
if threadData.shared.limits:
|
||||||
num = threadData.shared.limits[-1]
|
num = threadData.shared.limits[-1]
|
||||||
del threadData.shared.limits[-1]
|
del threadData.shared.limits[-1]
|
||||||
kb.locks.limits.release()
|
kb.locks.limits.release()
|
||||||
else:
|
else:
|
||||||
kb.locks.limits.release()
|
kb.locks.limits.release()
|
||||||
break
|
break
|
||||||
|
|
||||||
output = __errorFields(expression, expressionFields, expressionFieldsList, expected, num, resumeValue)
|
output = __errorFields(expression, expressionFields, expressionFieldsList, expected, num, resumeValue)
|
||||||
|
|
||||||
if not kb.threadContinue:
|
if not kb.threadContinue:
|
||||||
break
|
break
|
||||||
|
|
||||||
if output and isinstance(output, list) and len(output) == 1:
|
if output and isinstance(output, list) and len(output) == 1:
|
||||||
output = output[0]
|
output = output[0]
|
||||||
|
|
||||||
kb.locks.outputs.acquire()
|
kb.locks.outputs.acquire()
|
||||||
threadData.shared.outputs.append(output)
|
threadData.shared.outputs.append(output)
|
||||||
kb.locks.outputs.release()
|
kb.locks.outputs.release()
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
kb.threadContinue = False
|
|
||||||
kb.threadException = True
|
|
||||||
raise
|
|
||||||
|
|
||||||
runThreads(numThreads, errorThread)
|
runThreads(numThreads, errorThread)
|
||||||
|
|
||||||
|
|
|
@ -275,59 +275,53 @@ def unionUse(expression, unpack=True, dump=False):
|
||||||
threadData.shared.value = ""
|
threadData.shared.value = ""
|
||||||
|
|
||||||
def unionThread():
|
def unionThread():
|
||||||
try:
|
threadData = getCurrentThreadData()
|
||||||
threadData = getCurrentThreadData()
|
|
||||||
|
|
||||||
while kb.threadContinue:
|
while kb.threadContinue:
|
||||||
kb.locks.limits.acquire()
|
kb.locks.limits.acquire()
|
||||||
if threadData.shared.limits:
|
if threadData.shared.limits:
|
||||||
num = threadData.shared.limits[-1]
|
num = threadData.shared.limits[-1]
|
||||||
del threadData.shared.limits[-1]
|
del threadData.shared.limits[-1]
|
||||||
kb.locks.limits.release()
|
kb.locks.limits.release()
|
||||||
else:
|
else:
|
||||||
kb.locks.limits.release()
|
kb.locks.limits.release()
|
||||||
break
|
break
|
||||||
|
|
||||||
if Backend.getIdentifiedDbms() in (DBMS.MSSQL, DBMS.SYBASE):
|
if Backend.getIdentifiedDbms() in (DBMS.MSSQL, DBMS.SYBASE):
|
||||||
field = expressionFieldsList[0]
|
field = expressionFieldsList[0]
|
||||||
elif Backend.isDbms(DBMS.ORACLE):
|
elif Backend.isDbms(DBMS.ORACLE):
|
||||||
field = expressionFieldsList
|
field = expressionFieldsList
|
||||||
else:
|
else:
|
||||||
field = None
|
field = None
|
||||||
|
|
||||||
limitedExpr = agent.limitQuery(num, expression, field)
|
limitedExpr = agent.limitQuery(num, expression, field)
|
||||||
output = resume(limitedExpr, None)
|
output = resume(limitedExpr, None)
|
||||||
|
|
||||||
if not output:
|
if not output:
|
||||||
output = __oneShotUnionUse(limitedExpr, unpack)
|
output = __oneShotUnionUse(limitedExpr, unpack)
|
||||||
|
|
||||||
if not kb.threadContinue:
|
if not kb.threadContinue:
|
||||||
break
|
break
|
||||||
|
|
||||||
if output:
|
if output:
|
||||||
kb.locks.value.acquire()
|
kb.locks.value.acquire()
|
||||||
threadData.shared.value += output
|
threadData.shared.value += output
|
||||||
kb.locks.value.release()
|
kb.locks.value.release()
|
||||||
|
|
||||||
if conf.verbose == 1:
|
if conf.verbose == 1:
|
||||||
if all(map(lambda x: x in output, [kb.misc.start, kb.misc.stop])):
|
if all(map(lambda x: x in output, [kb.misc.start, kb.misc.stop])):
|
||||||
items = extractRegexResult(r'%s(?P<result>.*?)%s' % (kb.misc.start, kb.misc.stop), output, re.DOTALL | re.IGNORECASE).split(kb.misc.delimiter)
|
items = extractRegexResult(r'%s(?P<result>.*?)%s' % (kb.misc.start, kb.misc.stop), output, re.DOTALL | re.IGNORECASE).split(kb.misc.delimiter)
|
||||||
else:
|
else:
|
||||||
items = output.replace(kb.misc.start, "").replace(kb.misc.stop, "").split(kb.misc.delimiter)
|
items = output.replace(kb.misc.start, "").replace(kb.misc.stop, "").split(kb.misc.delimiter)
|
||||||
|
|
||||||
status = "[%s] [INFO] retrieved: %s\r\n" % (time.strftime("%X"), safecharencode(",".join(map(lambda x: "\"%s\"" % x, items))))
|
status = "[%s] [INFO] retrieved: %s\r\n" % (time.strftime("%X"), safecharencode(",".join(map(lambda x: "\"%s\"" % x, items))))
|
||||||
|
|
||||||
if len(status) > width:
|
if len(status) > width:
|
||||||
status = "%s..." % status[:width - 3]
|
status = "%s..." % status[:width - 3]
|
||||||
|
|
||||||
kb.locks.ioLock.acquire()
|
kb.locks.ioLock.acquire()
|
||||||
dataToStdout(status, True)
|
dataToStdout(status, True)
|
||||||
kb.locks.ioLock.release()
|
kb.locks.ioLock.release()
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
kb.threadContinue = False
|
|
||||||
kb.threadException = True
|
|
||||||
raise
|
|
||||||
|
|
||||||
runThreads(numThreads, unionThread)
|
runThreads(numThreads, unionThread)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user