mirror of
				https://github.com/sqlmapproject/sqlmap.git
				synced 2025-11-01 00:17:25 +03:00 
			
		
		
		
	implementation of multithreading for UNION and ERROR techniques
This commit is contained in:
		
							parent
							
								
									d51efa679d
								
							
						
					
					
						commit
						86455ceb9c
					
				|  | @ -1358,6 +1358,7 @@ def __setKnowledgeBaseAttributes(flushAll=True): | |||
|     kb.locks = advancedDict() | ||||
|     kb.locks.cacheLock = threading.Lock() | ||||
|     kb.locks.logLock = threading.Lock() | ||||
|     kb.locks.ioLock = threading.Lock() | ||||
| 
 | ||||
|     kb.matchRatio = None | ||||
|     kb.nullConnection = None | ||||
|  |  | |||
|  | @ -11,6 +11,11 @@ import difflib | |||
| import threading | ||||
| 
 | ||||
| from lib.core.data import kb | ||||
| from lib.core.data import logger | ||||
| from lib.core.datatype import advancedDict | ||||
| from lib.core.exception import sqlmapThreadException | ||||
| 
 | ||||
| shared = advancedDict() | ||||
| 
 | ||||
| class ThreadData(): | ||||
|     """ | ||||
|  | @ -18,6 +23,8 @@ class ThreadData(): | |||
|     """ | ||||
| 
 | ||||
|     def __init__(self): | ||||
|         global shared | ||||
| 
 | ||||
|         self.disableStdOut = False | ||||
|         self.lastErrorPage = None | ||||
|         self.lastHTTPError = None | ||||
|  | @ -26,6 +33,7 @@ class ThreadData(): | |||
|         self.lastRequestUID = 0 | ||||
|         self.retriesCount = 0 | ||||
|         self.seqMatcher = difflib.SequenceMatcher(None) | ||||
|         self.shared = shared | ||||
|         self.valueStack = [] | ||||
| 
 | ||||
| def getCurrentThreadUID(): | ||||
|  | @ -40,3 +48,55 @@ def getCurrentThreadData(): | |||
|     if threadUID not in kb.threadData: | ||||
|         kb.threadData[threadUID] = ThreadData() | ||||
|     return kb.threadData[threadUID] | ||||
| 
 | ||||
| def runThreads(numThreads, threadFunction, cleanupFunction=None): | ||||
|     threads = [] | ||||
| 
 | ||||
|     kb.threadContinue = True | ||||
|     kb.threadException = False | ||||
| 
 | ||||
|     if numThreads > 1: | ||||
|         infoMsg = "starting %d threads" % numThreads | ||||
|         logger.info(infoMsg) | ||||
|     else: | ||||
|         threadFunction() | ||||
|         return | ||||
| 
 | ||||
|     # Start the threads | ||||
|     for numThread in range(numThreads): | ||||
|         thread = threading.Thread(target=threadFunction, name=str(numThread)) | ||||
|         thread.start() | ||||
|         threads.append(thread) | ||||
| 
 | ||||
|     # And wait for them to all finish | ||||
|     try: | ||||
|         alive = True | ||||
| 
 | ||||
|         while alive: | ||||
|             alive = False | ||||
| 
 | ||||
|             for thread in threads: | ||||
|                 if thread.isAlive(): | ||||
|                     alive = True | ||||
|                     thread.join(1) | ||||
| 
 | ||||
|     except KeyboardInterrupt: | ||||
|         kb.threadContinue = False | ||||
|         kb.threadException = True | ||||
| 
 | ||||
|         print | ||||
|         logger.debug("waiting for threads to finish") | ||||
| 
 | ||||
|         try: | ||||
|             while (threading.activeCount() > 1): | ||||
|                 pass | ||||
| 
 | ||||
|         except KeyboardInterrupt: | ||||
|             raise sqlmapThreadException, "user aborted (Ctrl+C was pressed multiple times)" | ||||
| 
 | ||||
|     finally: | ||||
|         kb.threadContinue = True | ||||
|         kb.threadException = False | ||||
| 
 | ||||
