bulk of fixes

This commit is contained in:
Miroslav Stampar 2011-07-02 22:48:56 +00:00
parent 861cdb1b14
commit 34d9a91af1
8 changed files with 145 additions and 151 deletions

View File

@ -2767,3 +2767,18 @@ def expandMnemonics(mnemonics, parser, args):
else: else:
errMsg = "mnemonic '%s' requires value of type '%s'" % (name, found.type) errMsg = "mnemonic '%s' requires value of type '%s'" % (name, found.type)
raise sqlmapSyntaxException, errMsg raise sqlmapSyntaxException, errMsg
def safeCSValue(value):
"""
Returns value safe for CSV dumping.
Reference: http://stackoverflow.com/questions/769621/dealing-with-commas-in-a-csv-file
"""
retVal = value
if isinstance(retVal, basestring):
if not (retVal[0] == retVal[-1] == '"'):
if any(map(lambda x: x in retVal, ['"', ',', '\n'])):
retVal = '"%s"' % retVal.replace('"', '""')
return retVal

View File

@ -18,6 +18,7 @@ from lib.core.common import getUnicode
from lib.core.common import normalizeUnicode from lib.core.common import normalizeUnicode
from lib.core.common import openFile from lib.core.common import openFile
from lib.core.common import restoreDumpMarkedChars from lib.core.common import restoreDumpMarkedChars
from lib.core.common import safeCSValue
from lib.core.data import conf from lib.core.data import conf
from lib.core.data import kb from lib.core.data import kb
from lib.core.data import logger from lib.core.data import logger
@ -392,9 +393,9 @@ class Dump:
if not conf.replicate: if not conf.replicate:
if not conf.multipleTargets and field == fields: if not conf.multipleTargets and field == fields:
dataToDumpFile(dumpFP, "%s" % column) dataToDumpFile(dumpFP, "%s" % safeCSValue(column))
elif not conf.multipleTargets: elif not conf.multipleTargets:
dataToDumpFile(dumpFP, "%s," % column) dataToDumpFile(dumpFP, "%s," % safeCSValue(column))
field += 1 field += 1
@ -432,9 +433,9 @@ class Dump:
if not conf.replicate: if not conf.replicate:
if not conf.multipleTargets and field == fields: if not conf.multipleTargets and field == fields:
dataToDumpFile(dumpFP, "\"%s\"" % value) dataToDumpFile(dumpFP, "%s" % safeCSValue(value))
elif not conf.multipleTargets: elif not conf.multipleTargets:
dataToDumpFile(dumpFP, "\"%s\"," % value) dataToDumpFile(dumpFP, "%s," % safeCSValue(value))
field += 1 field += 1

View File

