Fixing mess with --common-files --threads>1 (threads in threads - '.shared.' hell)

This commit is contained in:
Miroslav Stampar 2019-07-18 14:59:42 +02:00
parent b62680b4bc
commit db90ff9c3f
5 changed files with 48 additions and 28 deletions

View File

@ -14,7 +14,8 @@ from lib.core.settings import MAX_CACHE_ITEMS
from lib.core.settings import UNICODE_ENCODING
from lib.core.threads import getCurrentThreadData
_lock = threading.Lock()
_cache_lock = threading.Lock()
_method_locks = {}
def cachedmethod(f, cache=LRUDict(capacity=MAX_CACHE_ITEMS)):
"""
@ -38,12 +39,12 @@ def cachedmethod(f, cache=LRUDict(capacity=MAX_CACHE_ITEMS)):
key = int(hashlib.md5("|".join(str(_) for _ in (f, args, kwargs)).encode(UNICODE_ENCODING)).hexdigest(), 16) & 0x7fffffffffffffff
try:
with _lock:
with _cache_lock:
result = cache[key]
except KeyError:
result = f(*args, **kwargs)
with _lock:
with _cache_lock:
cache[key] = result
return result
@ -76,3 +77,16 @@ def stackedmethod(f):
return result
return _
def lockedmethod(f):
@functools.wraps(f)
def _(*args, **kwargs):
if f not in _method_locks:
_method_locks[f] = threading.Lock()
with _method_locks[f]:
result = f(*args, **kwargs)
return result
return _

View File

@ -18,7 +18,7 @@ from lib.core.enums import OS
from thirdparty.six import unichr as _unichr
# sqlmap version (<major>.<minor>.<month>.<monthly commit>)
VERSION = "1.3.7.34"
VERSION = "1.3.7.35"
TYPE = "dev" if VERSION.count('.') > 2 and VERSION.split('.')[-1] != '0' else "stable"
TYPE_COLORS = {"dev": 33, "stable": 90, "pip": 34}
VERSION_STRING = "sqlmap/%s#%s" % ('.'.join(VERSION.split('.')[:-1]) if VERSION.count('.') > 2 and VERSION.split('.')[-1] == '0' else VERSION, TYPE)

View File

@ -39,6 +39,7 @@ from lib.core.data import conf
from lib.core.data import kb
from lib.core.data import logger
from lib.core.data import queries
from lib.core.decorators import lockedmethod
from lib.core.decorators import stackedmethod
from lib.core.dicts import FROM_DUMMY_TABLE
from lib.core.enums import CHARSET_TYPE
@ -351,6 +352,7 @@ def _goUnion(expression, unpack=True, dump=False):
return output
@lockedmethod
@stackedmethod
def getValue(expression, blind=True, union=True, error=True, time=True, fromUser=False, expected=None, batch=False, unpack=True, resumeValue=True, charsetType=None, firstChar=None, lastChar=None, dump=False, suppressOutput=None, expectingNone=False, safeCharEncode=True):
"""

View File

@ -162,7 +162,11 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
length = None
showEta = conf.eta and isinstance(length, int)
numThreads = min(conf.threads or 0, length or 0) or 1
if kb.bruteMode:
numThreads = 1
else:
numThreads = min(conf.threads or 0, length or 0) or 1
if showEta:
progress = ProgressBar(maxValue=length)
@ -174,13 +178,13 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
else:
numThreads = 1
if conf.threads == 1 and not timeBasedCompare and not conf.predictOutput:
if numThreads == 1 and not timeBasedCompare and not conf.predictOutput:
warnMsg = "running in a single-thread mode. Please consider "
warnMsg += "usage of option '--threads' for faster data retrieval"
singleTimeWarnMessage(warnMsg)
if conf.verbose in (1, 2) and not showEta and not conf.api:
if isinstance(length, int) and conf.threads > 1:
if conf.verbose in (1, 2) and not any((showEta, conf.api, kb.bruteMode)):
if isinstance(length, int) and numThreads > 1:
dataToStdout("[%s] [INFO] retrieved: %s" % (time.strftime("%X"), "_" * min(length, conf.progressWidth)))
dataToStdout("\r[%s] [INFO] retrieved: " % time.strftime("%X"))
else:
@ -459,7 +463,7 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
return decodeIntToUnicode(candidates[0])
# Go multi-threading (--threads > 1)
if conf.threads > 1 and isinstance(length, int) and length > 1:
if numThreads > 1 and isinstance(length, int) and length > 1:
threadData.shared.value = [None] * length
threadData.shared.index = [firstChar] # As list for python nested function scoping
threadData.shared.start = firstChar
@ -517,7 +521,7 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
if (endCharIndex - startCharIndex == conf.progressWidth) and (endCharIndex < length - 1):
output = output[:-2] + ".."
if conf.verbose in (1, 2) and not showEta and not conf.api:
if conf.verbose in (1, 2) and not any((showEta, conf.api, kb.bruteMode)):
_ = count - firstChar
output += '_' * (min(length, conf.progressWidth) - len(output))
status = ' %d/%d (%d%%)' % (_, length, int(100.0 * _ / length))
@ -547,7 +551,7 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
finalValue = "".join(value)
infoMsg = "\r[%s] [INFO] retrieved: %s" % (time.strftime("%X"), filterControlChars(finalValue))
if conf.verbose in (1, 2) and not showEta and infoMsg and not conf.api:
if conf.verbose in (1, 2) and infoMsg and not any((showEta, conf.api, kb.bruteMode)):
dataToStdout(infoMsg)
# No multi-threading (--threads = 1)
@ -632,7 +636,7 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
if showEta:
progress.progress(index)
elif conf.verbose in (1, 2) or conf.api:
elif (conf.verbose in (1, 2) and not kb.bruteMode) or conf.api:
dataToStdout(filterControlChars(val))
# some DBMSes (e.g. Firebird, DB2, etc.) have issues with trailing spaces
@ -661,11 +665,11 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
elif partialValue:
hashDBWrite(expression, "%s%s" % (PARTIAL_VALUE_MARKER if not conf.hexConvert else PARTIAL_HEX_VALUE_MARKER, partialValue))
if conf.hexConvert and not abortedFlag and not conf.api:
if conf.hexConvert and not any((abortedFlag, conf.api, kb.bruteMode)):
infoMsg = "\r[%s] [INFO] retrieved: %s %s\n" % (time.strftime("%X"), filterControlChars(finalValue), " " * retrievedLength)
dataToStdout(infoMsg)
else:
if conf.verbose in (1, 2) and not showEta and not conf.api:
if conf.verbose in (1, 2) and not any((showEta, conf.api, kb.bruteMode)):
dataToStdout("\n")
if (conf.verbose in (1, 2) and showEta) or conf.verbose >= 3:

View File

@ -102,7 +102,7 @@ def tableExists(tableFile, regex=None):
threadData = getCurrentThreadData()
threadData.shared.count = 0
threadData.shared.limit = len(tables)
threadData.shared.value = []
threadData.shared.files = []
threadData.shared.unique = set()
def tableExistsThread():
@ -128,7 +128,7 @@ def tableExists(tableFile, regex=None):
kb.locks.io.acquire()
if result and table.lower() not in threadData.shared.unique:
threadData.shared.value.append(table)
threadData.shared.files.append(table)
threadData.shared.unique.add(table.lower())
if conf.verbose in (1, 2) and not conf.api:
@ -152,17 +152,17 @@ def tableExists(tableFile, regex=None):
clearConsoleLine(True)
dataToStdout("\n")
if not threadData.shared.value:
if not threadData.shared.files:
warnMsg = "no table(s) found"
logger.warn(warnMsg)
else:
for item in threadData.shared.value:
for item in threadData.shared.files:
if conf.db not in kb.data.cachedTables:
kb.data.cachedTables[conf.db] = [item]
else:
kb.data.cachedTables[conf.db].append(item)
for _ in ((conf.db, item) for item in threadData.shared.value):
for _ in ((conf.db, item) for item in threadData.shared.files):
if _ not in kb.brute.tables:
kb.brute.tables.append(_)
@ -224,7 +224,7 @@ def columnExists(columnFile, regex=None):
threadData = getCurrentThreadData()
threadData.shared.count = 0
threadData.shared.limit = len(columns)
threadData.shared.value = []
threadData.shared.files = []
def columnExistsThread():
threadData = getCurrentThreadData()
@ -244,7 +244,7 @@ def columnExists(columnFile, regex=None):
kb.locks.io.acquire()
if result:
threadData.shared.value.append(column)
threadData.shared.files.append(column)
if conf.verbose in (1, 2) and not conf.api:
clearConsoleLine(True)
@ -269,13 +269,13 @@ def columnExists(columnFile, regex=None):
clearConsoleLine(True)
dataToStdout("\n")
if not threadData.shared.value:
if not threadData.shared.files:
warnMsg = "no column(s) found"
logger.warn(warnMsg)
else:
columns = {}
for column in threadData.shared.value:
for column in threadData.shared.files:
if Backend.getIdentifiedDbms() in (DBMS.MYSQL,):
result = not inject.checkBooleanExpression("%s" % safeStringFormat("EXISTS(SELECT %s FROM %s WHERE %s REGEXP '[^0-9]')", (column, table, column)))
else:
@ -327,7 +327,7 @@ def fileExists(pathFile):
threadData = getCurrentThreadData()
threadData.shared.count = 0
threadData.shared.limit = len(paths)
threadData.shared.value = []
threadData.shared.files = []
def fileExistsThread():
threadData = getCurrentThreadData()
@ -350,9 +350,9 @@ def fileExists(pathFile):
kb.locks.io.acquire()
if not isNoneValue(result):
threadData.shared.value.append(result)
threadData.shared.files.append(result)
if conf.verbose in (1, 2) and not conf.api:
if not conf.api:
clearConsoleLine(True)
infoMsg = "[%s] [INFO] retrieved: '%s'\n" % (time.strftime("%X"), path)
dataToStdout(infoMsg, True)
@ -379,10 +379,10 @@ def fileExists(pathFile):
clearConsoleLine(True)
dataToStdout("\n")
if not threadData.shared.value:
if not threadData.shared.files:
warnMsg = "no file(s) found"
logger.warn(warnMsg)
else:
retVal = threadData.shared.value
retVal = threadData.shared.files
return retVal