|         if cleanupFunction: | ||||
|             cleanupFunction() | ||||
|  |  | |||
|  | @ -8,6 +8,7 @@ See the file 'doc/COPYING' for copying permission | |||
| """ | ||||
| 
 | ||||
| import re | ||||
| import threading | ||||
| import time | ||||
| 
 | ||||
| from lib.core.agent import agent | ||||
|  | @ -39,6 +40,7 @@ from lib.core.settings import MSSQL_ERROR_CHUNK_LENGTH | |||
| from lib.core.settings import SQL_SCALAR_REGEX | ||||
| from lib.core.settings import TURN_OFF_RESUME_INFO_LIMIT | ||||
| from lib.core.threads import getCurrentThreadData | ||||
| from lib.core.threads import runThreads | ||||
| from lib.core.unescaper import unescaper | ||||
| from lib.request.connect import Connect as Request | ||||
| from lib.utils.resume import resume | ||||
|  | @ -159,7 +161,9 @@ def __errorFields(expression, expressionFields, expressionFieldsList, expected=N | |||
|             output = __oneShotErrorUse(expressionReplaced, field) | ||||
| 
 | ||||
|             if output is not None: | ||||
|                 kb.locks.ioLock.acquire() | ||||
|                 dataToStdout("[%s] [INFO] retrieved: %s\r\n" % (time.strftime("%X"), safecharencode(output))) | ||||
|                 kb.locks.ioLock.release() | ||||
| 
 | ||||
|         if isinstance(num, int): | ||||
|             expression = origExpr | ||||
|  | @ -316,13 +320,39 @@ def errorUse(expression, expected=None, resumeValue=True, dump=False): | |||
|                     infoMsg += "large number of rows (possible slowdown)" | ||||
|                     logger.info(infoMsg) | ||||
| 
 | ||||
|                 for num in xrange(startLimit, stopLimit): | ||||
|                     output = __errorFields(expression, expressionFields, expressionFieldsList, expected, num, resumeValue) | ||||
|                 lockNames = ('limits', 'outputs') | ||||
|                 for lock in lockNames: | ||||
|                     kb.locks[lock] = threading.Lock() | ||||
| 
 | ||||
|                     if output and isinstance(output, list) and len(output) == 1: | ||||
|                         output = output[0] | ||||
|                 threadData = getCurrentThreadData() | ||||
|                 numThreads = min(conf.threads, stopLimit-startLimit) | ||||
|                 threadData.shared.limits = range(startLimit, stopLimit) | ||||
|                 threadData.shared.outputs = [] | ||||
| 
 | ||||
|                     outputs.append(output) | ||||
|                 def errorThread(): | ||||
|                     try: | ||||
|                         threadData = getCurrentThreadData() | ||||
| 
 | ||||
|                         while threadData.shared.limits and kb.threadContinue: | ||||
|                             kb.locks.limits.acquire() | ||||
|                             num = threadData.shared.limits[-1] | ||||
|                             del threadData.shared.limits[-1] | ||||
|                             kb.locks.limits.release() | ||||
| 
 | ||||
|                             output = __errorFields(expression, expressionFields, expressionFieldsList, expected, num, resumeValue) | ||||
| 
 | ||||
|                             if output and isinstance(output, list) and len(output) == 1: | ||||
|                                 output = output[0] | ||||
| 
 | ||||
|                             kb.locks.outputs.acquire() | ||||
|                             threadData.shared.outputs.append(output) | ||||
|                             kb.locks.outputs.release() | ||||
|                     except KeyboardInterrupt: | ||||
|                         raise | ||||
| 
 | ||||
|                 runThreads(numThreads, errorThread) | ||||
| 
 | ||||
|                 outputs = threadData.shared.outputs | ||||
| 
 | ||||
|             except KeyboardInterrupt: | ||||
|                 warnMsg = "user aborted during enumeration. sqlmap " | ||||
|  |  | |||
|  | @ -9,6 +9,7 @@ See the file 'doc/COPYING' for copying permission | |||
| 
 | ||||
| import logging | ||||
| import re | ||||
| import threading | ||||
| import time | ||||
| 
 | ||||
| from lib.core.agent import agent | ||||
|  | @ -39,6 +40,8 @@ from lib.core.exception import sqlmapSyntaxException | |||
| from lib.core.settings import FROM_TABLE | ||||
| from lib.core.settings import SQL_SCALAR_REGEX | ||||
| from lib.core.settings import TURN_OFF_RESUME_INFO_LIMIT | ||||
| from lib.core.threads import getCurrentThreadData | ||||
| from lib.core.threads import runThreads | ||||
| from lib.core.unescaper import unescaper | ||||
| from lib.request.connect import Connect as Request | ||||
| from lib.utils.resume import resume | ||||
|  | @ -260,32 +263,60 @@ def unionUse(expression, unpack=True, dump=False): | |||
|                     infoMsg += "large number of rows (possible slowdown)" | ||||
|                     logger.info(infoMsg) | ||||
| 
 | ||||
|                 for num in xrange(startLimit, stopLimit): | ||||
|                     if Backend.getIdentifiedDbms() in (DBMS.MSSQL, DBMS.SYBASE): | ||||
|                         field = expressionFieldsList[0] | ||||
|                     elif Backend.isDbms(DBMS.ORACLE): | ||||
|                         field = expressionFieldsList | ||||
|                     else: | ||||
|                         field = None | ||||
|                 lockNames = ('limits', 'value') | ||||
|                 for lock in lockNames: | ||||
|                     kb.locks[lock] = threading.Lock() | ||||
| 
 | ||||
|                     limitedExpr = agent.limitQuery(num, expression, field) | ||||
|                     output = resume(limitedExpr, None) | ||||
|                 threadData = getCurrentThreadData() | ||||
|                 numThreads = min(conf.threads, stopLimit-startLimit) | ||||
|                 threadData.shared.limits = range(startLimit, stopLimit) | ||||
|                 threadData.shared.value = "" | ||||
| 
 | ||||
|                     if not output: | ||||
|                         output = __oneShotUnionUse(limitedExpr, unpack) | ||||
|                 def unionThread(): | ||||
|                     threadData = getCurrentThreadData() | ||||
| 
 | ||||
|                     if output: | ||||
|                         value += output | ||||
|                     while threadData.shared.limits and kb.threadContinue: | ||||
|                         kb.locks.limits.acquire() | ||||
|                         num = threadData.shared.limits[-1] | ||||
|                         del threadData.shared.limits[-1] | ||||
|                         kb.locks.limits.release() | ||||
| 
 | ||||
|                         if conf.verbose == 1: | ||||
|                             if all(map(lambda x: x in output, [kb.misc.start, kb.misc.stop])): | ||||
|                                 items = extractRegexResult(r'%s(?P<result>.*?)%s' % (kb.misc.start, kb.misc.stop), output, re.DOTALL | re.IGNORECASE).split(kb.misc.delimiter) | ||||
|                             else: | ||||
|                                 items = output.replace(kb.misc.start, "").replace(kb.misc.stop, "").split(kb.misc.delimiter) | ||||
|                             status = "[%s] [INFO] retrieved: %s\r\n" % (time.strftime("%X"), safecharencode(",".join(map(lambda x: "\"%s\"" % x, items)))) | ||||
|                             if len(status) > width: | ||||
|                                 status = "%s..." % status[:width - 3] | ||||
|                             dataToStdout(status, True) | ||||
|                         if Backend.getIdentifiedDbms() in (DBMS.MSSQL, DBMS.SYBASE): | ||||
|                             field = expressionFieldsList[0] | ||||
|                         elif Backend.isDbms(DBMS.ORACLE): | ||||
|                             field = expressionFieldsList | ||||
|                         else: | ||||
|                             field = None | ||||
| 
 | ||||
|                         limitedExpr = agent.limitQuery(num, expression, field) | ||||
|                         output = resume(limitedExpr, None) | ||||
| 
 | ||||
|                         if not output: | ||||
|                             output = __oneShotUnionUse(limitedExpr, unpack) | ||||
| 
 | ||||
|                         if output: | ||||
|                             kb.locks.value.acquire() | ||||
|                             threadData.shared.value += output | ||||
|                             kb.locks.value.release() | ||||
| 
 | ||||
|                             if conf.verbose == 1: | ||||
|                                 if all(map(lambda x: x in output, [kb.misc.start, kb.misc.stop])): | ||||
|                                     items = extractRegexResult(r'%s(?P<result>.*?)%s' % (kb.misc.start, kb.misc.stop), output, re.DOTALL | re.IGNORECASE).split(kb.misc.delimiter) | ||||
|                                 else: | ||||
|                                     items = output.replace(kb.misc.start, "").replace(kb.misc.stop, "").split(kb.misc.delimiter) | ||||
| 
 | ||||
|                                 status = "[%s] [INFO] retrieved: %s\r\n" % (time.strftime("%X"), safecharencode(",".join(map(lambda x: "\"%s\"" % x, items)))) | ||||
| 
 | ||||
|                                 if len(status) > width: | ||||
|                                     status = "%s..." % status[:width - 3] | ||||
| 
 | ||||
|                                 kb.locks.ioLock.acquire() | ||||
|                                 dataToStdout(status, True) | ||||
|                                 kb.locks.ioLock.release() | ||||
| 
 | ||||
|                 runThreads(numThreads, unionThread) | ||||
| 
 | ||||
|                 value = threadData.shared.value | ||||
| 
 | ||||
|                 if conf.verbose == 1: | ||||
|                     clearConsoleLine(True) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user