@ -11,11 +11,15 @@ import difflib
import threading import threading
import time import time
from thread import error as threadError
from lib.core.data import kb from lib.core.data import kb
from lib.core.data import logger from lib.core.data import logger
from lib.core.datatype import advancedDict from lib.core.datatype import advancedDict
from lib.core.enums import PAYLOAD from lib.core.enums import PAYLOAD
from lib.core.exception import sqlmapConnectionException
from lib.core.exception import sqlmapThreadException from lib.core.exception import sqlmapThreadException
from lib.core.exception import sqlmapValueException
from lib.core.settings import MAX_NUMBER_OF_THREADS from lib.core.settings import MAX_NUMBER_OF_THREADS
from lib.core.settings import PYVERSION from lib.core.settings import PYVERSION
@ -68,7 +72,7 @@ def exceptionHandledFunction(threadFunction):
print print
logger.error("thread %s: %s" % (threading.currentThread().getName(), errMsg)) logger.error("thread %s: %s" % (threading.currentThread().getName(), errMsg))
def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardException=True, threadChoice=False): def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardException=True, threadChoice=False, startThreadMsg=True):
threads = [] threads = []
kb.multiThreadMode = True kb.multiThreadMode = True
@ -92,8 +96,9 @@ def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardExceptio
logger.warn(warnMsg) logger.warn(warnMsg)
if numThreads > 1: if numThreads > 1:
infoMsg = "starting %d threads" % numThreads if startThreadMsg:
logger.info(infoMsg) infoMsg = "starting %d threads" % numThreads
logger.info(infoMsg)
else: else:
threadFunction() threadFunction()
return return
@ -108,7 +113,13 @@ def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardExceptio
else: else:
thread.setDaemon(True) thread.setDaemon(True)
thread.start() try:
thread.start()
except threadError, errMsg:
errMsg = "error occured while starting new thread ('%s')" % errMsg
logger.critical(errMsg)
break
threads.append(thread) threads.append(thread)
# And wait for them to all finish # And wait for them to all finish
@ -122,11 +133,9 @@ def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardExceptio
time.sleep(1) time.sleep(1)
except KeyboardInterrupt: except KeyboardInterrupt:
print
kb.threadContinue = False kb.threadContinue = False
kb.threadException = True kb.threadException = True
print '\r',
logger.info("waiting for threads to finish (Ctrl+C was pressed)") logger.info("waiting for threads to finish (Ctrl+C was pressed)")
try: try:
@ -139,6 +148,20 @@ def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardExceptio
if forwardException: if forwardException:
raise raise
except (sqlmapConnectionException, sqlmapValueException), errMsg:
print
kb.threadException = True
logger.error("thread %s: %s" % (threading.currentThread().getName(), errMsg))
except:
from lib.core.common import unhandledExceptionMessage
print
kb.threadException = True
errMsg = unhandledExceptionMessage()
logger.error("thread %s: %s" % (threading.currentThread().getName(), errMsg))
traceback.print_exc()
finally: finally:
kb.multiThreadMode = False kb.multiThreadMode = False
kb.bruteMode = False kb.bruteMode = False

View File

