some more refactorings

This commit is contained in:
Miroslav Stampar 2012-02-16 14:42:28 +00:00
parent 6632aa7308
commit dcf7277a0f
12 changed files with 245 additions and 237 deletions

View File

@ -1192,11 +1192,14 @@ def expandAsteriskForColumns(expression):
return expression return expression
def getRange(count, dump=False, plusOne=False): def getLimitRange(count, dump=False, plusOne=False):
"""
Returns range of values used in limit/offset constructs
"""
retVal = None
count = int(count) count = int(count)
indexRange = None limitStart, limitStop = 1, count
limitStart = 1
limitStop = count
if dump: if dump:
if isinstance(conf.limitStop, int) and conf.limitStop > 0 and conf.limitStop < limitStop: if isinstance(conf.limitStop, int) and conf.limitStop > 0 and conf.limitStop < limitStop:
@ -1205,11 +1208,15 @@ def getRange(count, dump=False, plusOne=False):
if isinstance(conf.limitStart, int) and conf.limitStart > 0 and conf.limitStart <= limitStop: if isinstance(conf.limitStart, int) and conf.limitStart > 0 and conf.limitStart <= limitStop:
limitStart = conf.limitStart limitStart = conf.limitStart
indexRange = xrange(limitStart, limitStop + 1) if plusOne else xrange(limitStart - 1, limitStop) retVal = xrange(limitStart, limitStop + 1) if plusOne else xrange(limitStart - 1, limitStop)
return indexRange return retVal
def parseUnionPage(output, unique=True): def parseUnionPage(output, unique=True):
"""
Returns resulting items from inband query inside provided page content
"""
if output is None: if output is None:
return None return None
@ -1250,7 +1257,7 @@ def parseUnionPage(output, unique=True):
def parseFilePaths(page): def parseFilePaths(page):
""" """
Detect (possible) absolute system paths inside the provided page content Detects (possible) absolute system paths inside the provided page content
""" """
if page: if page:
@ -1265,32 +1272,6 @@ def parseFilePaths(page):
if absFilePath not in kb.absFilePaths: if absFilePath not in kb.absFilePaths:
kb.absFilePaths.add(absFilePath) kb.absFilePaths.add(absFilePath)
def getDelayQuery(andCond=False):
query = None
if Backend.getIdentifiedDbms() in (DBMS.MYSQL, DBMS.PGSQL):
if not kb.data.banner:
conf.dbmsHandler.getVersionFromBanner()
banVer = kb.bannerFp["dbmsVersion"] if 'dbmsVersion' in kb.bannerFp else None
if banVer is None or (Backend.isDbms(DBMS.MYSQL) and banVer >= "5.0.12") or (Backend.isDbms(DBMS.PGSQL) and banVer >= "8.2"):
query = queries[Backend.getIdentifiedDbms()].timedelay.query % conf.timeSec
else:
query = queries[Backend.getIdentifiedDbms()].timedelay.query2 % conf.timeSec
elif Backend.isDbms(DBMS.FIREBIRD):
query = queries[Backend.getIdentifiedDbms()].timedelay.query
else:
query = queries[Backend.getIdentifiedDbms()].timedelay.query % conf.timeSec
if andCond:
if Backend.getIdentifiedDbms() in ( DBMS.MYSQL, DBMS.SQLITE ):
query = query.replace("SELECT ", "")
elif Backend.isDbms(DBMS.FIREBIRD):
query = "(%s)>0" % query
return query
def getLocalIP(): def getLocalIP():
retVal = None retVal = None
try: try:
@ -1310,11 +1291,11 @@ def getRemoteIP():
def getFileType(filePath): def getFileType(filePath):
try: try:
magicFileType = magic.from_file(filePath) _ = magic.from_file(filePath)
except: except:
return "unknown" return "unknown"
return "text" if "ASCII" in magicFileType or "text" in magicFileType else "binary" return "text" if "ASCII" in _ or "text" in _ else "binary"
def getCharset(charsetType=None): def getCharset(charsetType=None):
asciiTbl = [] asciiTbl = []
@ -1354,15 +1335,14 @@ def getCharset(charsetType=None):
return asciiTbl return asciiTbl
def searchEnvPath(fileName): def searchEnvPath(filename):
envPaths = os.environ["PATH"]
result = None result = None
path = os.environ.get("PATH", "")
paths = path.split(";") if IS_WIN else path.split(":")
envPaths = envPaths.split(";") if IS_WIN else envPaths.split(":") for _ in paths:
_ = _.replace(";", "")
for envPath in envPaths: result = os.path.exists(os.path.normpath(os.path.join(_, filename)))
envPath = envPath.replace(";", "")
result = os.path.exists(os.path.normpath(os.path.join(envPath, fileName)))
if result: if result:
break break
@ -1394,28 +1374,40 @@ def urlEncodeCookieValues(cookieStr):
else: else:
return None return None
def directoryPath(path): def directoryPath(filepath):
"""
Returns directory path for a given filepath
"""
retVal = None retVal = None
if isWindowsDriveLetterPath(path): if isWindowsDriveLetterPath(filepath):
retVal = ntpath.dirname(path) retVal = ntpath.dirname(filepath)
else: else:
retVal = posixpath.dirname(path) retVal = posixpath.dirname(filepath)
return retVal return retVal
def normalizePath(path): def normalizePath(filepath):
"""
Returns normalized string representation of a given filepath
"""
retVal = None retVal = None
if isWindowsDriveLetterPath(path): if isWindowsDriveLetterPath(filepath):
retVal = ntpath.normpath(path) retVal = ntpath.normpath(filepath)
else: else:
retVal = posixpath.normpath(path) retVal = posixpath.normpath(filepath)
return retVal return retVal
def safeStringFormat(formatStr, params): def safeStringFormat(format_, params):
retVal = formatStr.replace("%d", "%s") """
Avoids problems with inappropriate string format strings
"""
retVal = format_.replace("%d", "%s")
if isinstance(params, basestring): if isinstance(params, basestring):
retVal = retVal.replace("%s", params) retVal = retVal.replace("%s", params)
@ -1435,23 +1427,12 @@ def safeStringFormat(formatStr, params):
return retVal return retVal
def sanitizeAsciiString(subject):
if subject:
index = None
for i in xrange(len(subject)):
if ord(subject[i]) >= 128:
index = i
break
if index is None:
return subject
else:
return subject[:index] + "".join(subject[i] if ord(subject[i]) < 128 else '?' for i in xrange(index, len(subject)))
else:
return None
def getFilteredPageContent(page, onlyText=True): def getFilteredPageContent(page, onlyText=True):
"""
Returns filtered page content without script, style and/or comments
or all HTML tags
"""
retVal = page retVal = page
# only if the page's charset has been successfully identified # only if the page's charset has been successfully identified
@ -2402,6 +2383,10 @@ def isTechniqueAvailable(technique):
return getTechniqueData(technique) is not None return getTechniqueData(technique) is not None
def isInferenceAvailable(): def isInferenceAvailable():
"""
Returns True whether techniques using inference technique are available
"""
return any(isTechniqueAvailable(_) for _ in (PAYLOAD.TECHNIQUE.BOOLEAN, PAYLOAD.TECHNIQUE.STACKED, PAYLOAD.TECHNIQUE.TIME)) return any(isTechniqueAvailable(_) for _ in (PAYLOAD.TECHNIQUE.BOOLEAN, PAYLOAD.TECHNIQUE.STACKED, PAYLOAD.TECHNIQUE.TIME))
def setOptimize(): def setOptimize():
@ -2619,7 +2604,7 @@ def listToStrValue(value):
def getExceptionFrameLocals(): def getExceptionFrameLocals():
""" """
Returns dictionary with local variable content from frame Returns dictionary with local variable content from frame
where exception was raised where exception has been raised
""" """
retVal = {} retVal = {}
@ -2793,7 +2778,7 @@ def isNullValue(value):
def expandMnemonics(mnemonics, parser, args): def expandMnemonics(mnemonics, parser, args):
""" """
Expand mnemonic options Expands mnemonic options
""" """
class MnemonicNode: class MnemonicNode:
@ -2876,7 +2861,7 @@ def expandMnemonics(mnemonics, parser, args):
def safeCSValue(value): def safeCSValue(value):
""" """
Returns value safe for CSV dumping. Returns value safe for CSV dumping
Reference: http://tools.ietf.org/html/rfc4180 Reference: http://tools.ietf.org/html/rfc4180
""" """
@ -2890,6 +2875,10 @@ def safeCSValue(value):
return retVal return retVal
def filterPairValues(values): def filterPairValues(values):
"""
Returns only list-like values with length 2
"""
retVal = [] retVal = []
if not isNoneValue(values) and hasattr(values, '__iter__'): if not isNoneValue(values) and hasattr(values, '__iter__'):
@ -2973,6 +2962,10 @@ def asciifyUrl(url, forceQuote=False):
return urlparse.urlunsplit([parts.scheme, netloc, path, query, parts.fragment]) return urlparse.urlunsplit([parts.scheme, netloc, path, query, parts.fragment])
def findPageForms(content, url, raise_=False, addToTargets=False): def findPageForms(content, url, raise_=False, addToTargets=False):
"""
Parses given page content for possible forms
"""
class _(StringIO): class _(StringIO):
def __init__(self, content, url): def __init__(self, content, url):
StringIO.__init__(self, unicodeencode(content, kb.pageEncoding) if isinstance(content, unicode) else content) StringIO.__init__(self, unicodeencode(content, kb.pageEncoding) if isinstance(content, unicode) else content)
@ -3016,15 +3009,18 @@ def findPageForms(content, url, raise_=False, addToTargets=False):
if not item.selected: if not item.selected:
item.selected = True item.selected = True
break break
request = form.click() request = form.click()
url = urldecode(request.get_full_url(), kb.pageEncoding) url = urldecode(request.get_full_url(), kb.pageEncoding)
method = request.get_method() method = request.get_method()
data = request.get_data() if request.has_data() else None data = request.get_data() if request.has_data() else None
data = urldecode(data, kb.pageEncoding) if data and urlencode(DEFAULT_GET_POST_DELIMITER, None) not in data else data data = urldecode(data, kb.pageEncoding) if data and urlencode(DEFAULT_GET_POST_DELIMITER, None) not in data else data
if not data and method and method.upper() == HTTPMETHOD.POST: if not data and method and method.upper() == HTTPMETHOD.POST:
debugMsg = "invalid POST form with blank data detected" debugMsg = "invalid POST form with blank data detected"
logger.debug(debugMsg) logger.debug(debugMsg)
continue continue
target = (url, method, data, conf.cookie) target = (url, method, data, conf.cookie)
retVal.add(target) retVal.add(target)
else: else:
@ -3041,6 +3037,10 @@ def findPageForms(content, url, raise_=False, addToTargets=False):
return retVal return retVal
def getHostHeader(url): def getHostHeader(url):
"""
Returns proper Host header value for a given target URL
"""
retVal = urlparse.urlparse(url).netloc retVal = urlparse.urlparse(url).netloc
if any(retVal.endswith(':%d' % _) for _ in [80, 443]): if any(retVal.endswith(':%d' % _) for _ in [80, 443]):
@ -3048,7 +3048,11 @@ def getHostHeader(url):
return retVal return retVal
def executeCode(code, variables=None): def evaluateCode(code, variables=None):
"""
Executes given python code given in a string form
"""
try: try:
exec(code, variables) exec(code, variables)
except Exception, ex: except Exception, ex:
@ -3056,21 +3060,39 @@ def executeCode(code, variables=None):
raise sqlmapGenericException, errMsg raise sqlmapGenericException, errMsg
def serializeObject(object_): def serializeObject(object_):
"""
Serializes given object
"""
return pickle.dumps(object_) return pickle.dumps(object_)
def unserializeObject(value): def unserializeObject(value):
"""
Unserializes object from given serialized form
"""
retVal = None retVal = None
if value: if value:
retVal = pickle.loads(value.encode(UNICODE_ENCODING)) # pickle has problems with Unicode retVal = pickle.loads(value.encode(UNICODE_ENCODING)) # pickle has problems with Unicode
return retVal return retVal
def resetCounter(counter): def resetCounter(technique):
kb.counters[counter] = 0 """
Resets query counter for a given technique
"""
def incrementCounter(counter): kb.counters[technique] = 0
if counter not in kb.counters:
resetCounter(counter)
kb.counters[counter] += 1
def getCounter(counter): def incrementCounter(technique):
return kb.counters.get(counter, 0) """
Increments query counter for a given technique
"""
kb.counters[technique] = getCounter(technique) + 1
def getCounter(technique):
"""
Returns query counter for a given technique
"""
return kb.counters.get(technique, 0)

