Fixing mess with --common-files --threads>1 (threads in threads - '.shared.' hell)

This commit is contained in:
Miroslav Stampar 2019-07-18 14:59:42 +02:00
parent b62680b4bc
commit db90ff9c3f
5 changed files with 48 additions and 28 deletions

View File

@ -14,7 +14,8 @@ from lib.core.settings import MAX_CACHE_ITEMS
from lib.core.settings import UNICODE_ENCODING from lib.core.settings import UNICODE_ENCODING
from lib.core.threads import getCurrentThreadData from lib.core.threads import getCurrentThreadData
_lock = threading.Lock() _cache_lock = threading.Lock()
_method_locks = {}
def cachedmethod(f, cache=LRUDict(capacity=MAX_CACHE_ITEMS)): def cachedmethod(f, cache=LRUDict(capacity=MAX_CACHE_ITEMS)):
""" """
@ -38,12 +39,12 @@ def cachedmethod(f, cache=LRUDict(capacity=MAX_CACHE_ITEMS)):
key = int(hashlib.md5("|".join(str(_) for _ in (f, args, kwargs)).encode(UNICODE_ENCODING)).hexdigest(), 16) & 0x7fffffffffffffff key = int(hashlib.md5("|".join(str(_) for _ in (f, args, kwargs)).encode(UNICODE_ENCODING)).hexdigest(), 16) & 0x7fffffffffffffff
try: try:
with _lock: with _cache_lock:
result = cache[key] result = cache[key]
except KeyError: except KeyError:
result = f(*args, **kwargs) result = f(*args, **kwargs)
with _lock: with _cache_lock:
cache[key] = result cache[key] = result
return result return result
@ -76,3 +77,16 @@ def stackedmethod(f):
return result return result
return _ return _
def lockedmethod(f):
@functools.wraps(f)
def _(*args, **kwargs):
if f not in _method_locks:
_method_locks[f] = threading.Lock()
with _method_locks[f]:
result = f(*args, **kwargs)
return result
return _

View File

@ -18,7 +18,7 @@ from lib.core.enums import OS
from thirdparty.six import unichr as _unichr from thirdparty.six import unichr as _unichr
# sqlmap version (<major>.<minor>.<month>.<monthly commit>) # sqlmap version (<major>.<minor>.<month>.<monthly commit>)
VERSION = "1.3.7.34" VERSION = "1.3.7.35"
TYPE = "dev" if VERSION.count('.') > 2 and VERSION.split('.')[-1] != '0' else "stable" TYPE = "dev" if VERSION.count('.') > 2 and VERSION.split('.')[-1] != '0' else "stable"
TYPE_COLORS = {"dev": 33, "stable": 90, "pip": 34} TYPE_COLORS = {"dev": 33, "stable": 90, "pip": 34}
VERSION_STRING = "sqlmap/%s#%s" % ('.'.join(VERSION.split('.')[:-1]) if VERSION.count('.') > 2 and VERSION.split('.')[-1] == '0' else VERSION, TYPE) VERSION_STRING = "sqlmap/%s#%s" % ('.'.join(VERSION.split('.')[:-1]) if VERSION.count('.') > 2 and VERSION.split('.')[-1] == '0' else VERSION, TYPE)

View File

@ -39,6 +39,7 @@ from lib.core.data import conf
from lib.core.data import kb from lib.core.data import kb
from lib.core.data import logger from lib.core.data import logger
from lib.core.data import queries from lib.core.data import queries
from lib.core.decorators import lockedmethod
from lib.core.decorators import stackedmethod from lib.core.decorators import stackedmethod
from lib.core.dicts import FROM_DUMMY_TABLE from lib.core.dicts import FROM_DUMMY_TABLE
from lib.core.enums import CHARSET_TYPE from lib.core.enums import CHARSET_TYPE
@ -351,6 +352,7 @@ def _goUnion(expression, unpack=True, dump=False):
return output return output
@lockedmethod
@stackedmethod @stackedmethod
def getValue(expression, blind=True, union=True, error=True, time=True, fromUser=False, expected=None, batch=False, unpack=True, resumeValue=True, charsetType=None, firstChar=None, lastChar=None, dump=False, suppressOutput=None, expectingNone=False, safeCharEncode=True): def getValue(expression, blind=True, union=True, error=True, time=True, fromUser=False, expected=None, batch=False, unpack=True, resumeValue=True, charsetType=None, firstChar=None, lastChar=None, dump=False, suppressOutput=None, expectingNone=False, safeCharEncode=True):
""" """