@ -44,6 +44,8 @@ from lib.core.settings import INFERENCE_GREATER_CHAR
from lib.core.settings import INFERENCE_EQUALS_CHAR from lib.core.settings import INFERENCE_EQUALS_CHAR
from lib.core.settings import INFERENCE_NOT_EQUALS_CHAR from lib.core.settings import INFERENCE_NOT_EQUALS_CHAR
from lib.core.settings import PYVERSION from lib.core.settings import PYVERSION
from lib.core.threads import getCurrentThreadData
from lib.core.threads import runThreads
from lib.core.unescaper import unescaper from lib.core.unescaper import unescaper
from lib.request.connect import Connect as Request from lib.request.connect import Connect as Request
@ -301,26 +303,31 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
# Go multi-threading (--threads > 1) # Go multi-threading (--threads > 1)
if conf.threads > 1 and isinstance(length, int) and length > 1: if conf.threads > 1 and isinstance(length, int) and length > 1:
value = [ None ] * length value = []
index = [ firstChar ] # As list for python nested function scoping threadData = getCurrentThreadData()
idxlock = threading.Lock()
iolock = threading.Lock() threadData.shared.value = [ None ] * length
valuelock = threading.Lock() threadData.shared.index = [ firstChar ] # As list for python nested function scoping
kb.threadContinue = True
lockNames = ('iolock', 'idxlock', 'valuelock')
for lock in lockNames:
kb.locks[lock] = threading.Lock()
try:
def blindThread():
threadData = getCurrentThreadData()
def downloadThread():
try:
while kb.threadContinue: while kb.threadContinue:
idxlock.acquire() kb.locks.idxlock.acquire()
if index[0] >= length: if threadData.shared.index[0] >= length:
idxlock.release() kb.locks.idxlock.release()
return return
index[0] += 1 threadData.shared.index[0] += 1
curidx = index[0] curidx = threadData.shared.index[0]
idxlock.release() kb.locks.idxlock.release()
if kb.threadContinue: if kb.threadContinue:
charStart = time.time() charStart = time.time()
@ -330,14 +337,14 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
else: else:
break break
valuelock.acquire() kb.locks.valuelock.acquire()
value[curidx-1] = val threadData.shared.value[curidx-1] = val
currentValue = list(value) currentValue = list(threadData.shared.value)
valuelock.release() kb.locks.valuelock.release()
if kb.threadContinue: if kb.threadContinue:
if showEta: if showEta:
etaProgressUpdate(time.time() - charStart, index[0]) etaProgressUpdate(time.time() - charStart, threadData.shared.index[0])
elif conf.verbose >= 1: elif conf.verbose >= 1:
startCharIndex = 0 startCharIndex = 0
endCharIndex = 0 endCharIndex = 0
@ -370,14 +377,14 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
status = ' %d/%d (%d%s)' % (count, length, round(100.0*count/length), '%') status = ' %d/%d (%d%s)' % (count, length, round(100.0*count/length), '%')
output += status if count != length else " "*len(status) output += status if count != length else " "*len(status)
iolock.acquire() kb.locks.iolock.acquire()
dataToStdout("\r[%s] [INFO] retrieved: %s" % (time.strftime("%X"), filterControlChars(output))) dataToStdout("\r[%s] [INFO] retrieved: %s" % (time.strftime("%X"), filterControlChars(output)))
iolock.release() kb.locks.iolock.release()
if not kb.threadContinue: if not kb.threadContinue:
if int(threading.currentThread().getName()) == numThreads - 1: if int(threading.currentThread().getName()) == numThreads - 1:
partialValue = unicode() partialValue = unicode()
for v in value: for v in threadData.shared.value:
if v is None: if v is None:
break break
elif isinstance(v, basestring): elif isinstance(v, basestring):
@ -386,57 +393,14 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
if len(partialValue) > 0: if len(partialValue) > 0:
dataToSessionFile(replaceNewlineTabs(partialValue)) dataToSessionFile(replaceNewlineTabs(partialValue))
except (sqlmapConnectionException, sqlmapValueException), errMsg: runThreads(numThreads, blindThread, startThreadMsg=False)
print
kb.threadException = True
logger.error("thread %s: %s" % (threading.currentThread().getName(), errMsg))
except KeyboardInterrupt:
kb.threadException = True
print
logger.debug("waiting for threads to finish")
try:
while (threading.activeCount() > 1):
pass
except KeyboardInterrupt:
raise sqlmapThreadException, "user aborted"
except:
print
kb.threadException = True
errMsg = unhandledExceptionMessage()
logger.error("thread %s: %s" % (threading.currentThread().getName(), errMsg))
traceback.print_exc()
# Start the threads
for numThread in range(numThreads):
thread = threading.Thread(target=downloadThread, name=str(numThread))
if PYVERSION >= "2.6":
thread.daemon = True
else:
thread.setDaemon(True)
thread.start()
threads.append(thread)
# And wait for them to all finish
try:
alive = True
while alive:
alive = False
for thread in threads:
if thread.isAlive():
alive = True
time.sleep(1)
except KeyboardInterrupt: except KeyboardInterrupt:
kb.threadContinue = False
raise raise
finally:
value = threadData.shared.value
infoMsg = None infoMsg = None
# If we have got one single character not correctly fetched it # If we have got one single character not correctly fetched it

View File

@ -318,22 +318,22 @@ def errorUse(expression, expected=None, resumeValue=True, dump=False):
stopLimit = 1 stopLimit = 1
threadData = getCurrentThreadData()
threadData.shared.limits = range(startLimit, stopLimit)
numThreads = min(conf.threads, len(threadData.shared.limits))
threadData.shared.outputs = []
if stopLimit > TURN_OFF_RESUME_INFO_LIMIT:
kb.suppressResumeInfo = True
debugMsg = "suppressing possible resume console info because of "
debugMsg += "large number of rows. It might take too long"
logger.debug(debugMsg)
lockNames = ('limits', 'outputs')
for lock in lockNames:
kb.locks[lock] = threading.Lock()
try: try:
threadData = getCurrentThreadData()
threadData.shared.limits = range(startLimit, stopLimit)
numThreads = min(conf.threads, len(threadData.shared.limits))
threadData.shared.outputs = []
if stopLimit > TURN_OFF_RESUME_INFO_LIMIT:
kb.suppressResumeInfo = True
debugMsg = "suppressing possible resume console info because of "
debugMsg += "large number of rows. It might take too long"
logger.debug(debugMsg)
lockNames = ('limits', 'outputs')
for lock in lockNames:
kb.locks[lock] = threading.Lock()
def errorThread(): def errorThread():
threadData = getCurrentThreadData() threadData = getCurrentThreadData()
@ -366,12 +366,6 @@ def errorUse(expression, expected=None, resumeValue=True, dump=False):
warnMsg += "will display partial output" warnMsg += "will display partial output"
logger.warn(warnMsg) logger.warn(warnMsg)
except sqlmapConnectionException, e:
errMsg = "connection exception detected. sqlmap "
errMsg += "will display partial output"
errMsg += "'%s'" % e
logger.critical(errMsg)
finally: finally:
outputs = threadData.shared.outputs outputs = threadData.shared.outputs
kb.suppressResumeInfo = False kb.suppressResumeInfo = False

