diff --git a/lib/core/common.py b/lib/core/common.py index 3824fe99f..6322e5e20 100644 --- a/lib/core/common.py +++ b/lib/core/common.py @@ -1544,29 +1544,37 @@ def commonFinderOnly(initial, sequence): def getCurrentThreadID(): return threading.currentThread().ident -def pushValue(value): +def getCurrentThreadData(): """ - Push value to the stack + Returns current thread's dependent data """ - threadId = getCurrentThreadID() - if threadId not in kb.valueStack: - kb.valueStack[threadId] = [] - kb.valueStack[threadId].append(value) + threadID = getCurrentThreadID() + if threadID not in kb.threadData: + kb.threadData[threadID] = ThreadData() + return kb.threadData[threadID] + +def pushValue(value): + """ + Push value to the stack (thread dependent) + """ + + getCurrentThreadData().valueStack.append(value) def popValue(): """ - Pop value from the stack + Pop value from the stack (thread dependent) """ - return kb.valueStack[getCurrentThreadID()].pop() + return getCurrentThreadData().valueStack.pop() def wasLastRequestDBMSError(): """ Returns True if the last web request resulted in a (recognized) DBMS error page """ - return kb.lastErrorPage and kb.lastErrorPage[0] == kb.lastRequestUID + threadData = getCurrentThreadData() + return threadData.lastErrorPage and threadData.lastErrorPage[0] == threadData.lastRequestUID def wasLastRequestDelayed(): """ @@ -1577,6 +1585,7 @@ def wasLastRequestDelayed(): # affected response times should be inside +-7*stdev([normal response times]) # (Math reference: http://www.answers.com/topic/standard-deviation) deviation = stdev(kb.responseTimes) + threadData = getCurrentThreadData() if deviation: if len(kb.responseTimes) < MIN_TIME_RESPONSES: @@ -1584,9 +1593,9 @@ def wasLastRequestDelayed(): warnMsg += "with less than %d response times" % MIN_TIME_RESPONSES logger.warn(warnMsg) - return (kb.lastQueryDuration >= average(kb.responseTimes) + 7 * deviation) + return (threadData.lastQueryDuration >= average(kb.responseTimes) + 7 * deviation) else: - return kb.lastQueryDuration - conf.timeSec + return threadData.lastQueryDuration - conf.timeSec def extractErrorMessage(page): """ @@ -1665,13 +1674,13 @@ def runningAsAdmin(): return isAdmin def logHTTPTraffic(requestLogMsg, responseLogMsg): - kb.locks.reqLock.acquire() + kb.locks.logLock.acquire() dataToTrafficFile("%s\n" % requestLogMsg) dataToTrafficFile("%s\n" % responseLogMsg) dataToTrafficFile("\n%s\n\n" % (76 * '#')) - kb.locks.reqLock.release() + kb.locks.logLock.release() def getPageTemplate(payload, place): pass diff --git a/lib/core/option.py b/lib/core/option.py index cb9a52645..5e5a85527 100644 --- a/lib/core/option.py +++ b/lib/core/option.py @@ -37,6 +37,7 @@ from lib.core.common import readCachedFileContent from lib.core.common import readInput from lib.core.common import runningAsAdmin from lib.core.common import sanitizeStr +from lib.core.common import ThreadData from lib.core.common import UnicodeRawConfigParser from lib.core.data import conf from lib.core.data import kb @@ -1164,13 +1165,10 @@ def __setKnowledgeBaseAttributes(flushAll=True): kb.htmlFp = [] kb.injection = injectionDict() kb.injections = [] - kb.lastErrorPage = None - kb.lastQueryDuration = 0 - kb.lastRequestUID = 0 kb.locks = advancedDict() kb.locks.cacheLock = threading.Lock() - kb.locks.reqLock = threading.Lock() + kb.locks.logLock = threading.Lock() kb.locks.seqLock = None kb.matchRatio = None @@ -1198,11 +1196,11 @@ def __setKnowledgeBaseAttributes(flushAll=True): kb.testMode = False kb.testQueryCount = 0 kb.threadContinue = True + kb.threadData = {} kb.unionComment = "" kb.unionCount = None kb.unionPosition = None kb.unionNegative = False - kb.valueStack = {} if flushAll: kb.keywords = set(getFileItems(paths.SQL_KEYWORDS)) diff --git a/lib/parse/html.py b/lib/parse/html.py index b5f8048d7..f4d2f6c01 100644 --- a/lib/parse/html.py +++ b/lib/parse/html.py @@ -12,6 +12,7 @@ import re from xml.sax.handler import ContentHandler from lib.core.common import checkFile +from lib.core.common import getCurrentThreadData from lib.core.common import parseXmlFile from lib.core.common import sanitizeStr from lib.core.data import kb @@ -42,7 +43,8 @@ class htmlHandler(ContentHandler): if self.__match: self.dbms = self.__dbms self.__match = None - kb.lastErrorPage = (kb.lastRequestUID, self.__page) + threadData = getCurrentThreadData() + threadData.lastErrorPage = (threadData.lastRequestUID, self.__page) def htmlParser(page): """ diff --git a/lib/request/connect.py b/lib/request/connect.py index f0ecb4851..2091e777f 100644 --- a/lib/request/connect.py +++ b/lib/request/connect.py @@ -20,6 +20,7 @@ from lib.core.agent import agent from lib.core.common import average from lib.core.common import calculateDeltaSeconds from lib.core.common import extractErrorMessage +from lib.core.common import getCurrentThreadData from lib.core.common import getFilteredPageContent from lib.core.common import getUnicode from lib.core.common import logHTTPTraffic @@ -69,12 +70,8 @@ class Connect: delay = 0.00001 * (conf.cpuThrottle ** 2) time.sleep(delay) - kb.locks.reqLock.acquire() - - kb.lastRequestUID += 1 - requestID = kb.lastRequestUID - - kb.locks.reqLock.release() + threadData = getCurrentThreadData() + threadData.lastRequestUID += 1 url = kwargs.get('url', conf.url).replace(" ", "%20") get = kwargs.get('get', None) @@ -92,7 +89,7 @@ class Connect: page = "" cookieStr = "" - requestMsg = "HTTP request [#%d]:\n%s " % (requestID, conf.method) + requestMsg = "HTTP request [#%d]:\n%s " % (threadData.lastRequestUID, conf.method) requestMsg += "%s" % urlparse.urlsplit(url)[2] or "/" responseMsg = "HTTP response " requestHeaders = "" @@ -246,7 +243,7 @@ class Connect: except: pass - responseMsg = "\n%s[#%d] (%d %s):\n" % (responseMsg, requestID, code, status) + responseMsg = "\n%s[#%d] (%d %s):\n" % (responseMsg, threadData.lastRequestUID, code, status) if responseHeaders: logHeaders = "\n".join(["%s: %s" % (key.capitalize() if isinstance(key, basestring) else key, value) for (key, value) in responseHeaders.items()]) @@ -310,7 +307,7 @@ class Connect: page = getUnicode(page) parseResponse(page, responseHeaders) - responseMsg += "[#%d] (%d %s):\n" % (requestID, code, status) + responseMsg += "[#%d] (%d %s):\n" % (threadData.lastRequestUID, code, status) logHeaders = "\n".join(["%s: %s" % (key.capitalize() if isinstance(key, basestring) else key, value) for (key, value) in responseHeaders.items()]) logHTTPTraffic(requestMsg, "%s%s\n\n%s" % (responseMsg, logHeaders, page)) @@ -355,6 +352,7 @@ class Connect: place = kb.injection.place payload = agent.extractPayload(value) + threadData = getCurrentThreadData() if payload: if kb.tamperFunctions: @@ -426,7 +424,7 @@ class Connect: if not pageLength: page, headers = Connect.getPage(url=uri, get=get, post=post, cookie=cookie, ua=ua, silent=silent, method=method, auxHeaders=auxHeaders, response=response, raise404=raise404, ignoreTimeout=timeBasedCompare) - kb.lastQueryDuration = calculateDeltaSeconds(start) + threadData.lastQueryDuration = calculateDeltaSeconds(start) if conf.textOnly: page = getFilteredPageContent(page) @@ -439,7 +437,7 @@ class Connect: if timeBasedCompare: return wasLastRequestDelayed() elif noteResponseTime: - kb.responseTimes.append(kb.lastQueryDuration) + kb.responseTimes.append(threadData.lastQueryDuration) if content or response: return page, headers diff --git a/plugins/dbms/access/fingerprint.py b/plugins/dbms/access/fingerprint.py index 6346714f4..218994423 100644 --- a/plugins/dbms/access/fingerprint.py +++ b/plugins/dbms/access/fingerprint.py @@ -12,6 +12,7 @@ import re from lib.core.agent import agent from lib.core.common import formatDBMSfp from lib.core.common import formatFingerprint +from lib.core.common import getCurrentThreadData from lib.core.common import getHtmlErrorFp from lib.core.common import randomInt from lib.core.common import randomStr @@ -93,7 +94,8 @@ class Fingerprint(GenericFingerprint): _ = inject.checkBooleanExpression("EXISTS(SELECT * FROM %s.%s WHERE %d=%d)" % (randStr, randStr, randInt, randInt)) if wasLastRequestDBMSError(): - match = re.search("Could not find file\s+'([^']+?)'", kb.lastErrorPage[1]) + threadData = getCurrentThreadData() + match = re.search("Could not find file\s+'([^']+?)'", threadData.lastErrorPage[1]) if match: retVal = match.group(1).rstrip("%s.mdb" % randStr)