thread based data added

This commit is contained in:
Miroslav Stampar 2010-12-20 22:45:01 +00:00
parent c9e8aae8a2
commit 8fd3e7ba1f
5 changed files with 40 additions and 31 deletions

View File

@ -1544,29 +1544,37 @@ def commonFinderOnly(initial, sequence):
def getCurrentThreadID():
return threading.currentThread().ident
def pushValue(value):
def getCurrentThreadData():
"""
Push value to the stack
Returns current thread's dependent data
"""
threadId = getCurrentThreadID()
if threadId not in kb.valueStack:
kb.valueStack[threadId] = []
kb.valueStack[threadId].append(value)
threadID = getCurrentThreadID()
if threadID not in kb.threadData:
kb.threadData[threadID] = ThreadData()
return kb.threadData[threadID]
def pushValue(value):
"""
Push value to the stack (thread dependent)
"""
getCurrentThreadData().valueStack.append(value)
def popValue():
"""
Pop value from the stack
Pop value from the stack (thread dependent)
"""
return kb.valueStack[getCurrentThreadID()].pop()
return getCurrentThreadData().valueStack.pop()
def wasLastRequestDBMSError():
"""
Returns True if the last web request resulted in a (recognized) DBMS error page
"""
return kb.lastErrorPage and kb.lastErrorPage[0] == kb.lastRequestUID
threadData = getCurrentThreadData()
return threadData.lastErrorPage and threadData.lastErrorPage[0] == threadData.lastRequestUID
def wasLastRequestDelayed():
"""
@ -1577,6 +1585,7 @@ def wasLastRequestDelayed():
# affected response times should be inside +-7*stdev([normal response times])
# (Math reference: http://www.answers.com/topic/standard-deviation)
deviation = stdev(kb.responseTimes)
threadData = getCurrentThreadData()
if deviation:
if len(kb.responseTimes) < MIN_TIME_RESPONSES:
@ -1584,9 +1593,9 @@ def wasLastRequestDelayed():
warnMsg += "with less than %d response times" % MIN_TIME_RESPONSES
logger.warn(warnMsg)
return (kb.lastQueryDuration >= average(kb.responseTimes) + 7 * deviation)
return (threadData.lastQueryDuration >= average(kb.responseTimes) + 7 * deviation)
else:
return kb.lastQueryDuration - conf.timeSec
return threadData.lastQueryDuration - conf.timeSec
def extractErrorMessage(page):
"""
@ -1665,13 +1674,13 @@ def runningAsAdmin():
return isAdmin
def logHTTPTraffic(requestLogMsg, responseLogMsg):
kb.locks.reqLock.acquire()
kb.locks.logLock.acquire()
dataToTrafficFile("%s\n" % requestLogMsg)
dataToTrafficFile("%s\n" % responseLogMsg)
dataToTrafficFile("\n%s\n\n" % (76 * '#'))
kb.locks.reqLock.release()
kb.locks.logLock.release()
def getPageTemplate(payload, place):
pass

View File

@ -37,6 +37,7 @@ from lib.core.common import readCachedFileContent
from lib.core.common import readInput
from lib.core.common import runningAsAdmin
from lib.core.common import sanitizeStr
from lib.core.common import ThreadData
from lib.core.common import UnicodeRawConfigParser
from lib.core.data import conf
from lib.core.data import kb
@ -1164,13 +1165,10 @@ def __setKnowledgeBaseAttributes(flushAll=True):
kb.htmlFp = []
kb.injection = injectionDict()
kb.injections = []
kb.lastErrorPage = None
kb.lastQueryDuration = 0
kb.lastRequestUID = 0
kb.locks = advancedDict()
kb.locks.cacheLock = threading.Lock()
kb.locks.reqLock = threading.Lock()
kb.locks.logLock = threading.Lock()
kb.locks.seqLock = None
kb.matchRatio = None
@ -1198,11 +1196,11 @@ def __setKnowledgeBaseAttributes(flushAll=True):
kb.testMode = False
kb.testQueryCount = 0
kb.threadContinue = True
kb.threadData = {}
kb.unionComment = ""
kb.unionCount = None
kb.unionPosition = None
kb.unionNegative = False
kb.valueStack = {}
if flushAll:
kb.keywords = set(getFileItems(paths.SQL_KEYWORDS))

View File

@ -12,6 +12,7 @@ import re
from xml.sax.handler import ContentHandler
from lib.core.common import checkFile
from lib.core.common import getCurrentThreadData
from lib.core.common import parseXmlFile
from lib.core.common import sanitizeStr
from lib.core.data import kb
@ -42,7 +43,8 @@ class htmlHandler(ContentHandler):
if self.__match:
self.dbms = self.__dbms
self.__match = None
kb.lastErrorPage = (kb.lastRequestUID, self.__page)
threadData = getCurrentThreadData()
threadData.lastErrorPage = (threadData.lastRequestUID, self.__page)
def htmlParser(page):
"""

View File

@ -20,6 +20,7 @@ from lib.core.agent import agent
from lib.core.common import average
from lib.core.common import calculateDeltaSeconds
from lib.core.common import extractErrorMessage
from lib.core.common import getCurrentThreadData
from lib.core.common import getFilteredPageContent
from lib.core.common import getUnicode
from lib.core.common import logHTTPTraffic
@ -69,12 +70,8 @@ class Connect:
delay = 0.00001 * (conf.cpuThrottle ** 2)
time.sleep(delay)
kb.locks.reqLock.acquire()
kb.lastRequestUID += 1
requestID = kb.lastRequestUID
kb.locks.reqLock.release()
threadData = getCurrentThreadData()
threadData.lastRequestUID += 1
url = kwargs.get('url', conf.url).replace(" ", "%20")
get = kwargs.get('get', None)
@ -92,7 +89,7 @@ class Connect:
page = ""
cookieStr = ""
requestMsg = "HTTP request [#%d]:\n%s " % (requestID, conf.method)
requestMsg = "HTTP request [#%d]:\n%s " % (threadData.lastRequestUID, conf.method)
requestMsg += "%s" % urlparse.urlsplit(url)[2] or "/"
responseMsg = "HTTP response "
requestHeaders = ""
@ -246,7 +243,7 @@ class Connect:
except:
pass
responseMsg = "\n%s[#%d] (%d %s):\n" % (responseMsg, requestID, code, status)
responseMsg = "\n%s[#%d] (%d %s):\n" % (responseMsg, threadData.lastRequestUID, code, status)
if responseHeaders:
logHeaders = "\n".join(["%s: %s" % (key.capitalize() if isinstance(key, basestring) else key, value) for (key, value) in responseHeaders.items()])
@ -310,7 +307,7 @@ class Connect:
page = getUnicode(page)
parseResponse(page, responseHeaders)
responseMsg += "[#%d] (%d %s):\n" % (requestID, code, status)
responseMsg += "[#%d] (%d %s):\n" % (threadData.lastRequestUID, code, status)
logHeaders = "\n".join(["%s: %s" % (key.capitalize() if isinstance(key, basestring) else key, value) for (key, value) in responseHeaders.items()])
logHTTPTraffic(requestMsg, "%s%s\n\n%s" % (responseMsg, logHeaders, page))
@ -355,6 +352,7 @@ class Connect:
place = kb.injection.place
payload = agent.extractPayload(value)
threadData = getCurrentThreadData()
if payload:
if kb.tamperFunctions:
@ -426,7 +424,7 @@ class Connect:
if not pageLength:
page, headers = Connect.getPage(url=uri, get=get, post=post, cookie=cookie, ua=ua, silent=silent, method=method, auxHeaders=auxHeaders, response=response, raise404=raise404, ignoreTimeout=timeBasedCompare)
kb.lastQueryDuration = calculateDeltaSeconds(start)
threadData.lastQueryDuration = calculateDeltaSeconds(start)
if conf.textOnly:
page = getFilteredPageContent(page)
@ -439,7 +437,7 @@ class Connect:
if timeBasedCompare:
return wasLastRequestDelayed()
elif noteResponseTime:
kb.responseTimes.append(kb.lastQueryDuration)
kb.responseTimes.append(threadData.lastQueryDuration)
if content or response:
return page, headers

View File

@ -12,6 +12,7 @@ import re
from lib.core.agent import agent
from lib.core.common import formatDBMSfp
from lib.core.common import formatFingerprint
from lib.core.common import getCurrentThreadData
from lib.core.common import getHtmlErrorFp
from lib.core.common import randomInt
from lib.core.common import randomStr
@ -93,7 +94,8 @@ class Fingerprint(GenericFingerprint):
_ = inject.checkBooleanExpression("EXISTS(SELECT * FROM %s.%s WHERE %d=%d)" % (randStr, randStr, randInt, randInt))
if wasLastRequestDBMSError():
match = re.search("Could not find file\s+'([^']+?)'", kb.lastErrorPage[1])
threadData = getCurrentThreadData()
match = re.search("Could not find file\s+'([^']+?)'", threadData.lastErrorPage[1])
if match:
retVal = match.group(1).rstrip("%s.mdb" % randStr)