diff --git a/lib/core/common.py b/lib/core/common.py index 74e78265d..aee413235 100644 --- a/lib/core/common.py +++ b/lib/core/common.py @@ -6,6 +6,7 @@ See the file 'doc/COPYING' for copying permission """ import codecs +import contextlib import cookielib import copy import ctypes @@ -102,14 +103,10 @@ from lib.core.settings import USER_AGENT_ALIASES from lib.core.settings import PARTIAL_VALUE_MARKER from lib.core.settings import ERROR_PARSING_REGEXES from lib.core.settings import PRINTABLE_CHAR_REGEX -from lib.core.settings import DUMP_DEL_MARKER from lib.core.settings import SQL_STATEMENTS from lib.core.settings import SUPPORTED_DBMS from lib.core.settings import UNKNOWN_DBMS_VERSION from lib.core.settings import DEFAULT_MSSQL_SCHEMA -from lib.core.settings import DUMP_NEWLINE_MARKER -from lib.core.settings import DUMP_CR_MARKER -from lib.core.settings import DUMP_TAB_MARKER from lib.core.settings import PARAMETER_AMP_MARKER from lib.core.settings import PARAMETER_SEMICOLON_MARKER from lib.core.settings import LARGE_OUTPUT_THRESHOLD @@ -492,27 +489,17 @@ class Backend: # Reference: http://code.activestate.com/recipes/325205-cache-decorator-in-python-24/ def cachedmethod(f, cache={}): - def g(*args, **kwargs): + def _(*args, **kwargs): key = (f, tuple(args), frozenset(kwargs.items())) if key not in cache: cache[key] = f(*args, **kwargs) return cache[key] - return g + return _ def paramToDict(place, parameters=None): """ Split the parameters into names and values, check if these parameters are within the testable parameters and return in a dictionary. - - @param place: where sqlmap has to work, can be GET, POST or Cookie. - @type place: C{str} - - @param parameters: parameters string in the format for instance - 'p1=v1&p2=v2' (GET and POST) or 'p1=v1;p2=v2' (Cookie). - @type parameters: C{str} - - @return: the parameters in a dictionary. - @rtype: C{str} """ testableParameters = OrderedDict() @@ -706,6 +693,10 @@ def singleTimeLogMessage(message, level=logging.INFO, flag=None): logger.log(level, message) def dataToStdout(data, forceOutput=False): + """ + Writes text to the stdout (console) stream + """ + if not kb.get("threadException"): if forceOutput or not getCurrentThreadData().disableStdOut: try: @@ -793,13 +784,10 @@ def readInput(message, default=None, checkBatch=True): else: logging._acquireLock() dataToStdout("\r%s" % message, True) - data = raw_input() + data = raw_input() or default #data = raw_input(message.encode(sys.stdout.encoding or UNICODE_ENCODING)) logging._releaseLock() - if not data: - data = default - return data def randomRange(start=0, stop=1000): @@ -845,19 +833,6 @@ def checkFile(filename): if not os.path.exists(filename): raise sqlmapFilePathException, "unable to read file '%s'" % filename -def replaceNewlineTabs(value, stdout=False): - if value is None: - return - - if stdout: - retVal = value.replace("\n", " ").replace("\r", " ").replace("\t", " ") - else: - retVal = value.replace("\n", DUMP_NEWLINE_MARKER).replace("\r", DUMP_CR_MARKER).replace("\t", DUMP_TAB_MARKER) - - retVal = retVal.replace(kb.chars.delimiter, DUMP_DEL_MARKER) - - return retVal - def restoreDumpMarkedChars(value, onlyNewlineTab=False): retVal = value @@ -1070,16 +1045,15 @@ def parseTargetUrl(): if CUSTOM_INJECTION_MARK_CHAR in conf.url: conf.url = conf.url.replace('?', URI_QUESTION_MARKER) - __urlSplit = urlparse.urlsplit(conf.url) - __hostnamePort = __urlSplit[1].split(":") if not re.search("\[.+\]", __urlSplit[1]) else filter(None, (re.search("\[.+\]", __urlSplit[1]).group(0), re.search("\](:(?P\d+))?", __urlSplit[1]).group("port"))) + urlSplit = urlparse.urlsplit(conf.url) + hostnamePort = urlSplit[1].split(":") if not re.search("\[.+\]", urlSplit[1]) else filter(None, (re.search("\[.+\]", urlSplit[1]).group(0), re.search("\](:(?P\d+))?", urlSplit[1]).group("port"))) - conf.scheme = __urlSplit[0].strip().lower() if not conf.forceSSL else "https" - conf.path = __urlSplit[2].strip() - conf.hostname = __hostnamePort[0].strip() + conf.scheme = urlSplit[0].strip().lower() if not conf.forceSSL else "https" + conf.path = urlSplit[2].strip() + conf.hostname = hostnamePort[0].strip() conf.ipv6 = conf.hostname != conf.hostname.strip("[]") conf.hostname = conf.hostname.strip("[]") - try: _ = conf.hostname.encode("idna") @@ -1090,9 +1064,9 @@ def parseTargetUrl(): errMsg = "invalid target url" raise sqlmapSyntaxException, errMsg - if len(__hostnamePort) == 2: + if len(hostnamePort) == 2: try: - conf.port = int(__hostnamePort[1]) + conf.port = int(hostnamePort[1]) except: errMsg = "invalid target url" raise sqlmapSyntaxException, errMsg @@ -1101,8 +1075,8 @@ def parseTargetUrl(): else: conf.port = 80 - if __urlSplit[3]: - conf.parameters[PLACE.GET] = urldecode(__urlSplit[3]) if __urlSplit[3] and urlencode(DEFAULT_GET_POST_DELIMITER, None) not in __urlSplit[3] else __urlSplit[3] + if urlSplit[3]: + conf.parameters[PLACE.GET] = urldecode(urlSplit[3]) if urlSplit[3] and urlencode(DEFAULT_GET_POST_DELIMITER, None) not in urlSplit[3] else urlSplit[3] conf.url = "%s://%s:%d%s" % (conf.scheme, ("[%s]" % conf.hostname) if conf.ipv6 else conf.hostname, conf.port, conf.path) conf.url = conf.url.replace(URI_QUESTION_MARKER, '?') @@ -1349,24 +1323,14 @@ def directoryPath(filepath): Returns directory path for a given filepath """ - if isWindowsDriveLetterPath(filepath): - retVal = ntpath.dirname(filepath) - else: - retVal = posixpath.dirname(filepath) - - return retVal + return ntpath.dirname(filepath) if isWindowsDriveLetterPath(filepath) else posixpath.dirname(filepath) def normalizePath(filepath): """ Returns normalized string representation of a given filepath """ - if isWindowsDriveLetterPath(filepath): - retVal = ntpath.normpath(filepath) - else: - retVal = posixpath.normpath(filepath) - - return retVal + return ntpath.normpath(filepath) if isWindowsDriveLetterPath(filepath) else posixpath.normpath(filepath) def safeStringFormat(format_, params): """ @@ -1379,16 +1343,13 @@ def safeStringFormat(format_, params): retVal = retVal.replace("%s", params) else: count, index = 0, 0 - while index != -1: index = retVal.find("%s") - if index != -1: if count < len(params): retVal = retVal[:index] + getUnicode(params[count]) + retVal[index + 2:] else: raise sqlmapNoneDataException, "wrong number of parameters during string formatting" - count += 1 return retVal @@ -1404,10 +1365,8 @@ def getFilteredPageContent(page, onlyText=True): # only if the page's charset has been successfully identified if isinstance(page, unicode): retVal = re.sub(r"(?si)||%s" % (r"|<[^>]+>|\t|\n|\r" if onlyText else ""), " ", page) - while retVal.find(" ") != -1: retVal = retVal.replace(" ", " ") - retVal = htmlunescape(retVal) return retVal @@ -1421,8 +1380,8 @@ def getPageWordSet(page): # only if the page's charset has been successfully identified if isinstance(page, unicode): - page = getFilteredPageContent(page) - retVal = set(re.findall(r"\w+", page)) + _ = getFilteredPageContent(page) + retVal = set(re.findall(r"\w+", _)) return retVal @@ -1557,8 +1516,8 @@ def getConsoleWidth(default=80): width = None - if 'COLUMNS' in os.environ and os.environ['COLUMNS'].isdigit(): - width = int(os.environ['COLUMNS']) + if os.getenv("COLUMNS", "").isdigit(): + width = int(os.getenv("COLUMNS")) else: output=execute('stty size', shell=True, stdout=PIPE, stderr=PIPE).stdout.read() items = output.split() @@ -1576,7 +1535,7 @@ def getConsoleWidth(default=80): except: pass - return width if width else default + return width or default def clearConsoleLine(forceOutput=False): """ @@ -1591,7 +1550,7 @@ def parseXmlFile(xmlFile, handler): Parses XML file by a given handler """ - with StringIO(readCachedFileContent(xmlFile)) as stream: + with contextlib.closing(StringIO(readCachedFileContent(xmlFile))) as stream: parse(stream, handler) def getSPQLSnippet(dbms, name, **variables): @@ -1632,8 +1591,7 @@ def readCachedFileContent(filename, mode='rb'): if filename not in kb.cache.content: checkFile(filename) with codecs.open(filename, mode, UNICODE_ENCODING) as f: - content = f.read() - kb.cache.content[filename] = content + kb.cache.content[filename] = f.read() return kb.cache.content[filename] @@ -1675,12 +1633,7 @@ def average(values): Computes the arithmetic mean of a list of numbers. """ - retVal = None - - if values: - retVal = sum(values) / len(values) - - return retVal + return (sum(values) / len(values)) if values else None def calculateDeltaSeconds(start): """ @@ -1891,8 +1844,8 @@ def getUnicode(value, encoding=None, system=False, noneToNull=False): def longestCommonPrefix(*sequences): """ Returns longest common prefix occuring in given sequences + Reference: http://boredzo.org/blog/archives/2007-01-06/longest-common-prefix-in-python-2 """ - # Reference: http://boredzo.org/blog/archives/2007-01-06/longest-common-prefix-in-python-2 if len(sequences) == 1: return sequences[0] @@ -2356,6 +2309,10 @@ def isInferenceAvailable(): return any(isTechniqueAvailable(_) for _ in (PAYLOAD.TECHNIQUE.BOOLEAN, PAYLOAD.TECHNIQUE.STACKED, PAYLOAD.TECHNIQUE.TIME)) def setOptimize(): + """ + Sets options turned on by switch '-o' + """ + #conf.predictOutput = True conf.keepAlive = True conf.threads = 3 if conf.threads < 3 else conf.threads @@ -2719,7 +2676,7 @@ def safeSQLIdentificatorNaming(name, isTable=False): parts = name.split('.') for i in xrange(len(parts)): - if not re.match(r"\A[A-Za-z0-9_@\$]+\Z", parts[i]): # reference: http://stackoverflow.com/questions/954884/what-special-characters-are-allowed-in-t-sql-column-name + if not re.match(r"\A[A-Za-z0-9_@\$]+\Z", parts[i]): # Reference: http://stackoverflow.com/questions/954884/what-special-characters-are-allowed-in-t-sql-column-name if Backend.getIdentifiedDbms() in (DBMS.MYSQL, DBMS.ACCESS): parts[i] = "`%s`" % parts[i].strip("`") elif Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.PGSQL, DBMS.DB2): @@ -3082,10 +3039,7 @@ def unserializeObject(value): Unserializes object from given serialized form """ - retVal = None - if value: - retVal = base64unpickle(value) - return retVal + return base64unpickle(value) if value else None def resetCounter(technique): """ @@ -3176,10 +3130,7 @@ def extractExpectedValue(value, expected): value = None elif expected == EXPECTED.INT: if isinstance(value, basestring): - if value.isdigit(): - value = int(value) - else: - value = None + value = int(value) if value.isdigit() else None return value @@ -3211,6 +3162,10 @@ def hashDBRetrieve(key, unserialize=False, checkConf=False): return _ def resetCookieJar(cookieJar): + """ + Cleans cookies from a given cookie jar + """ + if not conf.loC: cookieJar.clear() else: @@ -3223,5 +3178,9 @@ def resetCookieJar(cookieJar): raise sqlmapGenericException, errMsg def prioritySortColumns(columns): + """ + Sorts given column names by length in ascending order while those containing + string 'id' go first + """ _ = lambda x: x and "id" in x.lower() return sorted(sorted(columns, key=len), lambda x, y: -1 if _(x) and not _(y) else 1 if not _(x) and _(y) else 0) diff --git a/lib/core/settings.py b/lib/core/settings.py index fc38fa69b..cf84542bf 100644 --- a/lib/core/settings.py +++ b/lib/core/settings.py @@ -7,7 +7,9 @@ See the file 'doc/COPYING' for copying permission import logging import os +import re import subprocess +import string import sys from lib.core.enums import CUSTOM_LOGGING @@ -44,12 +46,6 @@ LOGGER_HANDLER.setFormatter(FORMATTER) LOGGER.addHandler(LOGGER_HANDLER) LOGGER.setLevel(logging.WARN) -# dump markers -DUMP_NEWLINE_MARKER = "__NEWLINE__" -DUMP_CR_MARKER = "__CARRIAGE_RETURN__" -DUMP_TAB_MARKER = "__TAB__" -DUMP_DEL_MARKER = "__DEL__" - # markers for special cases when parameter values contain html encoded characters PARAMETER_AMP_MARKER = "__AMP__" PARAMETER_SEMICOLON_MARKER = "__SEMICOLON__" @@ -475,6 +471,9 @@ MAX_TOTAL_REDIRECTIONS = 10 # Reference: http://www.tcpipguide.com/free/t_DNSLabelsNamesandSyntaxRules.htm MAX_DNS_LABEL = 63 +# Alphabet used for prefix and suffix strings of name resolution requests in DNS technique (excluding hexadecimal chars for not mixing with inner content) +DNS_BOUNDARIES_ALPHABET = re.sub("[a-fA-F]", "", string.letters) + # Connection chunk size (processing large responses in chunks to avoid MemoryError crashes - e.g. large table dump in full UNION/inband injections) MAX_CONNECTION_CHUNK_SIZE = 10 * 1024 * 1024 diff --git a/lib/techniques/dns/use.py b/lib/techniques/dns/use.py index 1c0ddeba8..2e2ddaf0d 100644 --- a/lib/techniques/dns/use.py +++ b/lib/techniques/dns/use.py @@ -29,6 +29,7 @@ from lib.core.data import kb from lib.core.data import logger from lib.core.data import queries from lib.core.enums import DBMS +from lib.core.settings import DNS_BOUNDARIES_ALPHABET from lib.core.settings import MAX_DNS_LABEL from lib.core.settings import PARTIAL_VALUE_MARKER from lib.core.unescaper import unescaper @@ -58,7 +59,7 @@ def dnsUse(payload, expression): while True: count += 1 - prefix, suffix = ("%s" % randomStr(length=3, alphabet=re.sub("[a-fA-F]", "", string.letters)) for _ in xrange(2)) + prefix, suffix = ("%s" % randomStr(length=3, alphabet=DNS_BOUNDARIES_ALPHABET) for _ in xrange(2)) chunk_length = MAX_DNS_LABEL / 2 if Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.MYSQL, DBMS.PGSQL) else MAX_DNS_LABEL / 4 - 2 _, _, _, _, _, _, fieldToCastStr, _ = agent.getFields(expression) nulledCastedField = agent.nullAndCastField(fieldToCastStr)