Minor refactoring (drei stuff)

This commit is contained in:
Miroslav Stampar 2019-03-29 02:28:16 +01:00
parent 89d13aaee4
commit dbd93e2670
14 changed files with 60 additions and 35 deletions

View File

@ -11,7 +11,6 @@ from __future__ import print_function
import os
import sys
import struct
from optparse import OptionError
from optparse import OptionParser

View File

@ -21,6 +21,7 @@ from lib.core.agent import agent
from lib.core.common import Backend
from lib.core.common import extractRegexResult
from lib.core.common import extractTextTagContent
from lib.core.common import filterNone
from lib.core.common import findDynamicContent
from lib.core.common import Format
from lib.core.common import getFilteredPageContent
@ -581,7 +582,7 @@ def checkSqlInjection(place, parameter, value):
else:
errorSet = set()
candidates = filter(None, (_.strip() if _.strip() in trueRawResponse and _.strip() not in falseRawResponse else None for _ in (trueSet - falseSet - errorSet)))
candidates = filterNone(_.strip() if _.strip() in trueRawResponse and _.strip() not in falseRawResponse else None for _ in (trueSet - falseSet - errorSet))
if candidates:
candidates = sorted(candidates, key=lambda _: len(_))
@ -595,7 +596,7 @@ def checkSqlInjection(place, parameter, value):
logger.info(infoMsg)
if not any((conf.string, conf.notString)):
candidates = filter(None, (_.strip() if _.strip() in falseRawResponse and _.strip() not in trueRawResponse else None for _ in (falseSet - trueSet)))
candidates = filterNone(_.strip() if _.strip() in falseRawResponse and _.strip() not in trueRawResponse else None for _ in (falseSet - trueSet))
if candidates:
candidates = sorted(candidates, key=lambda _: len(_))

View File

@ -9,6 +9,7 @@ import re
from lib.core.common import Backend
from lib.core.common import extractRegexResult
from lib.core.common import filterNone
from lib.core.common import getSQLSnippet
from lib.core.common import getUnicode
from lib.core.common import isDBMSVersionAtLeast
@ -106,7 +107,7 @@ class Agent(object):
if place == PLACE.URI:
origValue = origValue.split(kb.customInjectionMark)[0]
else:
origValue = filter(None, (re.search(_, origValue.split(BOUNDED_INJECTION_MARKER)[0]) for _ in (r"\w+\Z", r"[^\"'><]+\Z", r"[^ ]+\Z")))[0].group(0)
origValue = filterNone(re.search(_, origValue.split(BOUNDED_INJECTION_MARKER)[0]) for _ in (r"\w+\Z", r"[^\"'><]+\Z", r"[^ ]+\Z"))[0].group(0)
origValue = origValue[origValue.rfind('/') + 1:]
for char in ('?', '=', ':', ',', '&'):
if char in origValue:

View File