View File

@ -259,22 +259,22 @@ def unionUse(expression, unpack=True, dump=False):
stopLimit = 1 stopLimit = 1
threadData = getCurrentThreadData()
threadData.shared.limits = range(startLimit, stopLimit)
numThreads = min(conf.threads, len(threadData.shared.limits))
threadData.shared.value = ""
if stopLimit > TURN_OFF_RESUME_INFO_LIMIT:
kb.suppressResumeInfo = True
debugMsg = "suppressing possible resume console info because of "
debugMsg += "large number of rows. It might take too long"
logger.debug(debugMsg)
lockNames = ('limits', 'value')
for lock in lockNames:
kb.locks[lock] = threading.Lock()
try: try:
threadData = getCurrentThreadData()
threadData.shared.limits = range(startLimit, stopLimit)
numThreads = min(conf.threads, len(threadData.shared.limits))
threadData.shared.value = ""
if stopLimit > TURN_OFF_RESUME_INFO_LIMIT:
kb.suppressResumeInfo = True
debugMsg = "suppressing possible resume console info because of "
debugMsg += "large number of rows. It might take too long"
logger.debug(debugMsg)
lockNames = ('limits', 'value')
for lock in lockNames:
kb.locks[lock] = threading.Lock()
def unionThread(): def unionThread():
threadData = getCurrentThreadData() threadData = getCurrentThreadData()
@ -334,12 +334,6 @@ def unionUse(expression, unpack=True, dump=False):
warnMsg += "will display partial output" warnMsg += "will display partial output"
logger.warn(warnMsg) logger.warn(warnMsg)
except sqlmapConnectionException, e:
errMsg = "connection exception detected. sqlmap "
errMsg += "will display partial output"
errMsg += "'%s'" % e
logger.critical(errMsg)
finally: finally:
value = threadData.shared.value value = threadData.shared.value
kb.suppressResumeInfo = False kb.suppressResumeInfo = False

View File

