some more refactoring

This commit is contained in:
Miroslav Stampar 2012-07-01 01:19:54 +02:00
parent f6509db31a
commit 21d9ae0a2c
3 changed files with 50 additions and 91 deletions

View File

@ -6,6 +6,7 @@ See the file 'doc/COPYING' for copying permission
""" """
import codecs import codecs
import contextlib
import cookielib import cookielib
import copy import copy
import ctypes import ctypes
@ -102,14 +103,10 @@ from lib.core.settings import USER_AGENT_ALIASES
from lib.core.settings import PARTIAL_VALUE_MARKER from lib.core.settings import PARTIAL_VALUE_MARKER
from lib.core.settings import ERROR_PARSING_REGEXES from lib.core.settings import ERROR_PARSING_REGEXES
from lib.core.settings import PRINTABLE_CHAR_REGEX from lib.core.settings import PRINTABLE_CHAR_REGEX
from lib.core.settings import DUMP_DEL_MARKER
from lib.core.settings import SQL_STATEMENTS from lib.core.settings import SQL_STATEMENTS
from lib.core.settings import SUPPORTED_DBMS from lib.core.settings import SUPPORTED_DBMS
from lib.core.settings import UNKNOWN_DBMS_VERSION from lib.core.settings import UNKNOWN_DBMS_VERSION
from lib.core.settings import DEFAULT_MSSQL_SCHEMA from lib.core.settings import DEFAULT_MSSQL_SCHEMA
from lib.core.settings import DUMP_NEWLINE_MARKER
from lib.core.settings import DUMP_CR_MARKER
from lib.core.settings import DUMP_TAB_MARKER
from lib.core.settings import PARAMETER_AMP_MARKER from lib.core.settings import PARAMETER_AMP_MARKER
from lib.core.settings import PARAMETER_SEMICOLON_MARKER from lib.core.settings import PARAMETER_SEMICOLON_MARKER
from lib.core.settings import LARGE_OUTPUT_THRESHOLD from lib.core.settings import LARGE_OUTPUT_THRESHOLD
@ -492,27 +489,17 @@ class Backend:
# Reference: http://code.activestate.com/recipes/325205-cache-decorator-in-python-24/ # Reference: http://code.activestate.com/recipes/325205-cache-decorator-in-python-24/
def cachedmethod(f, cache={}): def cachedmethod(f, cache={}):
def g(*args, **kwargs): def _(*args, **kwargs):
key = (f, tuple(args), frozenset(kwargs.items())) key = (f, tuple(args), frozenset(kwargs.items()))
if key not in cache: if key not in cache:
cache[key] = f(*args, **kwargs) cache[key] = f(*args, **kwargs)
return cache[key] return cache[key]
return g return _
def paramToDict(place, parameters=None): def paramToDict(place, parameters=None):
""" """
Split the parameters into names and values, check if these parameters Split the parameters into names and values, check if these parameters
are within the testable parameters and return in a dictionary. are within the testable parameters and return in a dictionary.
@param place: where sqlmap has to work, can be GET, POST or Cookie.
@type place: C{str}
@param parameters: parameters string in the format for instance
'p1=v1&p2=v2' (GET and POST) or 'p1=v1;p2=v2' (Cookie).
@type parameters: C{str}
@return: the parameters in a dictionary.
@rtype: C{str}
""" """
testableParameters = OrderedDict() testableParameters = OrderedDict()
@ -706,6 +693,10 @@ def singleTimeLogMessage(message, level=logging.INFO, flag=None):
logger.log(level, message) logger.log(level, message)
def dataToStdout(data, forceOutput=False): def dataToStdout(data, forceOutput=False):
"""
Writes text to the stdout (console) stream
"""
if not kb.get("threadException"): if not kb.get("threadException"):
if forceOutput or not getCurrentThreadData().disableStdOut: if forceOutput or not getCurrentThreadData().disableStdOut:
try: try:
@ -793,13 +784,10 @@ def readInput(message, default=None, checkBatch=True):
else: else:
logging._acquireLock() logging._acquireLock()
dataToStdout("\r%s" % message, True) dataToStdout("\r%s" % message, True)
data = raw_input() data = raw_input() or default
#data = raw_input(message.encode(sys.stdout.encoding or UNICODE_ENCODING)) #data = raw_input(message.encode(sys.stdout.encoding or UNICODE_ENCODING))
logging._releaseLock() logging._releaseLock()
if not data:
data = default
return data return data
def randomRange(start=0, stop=1000): def randomRange(start=0, stop=1000):
@ -845,19 +833,6 @@ def checkFile(filename):
if not os.path.exists(filename): if not os.path.exists(filename):
raise sqlmapFilePathException, "unable to read file '%s'" % filename raise sqlmapFilePathException, "unable to read file '%s'" % filename
def replaceNewlineTabs(value, stdout=False):
if value is None:
return
if stdout:
retVal = value.replace("\n", " ").replace("\r", " ").replace("\t", " ")
else:
retVal = value.replace("\n", DUMP_NEWLINE_MARKER).replace("\r", DUMP_CR_MARKER).replace("\t", DUMP_TAB_MARKER)
retVal = retVal.replace(kb.chars.delimiter, DUMP_DEL_MARKER)
return retVal
def restoreDumpMarkedChars(value, onlyNewlineTab=False): def restoreDumpMarkedChars(value, onlyNewlineTab=False):
retVal = value retVal = value
@ -1070,16 +1045,15 @@ def parseTargetUrl():
if CUSTOM_INJECTION_MARK_CHAR in conf.url: if CUSTOM_INJECTION_MARK_CHAR in conf.url:
conf.url = conf.url.replace('?', URI_QUESTION_MARKER) conf.url = conf.url.replace('?', URI_QUESTION_MARKER)
__urlSplit = urlparse.urlsplit(conf.url) urlSplit = urlparse.urlsplit(conf.url)
__hostnamePort = __urlSplit[1].split(":") if not re.search("\[.+\]", __urlSplit[1]) else filter(None, (re.search("\[.+\]", __urlSplit[1]).group(0), re.search("\](:(?P<port>\d+))?", __urlSplit[1]).group("port"))) hostnamePort = urlSplit[1].split(":") if not re.search("\[.+\]", urlSplit[1]) else filter(None, (re.search("\[.+\]", urlSplit[1]).group(0), re.search("\](:(?P<port>\d+))?", urlSplit[1]).group("port")))
conf.scheme = __urlSplit[0].strip().lower() if not conf.forceSSL else "https" conf.scheme = urlSplit[0].strip().lower() if not conf.forceSSL else "https"
conf.path = __urlSplit[2].strip() conf.path = urlSplit[2].strip()
conf.hostname = __hostnamePort[0].strip() conf.hostname = hostnamePort[0].strip()
conf.ipv6 = conf.hostname != conf.hostname.strip("[]") conf.ipv6 = conf.hostname != conf.hostname.strip("[]")
conf.hostname = conf.hostname.strip("[]") conf.hostname = conf.hostname.strip("[]")
try: try:
_ = conf.hostname.encode("idna") _ = conf.hostname.encode("idna")
@ -1090,9 +1064,9 @@ def parseTargetUrl():
errMsg = "invalid target url" errMsg = "invalid target url"
raise sqlmapSyntaxException, errMsg raise sqlmapSyntaxException, errMsg
if len(__hostnamePort) == 2: if len(hostnamePort) == 2:
try: try:
conf.port = int(__hostnamePort[1]) conf.port = int(hostnamePort[1])
except: except:
errMsg = "invalid target url" errMsg = "invalid target url"
raise sqlmapSyntaxException, errMsg raise sqlmapSyntaxException, errMsg
@ -1101,8 +1075,8 @@ def parseTargetUrl():
else: else:
conf.port = 80 conf.port = 80
if __urlSplit[3]: if urlSplit[3]:
conf.parameters[PLACE.GET] = urldecode(__urlSplit[3]) if __urlSplit[3] and urlencode(DEFAULT_GET_POST_DELIMITER, None) not in __urlSplit[3] else __urlSplit[3] conf.parameters[PLACE.GET] = urldecode(urlSplit[3]) if urlSplit[3] and urlencode(DEFAULT_GET_POST_DELIMITER, None) not in urlSplit[3] else urlSplit[3]
conf.url = "%s://%s:%d%s" % (conf.scheme, ("[%s]" % conf.hostname) if conf.ipv6 else conf.hostname, conf.port, conf.path) conf.url = "%s://%s:%d%s" % (conf.scheme, ("[%s]" % conf.hostname) if conf.ipv6 else conf.hostname, conf.port, conf.path)
conf.url = conf.url.replace(URI_QUESTION_MARKER, '?') conf.url = conf.url.replace(URI_QUESTION_MARKER, '?')
@ -1349,24 +1323,14 @@ def directoryPath(filepath):
Returns directory path for a given filepath Returns directory path for a given filepath
""" """
if isWindowsDriveLetterPath(filepath): return ntpath.dirname(filepath) if isWindowsDriveLetterPath(filepath) else posixpath.dirname(filepath)
retVal = ntpath.dirname(filepath)
else:
retVal = posixpath.dirname(filepath)
return retVal
def normalizePath(filepath): def normalizePath(filepath):
""" """
Returns normalized string representation of a given filepath Returns normalized string representation of a given filepath
""" """
if isWindowsDriveLetterPath(filepath): return ntpath.normpath(filepath) if isWindowsDriveLetterPath(filepath) else posixpath.normpath(filepath)
retVal = ntpath.normpath(filepath)
else:
retVal = posixpath.normpath(filepath)
return retVal
def safeStringFormat(format_, params): def safeStringFormat(format_, params):
""" """
@ -1379,16 +1343,13 @@ def safeStringFormat(format_, params):
retVal = retVal.replace("%s", params) retVal = retVal.replace("%s", params)
else: else:
count, index = 0, 0 count, index = 0, 0
while index != -1: while index != -1:
index = retVal.find("%s") index = retVal.find("%s")
if index != -1: if index != -1:
if count < len(params): if count < len(params):
retVal = retVal[:index] + getUnicode(params[count]) + retVal[index + 2:] retVal = retVal[:index] + getUnicode(params[count]) + retVal[index + 2:]
else: else:
raise sqlmapNoneDataException, "wrong number of parameters during string formatting" raise sqlmapNoneDataException, "wrong number of parameters during string formatting"
count += 1 count += 1
return retVal return retVal
@ -1404,10 +1365,8 @@ def getFilteredPageContent(page, onlyText=True):
# only if the page's charset has been successfully identified # only if the page's charset has been successfully identified
if isinstance(page, unicode): if isinstance(page, unicode):
retVal = re.sub(r"(?si)<script.+?</script>|<!--.+?-->|<style.+?</style>%s" % (r"|<[^>]+>|\t|\n|\r" if onlyText else ""), " ", page) retVal = re.sub(r"(?si)<script.+?</script>|<!--.+?-->|<style.+?</style>%s" % (r"|<[^>]+>|\t|\n|\r" if onlyText else ""), " ", page)
while retVal.find(" ") != -1: while retVal.find(" ") != -1:
retVal = retVal.replace(" ", " ") retVal = retVal.replace(" ", " ")
retVal = htmlunescape(retVal) retVal = htmlunescape(retVal)
return retVal return retVal
@ -1421,8 +1380,8 @@ def getPageWordSet(page):
# only if the page's charset has been successfully identified # only if the page's charset has been successfully identified
if isinstance(page, unicode): if isinstance(page, unicode):
page = getFilteredPageContent(page) _ = getFilteredPageContent(page)
retVal = set(re.findall(r"\w+", page)) retVal = set(re.findall(r"\w+", _))
return retVal return retVal
@ -1557,8 +1516,8 @@ def getConsoleWidth(default=80):
width = None width = None
if 'COLUMNS' in os.environ and os.environ['COLUMNS'].isdigit(): if os.getenv("COLUMNS", "").isdigit():
width = int(os.environ['COLUMNS']) width = int(os.getenv("COLUMNS"))
else: else:
output=execute('stty size', shell=True, stdout=PIPE, stderr=PIPE).stdout.read() output=execute('stty size', shell=True, stdout=PIPE, stderr=PIPE).stdout.read()
items = output.split() items = output.split()
@ -1576,7 +1535,7 @@ def getConsoleWidth(default=80):
except: except:
pass pass
return width if width else default return width or default
def clearConsoleLine(forceOutput=False): def clearConsoleLine(forceOutput=False):
""" """
@ -1591,7 +1550,7 @@ def parseXmlFile(xmlFile, handler):
Parses XML file by a given handler Parses XML file by a given handler
""" """
with StringIO(readCachedFileContent(xmlFile)) as stream: with contextlib.closing(StringIO(readCachedFileContent(xmlFile))) as stream:
parse(stream, handler) parse(stream, handler)
def getSPQLSnippet(dbms, name, **variables): def getSPQLSnippet(dbms, name, **variables):
@ -1632,8 +1591,7 @@ def readCachedFileContent(filename, mode='rb'):
if filename not in kb.cache.content: if filename not in kb.cache.content:
checkFile(filename) checkFile(filename)
with codecs.open(filename, mode, UNICODE_ENCODING) as f: with codecs.open(filename, mode, UNICODE_ENCODING) as f:
content = f.read() kb.cache.content[filename] = f.read()
kb.cache.content[filename] = content
return kb.cache.content[filename] return kb.cache.content[filename]
@ -1675,12 +1633,7 @@ def average(values):
Computes the arithmetic mean of a list of numbers. Computes the arithmetic mean of a list of numbers.
""" """
retVal = None return (sum(values) / len(values)) if values else None
if values:
retVal = sum(values) / len(values)
return retVal
def calculateDeltaSeconds(start): def calculateDeltaSeconds(start):
""" """
@ -1891,8 +1844,8 @@ def getUnicode(value, encoding=None, system=False, noneToNull=False):
def longestCommonPrefix(*sequences): def longestCommonPrefix(*sequences):
""" """
Returns longest common prefix occuring in given sequences Returns longest common prefix occuring in given sequences
Reference: http://boredzo.org/blog/archives/2007-01-06/longest-common-prefix-in-python-2
""" """
# Reference: http://boredzo.org/blog/archives/2007-01-06/longest-common-prefix-in-python-2
if len(sequences) == 1: if len(sequences) == 1:
return sequences[0] return sequences[0]
@ -2356,6 +2309,10 @@ def isInferenceAvailable():
return any(isTechniqueAvailable(_) for _ in (PAYLOAD.TECHNIQUE.BOOLEAN, PAYLOAD.TECHNIQUE.STACKED, PAYLOAD.TECHNIQUE.TIME)) return any(isTechniqueAvailable(_) for _ in (PAYLOAD.TECHNIQUE.BOOLEAN, PAYLOAD.TECHNIQUE.STACKED, PAYLOAD.TECHNIQUE.TIME))
def setOptimize(): def setOptimize():
"""
Sets options turned on by switch '-o'
"""
#conf.predictOutput = True #conf.predictOutput = True
conf.keepAlive = True conf.keepAlive = True
conf.threads = 3 if conf.threads < 3 else conf.threads conf.threads = 3 if conf.threads < 3 else conf.threads
@ -2719,7 +2676,7 @@ def safeSQLIdentificatorNaming(name, isTable=False):
parts = name.split('.') parts = name.split('.')
for i in xrange(len(parts)): for i in xrange(len(parts)):
if not re.match(r"\A[A-Za-z0-9_@\$]+\Z", parts[i]): # reference: http://stackoverflow.com/questions/954884/what-special-characters-are-allowed-in-t-sql-column-name if not re.match(r"\A[A-Za-z0-9_@\$]+\Z", parts[i]): # Reference: http://stackoverflow.com/questions/954884/what-special-characters-are-allowed-in-t-sql-column-name
if Backend.getIdentifiedDbms() in (DBMS.MYSQL, DBMS.ACCESS): if Backend.getIdentifiedDbms() in (DBMS.MYSQL, DBMS.ACCESS):
parts[i] = "`%s`" % parts[i].strip("`") parts[i] = "`%s`" % parts[i].strip("`")
elif Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.PGSQL, DBMS.DB2): elif Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.PGSQL, DBMS.DB2):
@ -3082,10 +3039,7 @@ def unserializeObject(value):
Unserializes object from given serialized form Unserializes object from given serialized form
""" """
retVal = None return base64unpickle(value) if value else None
if value:
retVal = base64unpickle(value)
return retVal
def resetCounter(technique): def resetCounter(technique):
""" """
@ -3176,10 +3130,7 @@ def extractExpectedValue(value, expected):
value = None value = None
elif expected == EXPECTED.INT: elif expected == EXPECTED.INT:
if isinstance(value, basestring): if isinstance(value, basestring):
if value.isdigit(): value = int(value) if value.isdigit() else None
value = int(value)
else:
value = None
return value return value
@ -3211,6 +3162,10 @@ def hashDBRetrieve(key, unserialize=False, checkConf=False):
return _ return _
def resetCookieJar(cookieJar): def resetCookieJar(cookieJar):
"""
Cleans cookies from a given cookie jar
"""
if not conf.loC: if not conf.loC:
cookieJar.clear() cookieJar.clear()
else: else:
@ -3223,5 +3178,9 @@ def resetCookieJar(cookieJar):
raise sqlmapGenericException, errMsg raise sqlmapGenericException, errMsg
def prioritySortColumns(columns): def prioritySortColumns(columns):
"""
Sorts given column names by length in ascending order while those containing
string 'id' go first
"""
_ = lambda x: x and "id" in x.lower() _ = lambda x: x and "id" in x.lower()
return sorted(sorted(columns, key=len), lambda x, y: -1 if _(x) and not _(y) else 1 if not _(x) and _(y) else 0) return sorted(sorted(columns, key=len), lambda x, y: -1 if _(x) and not _(y) else 1 if not _(x) and _(y) else 0)

