mirror of
				https://github.com/sqlmapproject/sqlmap.git
				synced 2025-10-25 21:21:03 +03:00 
			
		
		
		
	refactoring and stabilization of multithreading
This commit is contained in:
		
							parent
							
								
									5f7858455d
								
							
						
					
					
						commit
						7a3cc38e3c
					
				|  | @ -20,6 +20,9 @@ import threading | |||
| import urllib2 | ||||
| import urlparse | ||||
| 
 | ||||
| import lib.core.common | ||||
| import lib.core.threads | ||||
| 
 | ||||
| from extra.clientform.clientform import ParseResponse | ||||
| from extra.clientform.clientform import ParseError | ||||
| from extra.keepalive import keepalive | ||||
|  | @ -109,6 +112,7 @@ from lib.request.basicauthhandler import SmartHTTPBasicAuthHandler | |||
| from lib.request.certhandler import HTTPSCertAuthHandler | ||||
| from lib.request.rangehandler import HTTPRangeHandler | ||||
| from lib.request.redirecthandler import SmartRedirectHandler | ||||
| from lib.request.templates import getPageTemplate | ||||
| from lib.utils.google import Google | ||||
| 
 | ||||
| authHandler = urllib2.BaseHandler() | ||||
|  | @ -1360,8 +1364,10 @@ def __setKnowledgeBaseAttributes(flushAll=True): | |||
|     kb.locks.cacheLock = threading.Lock() | ||||
|     kb.locks.logLock = threading.Lock() | ||||
|     kb.locks.ioLock = threading.Lock() | ||||
|     kb.locks.countLock = threading.Lock() | ||||
| 
 | ||||
|     kb.matchRatio = None | ||||
|     kb.multiThreadMode = False | ||||
|     kb.nullConnection = None | ||||
|     kb.pageTemplate = None | ||||
|     kb.pageTemplates = dict() | ||||
|  | @ -1701,6 +1707,10 @@ def __basicOptionValidation(): | |||
|             errMsg += "to get the full list of supported charsets" | ||||
|             raise sqlmapSyntaxException, errMsg | ||||
| 
 | ||||
| def __resolveCrossReferences(): | ||||
|     lib.core.threads.readInput = readInput | ||||
|     lib.core.common.getPageTemplate = getPageTemplate | ||||
| 
 | ||||
| def init(inputOptions=advancedDict(), overrideOptions=False): | ||||
|     """ | ||||
|     Set attributes into both configuration and knowledge base singletons | ||||
|  | @ -1720,6 +1730,7 @@ def init(inputOptions=advancedDict(), overrideOptions=False): | |||
|     __setMultipleTargets() | ||||
|     __setTamperingFunctions() | ||||
|     __setTrafficOutputFP() | ||||
|     __resolveCrossReferences() | ||||
| 
 | ||||
|     parseTargetUrl() | ||||
|     parseTargetDirect() | ||||
|  |  | |||
|  | @ -14,6 +14,7 @@ from lib.core.data import kb | |||
| from lib.core.data import logger | ||||
| from lib.core.datatype import advancedDict | ||||
| from lib.core.exception import sqlmapThreadException | ||||
| from lib.core.settings import MAX_NUMBER_OF_THREADS | ||||
| 
 | ||||
| shared = advancedDict() | ||||
| 
 | ||||
|  | @ -39,6 +40,9 @@ class ThreadData(): | |||
| def getCurrentThreadUID(): | ||||
|     return hash(threading.currentThread()) | ||||
| 
 | ||||
| def readInput(message, default=None): | ||||
|     pass | ||||
| 
 | ||||
| def getCurrentThreadData(): | ||||
|     """ | ||||
|     Returns current thread's dependent data | ||||
|  | @ -49,12 +53,40 @@ def getCurrentThreadData(): | |||
|         kb.threadData[threadUID] = ThreadData() | ||||
|     return kb.threadData[threadUID] | ||||
| 
 | ||||
| def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardException=True): | ||||
| def exceptionHandledFunction(threadFunction): | ||||
|     try: | ||||
|         threadFunction() | ||||
|     except KeyboardInterrupt: | ||||
|         kb.threadContinue = False | ||||
|         kb.threadException = True | ||||
|         raise | ||||
|     except: | ||||
|         kb.threadContinue = False | ||||
|         kb.threadException = True | ||||
| 
 | ||||