View File

@ -91,34 +91,32 @@ def urlencode(value, safe="%&=", convall=False, limit=False):
return value return value
count = 0 count = 0
result = None result = None if value is None else ""
if value is None: if value:
return result if convall or safe is None:
safe = ""
if convall or safe is None: # corner case when character % really needs to be
safe = "" # encoded (when not representing url encoded char)
# except in cases when tampering scripts are used
if all(map(lambda x: '%' in x, [safe, value])) and not kb.tamperFunctions:
value = re.sub("%(?![0-9a-fA-F]{2})", "%25", value, re.DOTALL | re.IGNORECASE)
# corner case when character % really needs to be while True:
# encoded (when not representing url encoded char) result = urllib.quote(utf8encode(value), safe)
# except in cases when tampering scripts are used
if all(map(lambda x: '%' in x, [safe, value])) and not kb.tamperFunctions:
value = re.sub("%(?![0-9a-fA-F]{2})", "%25", value, re.DOTALL | re.IGNORECASE)
while True: if limit and len(result) > URLENCODE_CHAR_LIMIT:
result = urllib.quote(utf8encode(value), safe) if count >= len(URLENCODE_FAILSAFE_CHARS):
if limit and len(result) > URLENCODE_CHAR_LIMIT:
if count >= len(URLENCODE_FAILSAFE_CHARS):
break
while count < len(URLENCODE_FAILSAFE_CHARS):
safe += URLENCODE_FAILSAFE_CHARS[count]
count += 1
if safe[-1] in value:
break break
else:
break while count < len(URLENCODE_FAILSAFE_CHARS):
safe += URLENCODE_FAILSAFE_CHARS[count]
count += 1
if safe[-1] in value:
break
else:
break
return result return result