View File

@ -162,6 +162,10 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
length = None length = None
showEta = conf.eta and isinstance(length, int) showEta = conf.eta and isinstance(length, int)
if kb.bruteMode:
numThreads = 1
else:
numThreads = min(conf.threads or 0, length or 0) or 1 numThreads = min(conf.threads or 0, length or 0) or 1
if showEta: if showEta:
@ -174,13 +178,13 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
else: else:
numThreads = 1 numThreads = 1
if conf.threads == 1 and not timeBasedCompare and not conf.predictOutput: if numThreads == 1 and not timeBasedCompare and not conf.predictOutput:
warnMsg = "running in a single-thread mode. Please consider " warnMsg = "running in a single-thread mode. Please consider "
warnMsg += "usage of option '--threads' for faster data retrieval" warnMsg += "usage of option '--threads' for faster data retrieval"
singleTimeWarnMessage(warnMsg) singleTimeWarnMessage(warnMsg)
if conf.verbose in (1, 2) and not showEta and not conf.api: if conf.verbose in (1, 2) and not any((showEta, conf.api, kb.bruteMode)):
if isinstance(length, int) and conf.threads > 1: if isinstance(length, int) and numThreads > 1:
dataToStdout("[%s] [INFO] retrieved: %s" % (time.strftime("%X"), "_" * min(length, conf.progressWidth))) dataToStdout("[%s] [INFO] retrieved: %s" % (time.strftime("%X"), "_" * min(length, conf.progressWidth)))
dataToStdout("\r[%s] [INFO] retrieved: " % time.strftime("%X")) dataToStdout("\r[%s] [INFO] retrieved: " % time.strftime("%X"))
else: else:
@ -459,7 +463,7 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
return decodeIntToUnicode(candidates[0]) return decodeIntToUnicode(candidates[0])
# Go multi-threading (--threads > 1) # Go multi-threading (--threads > 1)
if conf.threads > 1 and isinstance(length, int) and length > 1: if numThreads > 1 and isinstance(length, int) and length > 1:
threadData.shared.value = [None] * length threadData.shared.value = [None] * length
threadData.shared.index = [firstChar] # As list for python nested function scoping threadData.shared.index = [firstChar] # As list for python nested function scoping
threadData.shared.start = firstChar threadData.shared.start = firstChar
@ -517,7 +521,7 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
if (endCharIndex - startCharIndex == conf.progressWidth) and (endCharIndex < length - 1): if (endCharIndex - startCharIndex == conf.progressWidth) and (endCharIndex < length - 1):
output = output[:-2] + ".." output = output[:-2] + ".."
if conf.verbose in (1, 2) and not showEta and not conf.api: if conf.verbose in (1, 2) and not any((showEta, conf.api, kb.bruteMode)):
_ = count - firstChar _ = count - firstChar
output += '_' * (min(length, conf.progressWidth) - len(output)) output += '_' * (min(length, conf.progressWidth) - len(output))
status = ' %d/%d (%d%%)' % (_, length, int(100.0 * _ / length)) status = ' %d/%d (%d%%)' % (_, length, int(100.0 * _ / length))
@ -547,7 +551,7 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
finalValue = "".join(value) finalValue = "".join(value)
infoMsg = "\r[%s] [INFO] retrieved: %s" % (time.strftime("%X"), filterControlChars(finalValue)) infoMsg = "\r[%s] [INFO] retrieved: %s" % (time.strftime("%X"), filterControlChars(finalValue))
if conf.verbose in (1, 2) and not showEta and infoMsg and not conf.api: if conf.verbose in (1, 2) and infoMsg and not any((showEta, conf.api, kb.bruteMode)):
dataToStdout(infoMsg) dataToStdout(infoMsg)
# No multi-threading (--threads = 1) # No multi-threading (--threads = 1)
@ -632,7 +636,7 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
if showEta: if showEta:
progress.progress(index) progress.progress(index)
elif conf.verbose in (1, 2) or conf.api: elif (conf.verbose in (1, 2) and not kb.bruteMode) or conf.api:
dataToStdout(filterControlChars(val)) dataToStdout(filterControlChars(val))
# some DBMSes (e.g. Firebird, DB2, etc.) have issues with trailing spaces # some DBMSes (e.g. Firebird, DB2, etc.) have issues with trailing spaces
@ -661,11 +665,11 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
elif partialValue: elif partialValue:
hashDBWrite(expression, "%s%s" % (PARTIAL_VALUE_MARKER if not conf.hexConvert else PARTIAL_HEX_VALUE_MARKER, partialValue)) hashDBWrite(expression, "%s%s" % (PARTIAL_VALUE_MARKER if not conf.hexConvert else PARTIAL_HEX_VALUE_MARKER, partialValue))
if conf.hexConvert and not abortedFlag and not conf.api: if conf.hexConvert and not any((abortedFlag, conf.api, kb.bruteMode)):
infoMsg = "\r[%s] [INFO] retrieved: %s %s\n" % (time.strftime("%X"), filterControlChars(finalValue), " " * retrievedLength) infoMsg = "\r[%s] [INFO] retrieved: %s %s\n" % (time.strftime("%X"), filterControlChars(finalValue), " " * retrievedLength)
dataToStdout(infoMsg) dataToStdout(infoMsg)
else: else:
if conf.verbose in (1, 2) and not showEta and not conf.api: if conf.verbose in (1, 2) and not any((showEta, conf.api, kb.bruteMode)):
dataToStdout("\n") dataToStdout("\n")
if (conf.verbose in (1, 2) and showEta) or conf.verbose >= 3: if (conf.verbose in (1, 2) and showEta) or conf.verbose >= 3:

View File

@ -102,7 +102,7 @@ def tableExists(tableFile, regex=None):
threadData = getCurrentThreadData() threadData = getCurrentThreadData()
threadData.shared.count = 0 threadData.shared.count = 0
threadData.shared.limit = len(tables) threadData.shared.limit = len(tables)
threadData.shared.value = [] threadData.shared.files = []
threadData.shared.unique = set() threadData.shared.unique = set()
def tableExistsThread(): def tableExistsThread():
@ -128,7 +128,7 @@ def tableExists(tableFile, regex=None):
kb.locks.io.acquire() kb.locks.io.acquire()
if result and table.lower() not in threadData.shared.unique: if result and table.lower() not in threadData.shared.unique:
threadData.shared.value.append(table) threadData.shared.files.append(table)
threadData.shared.unique.add(table.lower()) threadData.shared.unique.add(table.lower())
if conf.verbose in (1, 2) and not conf.api: if conf.verbose in (1, 2) and not conf.api:
@ -152,17 +152,17 @@ def tableExists(tableFile, regex=None):
clearConsoleLine(True) clearConsoleLine(True)
dataToStdout("\n") dataToStdout("\n")
if not threadData.shared.value: if not threadData.shared.files:
warnMsg = "no table(s) found" warnMsg = "no table(s) found"
logger.warn(warnMsg) logger.warn(warnMsg)
else: else:
for item in threadData.shared.value: for item in threadData.shared.files:
if conf.db not in kb.data.cachedTables: if conf.db not in kb.data.cachedTables:
kb.data.cachedTables[conf.db] = [item] kb.data.cachedTables[conf.db] = [item]
else: else:
kb.data.cachedTables[conf.db].append(item) kb.data.cachedTables[conf.db].append(item)
for _ in ((conf.db, item) for item in threadData.shared.value): for _ in ((conf.db, item) for item in threadData.shared.files):
if _ not in kb.brute.tables: if _ not in kb.brute.tables:
kb.brute.tables.append(_) kb.brute.tables.append(_)
@ -224,7 +224,7 @@ def columnExists(columnFile, regex=None):
threadData = getCurrentThreadData() threadData = getCurrentThreadData()
threadData.shared.count = 0 threadData.shared.count = 0
threadData.shared.limit = len(columns) threadData.shared.limit = len(columns)
threadData.shared.value = [] threadData.shared.files = []
def columnExistsThread(): def columnExistsThread():
threadData = getCurrentThreadData() threadData = getCurrentThreadData()
@ -244,7 +244,7 @@ def columnExists(columnFile, regex=None):
kb.locks.io.acquire() kb.locks.io.acquire()
if result: if result:
threadData.shared.value.append(column) threadData.shared.files.append(column)
if conf.verbose in (1, 2) and not conf.api: if conf.verbose in (1, 2) and not conf.api:
clearConsoleLine(True) clearConsoleLine(True)
@ -269,13 +269,13 @@ def columnExists(columnFile, regex=None):
clearConsoleLine(True) clearConsoleLine(True)
dataToStdout("\n") dataToStdout("\n")
if not threadData.shared.value: if not threadData.shared.files:
warnMsg = "no column(s) found" warnMsg = "no column(s) found"
logger.warn(warnMsg) logger.warn(warnMsg)
else: else:
columns = {} columns = {}
for column in threadData.shared.value: for column in threadData.shared.files:
if Backend.getIdentifiedDbms() in (DBMS.MYSQL,): if Backend.getIdentifiedDbms() in (DBMS.MYSQL,):
result = not inject.checkBooleanExpression("%s" % safeStringFormat("EXISTS(SELECT %s FROM %s WHERE %s REGEXP '[^0-9]')", (column, table, column))) result = not inject.checkBooleanExpression("%s" % safeStringFormat("EXISTS(SELECT %s FROM %s WHERE %s REGEXP '[^0-9]')", (column, table, column)))
else: else:
@ -327,7 +327,7 @@ def fileExists(pathFile):
threadData = getCurrentThreadData() threadData = getCurrentThreadData()
threadData.shared.count = 0 threadData.shared.count = 0
threadData.shared.limit = len(paths) threadData.shared.limit = len(paths)
threadData.shared.value = [] threadData.shared.files = []
def fileExistsThread(): def fileExistsThread():
threadData = getCurrentThreadData() threadData = getCurrentThreadData()
@ -350,9 +350,9 @@ def fileExists(pathFile):
kb.locks.io.acquire() kb.locks.io.acquire()
if not isNoneValue(result): if not isNoneValue(result):
threadData.shared.value.append(result) threadData.shared.files.append(result)
if conf.verbose in (1, 2) and not conf.api: if not conf.api:
clearConsoleLine(True) clearConsoleLine(True)
infoMsg = "[%s] [INFO] retrieved: '%s'\n" % (time.strftime("%X"), path) infoMsg = "[%s] [INFO] retrieved: '%s'\n" % (time.strftime("%X"), path)
dataToStdout(infoMsg, True) dataToStdout(infoMsg, True)
@ -379,10 +379,10 @@ def fileExists(pathFile):
clearConsoleLine(True) clearConsoleLine(True)
dataToStdout("\n") dataToStdout("\n")
if not threadData.shared.value: if not threadData.shared.files:
warnMsg = "no file(s) found" warnMsg = "no file(s) found"
logger.warn(warnMsg) logger.warn(warnMsg)
else: else:
retVal = threadData.shared.value retVal = threadData.shared.files
return retVal return retVal