| def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardException=True, threadChoice=False): | ||||
|     threads = [] | ||||
| 
 | ||||
|     kb.multiThreadMode = True | ||||
|     kb.threadContinue = True | ||||
|     kb.threadException = False | ||||
| 
 | ||||
|     if threadChoice and numThreads == 1: | ||||
|         while True: | ||||
|             message = "please enter number of threads? [Enter for %d (current)] " % numThreads | ||||
|             choice = readInput(message, default=str(numThreads)) | ||||
|             if choice and choice.isdigit(): | ||||
|                 if int(choice) > MAX_NUMBER_OF_THREADS: | ||||
|                     errMsg = "maximum number of used threads is %d avoiding possible connection issues" % MAX_NUMBER_OF_THREADS | ||||
|                     logger.critical(errMsg) | ||||
|                 else: | ||||
|                     numThreads = int(choice) | ||||
|                     break | ||||
| 
 | ||||
|         if numThreads == 1: | ||||
|             warnMsg = "running in a single-thread mode. This could take a while." | ||||
|             logger.warn(warnMsg) | ||||
| 
 | ||||
|     if numThreads > 1: | ||||
|         infoMsg = "starting %d threads" % numThreads | ||||
|         logger.info(infoMsg) | ||||
|  | @ -64,7 +96,7 @@ def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardExceptio | |||
| 
 | ||||
|     # Start the threads | ||||
|     for numThread in range(numThreads): | ||||
|         thread = threading.Thread(target=threadFunction, name=str(numThread)) | ||||
|         thread = threading.Thread(target=exceptionHandledFunction, name=str(numThread), args=[threadFunction]) | ||||
|         thread.start() | ||||
|         threads.append(thread) | ||||
| 
 | ||||
|  | @ -98,6 +130,8 @@ def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardExceptio | |||
|             raise | ||||
| 
 | ||||
|     finally: | ||||
|         kb.multiThreadMode = False | ||||
|         kb.bruteMode = False | ||||
|         kb.threadContinue = True | ||||
|         kb.threadException = False | ||||
| 
 | ||||
