diff --git a/lib/core/common.py b/lib/core/common.py index 20ad267da..51aff3807 100644 --- a/lib/core/common.py +++ b/lib/core/common.py @@ -1192,11 +1192,14 @@ def expandAsteriskForColumns(expression): return expression -def getRange(count, dump=False, plusOne=False): +def getLimitRange(count, dump=False, plusOne=False): + """ + Returns range of values used in limit/offset constructs + """ + + retVal = None count = int(count) - indexRange = None - limitStart = 1 - limitStop = count + limitStart, limitStop = 1, count if dump: if isinstance(conf.limitStop, int) and conf.limitStop > 0 and conf.limitStop < limitStop: @@ -1205,11 +1208,15 @@ def getRange(count, dump=False, plusOne=False): if isinstance(conf.limitStart, int) and conf.limitStart > 0 and conf.limitStart <= limitStop: limitStart = conf.limitStart - indexRange = xrange(limitStart, limitStop + 1) if plusOne else xrange(limitStart - 1, limitStop) + retVal = xrange(limitStart, limitStop + 1) if plusOne else xrange(limitStart - 1, limitStop) - return indexRange + return retVal def parseUnionPage(output, unique=True): + """ + Returns resulting items from inband query inside provided page content + """ + if output is None: return None @@ -1250,7 +1257,7 @@ def parseUnionPage(output, unique=True): def parseFilePaths(page): """ - Detect (possible) absolute system paths inside the provided page content + Detects (possible) absolute system paths inside the provided page content """ if page: @@ -1265,32 +1272,6 @@ def parseFilePaths(page): if absFilePath not in kb.absFilePaths: kb.absFilePaths.add(absFilePath) -def getDelayQuery(andCond=False): - query = None - - if Backend.getIdentifiedDbms() in (DBMS.MYSQL, DBMS.PGSQL): - if not kb.data.banner: - conf.dbmsHandler.getVersionFromBanner() - - banVer = kb.bannerFp["dbmsVersion"] if 'dbmsVersion' in kb.bannerFp else None - - if banVer is None or (Backend.isDbms(DBMS.MYSQL) and banVer >= "5.0.12") or (Backend.isDbms(DBMS.PGSQL) and banVer >= "8.2"): - query = queries[Backend.getIdentifiedDbms()].timedelay.query % conf.timeSec - else: - query = queries[Backend.getIdentifiedDbms()].timedelay.query2 % conf.timeSec - elif Backend.isDbms(DBMS.FIREBIRD): - query = queries[Backend.getIdentifiedDbms()].timedelay.query - else: - query = queries[Backend.getIdentifiedDbms()].timedelay.query % conf.timeSec - - if andCond: - if Backend.getIdentifiedDbms() in ( DBMS.MYSQL, DBMS.SQLITE ): - query = query.replace("SELECT ", "") - elif Backend.isDbms(DBMS.FIREBIRD): - query = "(%s)>0" % query - - return query - def getLocalIP(): retVal = None try: @@ -1310,11 +1291,11 @@ def getRemoteIP(): def getFileType(filePath): try: - magicFileType = magic.from_file(filePath) + _ = magic.from_file(filePath) except: return "unknown" - return "text" if "ASCII" in magicFileType or "text" in magicFileType else "binary" + return "text" if "ASCII" in _ or "text" in _ else "binary" def getCharset(charsetType=None): asciiTbl = [] @@ -1354,15 +1335,14 @@ def getCharset(charsetType=None): return asciiTbl -def searchEnvPath(fileName): - envPaths = os.environ["PATH"] +def searchEnvPath(filename): result = None + path = os.environ.get("PATH", "") + paths = path.split(";") if IS_WIN else path.split(":") - envPaths = envPaths.split(";") if IS_WIN else envPaths.split(":") - - for envPath in envPaths: - envPath = envPath.replace(";", "") - result = os.path.exists(os.path.normpath(os.path.join(envPath, fileName))) + for _ in paths: + _ = _.replace(";", "") + result = os.path.exists(os.path.normpath(os.path.join(_, filename))) if result: break @@ -1394,28 +1374,40 @@ def urlEncodeCookieValues(cookieStr): else: return None -def directoryPath(path): +def directoryPath(filepath): + """ + Returns directory path for a given filepath + """ + retVal = None - if isWindowsDriveLetterPath(path): - retVal = ntpath.dirname(path) + if isWindowsDriveLetterPath(filepath): + retVal = ntpath.dirname(filepath) else: - retVal = posixpath.dirname(path) + retVal = posixpath.dirname(filepath) return retVal -def normalizePath(path): +def normalizePath(filepath): + """ + Returns normalized string representation of a given filepath + """ + retVal = None - if isWindowsDriveLetterPath(path): - retVal = ntpath.normpath(path) + if isWindowsDriveLetterPath(filepath): + retVal = ntpath.normpath(filepath) else: - retVal = posixpath.normpath(path) + retVal = posixpath.normpath(filepath) return retVal -def safeStringFormat(formatStr, params): - retVal = formatStr.replace("%d", "%s") +def safeStringFormat(format_, params): + """ + Avoids problems with inappropriate string format strings + """ + + retVal = format_.replace("%d", "%s") if isinstance(params, basestring): retVal = retVal.replace("%s", params) @@ -1435,23 +1427,12 @@ def safeStringFormat(formatStr, params): return retVal -def sanitizeAsciiString(subject): - if subject: - index = None - - for i in xrange(len(subject)): - if ord(subject[i]) >= 128: - index = i - break - - if index is None: - return subject - else: - return subject[:index] + "".join(subject[i] if ord(subject[i]) < 128 else '?' for i in xrange(index, len(subject))) - else: - return None - def getFilteredPageContent(page, onlyText=True): + """ + Returns filtered page content without script, style and/or comments + or all HTML tags + """ + retVal = page # only if the page's charset has been successfully identified @@ -2402,6 +2383,10 @@ def isTechniqueAvailable(technique): return getTechniqueData(technique) is not None def isInferenceAvailable(): + """ + Returns True whether techniques using inference technique are available + """ + return any(isTechniqueAvailable(_) for _ in (PAYLOAD.TECHNIQUE.BOOLEAN, PAYLOAD.TECHNIQUE.STACKED, PAYLOAD.TECHNIQUE.TIME)) def setOptimize(): @@ -2619,7 +2604,7 @@ def listToStrValue(value): def getExceptionFrameLocals(): """ Returns dictionary with local variable content from frame - where exception was raised + where exception has been raised """ retVal = {} @@ -2793,7 +2778,7 @@ def isNullValue(value): def expandMnemonics(mnemonics, parser, args): """ - Expand mnemonic options + Expands mnemonic options """ class MnemonicNode: @@ -2876,7 +2861,7 @@ def expandMnemonics(mnemonics, parser, args): def safeCSValue(value): """ - Returns value safe for CSV dumping. + Returns value safe for CSV dumping Reference: http://tools.ietf.org/html/rfc4180 """ @@ -2890,6 +2875,10 @@ def safeCSValue(value): return retVal def filterPairValues(values): + """ + Returns only list-like values with length 2 + """ + retVal = [] if not isNoneValue(values) and hasattr(values, '__iter__'): @@ -2973,6 +2962,10 @@ def asciifyUrl(url, forceQuote=False): return urlparse.urlunsplit([parts.scheme, netloc, path, query, parts.fragment]) def findPageForms(content, url, raise_=False, addToTargets=False): + """ + Parses given page content for possible forms + """ + class _(StringIO): def __init__(self, content, url): StringIO.__init__(self, unicodeencode(content, kb.pageEncoding) if isinstance(content, unicode) else content) @@ -3016,15 +3009,18 @@ def findPageForms(content, url, raise_=False, addToTargets=False): if not item.selected: item.selected = True break + request = form.click() url = urldecode(request.get_full_url(), kb.pageEncoding) method = request.get_method() data = request.get_data() if request.has_data() else None data = urldecode(data, kb.pageEncoding) if data and urlencode(DEFAULT_GET_POST_DELIMITER, None) not in data else data + if not data and method and method.upper() == HTTPMETHOD.POST: debugMsg = "invalid POST form with blank data detected" logger.debug(debugMsg) continue + target = (url, method, data, conf.cookie) retVal.add(target) else: @@ -3041,6 +3037,10 @@ def findPageForms(content, url, raise_=False, addToTargets=False): return retVal def getHostHeader(url): + """ + Returns proper Host header value for a given target URL + """ + retVal = urlparse.urlparse(url).netloc if any(retVal.endswith(':%d' % _) for _ in [80, 443]): @@ -3048,7 +3048,11 @@ def getHostHeader(url): return retVal -def executeCode(code, variables=None): +def evaluateCode(code, variables=None): + """ + Executes given python code given in a string form + """ + try: exec(code, variables) except Exception, ex: @@ -3056,21 +3060,39 @@ def executeCode(code, variables=None): raise sqlmapGenericException, errMsg def serializeObject(object_): + """ + Serializes given object + """ + return pickle.dumps(object_) def unserializeObject(value): + """ + Unserializes object from given serialized form + """ + retVal = None if value: retVal = pickle.loads(value.encode(UNICODE_ENCODING)) # pickle has problems with Unicode return retVal -def resetCounter(counter): - kb.counters[counter] = 0 +def resetCounter(technique): + """ + Resets query counter for a given technique + """ -def incrementCounter(counter): - if counter not in kb.counters: - resetCounter(counter) - kb.counters[counter] += 1 + kb.counters[technique] = 0 -def getCounter(counter): - return kb.counters.get(counter, 0) +def incrementCounter(technique): + """ + Increments query counter for a given technique + """ + + kb.counters[technique] = getCounter(technique) + 1 + +def getCounter(technique): + """ + Returns query counter for a given technique + """ + + return kb.counters.get(technique, 0) diff --git a/lib/core/convert.py b/lib/core/convert.py index 7aa5175cd..518960a51 100644 --- a/lib/core/convert.py +++ b/lib/core/convert.py @@ -91,34 +91,32 @@ def urlencode(value, safe="%&=", convall=False, limit=False): return value count = 0 - result = None + result = None if value is None else "" - if value is None: - return result + if value: + if convall or safe is None: + safe = "" - if convall or safe is None: - safe = "" + # corner case when character % really needs to be + # encoded (when not representing url encoded char) + # except in cases when tampering scripts are used + if all(map(lambda x: '%' in x, [safe, value])) and not kb.tamperFunctions: + value = re.sub("%(?![0-9a-fA-F]{2})", "%25", value, re.DOTALL | re.IGNORECASE) - # corner case when character % really needs to be - # encoded (when not representing url encoded char) - # except in cases when tampering scripts are used - if all(map(lambda x: '%' in x, [safe, value])) and not kb.tamperFunctions: - value = re.sub("%(?![0-9a-fA-F]{2})", "%25", value, re.DOTALL | re.IGNORECASE) + while True: + result = urllib.quote(utf8encode(value), safe) - while True: - result = urllib.quote(utf8encode(value), safe) - - if limit and len(result) > URLENCODE_CHAR_LIMIT: - if count >= len(URLENCODE_FAILSAFE_CHARS): - break - - while count < len(URLENCODE_FAILSAFE_CHARS): - safe += URLENCODE_FAILSAFE_CHARS[count] - count += 1 - if safe[-1] in value: + if limit and len(result) > URLENCODE_CHAR_LIMIT: + if count >= len(URLENCODE_FAILSAFE_CHARS): break - else: - break + + while count < len(URLENCODE_FAILSAFE_CHARS): + safe += URLENCODE_FAILSAFE_CHARS[count] + count += 1 + if safe[-1] in value: + break + else: + break return result diff --git a/lib/core/dump.py b/lib/core/dump.py index 9d314620a..39310bad0 100644 --- a/lib/core/dump.py +++ b/lib/core/dump.py @@ -41,45 +41,45 @@ class Dump: """ def __init__(self): - self.__outputFile = None - self.__outputFP = None - self.__outputBP = None - self.__lock = threading.Lock() + self._outputFile = None + self._outputFP = None + self._outputBP = None + self._lock = threading.Lock() - def __write(self, data, n=True, console=True): + def _write(self, data, n=True, console=True): text = "%s%s" % (data, "\n" if n else " ") if console: dataToStdout(text) if kb.get("multiThreadMode"): - self.__lock.acquire() + self._lock.acquire() - self.__outputBP.write(text) + self._outputBP.write(text) - if self.__outputBP.tell() > BUFFERED_LOG_SIZE: + if self._outputBP.tell() > BUFFERED_LOG_SIZE: self.flush() if kb.get("multiThreadMode"): - self.__lock.release() + self._lock.release() kb.dataOutputFlag = True def flush(self): - if self.__outputBP and self.__outputFP and self.__outputBP.tell() > 0: - _ = self.__outputBP.getvalue() - self.__outputBP.truncate(0) - self.__outputFP.write(_) + if self._outputBP and self._outputFP and self._outputBP.tell() > 0: + _ = self._outputBP.getvalue() + self._outputBP.truncate(0) + self._outputFP.write(_) - def __formatString(self, inpStr): + def _formatString(self, inpStr): return restoreDumpMarkedChars(getUnicode(inpStr)) def setOutputFile(self): - self.__outputFile = "%s%slog" % (conf.outputPath, os.sep) - self.__outputFP = codecs.open(self.__outputFile, "ab", UNICODE_ENCODING) - self.__outputBP = StringIO.StringIO() + self._outputFile = "%s%slog" % (conf.outputPath, os.sep) + self._outputFP = codecs.open(self._outputFile, "ab", UNICODE_ENCODING) + self._outputBP = StringIO.StringIO() def getOutputFile(self): - return self.__outputFile + return self._outputFile def string(self, header, data, sort=True): if isinstance(data, (list, tuple, set)): @@ -90,21 +90,21 @@ class Dump: data = getUnicode(data) if data: - data = self.__formatString(data) + data = self._formatString(data) if data[-1] == '\n': data = data[:-1] if "\n" in data: - self.__write("%s:\n---\n%s\n---\n" % (header, data)) + self._write("%s:\n---\n%s\n---\n" % (header, data)) else: - self.__write("%s: '%s'\n" % (header, data)) + self._write("%s: '%s'\n" % (header, data)) else: - self.__write("%s:\tNone\n" % header) + self._write("%s:\tNone\n" % header) def lister(self, header, elements, sort=True): if elements: - self.__write("%s [%d]:" % (header, len(elements))) + self._write("%s [%d]:" % (header, len(elements))) if sort: try: @@ -116,12 +116,12 @@ class Dump: for element in elements: if isinstance(element, basestring): - self.__write("[*] %s" % element) + self._write("[*] %s" % element) elif isinstance(element, (list, tuple, set)): - self.__write("[*] " + ", ".join(getUnicode(e) for e in element)) + self._write("[*] " + ", ".join(getUnicode(e) for e in element)) if elements: - self.__write("") + self._write("") def technic(self, header, data): self.string(header, data) @@ -147,13 +147,13 @@ class Dump: self.lister("database management system users", users) def userSettings(self, header, userSettings, subHeader): - self.__areAdmins = set() + self._areAdmins = set() if userSettings: - self.__write("%s:" % header) + self._write("%s:" % header) if isinstance(userSettings, (tuple, list, set)): - self.__areAdmins = userSettings[1] + self._areAdmins = userSettings[1] userSettings = userSettings[0] users = userSettings.keys() @@ -167,16 +167,16 @@ class Dump: else: stringSettings = " [%d]:" % len(settings) - if user in self.__areAdmins: - self.__write("[*] %s (administrator)%s" % (user, stringSettings)) + if user in self._areAdmins: + self._write("[*] %s (administrator)%s" % (user, stringSettings)) else: - self.__write("[*] %s%s" % (user, stringSettings)) + self._write("[*] %s%s" % (user, stringSettings)) if settings: settings.sort() for setting in settings: - self.__write(" %s: %s" % (subHeader, setting)) + self._write(" %s: %s" % (subHeader, setting)) print def dbs(self,dbs): @@ -198,23 +198,23 @@ class Dump: for db, tables in dbTables.items(): tables.sort() - self.__write("Database: %s" % db if db else "Current database") + self._write("Database: %s" % db if db else "Current database") if len(tables) == 1: - self.__write("[1 table]") + self._write("[1 table]") else: - self.__write("[%d tables]" % len(tables)) + self._write("[%d tables]" % len(tables)) - self.__write("+%s+" % lines) + self._write("+%s+" % lines) for table in tables: if isinstance(table, (list, tuple, set)): table = table[0] blank = " " * (maxlength - len(normalizeUnicode(table) or str(table))) - self.__write("| %s%s |" % (table, blank)) + self._write("| %s%s |" % (table, blank)) - self.__write("+%s+\n" % lines) + self._write("+%s+\n" % lines) else: self.string("tables", dbTables) @@ -246,17 +246,17 @@ class Dump: maxlength2 = max(maxlength2, len("TYPE")) lines2 = "-" * (maxlength2 + 2) - self.__write("Database: %s\nTable: %s" % (db if db else "Current database", table)) + self._write("Database: %s\nTable: %s" % (db if db else "Current database", table)) if len(columns) == 1: - self.__write("[1 column]") + self._write("[1 column]") else: - self.__write("[%d columns]" % len(columns)) + self._write("[%d columns]" % len(columns)) if colType is not None: - self.__write("+%s+%s+" % (lines1, lines2)) + self._write("+%s+%s+" % (lines1, lines2)) else: - self.__write("+%s+" % lines1) + self._write("+%s+" % lines1) blank1 = " " * (maxlength1 - len("COLUMN")) @@ -264,11 +264,11 @@ class Dump: blank2 = " " * (maxlength2 - len("TYPE")) if colType is not None: - self.__write("| Column%s | Type%s |" % (blank1, blank2)) - self.__write("+%s+%s+" % (lines1, lines2)) + self._write("| Column%s | Type%s |" % (blank1, blank2)) + self._write("+%s+%s+" % (lines1, lines2)) else: - self.__write("| Column%s |" % blank1) - self.__write("+%s+" % lines1) + self._write("| Column%s |" % blank1) + self._write("+%s+" % lines1) for column in colList: colType = columns[column] @@ -276,14 +276,14 @@ class Dump: if colType is not None: blank2 = " " * (maxlength2 - len(colType)) - self.__write("| %s%s | %s%s |" % (column, blank1, colType, blank2)) + self._write("| %s%s | %s%s |" % (column, blank1, colType, blank2)) else: - self.__write("| %s%s |" % (column, blank1)) + self._write("| %s%s |" % (column, blank1)) if colType is not None: - self.__write("+%s+%s+\n" % (lines1, lines2)) + self._write("+%s+%s+\n" % (lines1, lines2)) else: - self.__write("+%s+\n" % lines1) + self._write("+%s+\n" % lines1) def dbTablesCount(self, dbTables): if isinstance(dbTables, dict) and len(dbTables) > 0: @@ -296,16 +296,16 @@ class Dump: maxlength1 = max(maxlength1, len(normalizeUnicode(table) or str(table))) for db, counts in dbTables.items(): - self.__write("Database: %s" % db if db else "Current database") + self._write("Database: %s" % db if db else "Current database") lines1 = "-" * (maxlength1 + 2) blank1 = " " * (maxlength1 - len("Table")) lines2 = "-" * (maxlength2 + 2) blank2 = " " * (maxlength2 - len("Entries")) - self.__write("+%s+%s+" % (lines1, lines2)) - self.__write("| Table%s | Entries%s |" % (blank1, blank2)) - self.__write("+%s+%s+" % (lines1, lines2)) + self._write("+%s+%s+" % (lines1, lines2)) + self._write("| Table%s | Entries%s |" % (blank1, blank2)) + self._write("+%s+%s+" % (lines1, lines2)) sortedCounts = counts.keys() sortedCounts.sort(reverse=True) @@ -321,9 +321,9 @@ class Dump: for table in tables: blank1 = " " * (maxlength1 - len(normalizeUnicode(table) or str(table))) blank2 = " " * (maxlength2 - len(str(count))) - self.__write("| %s%s | %d%s |" % (table, blank1, count, blank2)) + self._write("| %s%s | %d%s |" % (table, blank1, count, blank2)) - self.__write("+%s+%s+\n" % (lines1, lines2)) + self._write("+%s+%s+\n" % (lines1, lines2)) else: logger.error("unable to retrieve the number of entries for any table") @@ -365,7 +365,7 @@ class Dump: separator += "+%s" % lines separator += "+" - self.__write("Database: %s\nTable: %s" % (db if db else "Current database", table)) + self._write("Database: %s\nTable: %s" % (db if db else "Current database", table)) if conf.replicate: cols = [] @@ -402,11 +402,11 @@ class Dump: rtable = replication.createTable(table, cols) if count == 1: - self.__write("[1 entry]") + self._write("[1 entry]") else: - self.__write("[%d entries]" % count) + self._write("[%d entries]" % count) - self.__write(separator) + self._write(separator) for column in columns: if column != "__infos__": @@ -414,7 +414,7 @@ class Dump: maxlength = int(info["length"]) blank = " " * (maxlength - len(column)) - self.__write("| %s%s" % (column, blank), n=False) + self._write("| %s%s" % (column, blank), n=False) if not conf.replicate: if field == fields: @@ -424,7 +424,7 @@ class Dump: field += 1 - self.__write("|\n%s" % separator) + self._write("|\n%s" % separator) if not conf.replicate: dataToDumpFile(dumpFP, "\n") @@ -461,7 +461,7 @@ class Dump: values.append(value) maxlength = int(info["length"]) blank = " " * (maxlength - len(value)) - self.__write("| %s%s" % (value, blank), n=False, console=console) + self._write("| %s%s" % (value, blank), n=False, console=console) if not conf.replicate: if field == fields: @@ -477,12 +477,12 @@ class Dump: except sqlmapValueException: pass - self.__write("|", console=console) + self._write("|", console=console) if not conf.replicate: dataToDumpFile(dumpFP, "\n") - self.__write("%s\n" % separator) + self._write("%s\n" % separator) if conf.replicate: rtable.endTransaction() @@ -502,26 +502,26 @@ class Dump: msg = "Column%s found in the " % colConsiderStr msg += "following databases:" - self.__write(msg) + self._write(msg) - printDbs = {} + _ = {} for db, tblData in dbs.items(): for tbl, colData in tblData.items(): for col, dataType in colData.items(): if column.lower() in col.lower(): - if db in printDbs: - if tbl in printDbs[db]: - printDbs[db][tbl][col] = dataType + if db in _: + if tbl in _[db]: + _[db][tbl][col] = dataType else: - printDbs[db][tbl] = { col: dataType } + _[db][tbl] = { col: dataType } else: - printDbs[db] = {} - printDbs[db][tbl] = { col: dataType } + _[db] = {} + _[db][tbl] = { col: dataType } continue - self.dbTableColumns(printDbs) + self.dbTableColumns(_) def query(self, query, queryRes): self.string(query, queryRes) diff --git a/lib/core/settings.py b/lib/core/settings.py index 5a9f4ca30..8b4144202 100644 --- a/lib/core/settings.py +++ b/lib/core/settings.py @@ -249,6 +249,9 @@ SQL_STATEMENTS = { # string representation for NULL value NULL = "NULL" +# string representation for current database +CURRENT_DB = "CD" + # Regular expressions used for parsing error messages (--parse-errors) ERROR_PARSING_REGEXES = ( r"[^<]*(fatal|error|warning|exception)[^<]*:?\s*(?P.+?)", diff --git a/lib/request/basic.py b/lib/request/basic.py index 496466364..a5ae3c97d 100644 --- a/lib/request/basic.py +++ b/lib/request/basic.py @@ -22,7 +22,6 @@ from lib.core.common import getUnicode from lib.core.common import isWindowsDriveLetterPath from lib.core.common import posixToNtSlashes from lib.core.common import readInput -from lib.core.common import sanitizeAsciiString from lib.core.common import singleTimeLogMessage from lib.core.data import conf from lib.core.data import kb diff --git a/lib/request/connect.py b/lib/request/connect.py index faf228c48..5e30f0007 100644 --- a/lib/request/connect.py +++ b/lib/request/connect.py @@ -24,7 +24,7 @@ from lib.core.common import average from lib.core.common import calculateDeltaSeconds from lib.core.common import clearConsoleLine from lib.core.common import cpuThrottle -from lib.core.common import executeCode +from lib.core.common import evaluateCode from lib.core.common import extractRegexResult from lib.core.common import getCurrentThreadData from lib.core.common import getFilteredPageContent @@ -636,10 +636,10 @@ class Connect: for part in item.split(delimiter): if '=' in part: name, value = part.split('=', 1) - executeCode("%s='%s'" % (name, value), variables) + evaluateCode("%s='%s'" % (name, value), variables) originals.update(variables) - executeCode(conf.evalCode, variables) + evaluateCode(conf.evalCode, variables) for name, value in variables.items(): if name != "__builtins__" and originals.get(name, "") != value: diff --git a/plugins/dbms/maxdb/enumeration.py b/plugins/dbms/maxdb/enumeration.py index 43d7d6fa1..5d1d2a4b4 100644 --- a/plugins/dbms/maxdb/enumeration.py +++ b/plugins/dbms/maxdb/enumeration.py @@ -20,6 +20,7 @@ from lib.core.data import queries from lib.core.enums import PAYLOAD from lib.core.exception import sqlmapMissingMandatoryOptionException from lib.core.exception import sqlmapNoneDataException +from lib.core.settings import CURRENT_DB from plugins.generic.enumeration import Enumeration as GenericEnumeration class Enumeration(GenericEnumeration): @@ -60,7 +61,7 @@ class Enumeration(GenericEnumeration): self.forceDbmsEnum() - if conf.db == "CD": + if conf.db == CURRENT_DB: conf.db = self.getCurrentDb() if conf.db: @@ -97,7 +98,7 @@ class Enumeration(GenericEnumeration): def getColumns(self, onlyColNames=False): self.forceDbmsEnum() - if conf.db is None or conf.db == "CD": + if conf.db is None or conf.db == CURRENT_DB: if conf.db is None: warnMsg = "missing database parameter, sqlmap is going " warnMsg += "to use the current database to enumerate " diff --git a/plugins/dbms/mssqlserver/enumeration.py b/plugins/dbms/mssqlserver/enumeration.py index 1f30d4e42..fe7b60237 100644 --- a/plugins/dbms/mssqlserver/enumeration.py +++ b/plugins/dbms/mssqlserver/enumeration.py @@ -10,7 +10,7 @@ See the file 'doc/COPYING' for copying permission from lib.core.agent import agent from lib.core.common import arrayizeValue from lib.core.common import Backend -from lib.core.common import getRange +from lib.core.common import getLimitRange from lib.core.common import isInferenceAvailable from lib.core.common import isNoneValue from lib.core.common import isNumPosStrValue @@ -25,6 +25,7 @@ from lib.core.data import queries from lib.core.enums import EXPECTED from lib.core.enums import PAYLOAD from lib.core.exception import sqlmapNoneDataException +from lib.core.settings import CURRENT_DB from lib.request import inject from plugins.generic.enumeration import Enumeration as GenericEnumeration @@ -68,7 +69,7 @@ class Enumeration(GenericEnumeration): self.forceDbmsEnum() - if conf.db == "CD": + if conf.db == CURRENT_DB: conf.db = self.getCurrentDb() if conf.db: @@ -230,7 +231,7 @@ class Enumeration(GenericEnumeration): continue - indexRange = getRange(count) + indexRange = getLimitRange(count) for index in indexRange: query = rootQuery.blind.query @@ -347,7 +348,7 @@ class Enumeration(GenericEnumeration): continue - indexRange = getRange(count) + indexRange = getLimitRange(count) for index in indexRange: query = rootQuery.blind.query diff --git a/plugins/dbms/mssqlserver/filesystem.py b/plugins/dbms/mssqlserver/filesystem.py index 81d499a0f..f27de444c 100644 --- a/plugins/dbms/mssqlserver/filesystem.py +++ b/plugins/dbms/mssqlserver/filesystem.py @@ -11,7 +11,7 @@ import codecs import ntpath import os -from lib.core.common import getRange +from lib.core.common import getLimitRange from lib.core.common import isNumPosStrValue from lib.core.common import isTechniqueAvailable from lib.core.common import posixToNtSlashes @@ -105,7 +105,7 @@ class Filesystem(GenericFilesystem): errMsg += "file '%s'" % rFile raise sqlmapNoneDataException(errMsg) - indexRange = getRange(count) + indexRange = getLimitRange(count) for index in indexRange: chunk = inject.getValue("SELECT TOP 1 %s FROM %s WHERE %s NOT IN (SELECT TOP %d %s FROM %s ORDER BY id ASC) ORDER BY id ASC" % (self.tblField, hexTbl, self.tblField, index, self.tblField, hexTbl), unpack=False, resumeValue=False, unique=False, charsetType=3) diff --git a/plugins/dbms/oracle/enumeration.py b/plugins/dbms/oracle/enumeration.py index 57e39fab5..7add219e2 100644 --- a/plugins/dbms/oracle/enumeration.py +++ b/plugins/dbms/oracle/enumeration.py @@ -9,7 +9,7 @@ See the file 'doc/COPYING' for copying permission from lib.core.agent import agent from lib.core.common import Backend -from lib.core.common import getRange +from lib.core.common import getLimitRange from lib.core.common import isInferenceAvailable from lib.core.common import isNoneValue from lib.core.common import isNumPosStrValue @@ -142,7 +142,7 @@ class Enumeration(GenericEnumeration): roles = set() - indexRange = getRange(count, plusOne=True) + indexRange = getLimitRange(count, plusOne=True) for index in indexRange: if query2: diff --git a/plugins/dbms/sybase/enumeration.py b/plugins/dbms/sybase/enumeration.py index 068c6f497..b006142a9 100644 --- a/plugins/dbms/sybase/enumeration.py +++ b/plugins/dbms/sybase/enumeration.py @@ -22,6 +22,7 @@ from lib.core.dicts import sybaseTypes from lib.core.enums import PAYLOAD from lib.core.exception import sqlmapMissingMandatoryOptionException from lib.core.exception import sqlmapNoneDataException +from lib.core.settings import CURRENT_DB from plugins.generic.enumeration import Enumeration as GenericEnumeration class Enumeration(GenericEnumeration): @@ -114,7 +115,7 @@ class Enumeration(GenericEnumeration): self.forceDbmsEnum() - if conf.db == "CD": + if conf.db == CURRENT_DB: conf.db = self.getCurrentDb() if conf.db: @@ -160,7 +161,7 @@ class Enumeration(GenericEnumeration): def getColumns(self, onlyColNames=False): self.forceDbmsEnum() - if conf.db is None or conf.db == "CD": + if conf.db is None or conf.db == CURRENT_DB: if conf.db is None: warnMsg = "missing database parameter, sqlmap is going " warnMsg += "to use the current database to enumerate " diff --git a/plugins/generic/enumeration.py b/plugins/generic/enumeration.py index bdf632f58..5e42954bf 100644 --- a/plugins/generic/enumeration.py +++ b/plugins/generic/enumeration.py @@ -17,7 +17,7 @@ from lib.core.common import Backend from lib.core.common import clearConsoleLine from lib.core.common import dataToStdout from lib.core.common import filterPairValues -from lib.core.common import getRange +from lib.core.common import getLimitRange from lib.core.common import getCompiledRegex from lib.core.common import getUnicode from lib.core.common import isInferenceAvailable @@ -59,6 +59,7 @@ from lib.core.exception import sqlmapUserQuitException from lib.core.session import setOs from lib.core.settings import CONCAT_ROW_DELIMITER from lib.core.settings import CONCAT_VALUE_DELIMITER +from lib.core.settings import CURRENT_DB from lib.core.settings import DEFAULT_MSSQL_SCHEMA from lib.core.settings import MAX_INT from lib.core.settings import SQL_STATEMENTS @@ -200,11 +201,8 @@ class Enumeration: errMsg = "unable to retrieve the number of database users" raise sqlmapNoneDataException, errMsg - if Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2): - plusOne = True - else: - plusOne = False - indexRange = getRange(count, plusOne=plusOne) + plusOne = Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2) + indexRange = getLimitRange(count, plusOne=plusOne) for index in indexRange: if Backend.getIdentifiedDbms() in (DBMS.SYBASE, DBMS.MAXDB): @@ -350,11 +348,8 @@ class Enumeration: passwords = [] - if Backend.isDbms(DBMS.ORACLE): - plusOne = True - else: - plusOne = False - indexRange = getRange(count, plusOne=plusOne) + plusOne = Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2) + indexRange = getLimitRange(count, plusOne=plusOne) for index in indexRange: if Backend.isDbms(DBMS.MSSQL): @@ -593,11 +588,8 @@ class Enumeration: privileges = set() - if Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2): - plusOne = True - else: - plusOne = False - indexRange = getRange(count, plusOne=plusOne) + plusOne = Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2) + indexRange = getLimitRange(count, plusOne=plusOne) for index in indexRange: if Backend.isDbms(DBMS.MYSQL) and not kb.data.has_information_schema: @@ -760,11 +752,8 @@ class Enumeration: errMsg = "unable to retrieve the number of databases" logger.error(errMsg) else: - if Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2): - plusOne = True - else: - plusOne = False - indexRange = getRange(count, plusOne=plusOne) + plusOne = Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2) + indexRange = getLimitRange(count, plusOne=plusOne) for index in indexRange: if Backend.isDbms(DBMS.SYBASE): @@ -820,7 +809,7 @@ class Enumeration: else: return tables - if conf.db == "CD": + if conf.db == CURRENT_DB: conf.db = self.getCurrentDb() if conf.db and Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2): @@ -930,11 +919,8 @@ class Enumeration: tables = [] - if Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2): - plusOne = True - else: - plusOne = False - indexRange = getRange(count, plusOne=plusOne) + plusOne = Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2) + indexRange = getLimitRange(count, plusOne=plusOne) for index in indexRange: if Backend.isDbms(DBMS.SYBASE): @@ -977,7 +963,7 @@ class Enumeration: def getColumns(self, onlyColNames=False, colTuple=None, bruteForce=None): self.forceDbmsEnum() - if conf.db is None or conf.db == "CD": + if conf.db is None or conf.db == CURRENT_DB: if conf.db is None: warnMsg = "missing database parameter, sqlmap is going " warnMsg += "to use the current database to enumerate " @@ -1226,7 +1212,7 @@ class Enumeration: table = {} columns = {} - indexRange = getRange(count) + indexRange = getLimitRange(count) for index in indexRange: if Backend.getIdentifiedDbms() in ( DBMS.MYSQL, DBMS.PGSQL ): @@ -1506,7 +1492,7 @@ class Enumeration: def dumpTable(self, foundData=None): self.forceDbmsEnum() - if conf.db is None or conf.db == "CD": + if conf.db is None or conf.db == CURRENT_DB: if conf.db is None: warnMsg = "missing database parameter, sqlmap is going " warnMsg += "to use the current database to enumerate " @@ -1719,11 +1705,8 @@ class Enumeration: entries, lengths = retVal else: - if Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2): - plusOne = True - else: - plusOne = False - indexRange = getRange(count, dump=True, plusOne=plusOne) + plusOne = Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2) + indexRange = getLimitRange(count, dump=True, plusOne=plusOne) try: for index in indexRange: @@ -1967,7 +1950,7 @@ class Enumeration: continue - indexRange = getRange(count) + indexRange = getLimitRange(count) for index in indexRange: if Backend.isDbms(DBMS.MYSQL) and not kb.data.has_information_schema: @@ -2032,7 +2015,7 @@ class Enumeration: infoMsg += " '%s'" % unsafeSQLIdentificatorNaming(tbl) logger.info(infoMsg) - if conf.db and conf.db != "CD": + if conf.db and conf.db != CURRENT_DB: _ = conf.db.split(",") whereDbsQuery = "".join(" AND '%s' = %s" % (unsafeSQLIdentificatorNaming(db), dbCond) for db in _) infoMsg = "for database%s '%s'" % ("s" if len(_) > 1 else "", ", ".join(db for db in _)) @@ -2085,7 +2068,7 @@ class Enumeration: continue - indexRange = getRange(count) + indexRange = getLimitRange(count) for index in indexRange: query = rootQuery.blind.query @@ -2130,7 +2113,7 @@ class Enumeration: continue - indexRange = getRange(count) + indexRange = getLimitRange(count) for index in indexRange: query = rootQuery.blind.query2 @@ -2201,7 +2184,7 @@ class Enumeration: foundCols[column] = {} - if conf.db and conf.db != "CD": + if conf.db and conf.db != CURRENT_DB: _ = conf.db.split(",") whereDbsQuery = "".join(" AND '%s' = %s" % (unsafeSQLIdentificatorNaming(db), dbCond) for db in _) infoMsg = "for database%s '%s'" % ("s" if len(_) > 1 else "", ", ".join(db for db in _)) @@ -2277,7 +2260,7 @@ class Enumeration: continue - indexRange = getRange(count) + indexRange = getLimitRange(count) for index in indexRange: query = rootQuery.blind.query @@ -2328,7 +2311,7 @@ class Enumeration: continue - indexRange = getRange(count) + indexRange = getLimitRange(count) for index in indexRange: query = rootQuery.blind.query2