some more refactoring

This commit is contained in:
Miroslav Stampar 2011-12-21 19:40:42 +00:00
parent d9d4e3ea9b
commit 0a039d84e0
9 changed files with 93 additions and 130 deletions

View File

@ -215,16 +215,13 @@ class Agent:
randStr = randomStr() randStr = randomStr()
randStr1 = randomStr() randStr1 = randomStr()
payload = payload.replace("[RANDNUM]", str(randInt)) _ = (
payload = payload.replace("[RANDNUM1]", str(randInt1)) ("[RANDNUM]", str(randInt)), ("[RANDNUM1]", str(randInt1)), ("[RANDSTR]", randStr),\
payload = payload.replace("[RANDSTR]", randStr) ("[RANDSTR1]", randStr1), ("[DELIMITER_START]", kb.chars.start), ("[DELIMITER_STOP]", kb.chars.stop),\
payload = payload.replace("[RANDSTR1]", randStr1) ("[AT_REPLACE]", kb.chars.at), ("[SPACE_REPLACE]", kb.chars.space), ("[DOLLAR_REPLACE]", kb.chars.dollar),\
payload = payload.replace("[DELIMITER_START]", kb.chars.start) ("[SLEEPTIME]", str(conf.timeSec))
payload = payload.replace("[DELIMITER_STOP]", kb.chars.stop) )
payload = payload.replace("[AT_REPLACE]", kb.chars.at) payload = reduce(lambda x, y: x.replace(y[0], y[1]), _, payload)
payload = payload.replace("[SPACE_REPLACE]", kb.chars.space)
payload = payload.replace("[DOLLAR_REPLACE]", kb.chars.dollar)
payload = payload.replace("[SLEEPTIME]", str(conf.timeSec))
if origValue is not None: if origValue is not None:
payload = payload.replace("[ORIGVALUE]", origValue) payload = payload.replace("[ORIGVALUE]", origValue)
@ -249,13 +246,8 @@ class Agent:
return payload return payload
def getComment(self, reqObj): def getComment(self, request):
if "comment" in reqObj: return request.comment if "comment" in request else ""
comment = reqObj.comment
else:
comment = ""
return comment
def nullAndCastField(self, field): def nullAndCastField(self, field):
""" """

View File