|  |  | |||
|  | @ -7,8 +7,6 @@ Copyright (c) 2006-2011 sqlmap developers (http://sqlmap.sourceforge.net/) | |||
| See the file 'doc/COPYING' for copying permission | ||||
| """ | ||||
| 
 | ||||
| import lib.core.common | ||||
| 
 | ||||
| from lib.core.data import kb | ||||
| from lib.request.connect import Connect as Request | ||||
| 
 | ||||
|  | @ -24,4 +22,3 @@ def getPageTemplate(payload, place): | |||
| 
 | ||||
|     return retVal | ||||
| 
 | ||||
| lib.core.common.getPageTemplate = getPageTemplate | ||||
|  |  | |||
|  | @ -32,6 +32,8 @@ from lib.core.exception import sqlmapThreadException | |||
| from lib.core.settings import MAX_NUMBER_OF_THREADS | ||||
| from lib.core.settings import METADB_SUFFIX | ||||
| from lib.core.session import safeFormatString | ||||
| from lib.core.threads import getCurrentThreadData | ||||
| from lib.core.threads import runThreads | ||||
| from lib.request import inject | ||||
| 
 | ||||
| def tableExists(tableFile, regex=None): | ||||
|  | @ -184,31 +186,36 @@ def columnExists(columnFile, regex=None): | |||
|         table = conf.tbl | ||||
|     table = safeSQLIdentificatorNaming(table) | ||||
| 
 | ||||
|     retVal = [] | ||||
|     infoMsg = "checking column existence using items from '%s'" % columnFile | ||||
|     logger.info(infoMsg) | ||||
| 
 | ||||
|     count = [0] | ||||
|     length = len(columns) | ||||
|     threads = [] | ||||
|     collock = threading.Lock() | ||||
|     iolock = threading.Lock() | ||||
|     kb.threadContinue = True | ||||
|     kb.bruteMode = True | ||||
| 
 | ||||
|     threadData = getCurrentThreadData() | ||||
|     threadData.shared.count = 0 | ||||
|     threadData.shared.limit = len(columns) | ||||
|     threadData.shared.outputs = [] | ||||
| 
 | ||||
|     def columnExistsThread(): | ||||
|         while count[0] < length and kb.threadContinue: | ||||
|             collock.acquire() | ||||
|             column = safeSQLIdentificatorNaming(columns[count[0]]) | ||||
|             count[0] += 1 | ||||
|             collock.release() | ||||
|         threadData = getCurrentThreadData() | ||||
| 
 | ||||
|         while kb.threadContinue: | ||||
|             kb.locks.countLock.acquire() | ||||
|             if threadData.shared.count < threadData.shared.limit: | ||||
|                 column = safeSQLIdentificatorNaming(columns[threadData.shared.count]) | ||||
|                 threadData.shared.count += 1 | ||||
|                 kb.locks.countLock.release() | ||||
|             else: | ||||
|                 kb.locks.countLock.release() | ||||
|                 break | ||||
| 
 | ||||
|             result = inject.checkBooleanExpression("%s" % safeStringFormat("EXISTS(SELECT %s FROM %s)", (column, table))) | ||||
| 
 | ||||
|             iolock.acquire() | ||||
|             kb.locks.ioLock.acquire() | ||||
| 
 | ||||
|             if result: | ||||
|                 retVal.append(column) | ||||
|                 threadData.shared.outputs.append(column) | ||||
| 
 | ||||
|                 if conf.verbose in (1, 2): | ||||
|                     clearConsoleLine(True) | ||||
|  | @ -216,79 +223,29 @@ def columnExists(columnFile, regex=None): | |||
|                     dataToStdout(infoMsg, True) | ||||
| 
 | ||||
|             if conf.verbose in (1, 2): | ||||
|                 status = '%d/%d items (%d%s)' % (count[0], length, round(100.0*count[0]/length), '%') | ||||
|                 status = '%d/%d items (%d%s)' % (threadData.shared.count, threadData.shared.limit, round(100.0*threadData.shared.count/threadData.shared.limit), '%') | ||||
|                 dataToStdout("\r[%s] [INFO] tried %s" % (time.strftime("%X"), status), True) | ||||
| 
 | ||||
|             iolock.release() | ||||
|             kb.locks.ioLock.release() | ||||
| 
 | ||||
|     if conf.threads > 1: | ||||
|         infoMsg = "starting %d threads" % conf.threads | ||||
|         logger.info(infoMsg) | ||||
|     else: | ||||
|         while True: | ||||
|             message = "please enter number of threads? [Enter for %d (current)] " % conf.threads | ||||
|             choice = readInput(message, default=str(conf.threads)) | ||||
|             if choice and choice.isdigit(): | ||||
|                 if int(choice) > MAX_NUMBER_OF_THREADS: | ||||
|                     errMsg = "maximum number of used threads is %d avoiding possible connection issues" % MAX_NUMBER_OF_THREADS | ||||
|                     logger.critical(errMsg) | ||||
|                 else: | ||||
|                     conf.threads = int(choice) | ||||
|                     break | ||||
| 
 | ||||
|     if conf.threads == 1: | ||||
|         warnMsg = "running in a single-thread mode. This could take a while." | ||||
|         logger.warn(warnMsg) | ||||
| 
 | ||||
|     # Start the threads | ||||
|     for numThread in range(conf.threads): | ||||
|         thread = threading.Thread(target=columnExistsThread, name=str(numThread)) | ||||
|         thread.start() | ||||
|         threads.append(thread) | ||||
| 
 | ||||
|     # And wait for them to all finish | ||||
|     try: | ||||
|         alive = True | ||||
|         runThreads(conf.threads, columnExistsThread, threadChoice=True) | ||||
| 
 | ||||
|         while alive: | ||||
|             alive = False | ||||
| 
 | ||||
|             for thread in threads: | ||||
|                 if thread.isAlive(): | ||||
|                     alive = True | ||||
|                     thread.join(5) | ||||
|     except KeyboardInterrupt: | ||||
|         kb.threadContinue = False | ||||
|         kb.threadException = True | ||||
| 
 | ||||
|         print | ||||
|         logger.debug("waiting for threads to finish") | ||||
| 
 | ||||
|         warnMsg = "user aborted during common column existence check. " | ||||
|         warnMsg += "sqlmap will display some columns only" | ||||
|         warnMsg = "user aborted during column existence " | ||||
|         warnMsg += "check. sqlmap will display partial output" | ||||
|         logger.warn(warnMsg) | ||||
| 
 | ||||
|         try: | ||||
|             while (threading.activeCount() > 1): | ||||
|                 pass | ||||
| 
 | ||||
|         except KeyboardInterrupt: | ||||
|             raise sqlmapThreadException, "user aborted" | ||||
|     finally: | ||||
|         kb.bruteMode = False | ||||
|         kb.threadContinue = True | ||||
|         kb.threadException = False | ||||
| 
 | ||||
|     clearConsoleLine(True) | ||||
|     dataToStdout("\n") | ||||
| 
 | ||||
|     if not retVal: | ||||
|     if not threadData.shared.outputs: | ||||
|         warnMsg = "no column(s) found" | ||||
|         logger.warn(warnMsg) | ||||
|     else: | ||||
|         columns = {} | ||||
| 
 | ||||
|         for column in retVal: | ||||
|         for column in threadData.shared.outputs: | ||||
|             result = inject.checkBooleanExpression("%s" % safeStringFormat("EXISTS(SELECT %s FROM %s WHERE ROUND(%s)=ROUND(%s))", (column, table, column, column))) | ||||
| 
 | ||||
|             if result: | ||||
|  |  | |||
|  | @ -335,35 +335,29 @@ def errorUse(expression, expected=None, resumeValue=True, dump=False): | |||
|                 threadData.shared.outputs = [] | ||||
| 
 | ||||
|                 def errorThread(): | ||||
|                     try: | ||||
|                         threadData = getCurrentThreadData() | ||||
|                     threadData = getCurrentThreadData() | ||||
| 
 | ||||
|                         while kb.threadContinue: | ||||
|                             kb.locks.limits.acquire() | ||||
|                             if threadData.shared.limits: | ||||
|                                 num = threadData.shared.limits[-1] | ||||
|                                 del threadData.shared.limits[-1] | ||||
|                                 kb.locks.limits.release() | ||||
|                             else: | ||||
|                                 kb.locks.limits.release() | ||||
|                                 break | ||||
|                     while kb.threadContinue: | ||||
|                         kb.locks.limits.acquire() | ||||
|                         if threadData.shared.limits: | ||||
|                             num = threadData.shared.limits[-1] | ||||
|                             del threadData.shared.limits[-1] | ||||
|                             kb.locks.limits.release() | ||||
|                         else: | ||||
|                             kb.locks.limits.release() | ||||
|                             break | ||||
| 
 | ||||
|                             output = __errorFields(expression, expressionFields, expressionFieldsList, expected, num, resumeValue) | ||||
|                         output = __errorFields(expression, expressionFields, expressionFieldsList, expected, num, resumeValue) | ||||
| 
 | ||||
|                             if not kb.threadContinue: | ||||
|                                 break | ||||
|                         if not kb.threadContinue: | ||||
|                             break | ||||
| 
 | ||||
|                             if output and isinstance(output, list) and len(output) == 1: | ||||
|                                 output = output[0] | ||||
|                         if output and isinstance(output, list) and len(output) == 1: | ||||
|                             output = output[0] | ||||
| 
 | ||||
|                             kb.locks.outputs.acquire() | ||||
|                             threadData.shared.outputs.append(output) | ||||
|                             kb.locks.outputs.release() | ||||
| 
 | ||||
|                     except KeyboardInterrupt: | ||||
|                         kb.threadContinue = False | ||||
|                         kb.threadException = True | ||||
|                         raise | ||||
|                         kb.locks.outputs.acquire() | ||||
|                         threadData.shared.outputs.append(output) | ||||
|                         kb.locks.outputs.release() | ||||
| 
 | ||||
|                 runThreads(numThreads, errorThread) | ||||
| 
 | ||||
|  |  | |||
|  | @ -275,59 +275,53 @@ def unionUse(expression, unpack=True, dump=False): | |||
|                 threadData.shared.value = "" | ||||
| 
 | ||||
|                 def unionThread(): | ||||
|                     try: | ||||
|                         threadData = getCurrentThreadData() | ||||
|                     threadData = getCurrentThreadData() | ||||
| 
 | ||||
|                         while kb.threadContinue: | ||||
|                             kb.locks.limits.acquire() | ||||
|                             if threadData.shared.limits: | ||||
|                                 num = threadData.shared.limits[-1] | ||||
|                                 del threadData.shared.limits[-1] | ||||
|                                 kb.locks.limits.release() | ||||
|                             else: | ||||
|                                 kb.locks.limits.release() | ||||
|                                 break | ||||
|                     while kb.threadContinue: | ||||
|                         kb.locks.limits.acquire() | ||||
|                         if threadData.shared.limits: | ||||
|                             num = threadData.shared.limits[-1] | ||||
|                             del threadData.shared.limits[-1] | ||||
|                             kb.locks.limits.release() | ||||
|                         else: | ||||
|                             kb.locks.limits.release() | ||||
|                             break | ||||
| 
 | ||||
|                             if Backend.getIdentifiedDbms() in (DBMS.MSSQL, DBMS.SYBASE): | ||||
|                                 field = expressionFieldsList[0] | ||||
|                             elif Backend.isDbms(DBMS.ORACLE): | ||||
|                                 field = expressionFieldsList | ||||
|                             else: | ||||
|                                 field = None | ||||
|                         if Backend.getIdentifiedDbms() in (DBMS.MSSQL, DBMS.SYBASE): | ||||
|                             field = expressionFieldsList[0] | ||||
|                         elif Backend.isDbms(DBMS.ORACLE): | ||||
|                             field = expressionFieldsList | ||||
|                         else: | ||||
|                             field = None | ||||
| 
 | ||||
|                             limitedExpr = agent.limitQuery(num, expression, field) | ||||
|                             output = resume(limitedExpr, None) | ||||
|                         limitedExpr = agent.limitQuery(num, expression, field) | ||||
|                         output = resume(limitedExpr, None) | ||||
| 
 | ||||
|                             if not output: | ||||
|                                 output = __oneShotUnionUse(limitedExpr, unpack) | ||||
|                         if not output: | ||||
|                             output = __oneShotUnionUse(limitedExpr, unpack) | ||||
| 
 | ||||
|                             if not kb.threadContinue: | ||||
|                                 break | ||||
|                         if not kb.threadContinue: | ||||
|                             break | ||||
| 
 | ||||
|                             if output: | ||||
|                                 kb.locks.value.acquire() | ||||
|                                 threadData.shared.value += output | ||||
|                                 kb.locks.value.release() | ||||
|                         if output: | ||||
|                             kb.locks.value.acquire() | ||||
|                             threadData.shared.value += output | ||||
|                             kb.locks.value.release() | ||||
| 
 | ||||
|                                 if conf.verbose == 1: | ||||
|                                     if all(map(lambda x: x in output, [kb.misc.start, kb.misc.stop])): | ||||
|                                         items = extractRegexResult(r'%s(?P<result>.*?)%s' % (kb.misc.start, kb.misc.stop), output, re.DOTALL | re.IGNORECASE).split(kb.misc.delimiter) | ||||
|                                     else: | ||||
|                                         items = output.replace(kb.misc.start, "").replace(kb.misc.stop, "").split(kb.misc.delimiter) | ||||
|                             if conf.verbose == 1: | ||||
|                                 if all(map(lambda x: x in output, [kb.misc.start, kb.misc.stop])): | ||||
|                                     items = extractRegexResult(r'%s(?P<result>.*?)%s' % (kb.misc.start, kb.misc.stop), output, re.DOTALL | re.IGNORECASE).split(kb.misc.delimiter) | ||||
|                                 else: | ||||
|                                     items = output.replace(kb.misc.start, "").replace(kb.misc.stop, "").split(kb.misc.delimiter) | ||||
| 
 | ||||
|                                     status = "[%s] [INFO] retrieved: %s\r\n" % (time.strftime("%X"), safecharencode(",".join(map(lambda x: "\"%s\"" % x, items)))) | ||||
|                                 status = "[%s] [INFO] retrieved: %s\r\n" % (time.strftime("%X"), safecharencode(",".join(map(lambda x: "\"%s\"" % x, items)))) | ||||
| 
 | ||||
|                                     if len(status) > width: | ||||
|                                         status = "%s..." % status[:width - 3] | ||||
|                                 if len(status) > width: | ||||
|                                     status = "%s..." % status[:width - 3] | ||||
| 
 | ||||
|                                     kb.locks.ioLock.acquire() | ||||
|                                     dataToStdout(status, True) | ||||
|                                     kb.locks.ioLock.release() | ||||
| 
 | ||||
|                     except KeyboardInterrupt: | ||||
|                         kb.threadContinue = False | ||||
|                         kb.threadException = True | ||||
|                         raise | ||||
|                                 kb.locks.ioLock.acquire() | ||||
|                                 dataToStdout(status, True) | ||||
|                                 kb.locks.ioLock.release() | ||||
| 
 | ||||
|                 runThreads(numThreads, unionThread) | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user