@ -7,6 +7,7 @@ See the file 'LICENSE' for copying permission
import binascii
import codecs
import collections
import contextlib
import copy
import distutils
@ -228,7 +229,7 @@ class Format(object):
if versions is None and Backend.getVersionList():
versions = Backend.getVersionList()
return Backend.getDbms() if versions is None else "%s %s" % (Backend.getDbms(), " and ".join(filter(None, versions)))
return Backend.getDbms() if versions is None else "%s %s" % (Backend.getDbms(), " and ".join(filterNone(versions)))
@staticmethod
def getErrorParsedDBMSes():
@ -501,7 +502,7 @@ class Backend:
@staticmethod
def getVersion():
versions = filter(None, flattenValue(kb.dbmsVersion))
versions = filterNone(flattenValue(kb.dbmsVersion))
if not isNoneValue(versions):
return versions[0]
else:
@ -509,7 +510,7 @@ class Backend:
@staticmethod
def getVersionList():
versions = filter(None, flattenValue(kb.dbmsVersion))
versions = filterNone(flattenValue(kb.dbmsVersion))
if not isNoneValue(versions):
return versions
else:
@ -787,7 +788,7 @@ def getManualDirectories():
else:
targets.add('.'.join(_[:-1]))
targets = filter(None, targets)
targets = filterNone(targets)
for prefix in BRUTE_DOC_ROOT_PREFIXES.get(Backend.getOs(), DEFAULT_DOC_ROOTS[OS.LINUX]):
if BRUTE_DOC_ROOT_TARGET_MARK in prefix and re.match(IP_ADDRESS_REGEX, conf.hostname):
@ -1473,7 +1474,7 @@ def parseTargetUrl():
errMsg += "in the hostname part"
raise SqlmapGenericException(errMsg)
hostnamePort = urlSplit.netloc.split(":") if not re.search(r"\[.+\]", urlSplit.netloc) else filter(None, (re.search(r"\[.+\]", urlSplit.netloc).group(0), re.search(r"\](:(?P<port>\d+))?", urlSplit.netloc).group("port")))
hostnamePort = urlSplit.netloc.split(":") if not re.search(r"\[.+\]", urlSplit.netloc) else filterNone((re.search(r"\[.+\]", urlSplit.netloc).group(0), re.search(r"\](:(?P<port>\d+))?", urlSplit.netloc).group("port")))
conf.scheme = (urlSplit.scheme.strip().lower() or "http") if not conf.forceSSL else "https"
conf.path = urlSplit.path.strip()
@ -2389,13 +2390,13 @@ def getUnicode(value, encoding=None, noneToNull=False):
return value
elif isinstance(value, six.binary_type):
# Heuristics (if encoding not explicitly specified)
candidates = filter(None, (encoding, kb.get("pageEncoding") if kb.get("originalPage") else None, conf.get("encoding"), UNICODE_ENCODING, sys.getfilesystemencoding()))
candidates = filterNone((encoding, kb.get("pageEncoding") if kb.get("originalPage") else None, conf.get("encoding"), UNICODE_ENCODING, sys.getfilesystemencoding()))
if all(_ in value for _ in ('<', '>')):
pass
elif any(_ in value for _ in (":\\", '/', '.')) and '\n' not in value:
candidates = filter(None, (encoding, sys.getfilesystemencoding(), kb.get("pageEncoding") if kb.get("originalPage") else None, UNICODE_ENCODING, conf.get("encoding")))
candidates = filterNone((encoding, sys.getfilesystemencoding(), kb.get("pageEncoding") if kb.get("originalPage") else None, UNICODE_ENCODING, conf.get("encoding")))
elif conf.get("encoding") and '\n' not in value:
candidates = filter(None, (encoding, conf.get("encoding"), kb.get("pageEncoding") if kb.get("originalPage") else None, sys.getfilesystemencoding(), UNICODE_ENCODING))
candidates = filterNone((encoding, conf.get("encoding"), kb.get("pageEncoding") if kb.get("originalPage") else None, sys.getfilesystemencoding(), UNICODE_ENCODING))
for candidate in candidates:
try:
@ -2837,7 +2838,7 @@ def extractTextTagContent(page):
except MemoryError:
page = page.replace(REFLECTED_VALUE_MARKER, "")
return filter(None, (_.group("result").strip() for _ in re.finditer(TEXT_TAG_REGEX, page)))
return filterNone(_.group("result").strip() for _ in re.finditer(TEXT_TAG_REGEX, page))
def trimAlphaNum(value):
"""
@ -2996,6 +2997,21 @@ def filterControlChars(value, replacement=' '):
return filterStringValue(value, PRINTABLE_CHAR_REGEX, replacement)
def filterNone(values):
"""
Emulates filterNone([...]) functionality
>>> filterNone([1, 2, "", None, 3])
[1, 2, 3]
"""
retVal = values
if isinstance(values, collections.Iterable):
retVal = [_ for _ in values if _]
return retVal
def isDBMSVersionAtLeast(version):
"""
Checks if the recognized DBMS version is at least the version
@ -3537,7 +3553,7 @@ def maskSensitiveData(msg):
retVal = getUnicode(msg)
for item in filter(None, (conf.get(_) for _ in SENSITIVE_OPTIONS)):
for item in filterNone(conf.get(_) for _ in SENSITIVE_OPTIONS):
regex = SENSITIVE_DATA_REGEX % re.sub(r"(\W)", r"\\\1", getUnicode(item))
while extractRegexResult(regex, retVal):
value = extractRegexResult(regex, retVal)
@ -3640,14 +3656,14 @@ def removeReflectiveValues(content, payload, suppressWarning=False):
regex = _(filterStringValue(payload, r"[A-Za-z0-9]", REFLECTED_REPLACEMENT_REGEX.encode("string_escape")))
if regex != payload:
if all(part.lower() in content.lower() for part in filter(None, regex.split(REFLECTED_REPLACEMENT_REGEX))[1:]): # fast optimization check
if all(part.lower() in content.lower() for part in filterNone(regex.split(REFLECTED_REPLACEMENT_REGEX))[1:]): # fast optimization check
parts = regex.split(REFLECTED_REPLACEMENT_REGEX)
retVal = content.replace(payload, REFLECTED_VALUE_MARKER) # dummy approach
if len(parts) > REFLECTED_MAX_REGEX_PARTS: # preventing CPU hogs
regex = _("%s%s%s" % (REFLECTED_REPLACEMENT_REGEX.join(parts[:REFLECTED_MAX_REGEX_PARTS // 2]), REFLECTED_REPLACEMENT_REGEX, REFLECTED_REPLACEMENT_REGEX.join(parts[-REFLECTED_MAX_REGEX_PARTS // 2:])))
parts = filter(None, regex.split(REFLECTED_REPLACEMENT_REGEX))
parts = filterNone(regex.split(REFLECTED_REPLACEMENT_REGEX))
if regex.startswith(REFLECTED_REPLACEMENT_REGEX):
regex = r"%s%s" % (REFLECTED_BORDER_REGEX, regex[len(REFLECTED_REPLACEMENT_REGEX):])
@ -4482,7 +4498,7 @@ def resetCookieJar(cookieJar):
logger.info(infoMsg)
content = readCachedFileContent(conf.loadCookies)
lines = filter(None, (line.strip() for line in content.split("\n") if not line.startswith('#')))
lines = filterNone(line.strip() for line in content.split("\n") if not line.startswith('#'))
handle, filename = tempfile.mkstemp(prefix=MKSTEMP_PREFIX.COOKIE_JAR)
os.close(handle)

View File

@ -33,6 +33,7 @@ from lib.core.common import decodeStringEscape
from lib.core.common import getPublicTypeMembers
from lib.core.common import getSafeExString
from lib.core.common import getUnicode
from lib.core.common import filterNone
from lib.core.common import findLocalPort
from lib.core.common import findPageForms
from lib.core.common import getConsoleWidth
@ -784,7 +785,7 @@ def _setTamperingFunctions():
if name == "tamper" and inspect.getargspec(function).args and inspect.getargspec(function).keywords == "kwargs":
found = True
kb.tamperFunctions.append(function)
function.func_name = module.__name__
function.__name__ = module.__name__
if check_priority and priority > last_priority:
message = "it appears that you might have mixed "
@ -880,7 +881,7 @@ def _setPreprocessFunctions():
found = True
kb.preprocessFunctions.append(function)
function.func_name = module.__name__
function.__name__ = module.__name__
break
@ -1113,7 +1114,7 @@ def _setHTTPHandlers():
debugMsg = "creating HTTP requests opener object"
logger.debug(debugMsg)
handlers = filter(None, [multipartPostHandler, proxyHandler if proxyHandler.proxies else None, authHandler, redirectHandler, rangeHandler, chunkedHandler if conf.chunked else None, httpsHandler])
handlers = filterNone([multipartPostHandler, proxyHandler if proxyHandler.proxies else None, authHandler, redirectHandler, rangeHandler, chunkedHandler if conf.chunked else None, httpsHandler])
if not conf.dropSetCookie:
if not conf.loadCookies:

View File

@ -17,7 +17,7 @@ from lib.core.enums import DBMS_DIRECTORY_NAME
from lib.core.enums import OS
# sqlmap version (<major>.<minor>.<month>.<monthly commit>)
VERSION = "1.3.3.77"
VERSION = "1.3.3.78"
TYPE = "dev" if VERSION.count('.') > 2 and VERSION.split('.')[-1] != '0' else "stable"
TYPE_COLORS = {"dev": 33, "stable": 90, "pip": 34}
VERSION_STRING = "sqlmap/%s#%s" % ('.'.join(VERSION.split('.')[:-1]) if VERSION.count('.') > 2 and VERSION.split('.')[-1] == '0' else VERSION, TYPE)

View File

@ -16,6 +16,7 @@ import zlib
from lib.core.common import Backend
from lib.core.common import extractErrorMessage
from lib.core.common import extractRegexResult
from lib.core.common import filterNone
from lib.core.common import getPublicTypeMembers
from lib.core.common import getSafeExString
from lib.core.common import getUnicode
@ -100,7 +101,7 @@ def forgeHeaders(items=None, base=None):
if ("%s=" % getUnicode(cookie.name)) in getUnicode(headers[HTTP_HEADER.COOKIE]):
if conf.loadCookies:
conf.httpHeaders = filter(None, ((item if item[0] != HTTP_HEADER.COOKIE else None) for item in conf.httpHeaders))
conf.httpHeaders = filterNone((item if item[0] != HTTP_HEADER.COOKIE else None) for item in conf.httpHeaders)
elif kb.mergeCookies is None:
message = "you provided a HTTP %s header value. " % HTTP_HEADER.COOKIE
message += "The target URL provided its own cookies within "

View File

@ -32,6 +32,7 @@ from lib.core.common import dataToStdout
from lib.core.common import escapeJsonValue
from lib.core.common import evaluateCode
from lib.core.common import extractRegexResult
from lib.core.common import filterNone
from lib.core.common import findMultipartPostBoundary
from lib.core.common import getCurrentThreadData
from lib.core.common import getHeader
@ -600,7 +601,7 @@ class Connect(object):
except:
pass
finally:
page = page if isinstance(page, unicode) else getUnicode(page)
page = getUnicode(page)
code = ex.code
status = getSafeExString(ex)
@ -758,7 +759,7 @@ class Connect(object):
page, responseHeaders, code = function(page, responseHeaders, code)
except Exception as ex:
errMsg = "error occurred while running preprocess "
errMsg += "function '%s' ('%s')" % (function.func_name, getSafeExString(ex))
errMsg += "function '%s' ('%s')" % (function.__name__, getSafeExString(ex))
raise SqlmapGenericException(errMsg)
threadData.lastPage = page
@ -857,11 +858,11 @@ class Connect(object):
payload = function(payload=payload, headers=auxHeaders, delimiter=delimiter, hints=hints)
except Exception as ex:
errMsg = "error occurred while running tamper "
errMsg += "function '%s' ('%s')" % (function.func_name, getSafeExString(ex))
errMsg += "function '%s' ('%s')" % (function.__name__, getSafeExString(ex))
raise SqlmapGenericException(errMsg)
if not isinstance(payload, six.string_types):
errMsg = "tamper function '%s' returns " % function.func_name
errMsg = "tamper function '%s' returns " % function.__name__
errMsg += "invalid payload type ('%s')" % type(payload)
raise SqlmapValueException(errMsg)
@ -1095,7 +1096,7 @@ class Connect(object):
else:
query = None
for item in filter(None, (get, post if not kb.postHint else None, query)):
for item in filterNone((get, post if not kb.postHint else None, query)):
for part in item.split(delimiter):
if '=' in part:
name, value = part.split('=', 1)

View File

@ -9,6 +9,7 @@ import distutils.version
import re
import socket
from lib.core.common import filterNone
from lib.core.common import getSafeExString
from lib.core.data import conf
from lib.core.data import kb
@ -25,7 +26,7 @@ try:
except ImportError:
pass
_protocols = filter(None, (getattr(ssl, _, None) for _ in ("PROTOCOL_TLSv1_2", "PROTOCOL_TLSv1_1", "PROTOCOL_TLSv1", "PROTOCOL_SSLv3", "PROTOCOL_SSLv23", "PROTOCOL_SSLv2")))
_protocols = filterNone(getattr(ssl, _, None) for _ in ("PROTOCOL_TLSv1_2", "PROTOCOL_TLSv1_1", "PROTOCOL_TLSv1", "PROTOCOL_SSLv3", "PROTOCOL_SSLv23", "PROTOCOL_SSLv2"))
class HTTPSConnection(_http_client.HTTPSConnection):
"""

View File

@ -17,6 +17,7 @@ from lib.core.common import calculateDeltaSeconds
from lib.core.common import cleanQuery
from lib.core.common import expandAsteriskForColumns
from lib.core.common import extractExpectedValue
from lib.core.common import filterNone
from lib.core.common import getPublicTypeMembers
from lib.core.common import getTechniqueData
from lib.core.common import hashDBRetrieve
@ -431,7 +432,7 @@ def getValue(expression, blind=True, union=True, error=True, time=True, fromUser
found = (value is not None) or (value is None and expectingNone) or count >= MAX_TECHNIQUES_PER_VALUE
if found and conf.dnsDomain:
_ = "".join(filter(None, (key if isTechniqueAvailable(value) else None for key, value in {'E': PAYLOAD.TECHNIQUE.ERROR, 'Q': PAYLOAD.TECHNIQUE.QUERY, 'U': PAYLOAD.TECHNIQUE.UNION}.items())))
_ = "".join(filterNone(key if isTechniqueAvailable(value) else None for key, value in {'E': PAYLOAD.TECHNIQUE.ERROR, 'Q': PAYLOAD.TECHNIQUE.QUERY, 'U': PAYLOAD.TECHNIQUE.UNION}.items()))
warnMsg = "option '--dns-domain' will be ignored "
warnMsg += "as faster techniques are usable "
warnMsg += "(%s) " % _

View File

@ -956,7 +956,7 @@ def dictionaryAttack(attack_dict):
if regex and regex not in hash_regexes:
hash_regexes.append(regex)
infoMsg = "using hash method '%s'" % __functions__[regex].func_name
infoMsg = "using hash method '%s'" % __functions__[regex].__name__
logger.info(infoMsg)
for hash_regex in hash_regexes:
@ -1084,7 +1084,7 @@ def dictionaryAttack(attack_dict):
if readInput(message, default='N', boolean=True):
suffix_list += COMMON_PASSWORD_SUFFIXES
infoMsg = "starting dictionary-based cracking (%s)" % __functions__[hash_regex].func_name
infoMsg = "starting dictionary-based cracking (%s)" % __functions__[hash_regex].__name__
logger.info(infoMsg)
for item in attack_info:

View File

@ -11,6 +11,7 @@ from extra.safe2bin.safe2bin import safechardecode
from lib.core.agent import agent
from lib.core.bigarray import BigArray
from lib.core.common import Backend
from lib.core.common import filterNone
from lib.core.common import getSafeExString
from lib.core.common import getUnicode
from lib.core.common import isNoneValue
@ -67,7 +68,7 @@ def pivotDumpTable(table, colList, count=None, blind=True, alias=None):
lengths[column] = 0
entries[column] = BigArray()
colList = filter(None, sorted(colList, key=lambda x: len(x) if x else MAX_INT))
colList = filterNone(sorted(colList, key=lambda x: len(x) if x else MAX_INT))
if conf.pivotColumn:
for _ in colList:
@ -141,7 +142,7 @@ def pivotDumpTable(table, colList, count=None, blind=True, alias=None):
if column == colList[0]:
if isNoneValue(value):
try:
for pivotValue in filter(None, (" " if pivotValue == " " else None, "%s%s" % (pivotValue[0], unichr(ord(pivotValue[1]) + 1)) if len(pivotValue) > 1 else None, unichr(ord(pivotValue[0]) + 1))):
for pivotValue in filterNone((" " if pivotValue == " " else None, "%s%s" % (pivotValue[0], unichr(ord(pivotValue[1]) + 1)) if len(pivotValue) > 1 else None, unichr(ord(pivotValue[0]) + 1))):
value = _(column, pivotValue)
if not isNoneValue(value):
break

View File

@ -9,6 +9,7 @@ from lib.core.agent import agent
from lib.core.common import arrayizeValue
from lib.core.common import Backend
from lib.core.common import extractRegexResult
from lib.core.common import filterNone
from lib.core.common import filterPairValues
from lib.core.common import flattenValue
from lib.core.common import getLimitRange
@ -490,7 +491,7 @@ class Databases:
else:
return kb.data.cachedColumns
tblList = filter(None, (safeSQLIdentificatorNaming(_, True) for _ in tblList))
tblList = filterNone(safeSQLIdentificatorNaming(_, True) for _ in tblList)
if bruteForce is None:
if Backend.isDbms(DBMS.MYSQL) and not kb.data.has_information_schema:

View File

@ -42,6 +42,7 @@ try:
from lib.core.common import checkPipedInput
from lib.core.common import createGithubIssue
from lib.core.common import dataToStdout
from lib.core.common import filterNone
from lib.core.common import getSafeExString
from lib.core.common import getUnicode
from lib.core.common import maskSensitiveData
@ -362,7 +363,7 @@ def main():
os.remove(filepath)
except OSError:
pass
if not filter(None, (filepath for filepath in glob.glob(os.path.join(kb.tempDir, '*')) if not any(filepath.endswith(_) for _ in ('.lock', '.exe', '_')))):
if not filterNone(filepath for filepath in glob.glob(os.path.join(kb.tempDir, '*')) if not any(filepath.endswith(_) for _ in ('.lock', '.exe', '_'))):
shutil.rmtree(kb.tempDir, ignore_errors=True)
if conf.get("hashDB"):