@ -65,7 +65,7 @@ from lib.core.enums import OS
from lib.core.enums import PLACE from lib.core.enums import PLACE
from lib.core.enums import PAYLOAD from lib.core.enums import PAYLOAD
from lib.core.enums import REFLECTIVE_COUNTER from lib.core.enums import REFLECTIVE_COUNTER
from lib.core.enums import SORTORDER from lib.core.enums import SORT_ORDER
from lib.core.exception import sqlmapDataException from lib.core.exception import sqlmapDataException
from lib.core.exception import sqlmapFilePathException from lib.core.exception import sqlmapFilePathException
from lib.core.exception import sqlmapGenericException from lib.core.exception import sqlmapGenericException
@ -1001,20 +1001,17 @@ def randomStr(length=4, lowercase=False, alphabet=None):
return rndStr return rndStr
def sanitizeStr(inpStr): def sanitizeStr(value):
""" """
@param inpStr: inpStr to sanitize: cast to str datatype and replace @param value: value to sanitize: cast to str datatype and replace
newlines with one space and strip carriage returns. newlines with one space and strip carriage returns.
@type inpStr: C{str} @type value: C{str}
@return: sanitized inpStr @return: sanitized value
@rtype: C{str} @rtype: C{str}
""" """
cleanString = getUnicode(inpStr) return getUnicode(value).replace("\n", " ").replace("\r", "")
cleanString = cleanString.replace("\n", " ").replace("\r", "")
return cleanString
def checkFile(filename): def checkFile(filename):
""" """
@ -1025,45 +1022,37 @@ def checkFile(filename):
if not os.path.exists(filename): if not os.path.exists(filename):
raise sqlmapFilePathException, "unable to read file '%s'" % filename raise sqlmapFilePathException, "unable to read file '%s'" % filename
def replaceNewlineTabs(inpStr, stdout=False): def replaceNewlineTabs(value, stdout=False):
if inpStr is None: if value is None:
return return
if stdout: if stdout:
replacedString = inpStr.replace("\n", " ").replace("\r", " ").replace("\t", " ") retVal = value.replace("\n", " ").replace("\r", " ").replace("\t", " ")
else: else:
replacedString = inpStr.replace("\n", DUMP_NEWLINE_MARKER).replace("\r", DUMP_CR_MARKER).replace("\t", DUMP_TAB_MARKER) retVal = value.replace("\n", DUMP_NEWLINE_MARKER).replace("\r", DUMP_CR_MARKER).replace("\t", DUMP_TAB_MARKER)
replacedString = replacedString.replace(kb.chars.delimiter, DUMP_DEL_MARKER) retVal = retVal.replace(kb.chars.delimiter, DUMP_DEL_MARKER)
return replacedString return retVal
def restoreDumpMarkedChars(inpStr, onlyNewlineTab=False): def restoreDumpMarkedChars(value, onlyNewlineTab=False):
replacedString = inpStr retVal = value
if isinstance(replacedString, basestring): if isinstance(retVal, basestring):
replacedString = replacedString.replace(DUMP_NEWLINE_MARKER, "\n").replace(DUMP_CR_MARKER, "\r").replace(DUMP_TAB_MARKER, "\t") retVal = retVal.replace(DUMP_NEWLINE_MARKER, "\n").replace(DUMP_CR_MARKER, "\r").replace(DUMP_TAB_MARKER, "\t")
if not onlyNewlineTab: if not onlyNewlineTab:
replacedString = replacedString.replace(DUMP_DEL_MARKER, ", ") retVal = retVal.replace(DUMP_DEL_MARKER, ", ")
return replacedString return retVal
def banner(): def banner():
""" """
This function prints sqlmap banner with its version This function prints sqlmap banner with its version
""" """
ban = """ _ = """\n %s - %s\n %s\n\n""" % (VERSION_STRING, DESCRIPTION, SITE)
%s - %s dataToStdout(_, forceOutput=True)
%s\n
""" % (VERSION_STRING, DESCRIPTION, SITE)
# Reference: http://www.frexx.de/xterm-256-notes/
#if not any([IS_WIN, os.getenv('ANSI_COLORS_DISABLED')]):
# ban = "\033[1;34m%s\033[0m" % ban
dataToStdout(ban, forceOutput=True)
def parsePasswordHash(password): def parsePasswordHash(password):
blank = " " * 8 blank = " " * 8
@ -1335,10 +1324,7 @@ def getRange(count, dump=False, plusOne=False):
if isinstance(conf.limitStart, int) and conf.limitStart > 0 and conf.limitStart <= limitStop: if isinstance(conf.limitStart, int) and conf.limitStart > 0 and conf.limitStart <= limitStop:
limitStart = conf.limitStart limitStart = conf.limitStart
if plusOne: indexRange = xrange(limitStart, limitStop + 1) if plusOne else xrange(limitStart - 1, limitStop)
indexRange = xrange(limitStart, limitStop + 1)
else:
indexRange = xrange(limitStart - 1, limitStop)
return indexRange return indexRange
@ -1445,46 +1431,43 @@ def getFileType(filePath):
except: except:
return "unknown" return "unknown"
if "ASCII" in magicFileType or "text" in magicFileType: return "text" if "ASCII" in magicFileType or "text" in magicFileType else "binary"
return "text"
else:
return "binary"
def getCharset(charsetType=None): def getCharset(charsetType=None):
asciiTbl = [] asciiTbl = []
if charsetType is None: if charsetType is None:
asciiTbl = range(0, 128) asciiTbl.extend(xrange(0, 128))
# 0 or 1 # 0 or 1
elif charsetType == 1: elif charsetType == 1:
asciiTbl.extend([ 0, 1 ]) asciiTbl.extend([ 0, 1 ])
asciiTbl.extend(range(47, 50)) asciiTbl.extend(xrange(47, 50))
# Digits # Digits
elif charsetType == 2: elif charsetType == 2:
asciiTbl.extend([ 0, 1 ]) asciiTbl.extend([ 0, 1 ])
asciiTbl.extend(range(47, 58)) asciiTbl.extend(xrange(47, 58))
# Hexadecimal # Hexadecimal
elif charsetType == 3: elif charsetType == 3:
asciiTbl.extend([ 0, 1 ]) asciiTbl.extend([ 0, 1 ])
asciiTbl.extend(range(47, 58)) asciiTbl.extend(xrange(47, 58))
asciiTbl.extend(range(64, 71)) asciiTbl.extend(xrange(64, 71))
asciiTbl.extend(range(96, 103)) asciiTbl.extend(xrange(96, 103))
# Characters # Characters
elif charsetType == 4: elif charsetType == 4:
asciiTbl.extend([ 0, 1 ]) asciiTbl.extend([ 0, 1 ])
asciiTbl.extend(range(64, 91)) asciiTbl.extend(xrange(64, 91))
asciiTbl.extend(range(96, 123)) asciiTbl.extend(xrange(96, 123))
# Characters and digits # Characters and digits
elif charsetType == 5: elif charsetType == 5:
asciiTbl.extend([ 0, 1 ]) asciiTbl.extend([ 0, 1 ])
asciiTbl.extend(range(47, 58)) asciiTbl.extend(xrange(47, 58))
asciiTbl.extend(range(64, 91)) asciiTbl.extend(xrange(64, 91))
asciiTbl.extend(range(96, 123)) asciiTbl.extend(xrange(96, 123))
return asciiTbl return asciiTbl
@ -1492,10 +1475,7 @@ def searchEnvPath(fileName):
envPaths = os.environ["PATH"] envPaths = os.environ["PATH"]
result = None result = None
if IS_WIN: envPaths = envPaths.split(";") if IS_WIN else envPaths.split(":")
envPaths = envPaths.split(";")
else:
envPaths = envPaths.split(":")
for envPath in envPaths: for envPath in envPaths:
envPath = envPath.replace(";", "") envPath = envPath.replace(";", "")
@ -1557,8 +1537,7 @@ def safeStringFormat(formatStr, params):
if isinstance(params, basestring): if isinstance(params, basestring):
retVal = retVal.replace("%s", params) retVal = retVal.replace("%s", params)
else: else:
count = 0 count, index = 0, 0
index = 0
while index != -1: while index != -1:
index = retVal.find("%s") index = retVal.find("%s")
@ -1791,10 +1770,9 @@ def readCachedFileContent(filename, mode='rb'):
if filename not in kb.cache.content: if filename not in kb.cache.content:
checkFile(filename) checkFile(filename)
xfile = codecs.open(filename, mode, UNICODE_ENCODING) with codecs.open(filename, mode, UNICODE_ENCODING) as f:
content = xfile.read() content = f.read()
kb.cache.content[filename] = content kb.cache.content[filename] = content
xfile.close()
kb.locks.cacheLock.release() kb.locks.cacheLock.release()
@ -1807,10 +1785,8 @@ def readXmlFile(xmlFile):
checkFile(xmlFile) checkFile(xmlFile)
xfile = codecs.open(xmlFile, 'r', UNICODE_ENCODING) with codecs.open(xmlFile, 'r', UNICODE_ENCODING) as f:
retVal = minidom.parse(xfile).documentElement retVal = minidom.parse(f).documentElement
xfile.close()
return retVal return retVal
@ -1825,8 +1801,10 @@ def stdev(values):
key = (values[0], values[-1], len(values)) key = (values[0], values[-1], len(values))
retVal = None
if key in kb.cache.stdev: if key in kb.cache.stdev:
return kb.cache.stdev[key] retVal = kb.cache.stdev[key]
else: else:
summa = 0.0 summa = 0.0
avg = average(values) avg = average(values)
@ -1837,6 +1815,7 @@ def stdev(values):
retVal = sqrt(summa/(len(values) - 1)) retVal = sqrt(summa/(len(values) - 1))
kb.cache.stdev[key] = retVal kb.cache.stdev[key] = retVal
return retVal return retVal
def average(values): def average(values):
@ -1866,9 +1845,8 @@ def initCommonOutputs():
kb.commonOutputs = {} kb.commonOutputs = {}
key = None key = None
cfile = codecs.open(paths.COMMON_OUTPUTS, 'r', UNICODE_ENCODING) with codecs.open(paths.COMMON_OUTPUTS, 'r', UNICODE_ENCODING) as f:
for line in f.readlines(): # xreadlines doesn't return unicode strings when codec.open() is used
for line in cfile.readlines(): # xreadlines doesn't return unicode strings when codec.open() is used
if line.find('#') != -1: if line.find('#') != -1:
line = line[:line.find('#')] line = line[:line.find('#')]
@ -1884,8 +1862,6 @@ def initCommonOutputs():
if line not in kb.commonOutputs[key]: if line not in kb.commonOutputs[key]:
kb.commonOutputs[key].add(line) kb.commonOutputs[key].add(line)
cfile.close()
def getFileItems(filename, commentPrefix='#', unicode_=True, lowercase=False, unique=False): def getFileItems(filename, commentPrefix='#', unicode_=True, lowercase=False, unique=False):
""" """
Returns newline delimited items contained inside file Returns newline delimited items contained inside file
@ -1896,11 +1872,11 @@ def getFileItems(filename, commentPrefix='#', unicode_=True, lowercase=False, un
checkFile(filename) checkFile(filename)
if unicode_: if unicode_:
ifile = codecs.open(filename, 'r', UNICODE_ENCODING) f = codecs.open(filename, 'r', UNICODE_ENCODING)
else: else:
ifile = open(filename, 'r') f = open(filename, 'r')
for line in ifile.readlines(): # xreadlines doesn't return unicode strings when codec.open() is used for line in f.readlines(): # xreadlines doesn't return unicode strings when codec.open() is used
if commentPrefix: if commentPrefix:
if line.find(commentPrefix) != -1: if line.find(commentPrefix) != -1:
line = line[:line.find(commentPrefix)] line = line[:line.find(commentPrefix)]
@ -1921,6 +1897,8 @@ def getFileItems(filename, commentPrefix='#', unicode_=True, lowercase=False, un
retVal.append(line) retVal.append(line)
f.close()
return retVal return retVal
def goGoodSamaritan(prevValue, originalCharset): def goGoodSamaritan(prevValue, originalCharset):
@ -2001,7 +1979,7 @@ def getCompiledRegex(regex, flags=0):
""" """
if (regex, flags) in kb.cache.regex: if (regex, flags) in kb.cache.regex:
return kb.cache.regex[(regex, flags)] retVal = kb.cache.regex[(regex, flags)]
else: else:
retVal = re.compile(regex, flags) retVal = re.compile(regex, flags)
kb.cache.regex[(regex, flags)] = retVal kb.cache.regex[(regex, flags)] = retVal
@ -2590,16 +2568,16 @@ def getSortedInjectionTests():
retVal = conf.tests retVal = conf.tests
def priorityFunction(test): def priorityFunction(test):
retVal = SORTORDER.FIRST retVal = SORT_ORDER.FIRST
if test.stype == PAYLOAD.TECHNIQUE.UNION: if test.stype == PAYLOAD.TECHNIQUE.UNION:
retVal = SORTORDER.LAST retVal = SORT_ORDER.LAST
elif 'details' in test and 'dbms' in test.details: elif 'details' in test and 'dbms' in test.details:
if test.details.dbms in Backend.getErrorParsedDBMSes(): if test.details.dbms in Backend.getErrorParsedDBMSes():
retVal = SORTORDER.SECOND retVal = SORT_ORDER.SECOND
else: else:
retVal = SORTORDER.THIRD retVal = SORT_ORDER.THIRD
return retVal return retVal
@ -2615,7 +2593,7 @@ def filterListValue(value, regex):
""" """
if isinstance(value, list) and regex: if isinstance(value, list) and regex:
retVal = filter(lambda word: getCompiledRegex(regex, re.I).search(word), value) retVal = filter(lambda x: getCompiledRegex(regex, re.I).search(x), value)
else: else:
retVal = value retVal = value
@ -2688,6 +2666,7 @@ def unhandledExceptionMessage():
errMsg += "Command line: %s\n" % " ".join(sys.argv) errMsg += "Command line: %s\n" % " ".join(sys.argv)
errMsg += "Technique: %s\n" % (enumValueToNameLookup(PAYLOAD.TECHNIQUE, kb.technique) if kb and kb.technique else None) errMsg += "Technique: %s\n" % (enumValueToNameLookup(PAYLOAD.TECHNIQUE, kb.technique) if kb and kb.technique else None)
errMsg += "Back-end DBMS: %s" % ("%s (fingerprinted)" % Backend.getDbms() if Backend.getDbms() is not None else "%s (identified)" % Backend.getIdentifiedDbms()) errMsg += "Back-end DBMS: %s" % ("%s (fingerprinted)" % Backend.getDbms() if Backend.getDbms() is not None else "%s (identified)" % Backend.getIdentifiedDbms())
return maskSensitiveData(errMsg) return maskSensitiveData(errMsg)
def maskSensitiveData(msg): def maskSensitiveData(msg):
@ -2751,8 +2730,8 @@ def intersect(valueA, valueB, lowerCase=False):
valueB = arrayizeValue(valueB) valueB = arrayizeValue(valueB)
if lowerCase: if lowerCase:
valueA = [val.lower() if isinstance(val, basestring) else val for val in valueA] valueA = (val.lower() if isinstance(val, basestring) else val for val in valueA)
valueB = [val.lower() if isinstance(val, basestring) else val for val in valueB] valueB = (val.lower() if isinstance(val, basestring) else val for val in valueB)
retVal = [val for val in valueA if val in valueB] retVal = [val for val in valueA if val in valueB]
@ -2957,8 +2936,8 @@ def expandMnemonics(mnemonics, parser, args):
logger.debug(debugMsg) logger.debug(debugMsg)
else: else:
found = sorted(options.keys(), key=lambda x: len(x))[0] found = sorted(options.keys(), key=lambda x: len(x))[0]
warnMsg = "detected ambiguity (mnemonic '%s' can be resolved to %s). " % (name, ", ".join("'%s'" % key for key in options.keys())) warnMsg = "detected ambiguity (mnemonic '%s' can be resolved to: %s). " % (name, ", ".join("'%s'" % key for key in options.keys()))
warnMsg += "resolved to shortest of those available ('%s')" % found warnMsg += "Resolved to shortest of those ('%s')" % found
logger.warn(warnMsg) logger.warn(warnMsg)
found = options[found] found = options[found]
@ -2988,7 +2967,7 @@ def safeCSValue(value):
if retVal and isinstance(retVal, basestring): if retVal and isinstance(retVal, basestring):
if not (retVal[0] == retVal[-1] == '"'): if not (retVal[0] == retVal[-1] == '"'):
if any(map(lambda x: x in retVal, [conf.csvDel, '"', '\n'])): if any(map(lambda x: x in retVal, (conf.csvDel, '"', '\n'))):
retVal = '"%s"' % retVal.replace('"', '""') retVal = '"%s"' % retVal.replace('"', '""')
return retVal return retVal

View File

@ -45,11 +45,7 @@ def base64unpickle(value):
def hexdecode(value): def hexdecode(value):
value = value.lower() value = value.lower()
return (value[2:] if value.startswith("0x") else value).decode("hex")
if value.startswith("0x"):
value = value[2:]
return value.decode("hex")
def hexencode(value): def hexencode(value):
return value.encode("hex") return value.encode("hex")
@ -149,12 +145,13 @@ def utf8decode(value):
return value.decode("utf-8") return value.decode("utf-8")
def htmlescape(value): def htmlescape(value):
return value.replace('&', '&amp;').replace('<', '&lt;').replace('>', '&gt;').replace('"', '&quot;').replace("'", '&#39;').replace(' ', '&nbsp;') _ = (('&', '&amp;'), ('<', '&lt;'), ('>', '&gt;'), ('"', '&quot;'), ("'", '&#39;'), (' ', '&nbsp;'))
return reduce(lambda x, y: x.replace(y[0], y[1]), _, value)
def htmlunescape(value): def htmlunescape(value):
retVal = value retVal = value
if value and isinstance(value, basestring): if value and isinstance(value, basestring):
if '&' in retVal: _ = (('&amp;', '&'), ('&lt;', '<'), ('&gt;', '>'), ('&quot;', '"'), ('&nbsp;', ' '))
retVal = retVal.replace('&amp;', '&').replace('&lt;', '<').replace('&gt;', '>').replace('&quot;', '"').replace('&nbsp;', ' ') retVal = reduce(lambda x, y: x.replace(y[0], y[1]), _, retVal)
retVal = re.sub('&#(\d+);', lambda x: unichr(int(x.group(1))), retVal) retVal = re.sub('&#(\d+);', lambda x: unichr(int(x.group(1))), retVal)
return retVal return retVal

View File

@ -16,7 +16,7 @@ class PRIORITY:
HIGHER = 50 HIGHER = 50
HIGHEST = 100 HIGHEST = 100
class SORTORDER: class SORT_ORDER:
FIRST = 0 FIRST = 0
SECOND = 1 SECOND = 1
THIRD = 2 THIRD = 2

View File

@ -8,7 +8,6 @@ See the file 'doc/COPYING' for copying permission
""" """
from lib.core.exception import sqlmapSyntaxException from lib.core.exception import sqlmapSyntaxException
from plugins.generic.syntax import Syntax as GenericSyntax from plugins.generic.syntax import Syntax as GenericSyntax
class Syntax(GenericSyntax): class Syntax(GenericSyntax):

View File

@ -9,7 +9,6 @@ See the file 'doc/COPYING' for copying permission
from lib.core.data import logger from lib.core.data import logger
from lib.core.exception import sqlmapSyntaxException from lib.core.exception import sqlmapSyntaxException
from plugins.generic.syntax import Syntax as GenericSyntax from plugins.generic.syntax import Syntax as GenericSyntax
class Syntax(GenericSyntax): class Syntax(GenericSyntax):

View File

@ -9,7 +9,6 @@ See the file 'doc/COPYING' for copying permission
from lib.core.common import isDBMSVersionAtLeast from lib.core.common import isDBMSVersionAtLeast
from lib.core.exception import sqlmapSyntaxException from lib.core.exception import sqlmapSyntaxException
from plugins.generic.syntax import Syntax as GenericSyntax from plugins.generic.syntax import Syntax as GenericSyntax
class Syntax(GenericSyntax): class Syntax(GenericSyntax):

View File

@ -8,7 +8,6 @@ See the file 'doc/COPYING' for copying permission
""" """
from lib.core.exception import sqlmapSyntaxException from lib.core.exception import sqlmapSyntaxException
from plugins.generic.syntax import Syntax as GenericSyntax from plugins.generic.syntax import Syntax as GenericSyntax
class Syntax(GenericSyntax): class Syntax(GenericSyntax):

View File

@ -8,7 +8,6 @@ See the file 'doc/COPYING' for copying permission
""" """
from lib.core.exception import sqlmapSyntaxException from lib.core.exception import sqlmapSyntaxException
from plugins.generic.syntax import Syntax as GenericSyntax from plugins.generic.syntax import Syntax as GenericSyntax
class Syntax(GenericSyntax): class Syntax(GenericSyntax):