View File

@ -7,7 +7,9 @@ See the file 'doc/COPYING' for copying permission
import logging import logging
import os import os
import re
import subprocess import subprocess
import string
import sys import sys
from lib.core.enums import CUSTOM_LOGGING from lib.core.enums import CUSTOM_LOGGING
@ -44,12 +46,6 @@ LOGGER_HANDLER.setFormatter(FORMATTER)
LOGGER.addHandler(LOGGER_HANDLER) LOGGER.addHandler(LOGGER_HANDLER)
LOGGER.setLevel(logging.WARN) LOGGER.setLevel(logging.WARN)
# dump markers
DUMP_NEWLINE_MARKER = "__NEWLINE__"
DUMP_CR_MARKER = "__CARRIAGE_RETURN__"
DUMP_TAB_MARKER = "__TAB__"
DUMP_DEL_MARKER = "__DEL__"
# markers for special cases when parameter values contain html encoded characters # markers for special cases when parameter values contain html encoded characters
PARAMETER_AMP_MARKER = "__AMP__" PARAMETER_AMP_MARKER = "__AMP__"
PARAMETER_SEMICOLON_MARKER = "__SEMICOLON__" PARAMETER_SEMICOLON_MARKER = "__SEMICOLON__"
@ -475,6 +471,9 @@ MAX_TOTAL_REDIRECTIONS = 10
# Reference: http://www.tcpipguide.com/free/t_DNSLabelsNamesandSyntaxRules.htm # Reference: http://www.tcpipguide.com/free/t_DNSLabelsNamesandSyntaxRules.htm
MAX_DNS_LABEL = 63 MAX_DNS_LABEL = 63
# Alphabet used for prefix and suffix strings of name resolution requests in DNS technique (excluding hexadecimal chars for not mixing with inner content)
DNS_BOUNDARIES_ALPHABET = re.sub("[a-fA-F]", "", string.letters)
# Connection chunk size (processing large responses in chunks to avoid MemoryError crashes - e.g. large table dump in full UNION/inband injections) # Connection chunk size (processing large responses in chunks to avoid MemoryError crashes - e.g. large table dump in full UNION/inband injections)
MAX_CONNECTION_CHUNK_SIZE = 10 * 1024 * 1024 MAX_CONNECTION_CHUNK_SIZE = 10 * 1024 * 1024

