some more refactoring

This commit is contained in:
Miroslav Stampar 2012-07-01 01:19:54 +02:00
parent f6509db31a
commit 21d9ae0a2c
3 changed files with 50 additions and 91 deletions

View File

@ -6,6 +6,7 @@ See the file 'doc/COPYING' for copying permission
"""
import codecs
import contextlib
import cookielib
import copy
import ctypes
@ -102,14 +103,10 @@ from lib.core.settings import USER_AGENT_ALIASES
from lib.core.settings import PARTIAL_VALUE_MARKER
from lib.core.settings import ERROR_PARSING_REGEXES
from lib.core.settings import PRINTABLE_CHAR_REGEX
from lib.core.settings import DUMP_DEL_MARKER
from lib.core.settings import SQL_STATEMENTS
from lib.core.settings import SUPPORTED_DBMS
from lib.core.settings import UNKNOWN_DBMS_VERSION
from lib.core.settings import DEFAULT_MSSQL_SCHEMA
from lib.core.settings import DUMP_NEWLINE_MARKER
from lib.core.settings import DUMP_CR_MARKER
from lib.core.settings import DUMP_TAB_MARKER
from lib.core.settings import PARAMETER_AMP_MARKER
from lib.core.settings import PARAMETER_SEMICOLON_MARKER
from lib.core.settings import LARGE_OUTPUT_THRESHOLD
@ -492,27 +489,17 @@ class Backend:
# Reference: http://code.activestate.com/recipes/325205-cache-decorator-in-python-24/
def cachedmethod(f, cache={}):
def g(*args, **kwargs):
def _(*args, **kwargs):
key = (f, tuple(args), frozenset(kwargs.items()))
if key not in cache:
cache[key] = f(*args, **kwargs)
return cache[key]
return g
return _
def paramToDict(place, parameters=None):
"""
Split the parameters into names and values, check if these parameters
are within the testable parameters and return in a dictionary.
@param place: where sqlmap has to work, can be GET, POST or Cookie.
@type place: C{str}
@param parameters: parameters string in the format for instance
'p1=v1&p2=v2' (GET and POST) or 'p1=v1;p2=v2' (Cookie).
@type parameters: C{str}
@return: the parameters in a dictionary.
@rtype: C{str}
"""
testableParameters = OrderedDict()
@ -706,6 +693,10 @@ def singleTimeLogMessage(message, level=logging.INFO, flag=None):
logger.log(level, message)
def dataToStdout(data, forceOutput=False):
"""
Writes text to the stdout (console) stream
"""
if not kb.get("threadException"):
if forceOutput or not getCurrentThreadData().disableStdOut:
try:
@ -793,13 +784,10 @@ def readInput(message, default=None, checkBatch=True):
else:
logging._acquireLock()
dataToStdout("\r%s" % message, True)
data = raw_input()
data = raw_input() or default
#data = raw_input(message.encode(sys.stdout.encoding or UNICODE_ENCODING))
logging._releaseLock()
if not data:
data = default
return data
def randomRange(start=0, stop=1000):
@ -845,19 +833,6 @@ def checkFile(filename):
if not os.path.exists(filename):
raise sqlmapFilePathException, "unable to read file '%s'" % filename
def replaceNewlineTabs(value, stdout=False):
if value is None:
return
if stdout:
retVal = value.replace("\n", " ").replace("\r", " ").replace("\t", " ")
else:
retVal = value.replace("\n", DUMP_NEWLINE_MARKER).replace("\r", DUMP_CR_MARKER).replace("\t", DUMP_TAB_MARKER)
retVal = retVal.replace(kb.chars.delimiter, DUMP_DEL_MARKER)
return retVal
def restoreDumpMarkedChars(value, onlyNewlineTab=False):
retVal = value
@ -1070,16 +1045,15 @@ def parseTargetUrl():
if CUSTOM_INJECTION_MARK_CHAR in conf.url:
conf.url = conf.url.replace('?', URI_QUESTION_MARKER)
__urlSplit = urlparse.urlsplit(conf.url)
__hostnamePort = __urlSplit[1].split(":") if not re.search("\[.+\]", __urlSplit[1]) else filter(None, (re.search("\[.+\]", __urlSplit[1]).group(0), re.search("\](:(?P<port>\d+))?", __urlSplit[1]).group("port")))
urlSplit = urlparse.urlsplit(conf.url)
hostnamePort = urlSplit[1].split(":") if not re.search("\[.+\]", urlSplit[1]) else filter(None, (re.search("\[.+\]", urlSplit[1]).group(0), re.search("\](:(?P<port>\d+))?", urlSplit[1]).group("port")))
conf.scheme = __urlSplit[0].strip().lower() if not conf.forceSSL else "https"
conf.path = __urlSplit[2].strip()
conf.hostname = __hostnamePort[0].strip()
conf.scheme = urlSplit[0].strip().lower() if not conf.forceSSL else "https"
conf.path = urlSplit[2].strip()
conf.hostname = hostnamePort[0].strip()
conf.ipv6 = conf.hostname != conf.hostname.strip("[]")
conf.hostname = conf.hostname.strip("[]")
try:
_ = conf.hostname.encode("idna")
@ -1090,9 +1064,9 @@ def parseTargetUrl():
errMsg = "invalid target url"
raise sqlmapSyntaxException, errMsg
if len(__hostnamePort) == 2:
if len(hostnamePort) == 2:
try:
conf.port = int(__hostnamePort[1])
conf.port = int(hostnamePort[1])
except:
errMsg = "invalid target url"
raise sqlmapSyntaxException, errMsg
@ -1101,8 +1075,8 @@ def parseTargetUrl():
else:
conf.port = 80
if __urlSplit[3]:
conf.parameters[PLACE.GET] = urldecode(__urlSplit[3]) if __urlSplit[3] and urlencode(DEFAULT_GET_POST_DELIMITER, None) not in __urlSplit[3] else __urlSplit[3]
if urlSplit[3]:
conf.parameters[PLACE.GET] = urldecode(urlSplit[3]) if urlSplit[3] and urlencode(DEFAULT_GET_POST_DELIMITER, None) not in urlSplit[3] else urlSplit[3]
conf.url = "%s://%s:%d%s" % (conf.scheme, ("[%s]" % conf.hostname) if conf.ipv6 else conf.hostname, conf.port, conf.path)
conf.url = conf.url.replace(URI_QUESTION_MARKER, '?')
@ -1349,24 +1323,14 @@ def directoryPath(filepath):
Returns directory path for a given filepath
"""
if isWindowsDriveLetterPath(filepath):
retVal = ntpath.dirname(filepath)
else:
retVal = posixpath.dirname(filepath)
return retVal
return ntpath.dirname(filepath) if isWindowsDriveLetterPath(filepath) else posixpath.dirname(filepath)
def normalizePath(filepath):
"""
Returns normalized string representation of a given filepath
"""
if isWindowsDriveLetterPath(filepath):
retVal = ntpath.normpath(filepath)
else:
retVal = posixpath.normpath(filepath)
return retVal
return ntpath.normpath(filepath) if isWindowsDriveLetterPath(filepath) else posixpath.normpath(filepath)
def safeStringFormat(format_, params):
"""
@ -1379,16 +1343,13 @@ def safeStringFormat(format_, params):
retVal = retVal.replace("%s", params)
else:
count, index = 0, 0
while index != -1:
index = retVal.find("%s")
if index != -1:
if count < len(params):
retVal = retVal[:index] + getUnicode(params[count]) + retVal[index + 2:]
else:
raise sqlmapNoneDataException, "wrong number of parameters during string formatting"
count += 1
return retVal
@ -1404,10 +1365,8 @@ def getFilteredPageContent(page, onlyText=True):
# only if the page's charset has been successfully identified
if isinstance(page, unicode):
retVal = re.sub(r"(?si)<script.+?</script>|<!--.+?-->|<style.+?</style>%s" % (r"|<[^>]+>|\t|\n|\r" if onlyText else ""), " ", page)
while retVal.find(" ") != -1:
retVal = retVal.replace(" ", " ")
retVal = htmlunescape(retVal)
return retVal
@ -1421,8 +1380,8 @@ def getPageWordSet(page):
# only if the page's charset has been successfully identified
if isinstance(page, unicode):
page = getFilteredPageContent(page)
retVal = set(re.findall(r"\w+", page))
_ = getFilteredPageContent(page)
retVal = set(re.findall(r"\w+", _))
return retVal
@ -1557,8 +1516,8 @@ def getConsoleWidth(default=80):
width = None
if 'COLUMNS' in os.environ and os.environ['COLUMNS'].isdigit():
width = int(os.environ['COLUMNS'])
if os.getenv("COLUMNS", "").isdigit():
width = int(os.getenv("COLUMNS"))
else:
output=execute('stty size', shell=True, stdout=PIPE, stderr=PIPE).stdout.read()
items = output.split()
@ -1576,7 +1535,7 @@ def getConsoleWidth(default=80):
except:
pass
return width if width else default
return width or default
def clearConsoleLine(forceOutput=False):
"""
@ -1591,7 +1550,7 @@ def parseXmlFile(xmlFile, handler):
Parses XML file by a given handler
"""
with StringIO(readCachedFileContent(xmlFile)) as stream:
with contextlib.closing(StringIO(readCachedFileContent(xmlFile))) as stream:
parse(stream, handler)
def getSPQLSnippet(dbms, name, **variables):
@ -1632,8 +1591,7 @@ def readCachedFileContent(filename, mode='rb'):
if filename not in kb.cache.content:
checkFile(filename)
with codecs.open(filename, mode, UNICODE_ENCODING) as f:
content = f.read()
kb.cache.content[filename] = content
kb.cache.content[filename] = f.read()
return kb.cache.content[filename]
@ -1675,12 +1633,7 @@ def average(values):
Computes the arithmetic mean of a list of numbers.
"""
retVal = None
if values:
retVal = sum(values) / len(values)
return retVal
return (sum(values) / len(values)) if values else None
def calculateDeltaSeconds(start):
"""
@ -1891,8 +1844,8 @@ def getUnicode(value, encoding=None, system=False, noneToNull=False):
def longestCommonPrefix(*sequences):
"""
Returns longest common prefix occuring in given sequences
Reference: http://boredzo.org/blog/archives/2007-01-06/longest-common-prefix-in-python-2
"""
# Reference: http://boredzo.org/blog/archives/2007-01-06/longest-common-prefix-in-python-2
if len(sequences) == 1:
return sequences[0]
@ -2356,6 +2309,10 @@ def isInferenceAvailable():
return any(isTechniqueAvailable(_) for _ in (PAYLOAD.TECHNIQUE.BOOLEAN, PAYLOAD.TECHNIQUE.STACKED, PAYLOAD.TECHNIQUE.TIME))
def setOptimize():
"""
Sets options turned on by switch '-o'
"""
#conf.predictOutput = True
conf.keepAlive = True
conf.threads = 3 if conf.threads < 3 else conf.threads
@ -2719,7 +2676,7 @@ def safeSQLIdentificatorNaming(name, isTable=False):
parts = name.split('.')
for i in xrange(len(parts)):
if not re.match(r"\A[A-Za-z0-9_@\$]+\Z", parts[i]): # reference: http://stackoverflow.com/questions/954884/what-special-characters-are-allowed-in-t-sql-column-name
if not re.match(r"\A[A-Za-z0-9_@\$]+\Z", parts[i]): # Reference: http://stackoverflow.com/questions/954884/what-special-characters-are-allowed-in-t-sql-column-name
if Backend.getIdentifiedDbms() in (DBMS.MYSQL, DBMS.ACCESS):
parts[i] = "`%s`" % parts[i].strip("`")
elif Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.PGSQL, DBMS.DB2):
@ -3082,10 +3039,7 @@ def unserializeObject(value):
Unserializes object from given serialized form
"""
retVal = None
if value:
retVal = base64unpickle(value)
return retVal
return base64unpickle(value) if value else None
def resetCounter(technique):
"""
@ -3176,10 +3130,7 @@ def extractExpectedValue(value, expected):
value = None
elif expected == EXPECTED.INT:
if isinstance(value, basestring):
if value.isdigit():
value = int(value)
else:
value = None
value = int(value) if value.isdigit() else None
return value
@ -3211,6 +3162,10 @@ def hashDBRetrieve(key, unserialize=False, checkConf=False):
return _
def resetCookieJar(cookieJar):
"""
Cleans cookies from a given cookie jar
"""
if not conf.loC:
cookieJar.clear()
else:
@ -3223,5 +3178,9 @@ def resetCookieJar(cookieJar):
raise sqlmapGenericException, errMsg
def prioritySortColumns(columns):
"""
Sorts given column names by length in ascending order while those containing
string 'id' go first
"""
_ = lambda x: x and "id" in x.lower()
return sorted(sorted(columns, key=len), lambda x, y: -1 if _(x) and not _(y) else 1 if not _(x) and _(y) else 0)

View File

@ -7,7 +7,9 @@ See the file 'doc/COPYING' for copying permission
import logging
import os
import re
import subprocess
import string
import sys
from lib.core.enums import CUSTOM_LOGGING
@ -44,12 +46,6 @@ LOGGER_HANDLER.setFormatter(FORMATTER)
LOGGER.addHandler(LOGGER_HANDLER)
LOGGER.setLevel(logging.WARN)
# dump markers
DUMP_NEWLINE_MARKER = "__NEWLINE__"
DUMP_CR_MARKER = "__CARRIAGE_RETURN__"
DUMP_TAB_MARKER = "__TAB__"
DUMP_DEL_MARKER = "__DEL__"
# markers for special cases when parameter values contain html encoded characters
PARAMETER_AMP_MARKER = "__AMP__"
PARAMETER_SEMICOLON_MARKER = "__SEMICOLON__"
@ -475,6 +471,9 @@ MAX_TOTAL_REDIRECTIONS = 10
# Reference: http://www.tcpipguide.com/free/t_DNSLabelsNamesandSyntaxRules.htm
MAX_DNS_LABEL = 63
# Alphabet used for prefix and suffix strings of name resolution requests in DNS technique (excluding hexadecimal chars for not mixing with inner content)
DNS_BOUNDARIES_ALPHABET = re.sub("[a-fA-F]", "", string.letters)
# Connection chunk size (processing large responses in chunks to avoid MemoryError crashes - e.g. large table dump in full UNION/inband injections)
MAX_CONNECTION_CHUNK_SIZE = 10 * 1024 * 1024

View File

@ -29,6 +29,7 @@ from lib.core.data import kb
from lib.core.data import logger
from lib.core.data import queries
from lib.core.enums import DBMS
from lib.core.settings import DNS_BOUNDARIES_ALPHABET
from lib.core.settings import MAX_DNS_LABEL
from lib.core.settings import PARTIAL_VALUE_MARKER
from lib.core.unescaper import unescaper
@ -58,7 +59,7 @@ def dnsUse(payload, expression):
while True:
count += 1
prefix, suffix = ("%s" % randomStr(length=3, alphabet=re.sub("[a-fA-F]", "", string.letters)) for _ in xrange(2))
prefix, suffix = ("%s" % randomStr(length=3, alphabet=DNS_BOUNDARIES_ALPHABET) for _ in xrange(2))
chunk_length = MAX_DNS_LABEL / 2 if Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.MYSQL, DBMS.PGSQL) else MAX_DNS_LABEL / 4 - 2
_, _, _, _, _, _, fieldToCastStr, _ = agent.getFields(expression)
nulledCastedField = agent.nullAndCastField(fieldToCastStr)