some refactoring

This commit is contained in:
Miroslav Stampar 2011-12-28 16:27:17 +00:00
parent 37d78ffe01
commit 29f502fe29
8 changed files with 25 additions and 47 deletions

View File

@ -1773,7 +1773,7 @@ def readCachedFileContent(filename, mode='rb'):
""" """
if filename not in kb.cache.content: if filename not in kb.cache.content:
kb.locks.cacheLock.acquire() kb.locks.cache.acquire()
if filename not in kb.cache.content: if filename not in kb.cache.content:
checkFile(filename) checkFile(filename)
@ -1781,7 +1781,7 @@ def readCachedFileContent(filename, mode='rb'):
content = f.read() content = f.read()
kb.cache.content[filename] = content kb.cache.content[filename] = content
kb.locks.cacheLock.release() kb.locks.cache.release()
return kb.cache.content[filename] return kb.cache.content[filename]
@ -2241,13 +2241,13 @@ def logHTTPTraffic(requestLogMsg, responseLogMsg):
if not conf.trafficFile: if not conf.trafficFile:
return return
kb.locks.logLock.acquire() kb.locks.log.acquire()
dataToTrafficFile("%s%s" % (requestLogMsg, os.linesep)) dataToTrafficFile("%s%s" % (requestLogMsg, os.linesep))
dataToTrafficFile("%s%s" % (responseLogMsg, os.linesep)) dataToTrafficFile("%s%s" % (responseLogMsg, os.linesep))
dataToTrafficFile("%s%s%s%s" % (os.linesep, 76 * '#', os.linesep, os.linesep)) dataToTrafficFile("%s%s%s%s" % (os.linesep, 76 * '#', os.linesep, os.linesep))
kb.locks.logLock.release() kb.locks.log.release()
def getPageTemplate(payload, place): def getPageTemplate(payload, place):
""" """

View File

@ -1437,10 +1437,8 @@ def __setKnowledgeBaseAttributes(flushAll=True):
kb.lastParserStatus = None kb.lastParserStatus = None
kb.locks = AttribDict() kb.locks = AttribDict()
kb.locks.cacheLock = threading.Lock() for _ in ("cache", "count", "index", "io", "limits", "log", "outputs", "value"):
kb.locks.logLock = threading.Lock() kb.locks[_] = threading.Lock()
kb.locks.ioLock = threading.Lock()
kb.locks.countLock = threading.Lock()
kb.matchRatio = None kb.matchRatio = None
kb.multiThreadMode = False kb.multiThreadMode = False

View File

@ -188,6 +188,10 @@ def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardExceptio
kb.threadContinue = True kb.threadContinue = True
kb.threadException = False kb.threadException = False
for lock in kb.locks.values():
if lock.locked_lock():
lock.release()
if conf.get("hashDB", None): if conf.get("hashDB", None):
conf.hashDB.flush(True) conf.hashDB.flush(True)

View File

@ -320,25 +320,21 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
threadData.shared.value = [ None ] * length threadData.shared.value = [ None ] * length
threadData.shared.index = [ firstChar ] # As list for python nested function scoping threadData.shared.index = [ firstChar ] # As list for python nested function scoping
lockNames = ('iolock', 'idxlock', 'valuelock')
for lock in lockNames:
kb.locks[lock] = threading.Lock()
try: try:
def blindThread(): def blindThread():
threadData = getCurrentThreadData() threadData = getCurrentThreadData()
while kb.threadContinue: while kb.threadContinue:
kb.locks.idxlock.acquire() kb.locks.index.acquire()
if threadData.shared.index[0] >= length: if threadData.shared.index[0] >= length:
kb.locks.idxlock.release() kb.locks.index.release()
return return
threadData.shared.index[0] += 1 threadData.shared.index[0] += 1
curidx = threadData.shared.index[0] curidx = threadData.shared.index[0]
kb.locks.idxlock.release() kb.locks.index.release()
if kb.threadContinue: if kb.threadContinue:
charStart = time.time() charStart = time.time()
@ -348,10 +344,10 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
else: else:
break break
kb.locks.valuelock.acquire() kb.locks.value.acquire()
threadData.shared.value[curidx-1] = val threadData.shared.value[curidx-1] = val
currentValue = list(threadData.shared.value) currentValue = list(threadData.shared.value)
kb.locks.valuelock.release() kb.locks.value.release()
if kb.threadContinue: if kb.threadContinue:
if showEta: if showEta:
@ -388,9 +384,7 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
status = ' %d/%d (%d%s)' % (count, length, round(100.0*count/length), '%') status = ' %d/%d (%d%s)' % (count, length, round(100.0*count/length), '%')
output += status if count != length else " "*len(status) output += status if count != length else " "*len(status)
kb.locks.iolock.acquire()
dataToStdout("\r[%s] [INFO] retrieved: %s" % (time.strftime("%X"), filterControlChars(output))) dataToStdout("\r[%s] [INFO] retrieved: %s" % (time.strftime("%X"), filterControlChars(output)))
kb.locks.iolock.release()
if not kb.threadContinue: if not kb.threadContinue:
if int(threading.currentThread().getName()) == numThreads - 1: if int(threading.currentThread().getName()) == numThreads - 1:

View File

@ -81,13 +81,13 @@ def tableExists(tableFile, regex=None):
threadData = getCurrentThreadData() threadData = getCurrentThreadData()
while kb.threadContinue: while kb.threadContinue:
kb.locks.countLock.acquire() kb.locks.count.acquire()
if threadData.shared.count < threadData.shared.limit: if threadData.shared.count < threadData.shared.limit:
table = safeSQLIdentificatorNaming(tables[threadData.shared.count], True) table = safeSQLIdentificatorNaming(tables[threadData.shared.count], True)
threadData.shared.count += 1 threadData.shared.count += 1
kb.locks.countLock.release() kb.locks.count.release()
else: else:
kb.locks.countLock.release() kb.locks.count.release()
break break
if conf.db and METADB_SUFFIX not in conf.db: if conf.db and METADB_SUFFIX not in conf.db:
@ -97,7 +97,7 @@ def tableExists(tableFile, regex=None):
result = inject.checkBooleanExpression("%s" % safeStringFormat(BRUTE_TABLE_EXISTS_TEMPLATE, (randomInt(1), fullTableName))) result = inject.checkBooleanExpression("%s" % safeStringFormat(BRUTE_TABLE_EXISTS_TEMPLATE, (randomInt(1), fullTableName)))
kb.locks.ioLock.acquire() kb.locks.io.acquire()
if result and table.lower() not in threadData.shared.unique: if result and table.lower() not in threadData.shared.unique:
threadData.shared.outputs.append(table) threadData.shared.outputs.append(table)
@ -112,7 +112,7 @@ def tableExists(tableFile, regex=None):
status = '%d/%d items (%d%s)' % (threadData.shared.count, threadData.shared.limit, round(100.0*threadData.shared.count/threadData.shared.limit), '%') status = '%d/%d items (%d%s)' % (threadData.shared.count, threadData.shared.limit, round(100.0*threadData.shared.count/threadData.shared.limit), '%')
dataToStdout("\r[%s] [INFO] tried %s" % (time.strftime("%X"), status), True) dataToStdout("\r[%s] [INFO] tried %s" % (time.strftime("%X"), status), True)
kb.locks.ioLock.release() kb.locks.io.release()
try: try:
runThreads(conf.threads, tableExistsThread, threadChoice=True) runThreads(conf.threads, tableExistsThread, threadChoice=True)
@ -180,18 +180,18 @@ def columnExists(columnFile, regex=None):
threadData = getCurrentThreadData() threadData = getCurrentThreadData()
while kb.threadContinue: while kb.threadContinue:
kb.locks.countLock.acquire() kb.locks.count.acquire()
if threadData.shared.count < threadData.shared.limit: if threadData.shared.count < threadData.shared.limit:
column = safeSQLIdentificatorNaming(columns[threadData.shared.count]) column = safeSQLIdentificatorNaming(columns[threadData.shared.count])
threadData.shared.count += 1 threadData.shared.count += 1
kb.locks.countLock.release() kb.locks.count.release()
else: else:
kb.locks.countLock.release() kb.locks.count.release()
break break
result = inject.checkBooleanExpression(safeStringFormat(BRUTE_COLUMN_EXISTS_TEMPLATE, (column, table))) result = inject.checkBooleanExpression(safeStringFormat(BRUTE_COLUMN_EXISTS_TEMPLATE, (column, table)))
kb.locks.ioLock.acquire() kb.locks.io.acquire()
if result: if result:
threadData.shared.outputs.append(column) threadData.shared.outputs.append(column)
@ -205,7 +205,7 @@ def columnExists(columnFile, regex=None):
status = '%d/%d items (%d%s)' % (threadData.shared.count, threadData.shared.limit, round(100.0*threadData.shared.count/threadData.shared.limit), '%') status = '%d/%d items (%d%s)' % (threadData.shared.count, threadData.shared.limit, round(100.0*threadData.shared.count/threadData.shared.limit), '%')
dataToStdout("\r[%s] [INFO] tried %s" % (time.strftime("%X"), status), True) dataToStdout("\r[%s] [INFO] tried %s" % (time.strftime("%X"), status), True)
kb.locks.ioLock.release() kb.locks.io.release()
try: try:
runThreads(conf.threads, columnExistsThread, threadChoice=True) runThreads(conf.threads, columnExistsThread, threadChoice=True)

View File

@ -173,9 +173,7 @@ def __errorFields(expression, expressionFields, expressionFieldsList, expected=N
return None return None
if output is not None: if output is not None:
kb.locks.ioLock.acquire()
dataToStdout("[%s] [INFO] %s: %s\r\n" % (time.strftime("%X"), "resumed" if threadData.resumed else "retrieved", safecharencode(output))) dataToStdout("[%s] [INFO] %s: %s\r\n" % (time.strftime("%X"), "resumed" if threadData.resumed else "retrieved", safecharencode(output)))
kb.locks.ioLock.release()
if isinstance(num, int): if isinstance(num, int):
expression = origExpr expression = origExpr
@ -347,10 +345,6 @@ def errorUse(expression, expected=None, resumeValue=True, dump=False):
debugMsg += "large number of rows. It might take too long" debugMsg += "large number of rows. It might take too long"
logger.debug(debugMsg) logger.debug(debugMsg)
lockNames = ('limits', 'outputs')
for lock in lockNames:
kb.locks[lock] = threading.Lock()
try: try:
def errorThread(): def errorThread():
threadData = getCurrentThreadData() threadData = getCurrentThreadData()

View File

@ -278,10 +278,6 @@ def unionUse(expression, unpack=True, dump=False):
debugMsg += "large number of rows. It might take too long" debugMsg += "large number of rows. It might take too long"
logger.debug(debugMsg) logger.debug(debugMsg)
lockNames = ('limits', 'value')
for lock in lockNames:
kb.locks[lock] = threading.Lock()
try: try:
def unionThread(): def unionThread():
threadData = getCurrentThreadData() threadData = getCurrentThreadData()
@ -326,9 +322,7 @@ def unionUse(expression, unpack=True, dump=False):
if len(status) > width: if len(status) > width:
status = "%s..." % status[:width - 3] status = "%s..." % status[:width - 3]
kb.locks.ioLock.acquire()
dataToStdout(status, True) dataToStdout(status, True)
kb.locks.ioLock.release()
runThreads(numThreads, unionThread) runThreads(numThreads, unionThread)

View File

@ -39,10 +39,6 @@ class Crawler:
threadData = getCurrentThreadData() threadData = getCurrentThreadData()
threadData.shared.outputs = oset() threadData.shared.outputs = oset()
lockNames = ('limits', 'outputs', 'ioLock')
for lock in lockNames:
kb.locks[lock] = threading.Lock()
def crawlThread(): def crawlThread():
threadData = getCurrentThreadData() threadData = getCurrentThreadData()
@ -100,11 +96,9 @@ class Crawler:
findPageForms(content, current, False, True) findPageForms(content, current, False, True)
if conf.verbose in (1, 2): if conf.verbose in (1, 2):
kb.locks.ioLock.acquire()
threadData.shared.count += 1 threadData.shared.count += 1
status = '%d/%d links visited (%d%s)' % (threadData.shared.count, threadData.shared.length, round(100.0*threadData.shared.count/threadData.shared.length), '%') status = '%d/%d links visited (%d%s)' % (threadData.shared.count, threadData.shared.length, round(100.0*threadData.shared.count/threadData.shared.length), '%')
dataToStdout("\r[%s] [INFO] %s" % (time.strftime("%X"), status), True) dataToStdout("\r[%s] [INFO] %s" % (time.strftime("%X"), status), True)
kb.locks.ioLock.release()
threadData.shared.deeper = set() threadData.shared.deeper = set()
threadData.shared.unprocessed = set([conf.url]) threadData.shared.unprocessed = set([conf.url])