View File

@ -29,6 +29,7 @@ from lib.core.data import kb
from lib.core.data import logger from lib.core.data import logger
from lib.core.data import queries from lib.core.data import queries
from lib.core.enums import DBMS from lib.core.enums import DBMS
from lib.core.settings import DNS_BOUNDARIES_ALPHABET
from lib.core.settings import MAX_DNS_LABEL from lib.core.settings import MAX_DNS_LABEL
from lib.core.settings import PARTIAL_VALUE_MARKER from lib.core.settings import PARTIAL_VALUE_MARKER
from lib.core.unescaper import unescaper from lib.core.unescaper import unescaper
@ -58,7 +59,7 @@ def dnsUse(payload, expression):
while True: while True:
count += 1 count += 1
prefix, suffix = ("%s" % randomStr(length=3, alphabet=re.sub("[a-fA-F]", "", string.letters)) for _ in xrange(2)) prefix, suffix = ("%s" % randomStr(length=3, alphabet=DNS_BOUNDARIES_ALPHABET) for _ in xrange(2))
chunk_length = MAX_DNS_LABEL / 2 if Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.MYSQL, DBMS.PGSQL) else MAX_DNS_LABEL / 4 - 2 chunk_length = MAX_DNS_LABEL / 2 if Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.MYSQL, DBMS.PGSQL) else MAX_DNS_LABEL / 4 - 2
_, _, _, _, _, _, fieldToCastStr, _ = agent.getFields(expression) _, _, _, _, _, _, fieldToCastStr, _ = agent.getFields(expression)
nulledCastedField = agent.nullAndCastField(fieldToCastStr) nulledCastedField = agent.nullAndCastField(fieldToCastStr)