diff --git a/lib/core/option.py b/lib/core/option.py index 67baeff09..95d0ca955 100644 --- a/lib/core/option.py +++ b/lib/core/option.py @@ -20,6 +20,9 @@ import threading import urllib2 import urlparse +import lib.core.common +import lib.core.threads + from extra.clientform.clientform import ParseResponse from extra.clientform.clientform import ParseError from extra.keepalive import keepalive @@ -109,6 +112,7 @@ from lib.request.basicauthhandler import SmartHTTPBasicAuthHandler from lib.request.certhandler import HTTPSCertAuthHandler from lib.request.rangehandler import HTTPRangeHandler from lib.request.redirecthandler import SmartRedirectHandler +from lib.request.templates import getPageTemplate from lib.utils.google import Google authHandler = urllib2.BaseHandler() @@ -1360,8 +1364,10 @@ def __setKnowledgeBaseAttributes(flushAll=True): kb.locks.cacheLock = threading.Lock() kb.locks.logLock = threading.Lock() kb.locks.ioLock = threading.Lock() + kb.locks.countLock = threading.Lock() kb.matchRatio = None + kb.multiThreadMode = False kb.nullConnection = None kb.pageTemplate = None kb.pageTemplates = dict() @@ -1701,6 +1707,10 @@ def __basicOptionValidation(): errMsg += "to get the full list of supported charsets" raise sqlmapSyntaxException, errMsg +def __resolveCrossReferences(): + lib.core.threads.readInput = readInput + lib.core.common.getPageTemplate = getPageTemplate + def init(inputOptions=advancedDict(), overrideOptions=False): """ Set attributes into both configuration and knowledge base singletons @@ -1720,6 +1730,7 @@ def init(inputOptions=advancedDict(), overrideOptions=False): __setMultipleTargets() __setTamperingFunctions() __setTrafficOutputFP() + __resolveCrossReferences() parseTargetUrl() parseTargetDirect() diff --git a/lib/core/threads.py b/lib/core/threads.py index 0f03856dd..470559a41 100644 --- a/lib/core/threads.py +++ b/lib/core/threads.py @@ -14,6 +14,7 @@ from lib.core.data import kb from lib.core.data import logger from lib.core.datatype import advancedDict from lib.core.exception import sqlmapThreadException +from lib.core.settings import MAX_NUMBER_OF_THREADS shared = advancedDict() @@ -39,6 +40,9 @@ class ThreadData(): def getCurrentThreadUID(): return hash(threading.currentThread()) +def readInput(message, default=None): + pass + def getCurrentThreadData(): """ Returns current thread's dependent data @@ -49,12 +53,40 @@ def getCurrentThreadData(): kb.threadData[threadUID] = ThreadData() return kb.threadData[threadUID] -def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardException=True): +def exceptionHandledFunction(threadFunction): + try: + threadFunction() + except KeyboardInterrupt: + kb.threadContinue = False + kb.threadException = True + raise + except: + kb.threadContinue = False + kb.threadException = True + +def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardException=True, threadChoice=False): threads = [] + kb.multiThreadMode = True kb.threadContinue = True kb.threadException = False + if threadChoice and numThreads == 1: + while True: + message = "please enter number of threads? [Enter for %d (current)] " % numThreads + choice = readInput(message, default=str(numThreads)) + if choice and choice.isdigit(): + if int(choice) > MAX_NUMBER_OF_THREADS: + errMsg = "maximum number of used threads is %d avoiding possible connection issues" % MAX_NUMBER_OF_THREADS + logger.critical(errMsg) + else: + numThreads = int(choice) + break + + if numThreads == 1: + warnMsg = "running in a single-thread mode. This could take a while." + logger.warn(warnMsg) + if numThreads > 1: infoMsg = "starting %d threads" % numThreads logger.info(infoMsg) @@ -64,7 +96,7 @@ def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardExceptio # Start the threads for numThread in range(numThreads): - thread = threading.Thread(target=threadFunction, name=str(numThread)) + thread = threading.Thread(target=exceptionHandledFunction, name=str(numThread), args=[threadFunction]) thread.start() threads.append(thread) @@ -98,6 +130,8 @@ def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardExceptio raise finally: + kb.multiThreadMode = False + kb.bruteMode = False kb.threadContinue = True kb.threadException = False diff --git a/lib/request/templates.py b/lib/request/templates.py index 70b4f6b53..f15a18265 100644 --- a/lib/request/templates.py +++ b/lib/request/templates.py @@ -7,8 +7,6 @@ Copyright (c) 2006-2011 sqlmap developers (http://sqlmap.sourceforge.net/) See the file 'doc/COPYING' for copying permission """ -import lib.core.common - from lib.core.data import kb from lib.request.connect import Connect as Request @@ -24,4 +22,3 @@ def getPageTemplate(payload, place): return retVal -lib.core.common.getPageTemplate = getPageTemplate diff --git a/lib/techniques/brute/use.py b/lib/techniques/brute/use.py index 08134a947..83e319b47 100644 --- a/lib/techniques/brute/use.py +++ b/lib/techniques/brute/use.py @@ -32,6 +32,8 @@ from lib.core.exception import sqlmapThreadException from lib.core.settings import MAX_NUMBER_OF_THREADS from lib.core.settings import METADB_SUFFIX from lib.core.session import safeFormatString +from lib.core.threads import getCurrentThreadData +from lib.core.threads import runThreads from lib.request import inject def tableExists(tableFile, regex=None): @@ -184,31 +186,36 @@ def columnExists(columnFile, regex=None): table = conf.tbl table = safeSQLIdentificatorNaming(table) - retVal = [] infoMsg = "checking column existence using items from '%s'" % columnFile logger.info(infoMsg) - count = [0] - length = len(columns) - threads = [] - collock = threading.Lock() - iolock = threading.Lock() kb.threadContinue = True kb.bruteMode = True + threadData = getCurrentThreadData() + threadData.shared.count = 0 + threadData.shared.limit = len(columns) + threadData.shared.outputs = [] + def columnExistsThread(): - while count[0] < length and kb.threadContinue: - collock.acquire() - column = safeSQLIdentificatorNaming(columns[count[0]]) - count[0] += 1 - collock.release() + threadData = getCurrentThreadData() + + while kb.threadContinue: + kb.locks.countLock.acquire() + if threadData.shared.count < threadData.shared.limit: + column = safeSQLIdentificatorNaming(columns[threadData.shared.count]) + threadData.shared.count += 1 + kb.locks.countLock.release() + else: + kb.locks.countLock.release() + break result = inject.checkBooleanExpression("%s" % safeStringFormat("EXISTS(SELECT %s FROM %s)", (column, table))) - iolock.acquire() + kb.locks.ioLock.acquire() if result: - retVal.append(column) + threadData.shared.outputs.append(column) if conf.verbose in (1, 2): clearConsoleLine(True) @@ -216,79 +223,29 @@ def columnExists(columnFile, regex=None): dataToStdout(infoMsg, True) if conf.verbose in (1, 2): - status = '%d/%d items (%d%s)' % (count[0], length, round(100.0*count[0]/length), '%') + status = '%d/%d items (%d%s)' % (threadData.shared.count, threadData.shared.limit, round(100.0*threadData.shared.count/threadData.shared.limit), '%') dataToStdout("\r[%s] [INFO] tried %s" % (time.strftime("%X"), status), True) - iolock.release() + kb.locks.ioLock.release() - if conf.threads > 1: - infoMsg = "starting %d threads" % conf.threads - logger.info(infoMsg) - else: - while True: - message = "please enter number of threads? [Enter for %d (current)] " % conf.threads - choice = readInput(message, default=str(conf.threads)) - if choice and choice.isdigit(): - if int(choice) > MAX_NUMBER_OF_THREADS: - errMsg = "maximum number of used threads is %d avoiding possible connection issues" % MAX_NUMBER_OF_THREADS - logger.critical(errMsg) - else: - conf.threads = int(choice) - break - - if conf.threads == 1: - warnMsg = "running in a single-thread mode. This could take a while." - logger.warn(warnMsg) - - # Start the threads - for numThread in range(conf.threads): - thread = threading.Thread(target=columnExistsThread, name=str(numThread)) - thread.start() - threads.append(thread) - - # And wait for them to all finish try: - alive = True + runThreads(conf.threads, columnExistsThread, threadChoice=True) - while alive: - alive = False - - for thread in threads: - if thread.isAlive(): - alive = True - thread.join(5) except KeyboardInterrupt: - kb.threadContinue = False - kb.threadException = True - - print - logger.debug("waiting for threads to finish") - - warnMsg = "user aborted during common column existence check. " - warnMsg += "sqlmap will display some columns only" + warnMsg = "user aborted during column existence " + warnMsg += "check. sqlmap will display partial output" logger.warn(warnMsg) - try: - while (threading.activeCount() > 1): - pass - - except KeyboardInterrupt: - raise sqlmapThreadException, "user aborted" - finally: - kb.bruteMode = False - kb.threadContinue = True - kb.threadException = False - clearConsoleLine(True) dataToStdout("\n") - if not retVal: + if not threadData.shared.outputs: warnMsg = "no column(s) found" logger.warn(warnMsg) else: columns = {} - for column in retVal: + for column in threadData.shared.outputs: result = inject.checkBooleanExpression("%s" % safeStringFormat("EXISTS(SELECT %s FROM %s WHERE ROUND(%s)=ROUND(%s))", (column, table, column, column))) if result: diff --git a/lib/techniques/error/use.py b/lib/techniques/error/use.py index 5a2069168..adadc704b 100644 --- a/lib/techniques/error/use.py +++ b/lib/techniques/error/use.py @@ -335,35 +335,29 @@ def errorUse(expression, expected=None, resumeValue=True, dump=False): threadData.shared.outputs = [] def errorThread(): - try: - threadData = getCurrentThreadData() + threadData = getCurrentThreadData() - while kb.threadContinue: - kb.locks.limits.acquire() - if threadData.shared.limits: - num = threadData.shared.limits[-1] - del threadData.shared.limits[-1] - kb.locks.limits.release() - else: - kb.locks.limits.release() - break + while kb.threadContinue: + kb.locks.limits.acquire() + if threadData.shared.limits: + num = threadData.shared.limits[-1] + del threadData.shared.limits[-1] + kb.locks.limits.release() + else: + kb.locks.limits.release() + break - output = __errorFields(expression, expressionFields, expressionFieldsList, expected, num, resumeValue) + output = __errorFields(expression, expressionFields, expressionFieldsList, expected, num, resumeValue) - if not kb.threadContinue: - break + if not kb.threadContinue: + break - if output and isinstance(output, list) and len(output) == 1: - output = output[0] + if output and isinstance(output, list) and len(output) == 1: + output = output[0] - kb.locks.outputs.acquire() - threadData.shared.outputs.append(output) - kb.locks.outputs.release() - - except KeyboardInterrupt: - kb.threadContinue = False - kb.threadException = True - raise + kb.locks.outputs.acquire() + threadData.shared.outputs.append(output) + kb.locks.outputs.release() runThreads(numThreads, errorThread) diff --git a/lib/techniques/inband/union/use.py b/lib/techniques/inband/union/use.py index a1c95b72d..cc28048ef 100644 --- a/lib/techniques/inband/union/use.py +++ b/lib/techniques/inband/union/use.py @@ -275,59 +275,53 @@ def unionUse(expression, unpack=True, dump=False): threadData.shared.value = "" def unionThread(): - try: - threadData = getCurrentThreadData() + threadData = getCurrentThreadData() - while kb.threadContinue: - kb.locks.limits.acquire() - if threadData.shared.limits: - num = threadData.shared.limits[-1] - del threadData.shared.limits[-1] - kb.locks.limits.release() - else: - kb.locks.limits.release() - break + while kb.threadContinue: + kb.locks.limits.acquire() + if threadData.shared.limits: + num = threadData.shared.limits[-1] + del threadData.shared.limits[-1] + kb.locks.limits.release() + else: + kb.locks.limits.release() + break - if Backend.getIdentifiedDbms() in (DBMS.MSSQL, DBMS.SYBASE): - field = expressionFieldsList[0] - elif Backend.isDbms(DBMS.ORACLE): - field = expressionFieldsList - else: - field = None + if Backend.getIdentifiedDbms() in (DBMS.MSSQL, DBMS.SYBASE): + field = expressionFieldsList[0] + elif Backend.isDbms(DBMS.ORACLE): + field = expressionFieldsList + else: + field = None - limitedExpr = agent.limitQuery(num, expression, field) - output = resume(limitedExpr, None) + limitedExpr = agent.limitQuery(num, expression, field) + output = resume(limitedExpr, None) - if not output: - output = __oneShotUnionUse(limitedExpr, unpack) + if not output: + output = __oneShotUnionUse(limitedExpr, unpack) - if not kb.threadContinue: - break + if not kb.threadContinue: + break - if output: - kb.locks.value.acquire() - threadData.shared.value += output - kb.locks.value.release() + if output: + kb.locks.value.acquire() + threadData.shared.value += output + kb.locks.value.release() - if conf.verbose == 1: - if all(map(lambda x: x in output, [kb.misc.start, kb.misc.stop])): - items = extractRegexResult(r'%s(?P.*?)%s' % (kb.misc.start, kb.misc.stop), output, re.DOTALL | re.IGNORECASE).split(kb.misc.delimiter) - else: - items = output.replace(kb.misc.start, "").replace(kb.misc.stop, "").split(kb.misc.delimiter) + if conf.verbose == 1: + if all(map(lambda x: x in output, [kb.misc.start, kb.misc.stop])): + items = extractRegexResult(r'%s(?P.*?)%s' % (kb.misc.start, kb.misc.stop), output, re.DOTALL | re.IGNORECASE).split(kb.misc.delimiter) + else: + items = output.replace(kb.misc.start, "").replace(kb.misc.stop, "").split(kb.misc.delimiter) - status = "[%s] [INFO] retrieved: %s\r\n" % (time.strftime("%X"), safecharencode(",".join(map(lambda x: "\"%s\"" % x, items)))) + status = "[%s] [INFO] retrieved: %s\r\n" % (time.strftime("%X"), safecharencode(",".join(map(lambda x: "\"%s\"" % x, items)))) - if len(status) > width: - status = "%s..." % status[:width - 3] + if len(status) > width: + status = "%s..." % status[:width - 3] - kb.locks.ioLock.acquire() - dataToStdout(status, True) - kb.locks.ioLock.release() - - except KeyboardInterrupt: - kb.threadContinue = False - kb.threadException = True - raise + kb.locks.ioLock.acquire() + dataToStdout(status, True) + kb.locks.ioLock.release() runThreads(numThreads, unionThread)