View File

@ -41,45 +41,45 @@ class Dump:
""" """
def __init__(self): def __init__(self):
self.__outputFile = None self._outputFile = None
self.__outputFP = None self._outputFP = None
self.__outputBP = None self._outputBP = None
self.__lock = threading.Lock() self._lock = threading.Lock()
def __write(self, data, n=True, console=True): def _write(self, data, n=True, console=True):
text = "%s%s" % (data, "\n" if n else " ") text = "%s%s" % (data, "\n" if n else " ")
if console: if console:
dataToStdout(text) dataToStdout(text)
if kb.get("multiThreadMode"): if kb.get("multiThreadMode"):
self.__lock.acquire() self._lock.acquire()
self.__outputBP.write(text) self._outputBP.write(text)
if self.__outputBP.tell() > BUFFERED_LOG_SIZE: if self._outputBP.tell() > BUFFERED_LOG_SIZE:
self.flush() self.flush()
if kb.get("multiThreadMode"): if kb.get("multiThreadMode"):
self.__lock.release() self._lock.release()
kb.dataOutputFlag = True kb.dataOutputFlag = True
def flush(self): def flush(self):
if self.__outputBP and self.__outputFP and self.__outputBP.tell() > 0: if self._outputBP and self._outputFP and self._outputBP.tell() > 0:
_ = self.__outputBP.getvalue() _ = self._outputBP.getvalue()
self.__outputBP.truncate(0) self._outputBP.truncate(0)
self.__outputFP.write(_) self._outputFP.write(_)
def __formatString(self, inpStr): def _formatString(self, inpStr):
return restoreDumpMarkedChars(getUnicode(inpStr)) return restoreDumpMarkedChars(getUnicode(inpStr))
def setOutputFile(self): def setOutputFile(self):
self.__outputFile = "%s%slog" % (conf.outputPath, os.sep) self._outputFile = "%s%slog" % (conf.outputPath, os.sep)
self.__outputFP = codecs.open(self.__outputFile, "ab", UNICODE_ENCODING) self._outputFP = codecs.open(self._outputFile, "ab", UNICODE_ENCODING)
self.__outputBP = StringIO.StringIO() self._outputBP = StringIO.StringIO()
def getOutputFile(self): def getOutputFile(self):
return self.__outputFile return self._outputFile
def string(self, header, data, sort=True): def string(self, header, data, sort=True):
if isinstance(data, (list, tuple, set)): if isinstance(data, (list, tuple, set)):
@ -90,21 +90,21 @@ class Dump:
data = getUnicode(data) data = getUnicode(data)
if data: if data:
data = self.__formatString(data) data = self._formatString(data)
if data[-1] == '\n': if data[-1] == '\n':
data = data[:-1] data = data[:-1]
if "\n" in data: if "\n" in data:
self.__write("%s:\n---\n%s\n---\n" % (header, data)) self._write("%s:\n---\n%s\n---\n" % (header, data))
else: else:
self.__write("%s: '%s'\n" % (header, data)) self._write("%s: '%s'\n" % (header, data))
else: else:
self.__write("%s:\tNone\n" % header) self._write("%s:\tNone\n" % header)
def lister(self, header, elements, sort=True): def lister(self, header, elements, sort=True):
if elements: if elements:
self.__write("%s [%d]:" % (header, len(elements))) self._write("%s [%d]:" % (header, len(elements)))
if sort: if sort:
try: try:
@ -116,12 +116,12 @@ class Dump:
for element in elements: for element in elements:
if isinstance(element, basestring): if isinstance(element, basestring):
self.__write("[*] %s" % element) self._write("[*] %s" % element)
elif isinstance(element, (list, tuple, set)): elif isinstance(element, (list, tuple, set)):
self.__write("[*] " + ", ".join(getUnicode(e) for e in element)) self._write("[*] " + ", ".join(getUnicode(e) for e in element))
if elements: if elements:
self.__write("") self._write("")
def technic(self, header, data): def technic(self, header, data):
self.string(header, data) self.string(header, data)
@ -147,13 +147,13 @@ class Dump:
self.lister("database management system users", users) self.lister("database management system users", users)
def userSettings(self, header, userSettings, subHeader): def userSettings(self, header, userSettings, subHeader):
self.__areAdmins = set() self._areAdmins = set()
if userSettings: if userSettings:
self.__write("%s:" % header) self._write("%s:" % header)
if isinstance(userSettings, (tuple, list, set)): if isinstance(userSettings, (tuple, list, set)):
self.__areAdmins = userSettings[1] self._areAdmins = userSettings[1]
userSettings = userSettings[0] userSettings = userSettings[0]
users = userSettings.keys() users = userSettings.keys()
@ -167,16 +167,16 @@ class Dump:
else: else:
stringSettings = " [%d]:" % len(settings) stringSettings = " [%d]:" % len(settings)
if user in self.__areAdmins: if user in self._areAdmins:
self.__write("[*] %s (administrator)%s" % (user, stringSettings)) self._write("[*] %s (administrator)%s" % (user, stringSettings))
else: else:
self.__write("[*] %s%s" % (user, stringSettings)) self._write("[*] %s%s" % (user, stringSettings))
if settings: if settings:
settings.sort() settings.sort()
for setting in settings: for setting in settings:
self.__write(" %s: %s" % (subHeader, setting)) self._write(" %s: %s" % (subHeader, setting))
print print
def dbs(self,dbs): def dbs(self,dbs):
@ -198,23 +198,23 @@ class Dump:
for db, tables in dbTables.items(): for db, tables in dbTables.items():
tables.sort() tables.sort()
self.__write("Database: %s" % db if db else "Current database") self._write("Database: %s" % db if db else "Current database")
if len(tables) == 1: if len(tables) == 1:
self.__write("[1 table]") self._write("[1 table]")
else: else:
self.__write("[%d tables]" % len(tables)) self._write("[%d tables]" % len(tables))
self.__write("+%s+" % lines) self._write("+%s+" % lines)
for table in tables: for table in tables:
if isinstance(table, (list, tuple, set)): if isinstance(table, (list, tuple, set)):
table = table[0] table = table[0]
blank = " " * (maxlength - len(normalizeUnicode(table) or str(table))) blank = " " * (maxlength - len(normalizeUnicode(table) or str(table)))
self.__write("| %s%s |" % (table, blank)) self._write("| %s%s |" % (table, blank))
self.__write("+%s+\n" % lines) self._write("+%s+\n" % lines)
else: else:
self.string("tables", dbTables) self.string("tables", dbTables)
@ -246,17 +246,17 @@ class Dump:
maxlength2 = max(maxlength2, len("TYPE")) maxlength2 = max(maxlength2, len("TYPE"))
lines2 = "-" * (maxlength2 + 2) lines2 = "-" * (maxlength2 + 2)
self.__write("Database: %s\nTable: %s" % (db if db else "Current database", table)) self._write("Database: %s\nTable: %s" % (db if db else "Current database", table))
if len(columns) == 1: if len(columns) == 1:
self.__write("[1 column]") self._write("[1 column]")
else: else:
self.__write("[%d columns]" % len(columns)) self._write("[%d columns]" % len(columns))
if colType is not None: if colType is not None:
self.__write("+%s+%s+" % (lines1, lines2)) self._write("+%s+%s+" % (lines1, lines2))
else: else:
self.__write("+%s+" % lines1) self._write("+%s+" % lines1)
blank1 = " " * (maxlength1 - len("COLUMN")) blank1 = " " * (maxlength1 - len("COLUMN"))
@ -264,11 +264,11 @@ class Dump:
blank2 = " " * (maxlength2 - len("TYPE")) blank2 = " " * (maxlength2 - len("TYPE"))
if colType is not None: if colType is not None:
self.__write("| Column%s | Type%s |" % (blank1, blank2)) self._write("| Column%s | Type%s |" % (blank1, blank2))
self.__write("+%s+%s+" % (lines1, lines2)) self._write("+%s+%s+" % (lines1, lines2))
else: else:
self.__write("| Column%s |" % blank1) self._write("| Column%s |" % blank1)
self.__write("+%s+" % lines1) self._write("+%s+" % lines1)
for column in colList: for column in colList:
colType = columns[column] colType = columns[column]
@ -276,14 +276,14 @@ class Dump:
if colType is not None: if colType is not None:
blank2 = " " * (maxlength2 - len(colType)) blank2 = " " * (maxlength2 - len(colType))
self.__write("| %s%s | %s%s |" % (column, blank1, colType, blank2)) self._write("| %s%s | %s%s |" % (column, blank1, colType, blank2))
else: else:
self.__write("| %s%s |" % (column, blank1)) self._write("| %s%s |" % (column, blank1))
if colType is not None: if colType is not None:
self.__write("+%s+%s+\n" % (lines1, lines2)) self._write("+%s+%s+\n" % (lines1, lines2))
else: else:
self.__write("+%s+\n" % lines1) self._write("+%s+\n" % lines1)
def dbTablesCount(self, dbTables): def dbTablesCount(self, dbTables):
if isinstance(dbTables, dict) and len(dbTables) > 0: if isinstance(dbTables, dict) and len(dbTables) > 0:
@ -296,16 +296,16 @@ class Dump:
maxlength1 = max(maxlength1, len(normalizeUnicode(table) or str(table))) maxlength1 = max(maxlength1, len(normalizeUnicode(table) or str(table)))
for db, counts in dbTables.items(): for db, counts in dbTables.items():
self.__write("Database: %s" % db if db else "Current database") self._write("Database: %s" % db if db else "Current database")
lines1 = "-" * (maxlength1 + 2) lines1 = "-" * (maxlength1 + 2)
blank1 = " " * (maxlength1 - len("Table")) blank1 = " " * (maxlength1 - len("Table"))
lines2 = "-" * (maxlength2 + 2) lines2 = "-" * (maxlength2 + 2)
blank2 = " " * (maxlength2 - len("Entries")) blank2 = " " * (maxlength2 - len("Entries"))
self.__write("+%s+%s+" % (lines1, lines2)) self._write("+%s+%s+" % (lines1, lines2))
self.__write("| Table%s | Entries%s |" % (blank1, blank2)) self._write("| Table%s | Entries%s |" % (blank1, blank2))
self.__write("+%s+%s+" % (lines1, lines2)) self._write("+%s+%s+" % (lines1, lines2))
sortedCounts = counts.keys() sortedCounts = counts.keys()
sortedCounts.sort(reverse=True) sortedCounts.sort(reverse=True)
@ -321,9 +321,9 @@ class Dump:
for table in tables: for table in tables:
blank1 = " " * (maxlength1 - len(normalizeUnicode(table) or str(table))) blank1 = " " * (maxlength1 - len(normalizeUnicode(table) or str(table)))
blank2 = " " * (maxlength2 - len(str(count))) blank2 = " " * (maxlength2 - len(str(count)))
self.__write("| %s%s | %d%s |" % (table, blank1, count, blank2)) self._write("| %s%s | %d%s |" % (table, blank1, count, blank2))
self.__write("+%s+%s+\n" % (lines1, lines2)) self._write("+%s+%s+\n" % (lines1, lines2))
else: else:
logger.error("unable to retrieve the number of entries for any table") logger.error("unable to retrieve the number of entries for any table")
@ -365,7 +365,7 @@ class Dump:
separator += "+%s" % lines separator += "+%s" % lines
separator += "+" separator += "+"
self.__write("Database: %s\nTable: %s" % (db if db else "Current database", table)) self._write("Database: %s\nTable: %s" % (db if db else "Current database", table))
if conf.replicate: if conf.replicate:
cols = [] cols = []
@ -402,11 +402,11 @@ class Dump:
rtable = replication.createTable(table, cols) rtable = replication.createTable(table, cols)
if count == 1: if count == 1:
self.__write("[1 entry]") self._write("[1 entry]")
else: else:
self.__write("[%d entries]" % count) self._write("[%d entries]" % count)
self.__write(separator) self._write(separator)
for column in columns: for column in columns:
if column != "__infos__": if column != "__infos__":
@ -414,7 +414,7 @@ class Dump:
maxlength = int(info["length"]) maxlength = int(info["length"])
blank = " " * (maxlength - len(column)) blank = " " * (maxlength - len(column))
self.__write("| %s%s" % (column, blank), n=False) self._write("| %s%s" % (column, blank), n=False)
if not conf.replicate: if not conf.replicate:
if field == fields: if field == fields:
@ -424,7 +424,7 @@ class Dump:
field += 1 field += 1
self.__write("|\n%s" % separator) self._write("|\n%s" % separator)
if not conf.replicate: if not conf.replicate:
dataToDumpFile(dumpFP, "\n") dataToDumpFile(dumpFP, "\n")
@ -461,7 +461,7 @@ class Dump:
values.append(value) values.append(value)
maxlength = int(info["length"]) maxlength = int(info["length"])
blank = " " * (maxlength - len(value)) blank = " " * (maxlength - len(value))
self.__write("| %s%s" % (value, blank), n=False, console=console) self._write("| %s%s" % (value, blank), n=False, console=console)
if not conf.replicate: if not conf.replicate:
if field == fields: if field == fields:
@ -477,12 +477,12 @@ class Dump:
except sqlmapValueException: except sqlmapValueException:
pass pass
self.__write("|", console=console) self._write("|", console=console)
if not conf.replicate: if not conf.replicate:
dataToDumpFile(dumpFP, "\n") dataToDumpFile(dumpFP, "\n")
self.__write("%s\n" % separator) self._write("%s\n" % separator)
if conf.replicate: if conf.replicate:
rtable.endTransaction() rtable.endTransaction()
@ -502,26 +502,26 @@ class Dump:
msg = "Column%s found in the " % colConsiderStr msg = "Column%s found in the " % colConsiderStr
msg += "following databases:" msg += "following databases:"
self.__write(msg) self._write(msg)
printDbs = {} _ = {}
for db, tblData in dbs.items(): for db, tblData in dbs.items():
for tbl, colData in tblData.items(): for tbl, colData in tblData.items():
for col, dataType in colData.items(): for col, dataType in colData.items():
if column.lower() in col.lower(): if column.lower() in col.lower():
if db in printDbs: if db in _:
if tbl in printDbs[db]: if tbl in _[db]:
printDbs[db][tbl][col] = dataType _[db][tbl][col] = dataType
else: else:
printDbs[db][tbl] = { col: dataType } _[db][tbl] = { col: dataType }
else: else:
printDbs[db] = {} _[db] = {}
printDbs[db][tbl] = { col: dataType } _[db][tbl] = { col: dataType }
continue continue
self.dbTableColumns(printDbs) self.dbTableColumns(_)
def query(self, query, queryRes): def query(self, query, queryRes):
self.string(query, queryRes) self.string(query, queryRes)

View File

@ -249,6 +249,9 @@ SQL_STATEMENTS = {
# string representation for NULL value # string representation for NULL value
NULL = "NULL" NULL = "NULL"
# string representation for current database
CURRENT_DB = "CD"
# Regular expressions used for parsing error messages (--parse-errors) # Regular expressions used for parsing error messages (--parse-errors)
ERROR_PARSING_REGEXES = ( ERROR_PARSING_REGEXES = (
r"<b>[^<]*(fatal|error|warning|exception)[^<]*</b>:?\s*(?P<result>.+?)<br\s*/?\s*>", r"<b>[^<]*(fatal|error|warning|exception)[^<]*</b>:?\s*(?P<result>.+?)<br\s*/?\s*>",

View File

@ -22,7 +22,6 @@ from lib.core.common import getUnicode
from lib.core.common import isWindowsDriveLetterPath from lib.core.common import isWindowsDriveLetterPath
from lib.core.common import posixToNtSlashes from lib.core.common import posixToNtSlashes
from lib.core.common import readInput from lib.core.common import readInput
from lib.core.common import sanitizeAsciiString
from lib.core.common import singleTimeLogMessage from lib.core.common import singleTimeLogMessage
from lib.core.data import conf from lib.core.data import conf
from lib.core.data import kb from lib.core.data import kb

View File

@ -24,7 +24,7 @@ from lib.core.common import average
from lib.core.common import calculateDeltaSeconds from lib.core.common import calculateDeltaSeconds
from lib.core.common import clearConsoleLine from lib.core.common import clearConsoleLine
from lib.core.common import cpuThrottle from lib.core.common import cpuThrottle
from lib.core.common import executeCode from lib.core.common import evaluateCode
from lib.core.common import extractRegexResult from lib.core.common import extractRegexResult
from lib.core.common import getCurrentThreadData from lib.core.common import getCurrentThreadData
from lib.core.common import getFilteredPageContent from lib.core.common import getFilteredPageContent
@ -636,10 +636,10 @@ class Connect:
for part in item.split(delimiter): for part in item.split(delimiter):
if '=' in part: if '=' in part:
name, value = part.split('=', 1) name, value = part.split('=', 1)
executeCode("%s='%s'" % (name, value), variables) evaluateCode("%s='%s'" % (name, value), variables)
originals.update(variables) originals.update(variables)
executeCode(conf.evalCode, variables) evaluateCode(conf.evalCode, variables)
for name, value in variables.items(): for name, value in variables.items():
if name != "__builtins__" and originals.get(name, "") != value: if name != "__builtins__" and originals.get(name, "") != value:

View File

@ -20,6 +20,7 @@ from lib.core.data import queries
from lib.core.enums import PAYLOAD from lib.core.enums import PAYLOAD
from lib.core.exception import sqlmapMissingMandatoryOptionException from lib.core.exception import sqlmapMissingMandatoryOptionException
from lib.core.exception import sqlmapNoneDataException from lib.core.exception import sqlmapNoneDataException
from lib.core.settings import CURRENT_DB
from plugins.generic.enumeration import Enumeration as GenericEnumeration from plugins.generic.enumeration import Enumeration as GenericEnumeration
class Enumeration(GenericEnumeration): class Enumeration(GenericEnumeration):
@ -60,7 +61,7 @@ class Enumeration(GenericEnumeration):
self.forceDbmsEnum() self.forceDbmsEnum()
if conf.db == "CD": if conf.db == CURRENT_DB:
conf.db = self.getCurrentDb() conf.db = self.getCurrentDb()
if conf.db: if conf.db:
@ -97,7 +98,7 @@ class Enumeration(GenericEnumeration):
def getColumns(self, onlyColNames=False): def getColumns(self, onlyColNames=False):
self.forceDbmsEnum() self.forceDbmsEnum()
if conf.db is None or conf.db == "CD": if conf.db is None or conf.db == CURRENT_DB:
if conf.db is None: if conf.db is None:
warnMsg = "missing database parameter, sqlmap is going " warnMsg = "missing database parameter, sqlmap is going "
warnMsg += "to use the current database to enumerate " warnMsg += "to use the current database to enumerate "

View File

@ -10,7 +10,7 @@ See the file 'doc/COPYING' for copying permission
from lib.core.agent import agent from lib.core.agent import agent
from lib.core.common import arrayizeValue from lib.core.common import arrayizeValue
from lib.core.common import Backend from lib.core.common import Backend
from lib.core.common import getRange from lib.core.common import getLimitRange
from lib.core.common import isInferenceAvailable from lib.core.common import isInferenceAvailable
from lib.core.common import isNoneValue from lib.core.common import isNoneValue
from lib.core.common import isNumPosStrValue from lib.core.common import isNumPosStrValue
@ -25,6 +25,7 @@ from lib.core.data import queries
from lib.core.enums import EXPECTED from lib.core.enums import EXPECTED
from lib.core.enums import PAYLOAD from lib.core.enums import PAYLOAD
from lib.core.exception import sqlmapNoneDataException from lib.core.exception import sqlmapNoneDataException
from lib.core.settings import CURRENT_DB
from lib.request import inject from lib.request import inject
from plugins.generic.enumeration import Enumeration as GenericEnumeration from plugins.generic.enumeration import Enumeration as GenericEnumeration
@ -68,7 +69,7 @@ class Enumeration(GenericEnumeration):
self.forceDbmsEnum() self.forceDbmsEnum()
if conf.db == "CD": if conf.db == CURRENT_DB:
conf.db = self.getCurrentDb() conf.db = self.getCurrentDb()
if conf.db: if conf.db:
@ -230,7 +231,7 @@ class Enumeration(GenericEnumeration):
continue continue
indexRange = getRange(count) indexRange = getLimitRange(count)
for index in indexRange: for index in indexRange:
query = rootQuery.blind.query query = rootQuery.blind.query
@ -347,7 +348,7 @@ class Enumeration(GenericEnumeration):
continue continue
indexRange = getRange(count) indexRange = getLimitRange(count)
for index in indexRange: for index in indexRange:
query = rootQuery.blind.query query = rootQuery.blind.query

View File

@ -11,7 +11,7 @@ import codecs
import ntpath import ntpath
import os import os
from lib.core.common import getRange from lib.core.common import getLimitRange
from lib.core.common import isNumPosStrValue from lib.core.common import isNumPosStrValue
from lib.core.common import isTechniqueAvailable from lib.core.common import isTechniqueAvailable
from lib.core.common import posixToNtSlashes from lib.core.common import posixToNtSlashes
@ -105,7 +105,7 @@ class Filesystem(GenericFilesystem):
errMsg += "file '%s'" % rFile errMsg += "file '%s'" % rFile
raise sqlmapNoneDataException(errMsg) raise sqlmapNoneDataException(errMsg)
indexRange = getRange(count) indexRange = getLimitRange(count)
for index in indexRange: for index in indexRange:
chunk = inject.getValue("SELECT TOP 1 %s FROM %s WHERE %s NOT IN (SELECT TOP %d %s FROM %s ORDER BY id ASC) ORDER BY id ASC" % (self.tblField, hexTbl, self.tblField, index, self.tblField, hexTbl), unpack=False, resumeValue=False, unique=False, charsetType=3) chunk = inject.getValue("SELECT TOP 1 %s FROM %s WHERE %s NOT IN (SELECT TOP %d %s FROM %s ORDER BY id ASC) ORDER BY id ASC" % (self.tblField, hexTbl, self.tblField, index, self.tblField, hexTbl), unpack=False, resumeValue=False, unique=False, charsetType=3)

View File

@ -9,7 +9,7 @@ See the file 'doc/COPYING' for copying permission
from lib.core.agent import agent from lib.core.agent import agent
from lib.core.common import Backend from lib.core.common import Backend
from lib.core.common import getRange from lib.core.common import getLimitRange
from lib.core.common import isInferenceAvailable from lib.core.common import isInferenceAvailable
from lib.core.common import isNoneValue from lib.core.common import isNoneValue
from lib.core.common import isNumPosStrValue from lib.core.common import isNumPosStrValue
@ -142,7 +142,7 @@ class Enumeration(GenericEnumeration):
roles = set() roles = set()
indexRange = getRange(count, plusOne=True) indexRange = getLimitRange(count, plusOne=True)
for index in indexRange: for index in indexRange:
if query2: if query2:

View File

@ -22,6 +22,7 @@ from lib.core.dicts import sybaseTypes
from lib.core.enums import PAYLOAD from lib.core.enums import PAYLOAD
from lib.core.exception import sqlmapMissingMandatoryOptionException from lib.core.exception import sqlmapMissingMandatoryOptionException
from lib.core.exception import sqlmapNoneDataException from lib.core.exception import sqlmapNoneDataException
from lib.core.settings import CURRENT_DB
from plugins.generic.enumeration import Enumeration as GenericEnumeration from plugins.generic.enumeration import Enumeration as GenericEnumeration
class Enumeration(GenericEnumeration): class Enumeration(GenericEnumeration):
@ -114,7 +115,7 @@ class Enumeration(GenericEnumeration):
self.forceDbmsEnum() self.forceDbmsEnum()
if conf.db == "CD": if conf.db == CURRENT_DB:
conf.db = self.getCurrentDb() conf.db = self.getCurrentDb()
if conf.db: if conf.db:
@ -160,7 +161,7 @@ class Enumeration(GenericEnumeration):
def getColumns(self, onlyColNames=False): def getColumns(self, onlyColNames=False):
self.forceDbmsEnum() self.forceDbmsEnum()
if conf.db is None or conf.db == "CD": if conf.db is None or conf.db == CURRENT_DB:
if conf.db is None: if conf.db is None:
warnMsg = "missing database parameter, sqlmap is going " warnMsg = "missing database parameter, sqlmap is going "
warnMsg += "to use the current database to enumerate " warnMsg += "to use the current database to enumerate "

View File

@ -17,7 +17,7 @@ from lib.core.common import Backend
from lib.core.common import clearConsoleLine from lib.core.common import clearConsoleLine
from lib.core.common import dataToStdout from lib.core.common import dataToStdout
from lib.core.common import filterPairValues from lib.core.common import filterPairValues
from lib.core.common import getRange from lib.core.common import getLimitRange
from lib.core.common import getCompiledRegex from lib.core.common import getCompiledRegex
from lib.core.common import getUnicode from lib.core.common import getUnicode
from lib.core.common import isInferenceAvailable from lib.core.common import isInferenceAvailable
@ -59,6 +59,7 @@ from lib.core.exception import sqlmapUserQuitException
from lib.core.session import setOs from lib.core.session import setOs
from lib.core.settings import CONCAT_ROW_DELIMITER from lib.core.settings import CONCAT_ROW_DELIMITER
from lib.core.settings import CONCAT_VALUE_DELIMITER from lib.core.settings import CONCAT_VALUE_DELIMITER
from lib.core.settings import CURRENT_DB
from lib.core.settings import DEFAULT_MSSQL_SCHEMA from lib.core.settings import DEFAULT_MSSQL_SCHEMA
from lib.core.settings import MAX_INT from lib.core.settings import MAX_INT
from lib.core.settings import SQL_STATEMENTS from lib.core.settings import SQL_STATEMENTS
@ -200,11 +201,8 @@ class Enumeration:
errMsg = "unable to retrieve the number of database users" errMsg = "unable to retrieve the number of database users"
raise sqlmapNoneDataException, errMsg raise sqlmapNoneDataException, errMsg
if Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2): plusOne = Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2)
plusOne = True indexRange = getLimitRange(count, plusOne=plusOne)
else:
plusOne = False
indexRange = getRange(count, plusOne=plusOne)
for index in indexRange: for index in indexRange:
if Backend.getIdentifiedDbms() in (DBMS.SYBASE, DBMS.MAXDB): if Backend.getIdentifiedDbms() in (DBMS.SYBASE, DBMS.MAXDB):
@ -350,11 +348,8 @@ class Enumeration:
passwords = [] passwords = []
if Backend.isDbms(DBMS.ORACLE): plusOne = Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2)
plusOne = True indexRange = getLimitRange(count, plusOne=plusOne)
else:
plusOne = False
indexRange = getRange(count, plusOne=plusOne)
for index in indexRange: for index in indexRange:
if Backend.isDbms(DBMS.MSSQL): if Backend.isDbms(DBMS.MSSQL):
@ -593,11 +588,8 @@ class Enumeration:
privileges = set() privileges = set()
if Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2): plusOne = Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2)
plusOne = True indexRange = getLimitRange(count, plusOne=plusOne)
else:
plusOne = False
indexRange = getRange(count, plusOne=plusOne)
for index in indexRange: for index in indexRange:
if Backend.isDbms(DBMS.MYSQL) and not kb.data.has_information_schema: if Backend.isDbms(DBMS.MYSQL) and not kb.data.has_information_schema:
@ -760,11 +752,8 @@ class Enumeration:
errMsg = "unable to retrieve the number of databases" errMsg = "unable to retrieve the number of databases"
logger.error(errMsg) logger.error(errMsg)
else: else:
if Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2): plusOne = Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2)
plusOne = True indexRange = getLimitRange(count, plusOne=plusOne)
else:
plusOne = False
indexRange = getRange(count, plusOne=plusOne)
for index in indexRange: for index in indexRange:
if Backend.isDbms(DBMS.SYBASE): if Backend.isDbms(DBMS.SYBASE):
@ -820,7 +809,7 @@ class Enumeration:
else: else:
return tables return tables
if conf.db == "CD": if conf.db == CURRENT_DB:
conf.db = self.getCurrentDb() conf.db = self.getCurrentDb()
if conf.db and Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2): if conf.db and Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2):
@ -930,11 +919,8 @@ class Enumeration:
tables = [] tables = []
if Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2): plusOne = Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2)
plusOne = True indexRange = getLimitRange(count, plusOne=plusOne)
else:
plusOne = False
indexRange = getRange(count, plusOne=plusOne)
for index in indexRange: for index in indexRange:
if Backend.isDbms(DBMS.SYBASE): if Backend.isDbms(DBMS.SYBASE):
@ -977,7 +963,7 @@ class Enumeration:
def getColumns(self, onlyColNames=False, colTuple=None, bruteForce=None): def getColumns(self, onlyColNames=False, colTuple=None, bruteForce=None):
self.forceDbmsEnum() self.forceDbmsEnum()
if conf.db is None or conf.db == "CD": if conf.db is None or conf.db == CURRENT_DB:
if conf.db is None: if conf.db is None:
warnMsg = "missing database parameter, sqlmap is going " warnMsg = "missing database parameter, sqlmap is going "
warnMsg += "to use the current database to enumerate " warnMsg += "to use the current database to enumerate "
@ -1226,7 +1212,7 @@ class Enumeration:
table = {} table = {}
columns = {} columns = {}
indexRange = getRange(count) indexRange = getLimitRange(count)
for index in indexRange: for index in indexRange:
if Backend.getIdentifiedDbms() in ( DBMS.MYSQL, DBMS.PGSQL ): if Backend.getIdentifiedDbms() in ( DBMS.MYSQL, DBMS.PGSQL ):
@ -1506,7 +1492,7 @@ class Enumeration:
def dumpTable(self, foundData=None): def dumpTable(self, foundData=None):
self.forceDbmsEnum() self.forceDbmsEnum()
if conf.db is None or conf.db == "CD": if conf.db is None or conf.db == CURRENT_DB:
if conf.db is None: if conf.db is None:
warnMsg = "missing database parameter, sqlmap is going " warnMsg = "missing database parameter, sqlmap is going "
warnMsg += "to use the current database to enumerate " warnMsg += "to use the current database to enumerate "
@ -1719,11 +1705,8 @@ class Enumeration:
entries, lengths = retVal entries, lengths = retVal
else: else:
if Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2): plusOne = Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2)
plusOne = True indexRange = getLimitRange(count, dump=True, plusOne=plusOne)
else:
plusOne = False
indexRange = getRange(count, dump=True, plusOne=plusOne)
try: try:
for index in indexRange: for index in indexRange:
@ -1967,7 +1950,7 @@ class Enumeration:
continue continue
indexRange = getRange(count) indexRange = getLimitRange(count)
for index in indexRange: for index in indexRange:
if Backend.isDbms(DBMS.MYSQL) and not kb.data.has_information_schema: if Backend.isDbms(DBMS.MYSQL) and not kb.data.has_information_schema:
@ -2032,7 +2015,7 @@ class Enumeration:
infoMsg += " '%s'" % unsafeSQLIdentificatorNaming(tbl) infoMsg += " '%s'" % unsafeSQLIdentificatorNaming(tbl)
logger.info(infoMsg) logger.info(infoMsg)
if conf.db and conf.db != "CD": if conf.db and conf.db != CURRENT_DB:
_ = conf.db.split(",") _ = conf.db.split(",")
whereDbsQuery = "".join(" AND '%s' = %s" % (unsafeSQLIdentificatorNaming(db), dbCond) for db in _) whereDbsQuery = "".join(" AND '%s' = %s" % (unsafeSQLIdentificatorNaming(db), dbCond) for db in _)
infoMsg = "for database%s '%s'" % ("s" if len(_) > 1 else "", ", ".join(db for db in _)) infoMsg = "for database%s '%s'" % ("s" if len(_) > 1 else "", ", ".join(db for db in _))
@ -2085,7 +2068,7 @@ class Enumeration:
continue continue
indexRange = getRange(count) indexRange = getLimitRange(count)
for index in indexRange: for index in indexRange:
query = rootQuery.blind.query query = rootQuery.blind.query
@ -2130,7 +2113,7 @@ class Enumeration:
continue continue
indexRange = getRange(count) indexRange = getLimitRange(count)
for index in indexRange: for index in indexRange:
query = rootQuery.blind.query2 query = rootQuery.blind.query2
@ -2201,7 +2184,7 @@ class Enumeration:
foundCols[column] = {} foundCols[column] = {}
if conf.db and conf.db != "CD": if conf.db and conf.db != CURRENT_DB:
_ = conf.db.split(",") _ = conf.db.split(",")
whereDbsQuery = "".join(" AND '%s' = %s" % (unsafeSQLIdentificatorNaming(db), dbCond) for db in _) whereDbsQuery = "".join(" AND '%s' = %s" % (unsafeSQLIdentificatorNaming(db), dbCond) for db in _)
infoMsg = "for database%s '%s'" % ("s" if len(_) > 1 else "", ", ".join(db for db in _)) infoMsg = "for database%s '%s'" % ("s" if len(_) > 1 else "", ", ".join(db for db in _))
@ -2277,7 +2260,7 @@ class Enumeration:
continue continue
indexRange = getRange(count) indexRange = getLimitRange(count)
for index in indexRange: for index in indexRange:
query = rootQuery.blind.query query = rootQuery.blind.query
@ -2328,7 +2311,7 @@ class Enumeration:
continue continue
indexRange = getRange(count) indexRange = getLimitRange(count)
for index in indexRange: for index in indexRange:
query = rootQuery.blind.query2 query = rootQuery.blind.query2