From 29f502fe29e8fc86b2ce64bfdd733302557eacfb Mon Sep 17 00:00:00 2001 From: Miroslav Stampar Date: Wed, 28 Dec 2011 16:27:17 +0000 Subject: [PATCH] some refactoring --- lib/core/common.py | 8 ++++---- lib/core/option.py | 6 ++---- lib/core/threads.py | 4 ++++ lib/techniques/blind/inference.py | 16 +++++----------- lib/techniques/brute/use.py | 20 ++++++++++---------- lib/techniques/error/use.py | 6 ------ lib/techniques/union/use.py | 6 ------ lib/utils/crawler.py | 6 ------ 8 files changed, 25 insertions(+), 47 deletions(-) diff --git a/lib/core/common.py b/lib/core/common.py index 56a391f23..9ded587a3 100644 --- a/lib/core/common.py +++ b/lib/core/common.py @@ -1773,7 +1773,7 @@ def readCachedFileContent(filename, mode='rb'): """ if filename not in kb.cache.content: - kb.locks.cacheLock.acquire() + kb.locks.cache.acquire() if filename not in kb.cache.content: checkFile(filename) @@ -1781,7 +1781,7 @@ def readCachedFileContent(filename, mode='rb'): content = f.read() kb.cache.content[filename] = content - kb.locks.cacheLock.release() + kb.locks.cache.release() return kb.cache.content[filename] @@ -2241,13 +2241,13 @@ def logHTTPTraffic(requestLogMsg, responseLogMsg): if not conf.trafficFile: return - kb.locks.logLock.acquire() + kb.locks.log.acquire() dataToTrafficFile("%s%s" % (requestLogMsg, os.linesep)) dataToTrafficFile("%s%s" % (responseLogMsg, os.linesep)) dataToTrafficFile("%s%s%s%s" % (os.linesep, 76 * '#', os.linesep, os.linesep)) - kb.locks.logLock.release() + kb.locks.log.release() def getPageTemplate(payload, place): """ diff --git a/lib/core/option.py b/lib/core/option.py index a2692f949..810b390c4 100644 --- a/lib/core/option.py +++ b/lib/core/option.py @@ -1437,10 +1437,8 @@ def __setKnowledgeBaseAttributes(flushAll=True): kb.lastParserStatus = None kb.locks = AttribDict() - kb.locks.cacheLock = threading.Lock() - kb.locks.logLock = threading.Lock() - kb.locks.ioLock = threading.Lock() - kb.locks.countLock = threading.Lock() + for _ in ("cache", "count", "index", "io", "limits", "log", "outputs", "value"): + kb.locks[_] = threading.Lock() kb.matchRatio = None kb.multiThreadMode = False diff --git a/lib/core/threads.py b/lib/core/threads.py index a6e48f4c6..d4c707bf3 100644 --- a/lib/core/threads.py +++ b/lib/core/threads.py @@ -188,6 +188,10 @@ def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardExceptio kb.threadContinue = True kb.threadException = False + for lock in kb.locks.values(): + if lock.locked_lock(): + lock.release() + if conf.get("hashDB", None): conf.hashDB.flush(True) diff --git a/lib/techniques/blind/inference.py b/lib/techniques/blind/inference.py index 498945160..da2969e34 100644 --- a/lib/techniques/blind/inference.py +++ b/lib/techniques/blind/inference.py @@ -320,25 +320,21 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None threadData.shared.value = [ None ] * length threadData.shared.index = [ firstChar ] # As list for python nested function scoping - lockNames = ('iolock', 'idxlock', 'valuelock') - for lock in lockNames: - kb.locks[lock] = threading.Lock() - try: def blindThread(): threadData = getCurrentThreadData() while kb.threadContinue: - kb.locks.idxlock.acquire() + kb.locks.index.acquire() if threadData.shared.index[0] >= length: - kb.locks.idxlock.release() + kb.locks.index.release() return threadData.shared.index[0] += 1 curidx = threadData.shared.index[0] - kb.locks.idxlock.release() + kb.locks.index.release() if kb.threadContinue: charStart = time.time() @@ -348,10 +344,10 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None else: break - kb.locks.valuelock.acquire() + kb.locks.value.acquire() threadData.shared.value[curidx-1] = val currentValue = list(threadData.shared.value) - kb.locks.valuelock.release() + kb.locks.value.release() if kb.threadContinue: if showEta: @@ -388,9 +384,7 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None status = ' %d/%d (%d%s)' % (count, length, round(100.0*count/length), '%') output += status if count != length else " "*len(status) - kb.locks.iolock.acquire() dataToStdout("\r[%s] [INFO] retrieved: %s" % (time.strftime("%X"), filterControlChars(output))) - kb.locks.iolock.release() if not kb.threadContinue: if int(threading.currentThread().getName()) == numThreads - 1: diff --git a/lib/techniques/brute/use.py b/lib/techniques/brute/use.py index 9944e4992..3b653e059 100644 --- a/lib/techniques/brute/use.py +++ b/lib/techniques/brute/use.py @@ -81,13 +81,13 @@ def tableExists(tableFile, regex=None): threadData = getCurrentThreadData() while kb.threadContinue: - kb.locks.countLock.acquire() + kb.locks.count.acquire() if threadData.shared.count < threadData.shared.limit: table = safeSQLIdentificatorNaming(tables[threadData.shared.count], True) threadData.shared.count += 1 - kb.locks.countLock.release() + kb.locks.count.release() else: - kb.locks.countLock.release() + kb.locks.count.release() break if conf.db and METADB_SUFFIX not in conf.db: @@ -97,7 +97,7 @@ def tableExists(tableFile, regex=None): result = inject.checkBooleanExpression("%s" % safeStringFormat(BRUTE_TABLE_EXISTS_TEMPLATE, (randomInt(1), fullTableName))) - kb.locks.ioLock.acquire() + kb.locks.io.acquire() if result and table.lower() not in threadData.shared.unique: threadData.shared.outputs.append(table) @@ -112,7 +112,7 @@ def tableExists(tableFile, regex=None): status = '%d/%d items (%d%s)' % (threadData.shared.count, threadData.shared.limit, round(100.0*threadData.shared.count/threadData.shared.limit), '%') dataToStdout("\r[%s] [INFO] tried %s" % (time.strftime("%X"), status), True) - kb.locks.ioLock.release() + kb.locks.io.release() try: runThreads(conf.threads, tableExistsThread, threadChoice=True) @@ -180,18 +180,18 @@ def columnExists(columnFile, regex=None): threadData = getCurrentThreadData() while kb.threadContinue: - kb.locks.countLock.acquire() + kb.locks.count.acquire() if threadData.shared.count < threadData.shared.limit: column = safeSQLIdentificatorNaming(columns[threadData.shared.count]) threadData.shared.count += 1 - kb.locks.countLock.release() + kb.locks.count.release() else: - kb.locks.countLock.release() + kb.locks.count.release() break result = inject.checkBooleanExpression(safeStringFormat(BRUTE_COLUMN_EXISTS_TEMPLATE, (column, table))) - kb.locks.ioLock.acquire() + kb.locks.io.acquire() if result: threadData.shared.outputs.append(column) @@ -205,7 +205,7 @@ def columnExists(columnFile, regex=None): status = '%d/%d items (%d%s)' % (threadData.shared.count, threadData.shared.limit, round(100.0*threadData.shared.count/threadData.shared.limit), '%') dataToStdout("\r[%s] [INFO] tried %s" % (time.strftime("%X"), status), True) - kb.locks.ioLock.release() + kb.locks.io.release() try: runThreads(conf.threads, columnExistsThread, threadChoice=True) diff --git a/lib/techniques/error/use.py b/lib/techniques/error/use.py index a52566bf0..13291a148 100644 --- a/lib/techniques/error/use.py +++ b/lib/techniques/error/use.py @@ -173,9 +173,7 @@ def __errorFields(expression, expressionFields, expressionFieldsList, expected=N return None if output is not None: - kb.locks.ioLock.acquire() dataToStdout("[%s] [INFO] %s: %s\r\n" % (time.strftime("%X"), "resumed" if threadData.resumed else "retrieved", safecharencode(output))) - kb.locks.ioLock.release() if isinstance(num, int): expression = origExpr @@ -347,10 +345,6 @@ def errorUse(expression, expected=None, resumeValue=True, dump=False): debugMsg += "large number of rows. It might take too long" logger.debug(debugMsg) - lockNames = ('limits', 'outputs') - for lock in lockNames: - kb.locks[lock] = threading.Lock() - try: def errorThread(): threadData = getCurrentThreadData() diff --git a/lib/techniques/union/use.py b/lib/techniques/union/use.py index 411819ded..8178c6422 100644 --- a/lib/techniques/union/use.py +++ b/lib/techniques/union/use.py @@ -278,10 +278,6 @@ def unionUse(expression, unpack=True, dump=False): debugMsg += "large number of rows. It might take too long" logger.debug(debugMsg) - lockNames = ('limits', 'value') - for lock in lockNames: - kb.locks[lock] = threading.Lock() - try: def unionThread(): threadData = getCurrentThreadData() @@ -326,9 +322,7 @@ def unionUse(expression, unpack=True, dump=False): if len(status) > width: status = "%s..." % status[:width - 3] - kb.locks.ioLock.acquire() dataToStdout(status, True) - kb.locks.ioLock.release() runThreads(numThreads, unionThread) diff --git a/lib/utils/crawler.py b/lib/utils/crawler.py index e25165238..c08818ab8 100644 --- a/lib/utils/crawler.py +++ b/lib/utils/crawler.py @@ -39,10 +39,6 @@ class Crawler: threadData = getCurrentThreadData() threadData.shared.outputs = oset() - lockNames = ('limits', 'outputs', 'ioLock') - for lock in lockNames: - kb.locks[lock] = threading.Lock() - def crawlThread(): threadData = getCurrentThreadData() @@ -100,11 +96,9 @@ class Crawler: findPageForms(content, current, False, True) if conf.verbose in (1, 2): - kb.locks.ioLock.acquire() threadData.shared.count += 1 status = '%d/%d links visited (%d%s)' % (threadData.shared.count, threadData.shared.length, round(100.0*threadData.shared.count/threadData.shared.length), '%') dataToStdout("\r[%s] [INFO] %s" % (time.strftime("%X"), status), True) - kb.locks.ioLock.release() threadData.shared.deeper = set() threadData.shared.unprocessed = set([conf.url])