@ -13,6 +13,7 @@ import time
from lib.core.agent import agent from lib.core.agent import agent
from lib.core.common import arrayizeValue from lib.core.common import arrayizeValue
from lib.core.common import Backend from lib.core.common import Backend
from lib.core.common import clearConsoleLine
from lib.core.common import dataToStdout from lib.core.common import dataToStdout
from lib.core.common import getRange from lib.core.common import getRange
from lib.core.common import getCompiledRegex from lib.core.common import getCompiledRegex
@ -463,7 +464,7 @@ class Enumeration:
query += " WHERE " query += " WHERE "
if Backend.isDbms(DBMS.MYSQL) and kb.data.has_information_schema: if Backend.isDbms(DBMS.MYSQL) and kb.data.has_information_schema:
query += " OR ".join("%s LIKE '%%%s%%'" % (condition, user) for user in sorted(users)) query += " OR ".join("%s LIKE '%s'" % (condition, user) for user in sorted(users))
else: else:
query += " OR ".join("%s = '%s'" % (condition, user) for user in sorted(users)) query += " OR ".join("%s = '%s'" % (condition, user) for user in sorted(users))
@ -1161,7 +1162,7 @@ class Enumeration:
infoMsg = "fetching columns " infoMsg = "fetching columns "
if len(colList) > 0: if len(colList) > 0:
condQuery = " AND (%s)" % " OR ".join("%s LIKE '%%%s%%'" % (condition, unsafeSQLIdentificatorNaming(col)) for col in sorted(colList)) condQuery = " AND (%s)" % " OR ".join("%s LIKE '%s'" % (condition, unsafeSQLIdentificatorNaming(col)) for col in sorted(colList))
likeMsg = "like '%s' " % ", ".join(unsafeSQLIdentificatorNaming(col) for col in sorted(colList)) likeMsg = "like '%s' " % ", ".join(unsafeSQLIdentificatorNaming(col) for col in sorted(colList))
infoMsg += likeMsg infoMsg += likeMsg
else: else:
@ -1703,30 +1704,36 @@ class Enumeration:
plusOne = False plusOne = False
indexRange = getRange(count, dump=True, plusOne=plusOne) indexRange = getRange(count, dump=True, plusOne=plusOne)
for index in indexRange: try:
for column in colList: for index in indexRange:
if column not in lengths: for column in colList:
lengths[column] = 0 if column not in lengths:
lengths[column] = 0
if column not in entries: if column not in entries:
entries[column] = [] entries[column] = []
if Backend.getIdentifiedDbms() in ( DBMS.MYSQL, DBMS.PGSQL ): if Backend.getIdentifiedDbms() in ( DBMS.MYSQL, DBMS.PGSQL ):
query = rootQuery.blind.query % (column, conf.db, conf.tbl, index) query = rootQuery.blind.query % (column, conf.db, conf.tbl, index)
elif Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2): elif Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2):
query = rootQuery.blind.query % (column, column, query = rootQuery.blind.query % (column, column,
tbl.upper() if not conf.db else ("%s.%s" % (conf.db.upper(), tbl.upper())), tbl.upper() if not conf.db else ("%s.%s" % (conf.db.upper(), tbl.upper())),
index) index)
elif Backend.isDbms(DBMS.SQLITE): elif Backend.isDbms(DBMS.SQLITE):
query = rootQuery.blind.query % (column, tbl, index) query = rootQuery.blind.query % (column, tbl, index)
elif Backend.isDbms(DBMS.FIREBIRD): elif Backend.isDbms(DBMS.FIREBIRD):
query = rootQuery.blind.query % (index, column, tbl) query = rootQuery.blind.query % (index, column, tbl)
value = inject.getValue(query, inband=False, error=False, dump=True) value = inject.getValue(query, inband=False, error=False, dump=True)
lengths[column] = max(lengths[column], len(value) if value else 0) lengths[column] = max(lengths[column], len(value) if value else 0)
entries[column].append(value) entries[column].append(value)
except KeyboardInterrupt:
clearConsoleLine()
warnMsg = "Ctrl+C detected in dumping phase"
logger.warn(warnMsg)
for column, columnEntries in entries.items(): for column, columnEntries in entries.items():
length = max(lengths[column], len(column)) length = max(lengths[column], len(column))
@ -1750,10 +1757,6 @@ class Enumeration:
warnMsg += "on database '%s'" % unsafeSQLIdentificatorNaming(conf.db) warnMsg += "on database '%s'" % unsafeSQLIdentificatorNaming(conf.db)
logger.warn(warnMsg) logger.warn(warnMsg)
except KeyboardInterrupt:
warnMsg = "Ctrl+C detected in dumping phase"
logger.warn(warnMsg)
except sqlmapConnectionException, e: except sqlmapConnectionException, e:
errMsg = "connection exception detected in dumping phase: " errMsg = "connection exception detected in dumping phase: "
errMsg += "'%s'" % e errMsg += "'%s'" % e

View File

@ -156,14 +156,14 @@ class Miscellaneous:
def likeOrExact(self, what): def likeOrExact(self, what):
message = "do you want sqlmap to consider provided %s(s):\n" % what message = "do you want sqlmap to consider provided %s(s):\n" % what
message += "[1] as LIKE %s names (default)\n" % what message += "[1] as LIKE %s names\n" % what
message += "[2] as exact %s names" % what message += "[2] as exact %s names (default)" % what
choice = readInput(message, default="1") choice = readInput(message, default="2")
if not choice or choice == "1": if not choice or choice == "1":
choice = "1" choice = "1"
condParam = " LIKE '%%%s%%'" condParam = " LIKE '%%%s%%'" # this doesn't work, neither not sure it ever did
elif choice.isdigit() and choice == "2": elif choice.isdigit() and choice == "2":
condParam = "='%s'" condParam = "='%s'"
else: else: