From dbd93e267000e4fce886dab8abc3fd336a48f73b Mon Sep 17 00:00:00 2001 From: Miroslav Stampar Date: Fri, 29 Mar 2019 02:28:16 +0100 Subject: [PATCH] Minor refactoring (drei stuff) --- extra/dbgtool/dbgtool.py | 1 - lib/controller/checks.py | 5 +++-- lib/core/agent.py | 3 ++- lib/core/common.py | 42 +++++++++++++++++++++++++----------- lib/core/option.py | 7 +++--- lib/core/settings.py | 2 +- lib/request/basic.py | 3 ++- lib/request/connect.py | 11 +++++----- lib/request/httpshandler.py | 3 ++- lib/request/inject.py | 3 ++- lib/utils/hash.py | 4 ++-- lib/utils/pivotdumptable.py | 5 +++-- plugins/generic/databases.py | 3 ++- sqlmap.py | 3 ++- 14 files changed, 60 insertions(+), 35 deletions(-) diff --git a/extra/dbgtool/dbgtool.py b/extra/dbgtool/dbgtool.py index 2e7bd095d..30ae5e837 100644 --- a/extra/dbgtool/dbgtool.py +++ b/extra/dbgtool/dbgtool.py @@ -11,7 +11,6 @@ from __future__ import print_function import os import sys -import struct from optparse import OptionError from optparse import OptionParser diff --git a/lib/controller/checks.py b/lib/controller/checks.py index cb46ef879..f8a4e5fc1 100644 --- a/lib/controller/checks.py +++ b/lib/controller/checks.py @@ -21,6 +21,7 @@ from lib.core.agent import agent from lib.core.common import Backend from lib.core.common import extractRegexResult from lib.core.common import extractTextTagContent +from lib.core.common import filterNone from lib.core.common import findDynamicContent from lib.core.common import Format from lib.core.common import getFilteredPageContent @@ -581,7 +582,7 @@ def checkSqlInjection(place, parameter, value): else: errorSet = set() - candidates = filter(None, (_.strip() if _.strip() in trueRawResponse and _.strip() not in falseRawResponse else None for _ in (trueSet - falseSet - errorSet))) + candidates = filterNone(_.strip() if _.strip() in trueRawResponse and _.strip() not in falseRawResponse else None for _ in (trueSet - falseSet - errorSet)) if candidates: candidates = sorted(candidates, key=lambda _: len(_)) @@ -595,7 +596,7 @@ def checkSqlInjection(place, parameter, value): logger.info(infoMsg) if not any((conf.string, conf.notString)): - candidates = filter(None, (_.strip() if _.strip() in falseRawResponse and _.strip() not in trueRawResponse else None for _ in (falseSet - trueSet))) + candidates = filterNone(_.strip() if _.strip() in falseRawResponse and _.strip() not in trueRawResponse else None for _ in (falseSet - trueSet)) if candidates: candidates = sorted(candidates, key=lambda _: len(_)) diff --git a/lib/core/agent.py b/lib/core/agent.py index 9fab4004c..fd79a2170 100644 --- a/lib/core/agent.py +++ b/lib/core/agent.py @@ -9,6 +9,7 @@ import re from lib.core.common import Backend from lib.core.common import extractRegexResult +from lib.core.common import filterNone from lib.core.common import getSQLSnippet from lib.core.common import getUnicode from lib.core.common import isDBMSVersionAtLeast @@ -106,7 +107,7 @@ class Agent(object): if place == PLACE.URI: origValue = origValue.split(kb.customInjectionMark)[0] else: - origValue = filter(None, (re.search(_, origValue.split(BOUNDED_INJECTION_MARKER)[0]) for _ in (r"\w+\Z", r"[^\"'><]+\Z", r"[^ ]+\Z")))[0].group(0) + origValue = filterNone(re.search(_, origValue.split(BOUNDED_INJECTION_MARKER)[0]) for _ in (r"\w+\Z", r"[^\"'><]+\Z", r"[^ ]+\Z"))[0].group(0) origValue = origValue[origValue.rfind('/') + 1:] for char in ('?', '=', ':', ',', '&'): if char in origValue: diff --git a/lib/core/common.py b/lib/core/common.py index 7bdec65e0..2458bf4d1 100644 --- a/lib/core/common.py +++ b/lib/core/common.py @@ -7,6 +7,7 @@ See the file 'LICENSE' for copying permission import binascii import codecs +import collections import contextlib import copy import distutils @@ -228,7 +229,7 @@ class Format(object): if versions is None and Backend.getVersionList(): versions = Backend.getVersionList() - return Backend.getDbms() if versions is None else "%s %s" % (Backend.getDbms(), " and ".join(filter(None, versions))) + return Backend.getDbms() if versions is None else "%s %s" % (Backend.getDbms(), " and ".join(filterNone(versions))) @staticmethod def getErrorParsedDBMSes(): @@ -501,7 +502,7 @@ class Backend: @staticmethod def getVersion(): - versions = filter(None, flattenValue(kb.dbmsVersion)) + versions = filterNone(flattenValue(kb.dbmsVersion)) if not isNoneValue(versions): return versions[0] else: @@ -509,7 +510,7 @@ class Backend: @staticmethod def getVersionList(): - versions = filter(None, flattenValue(kb.dbmsVersion)) + versions = filterNone(flattenValue(kb.dbmsVersion)) if not isNoneValue(versions): return versions else: @@ -787,7 +788,7 @@ def getManualDirectories(): else: targets.add('.'.join(_[:-1])) - targets = filter(None, targets) + targets = filterNone(targets) for prefix in BRUTE_DOC_ROOT_PREFIXES.get(Backend.getOs(), DEFAULT_DOC_ROOTS[OS.LINUX]): if BRUTE_DOC_ROOT_TARGET_MARK in prefix and re.match(IP_ADDRESS_REGEX, conf.hostname): @@ -1473,7 +1474,7 @@ def parseTargetUrl(): errMsg += "in the hostname part" raise SqlmapGenericException(errMsg) - hostnamePort = urlSplit.netloc.split(":") if not re.search(r"\[.+\]", urlSplit.netloc) else filter(None, (re.search(r"\[.+\]", urlSplit.netloc).group(0), re.search(r"\](:(?P\d+))?", urlSplit.netloc).group("port"))) + hostnamePort = urlSplit.netloc.split(":") if not re.search(r"\[.+\]", urlSplit.netloc) else filterNone((re.search(r"\[.+\]", urlSplit.netloc).group(0), re.search(r"\](:(?P\d+))?", urlSplit.netloc).group("port"))) conf.scheme = (urlSplit.scheme.strip().lower() or "http") if not conf.forceSSL else "https" conf.path = urlSplit.path.strip() @@ -2389,13 +2390,13 @@ def getUnicode(value, encoding=None, noneToNull=False): return value elif isinstance(value, six.binary_type): # Heuristics (if encoding not explicitly specified) - candidates = filter(None, (encoding, kb.get("pageEncoding") if kb.get("originalPage") else None, conf.get("encoding"), UNICODE_ENCODING, sys.getfilesystemencoding())) + candidates = filterNone((encoding, kb.get("pageEncoding") if kb.get("originalPage") else None, conf.get("encoding"), UNICODE_ENCODING, sys.getfilesystemencoding())) if all(_ in value for _ in ('<', '>')): pass elif any(_ in value for _ in (":\\", '/', '.')) and '\n' not in value: - candidates = filter(None, (encoding, sys.getfilesystemencoding(), kb.get("pageEncoding") if kb.get("originalPage") else None, UNICODE_ENCODING, conf.get("encoding"))) + candidates = filterNone((encoding, sys.getfilesystemencoding(), kb.get("pageEncoding") if kb.get("originalPage") else None, UNICODE_ENCODING, conf.get("encoding"))) elif conf.get("encoding") and '\n' not in value: - candidates = filter(None, (encoding, conf.get("encoding"), kb.get("pageEncoding") if kb.get("originalPage") else None, sys.getfilesystemencoding(), UNICODE_ENCODING)) + candidates = filterNone((encoding, conf.get("encoding"), kb.get("pageEncoding") if kb.get("originalPage") else None, sys.getfilesystemencoding(), UNICODE_ENCODING)) for candidate in candidates: try: @@ -2837,7 +2838,7 @@ def extractTextTagContent(page): except MemoryError: page = page.replace(REFLECTED_VALUE_MARKER, "") - return filter(None, (_.group("result").strip() for _ in re.finditer(TEXT_TAG_REGEX, page))) + return filterNone(_.group("result").strip() for _ in re.finditer(TEXT_TAG_REGEX, page)) def trimAlphaNum(value): """ @@ -2996,6 +2997,21 @@ def filterControlChars(value, replacement=' '): return filterStringValue(value, PRINTABLE_CHAR_REGEX, replacement) +def filterNone(values): + """ + Emulates filterNone([...]) functionality + + >>> filterNone([1, 2, "", None, 3]) + [1, 2, 3] + """ + + retVal = values + + if isinstance(values, collections.Iterable): + retVal = [_ for _ in values if _] + + return retVal + def isDBMSVersionAtLeast(version): """ Checks if the recognized DBMS version is at least the version @@ -3537,7 +3553,7 @@ def maskSensitiveData(msg): retVal = getUnicode(msg) - for item in filter(None, (conf.get(_) for _ in SENSITIVE_OPTIONS)): + for item in filterNone(conf.get(_) for _ in SENSITIVE_OPTIONS): regex = SENSITIVE_DATA_REGEX % re.sub(r"(\W)", r"\\\1", getUnicode(item)) while extractRegexResult(regex, retVal): value = extractRegexResult(regex, retVal) @@ -3640,14 +3656,14 @@ def removeReflectiveValues(content, payload, suppressWarning=False): regex = _(filterStringValue(payload, r"[A-Za-z0-9]", REFLECTED_REPLACEMENT_REGEX.encode("string_escape"))) if regex != payload: - if all(part.lower() in content.lower() for part in filter(None, regex.split(REFLECTED_REPLACEMENT_REGEX))[1:]): # fast optimization check + if all(part.lower() in content.lower() for part in filterNone(regex.split(REFLECTED_REPLACEMENT_REGEX))[1:]): # fast optimization check parts = regex.split(REFLECTED_REPLACEMENT_REGEX) retVal = content.replace(payload, REFLECTED_VALUE_MARKER) # dummy approach if len(parts) > REFLECTED_MAX_REGEX_PARTS: # preventing CPU hogs regex = _("%s%s%s" % (REFLECTED_REPLACEMENT_REGEX.join(parts[:REFLECTED_MAX_REGEX_PARTS // 2]), REFLECTED_REPLACEMENT_REGEX, REFLECTED_REPLACEMENT_REGEX.join(parts[-REFLECTED_MAX_REGEX_PARTS // 2:]))) - parts = filter(None, regex.split(REFLECTED_REPLACEMENT_REGEX)) + parts = filterNone(regex.split(REFLECTED_REPLACEMENT_REGEX)) if regex.startswith(REFLECTED_REPLACEMENT_REGEX): regex = r"%s%s" % (REFLECTED_BORDER_REGEX, regex[len(REFLECTED_REPLACEMENT_REGEX):]) @@ -4482,7 +4498,7 @@ def resetCookieJar(cookieJar): logger.info(infoMsg) content = readCachedFileContent(conf.loadCookies) - lines = filter(None, (line.strip() for line in content.split("\n") if not line.startswith('#'))) + lines = filterNone(line.strip() for line in content.split("\n") if not line.startswith('#')) handle, filename = tempfile.mkstemp(prefix=MKSTEMP_PREFIX.COOKIE_JAR) os.close(handle) diff --git a/lib/core/option.py b/lib/core/option.py index 8744c1951..62cb66cc0 100644 --- a/lib/core/option.py +++ b/lib/core/option.py @@ -33,6 +33,7 @@ from lib.core.common import decodeStringEscape from lib.core.common import getPublicTypeMembers from lib.core.common import getSafeExString from lib.core.common import getUnicode +from lib.core.common import filterNone from lib.core.common import findLocalPort from lib.core.common import findPageForms from lib.core.common import getConsoleWidth @@ -784,7 +785,7 @@ def _setTamperingFunctions(): if name == "tamper" and inspect.getargspec(function).args and inspect.getargspec(function).keywords == "kwargs": found = True kb.tamperFunctions.append(function) - function.func_name = module.__name__ + function.__name__ = module.__name__ if check_priority and priority > last_priority: message = "it appears that you might have mixed " @@ -880,7 +881,7 @@ def _setPreprocessFunctions(): found = True kb.preprocessFunctions.append(function) - function.func_name = module.__name__ + function.__name__ = module.__name__ break @@ -1113,7 +1114,7 @@ def _setHTTPHandlers(): debugMsg = "creating HTTP requests opener object" logger.debug(debugMsg) - handlers = filter(None, [multipartPostHandler, proxyHandler if proxyHandler.proxies else None, authHandler, redirectHandler, rangeHandler, chunkedHandler if conf.chunked else None, httpsHandler]) + handlers = filterNone([multipartPostHandler, proxyHandler if proxyHandler.proxies else None, authHandler, redirectHandler, rangeHandler, chunkedHandler if conf.chunked else None, httpsHandler]) if not conf.dropSetCookie: if not conf.loadCookies: diff --git a/lib/core/settings.py b/lib/core/settings.py index 8f62d27ed..1b8865782 100644 --- a/lib/core/settings.py +++ b/lib/core/settings.py @@ -17,7 +17,7 @@ from lib.core.enums import DBMS_DIRECTORY_NAME from lib.core.enums import OS # sqlmap version (...) -VERSION = "1.3.3.77" +VERSION = "1.3.3.78" TYPE = "dev" if VERSION.count('.') > 2 and VERSION.split('.')[-1] != '0' else "stable" TYPE_COLORS = {"dev": 33, "stable": 90, "pip": 34} VERSION_STRING = "sqlmap/%s#%s" % ('.'.join(VERSION.split('.')[:-1]) if VERSION.count('.') > 2 and VERSION.split('.')[-1] == '0' else VERSION, TYPE) diff --git a/lib/request/basic.py b/lib/request/basic.py index 255c318af..2f112d6ed 100644 --- a/lib/request/basic.py +++ b/lib/request/basic.py @@ -16,6 +16,7 @@ import zlib from lib.core.common import Backend from lib.core.common import extractErrorMessage from lib.core.common import extractRegexResult +from lib.core.common import filterNone from lib.core.common import getPublicTypeMembers from lib.core.common import getSafeExString from lib.core.common import getUnicode @@ -100,7 +101,7 @@ def forgeHeaders(items=None, base=None): if ("%s=" % getUnicode(cookie.name)) in getUnicode(headers[HTTP_HEADER.COOKIE]): if conf.loadCookies: - conf.httpHeaders = filter(None, ((item if item[0] != HTTP_HEADER.COOKIE else None) for item in conf.httpHeaders)) + conf.httpHeaders = filterNone((item if item[0] != HTTP_HEADER.COOKIE else None) for item in conf.httpHeaders) elif kb.mergeCookies is None: message = "you provided a HTTP %s header value. " % HTTP_HEADER.COOKIE message += "The target URL provided its own cookies within " diff --git a/lib/request/connect.py b/lib/request/connect.py index 3d2e354b9..ed782d65c 100644 --- a/lib/request/connect.py +++ b/lib/request/connect.py @@ -32,6 +32,7 @@ from lib.core.common import dataToStdout from lib.core.common import escapeJsonValue from lib.core.common import evaluateCode from lib.core.common import extractRegexResult +from lib.core.common import filterNone from lib.core.common import findMultipartPostBoundary from lib.core.common import getCurrentThreadData from lib.core.common import getHeader @@ -600,7 +601,7 @@ class Connect(object): except: pass finally: - page = page if isinstance(page, unicode) else getUnicode(page) + page = getUnicode(page) code = ex.code status = getSafeExString(ex) @@ -758,7 +759,7 @@ class Connect(object): page, responseHeaders, code = function(page, responseHeaders, code) except Exception as ex: errMsg = "error occurred while running preprocess " - errMsg += "function '%s' ('%s')" % (function.func_name, getSafeExString(ex)) + errMsg += "function '%s' ('%s')" % (function.__name__, getSafeExString(ex)) raise SqlmapGenericException(errMsg) threadData.lastPage = page @@ -857,11 +858,11 @@ class Connect(object): payload = function(payload=payload, headers=auxHeaders, delimiter=delimiter, hints=hints) except Exception as ex: errMsg = "error occurred while running tamper " - errMsg += "function '%s' ('%s')" % (function.func_name, getSafeExString(ex)) + errMsg += "function '%s' ('%s')" % (function.__name__, getSafeExString(ex)) raise SqlmapGenericException(errMsg) if not isinstance(payload, six.string_types): - errMsg = "tamper function '%s' returns " % function.func_name + errMsg = "tamper function '%s' returns " % function.__name__ errMsg += "invalid payload type ('%s')" % type(payload) raise SqlmapValueException(errMsg) @@ -1095,7 +1096,7 @@ class Connect(object): else: query = None - for item in filter(None, (get, post if not kb.postHint else None, query)): + for item in filterNone((get, post if not kb.postHint else None, query)): for part in item.split(delimiter): if '=' in part: name, value = part.split('=', 1) diff --git a/lib/request/httpshandler.py b/lib/request/httpshandler.py index dcb71abb0..2a98d6aba 100644 --- a/lib/request/httpshandler.py +++ b/lib/request/httpshandler.py @@ -9,6 +9,7 @@ import distutils.version import re import socket +from lib.core.common import filterNone from lib.core.common import getSafeExString from lib.core.data import conf from lib.core.data import kb @@ -25,7 +26,7 @@ try: except ImportError: pass -_protocols = filter(None, (getattr(ssl, _, None) for _ in ("PROTOCOL_TLSv1_2", "PROTOCOL_TLSv1_1", "PROTOCOL_TLSv1", "PROTOCOL_SSLv3", "PROTOCOL_SSLv23", "PROTOCOL_SSLv2"))) +_protocols = filterNone(getattr(ssl, _, None) for _ in ("PROTOCOL_TLSv1_2", "PROTOCOL_TLSv1_1", "PROTOCOL_TLSv1", "PROTOCOL_SSLv3", "PROTOCOL_SSLv23", "PROTOCOL_SSLv2")) class HTTPSConnection(_http_client.HTTPSConnection): """ diff --git a/lib/request/inject.py b/lib/request/inject.py index 5555e962d..eddab9b7f 100644 --- a/lib/request/inject.py +++ b/lib/request/inject.py @@ -17,6 +17,7 @@ from lib.core.common import calculateDeltaSeconds from lib.core.common import cleanQuery from lib.core.common import expandAsteriskForColumns from lib.core.common import extractExpectedValue +from lib.core.common import filterNone from lib.core.common import getPublicTypeMembers from lib.core.common import getTechniqueData from lib.core.common import hashDBRetrieve @@ -431,7 +432,7 @@ def getValue(expression, blind=True, union=True, error=True, time=True, fromUser found = (value is not None) or (value is None and expectingNone) or count >= MAX_TECHNIQUES_PER_VALUE if found and conf.dnsDomain: - _ = "".join(filter(None, (key if isTechniqueAvailable(value) else None for key, value in {'E': PAYLOAD.TECHNIQUE.ERROR, 'Q': PAYLOAD.TECHNIQUE.QUERY, 'U': PAYLOAD.TECHNIQUE.UNION}.items()))) + _ = "".join(filterNone(key if isTechniqueAvailable(value) else None for key, value in {'E': PAYLOAD.TECHNIQUE.ERROR, 'Q': PAYLOAD.TECHNIQUE.QUERY, 'U': PAYLOAD.TECHNIQUE.UNION}.items())) warnMsg = "option '--dns-domain' will be ignored " warnMsg += "as faster techniques are usable " warnMsg += "(%s) " % _ diff --git a/lib/utils/hash.py b/lib/utils/hash.py index 5a2a9f084..03c5c4db1 100644 --- a/lib/utils/hash.py +++ b/lib/utils/hash.py @@ -956,7 +956,7 @@ def dictionaryAttack(attack_dict): if regex and regex not in hash_regexes: hash_regexes.append(regex) - infoMsg = "using hash method '%s'" % __functions__[regex].func_name + infoMsg = "using hash method '%s'" % __functions__[regex].__name__ logger.info(infoMsg) for hash_regex in hash_regexes: @@ -1084,7 +1084,7 @@ def dictionaryAttack(attack_dict): if readInput(message, default='N', boolean=True): suffix_list += COMMON_PASSWORD_SUFFIXES - infoMsg = "starting dictionary-based cracking (%s)" % __functions__[hash_regex].func_name + infoMsg = "starting dictionary-based cracking (%s)" % __functions__[hash_regex].__name__ logger.info(infoMsg) for item in attack_info: diff --git a/lib/utils/pivotdumptable.py b/lib/utils/pivotdumptable.py index 902e4d51c..0b07907d8 100644 --- a/lib/utils/pivotdumptable.py +++ b/lib/utils/pivotdumptable.py @@ -11,6 +11,7 @@ from extra.safe2bin.safe2bin import safechardecode from lib.core.agent import agent from lib.core.bigarray import BigArray from lib.core.common import Backend +from lib.core.common import filterNone from lib.core.common import getSafeExString from lib.core.common import getUnicode from lib.core.common import isNoneValue @@ -67,7 +68,7 @@ def pivotDumpTable(table, colList, count=None, blind=True, alias=None): lengths[column] = 0 entries[column] = BigArray() - colList = filter(None, sorted(colList, key=lambda x: len(x) if x else MAX_INT)) + colList = filterNone(sorted(colList, key=lambda x: len(x) if x else MAX_INT)) if conf.pivotColumn: for _ in colList: @@ -141,7 +142,7 @@ def pivotDumpTable(table, colList, count=None, blind=True, alias=None): if column == colList[0]: if isNoneValue(value): try: - for pivotValue in filter(None, (" " if pivotValue == " " else None, "%s%s" % (pivotValue[0], unichr(ord(pivotValue[1]) + 1)) if len(pivotValue) > 1 else None, unichr(ord(pivotValue[0]) + 1))): + for pivotValue in filterNone((" " if pivotValue == " " else None, "%s%s" % (pivotValue[0], unichr(ord(pivotValue[1]) + 1)) if len(pivotValue) > 1 else None, unichr(ord(pivotValue[0]) + 1))): value = _(column, pivotValue) if not isNoneValue(value): break diff --git a/plugins/generic/databases.py b/plugins/generic/databases.py index 76f125c1f..8e2ee5514 100644 --- a/plugins/generic/databases.py +++ b/plugins/generic/databases.py @@ -9,6 +9,7 @@ from lib.core.agent import agent from lib.core.common import arrayizeValue from lib.core.common import Backend from lib.core.common import extractRegexResult +from lib.core.common import filterNone from lib.core.common import filterPairValues from lib.core.common import flattenValue from lib.core.common import getLimitRange @@ -490,7 +491,7 @@ class Databases: else: return kb.data.cachedColumns - tblList = filter(None, (safeSQLIdentificatorNaming(_, True) for _ in tblList)) + tblList = filterNone(safeSQLIdentificatorNaming(_, True) for _ in tblList) if bruteForce is None: if Backend.isDbms(DBMS.MYSQL) and not kb.data.has_information_schema: diff --git a/sqlmap.py b/sqlmap.py index f5976c550..bb119e06a 100755 --- a/sqlmap.py +++ b/sqlmap.py @@ -42,6 +42,7 @@ try: from lib.core.common import checkPipedInput from lib.core.common import createGithubIssue from lib.core.common import dataToStdout + from lib.core.common import filterNone from lib.core.common import getSafeExString from lib.core.common import getUnicode from lib.core.common import maskSensitiveData @@ -362,7 +363,7 @@ def main(): os.remove(filepath) except OSError: pass - if not filter(None, (filepath for filepath in glob.glob(os.path.join(kb.tempDir, '*')) if not any(filepath.endswith(_) for _ in ('.lock', '.exe', '_')))): + if not filterNone(filepath for filepath in glob.glob(os.path.join(kb.tempDir, '*')) if not any(filepath.endswith(_) for _ in ('.lock', '.exe', '_'))): shutil.rmtree(kb.tempDir, ignore_errors=True) if conf.get("hashDB"):