From 34d9a91af1464a3f87689b8a29400ab0df85cb72 Mon Sep 17 00:00:00 2001 From: Miroslav Stampar Date: Sat, 2 Jul 2011 22:48:56 +0000 Subject: [PATCH] bulk of fixes --- lib/core/common.py | 15 +++++ lib/core/dump.py | 9 +-- lib/core/threads.py | 37 +++++++++-- lib/techniques/blind/inference.py | 102 ++++++++++-------------------- lib/techniques/error/use.py | 36 +++++------ lib/techniques/union/use.py | 36 +++++------ plugins/generic/enumeration.py | 53 ++++++++-------- plugins/generic/misc.py | 8 +-- 8 files changed, 145 insertions(+), 151 deletions(-) diff --git a/lib/core/common.py b/lib/core/common.py index 4df3aba11..3ebfe1618 100644 --- a/lib/core/common.py +++ b/lib/core/common.py @@ -2767,3 +2767,18 @@ def expandMnemonics(mnemonics, parser, args): else: errMsg = "mnemonic '%s' requires value of type '%s'" % (name, found.type) raise sqlmapSyntaxException, errMsg + +def safeCSValue(value): + """ + Returns value safe for CSV dumping. + Reference: http://stackoverflow.com/questions/769621/dealing-with-commas-in-a-csv-file + """ + + retVal = value + + if isinstance(retVal, basestring): + if not (retVal[0] == retVal[-1] == '"'): + if any(map(lambda x: x in retVal, ['"', ',', '\n'])): + retVal = '"%s"' % retVal.replace('"', '""') + + return retVal diff --git a/lib/core/dump.py b/lib/core/dump.py index 9a470d2cf..a66c7461f 100644 --- a/lib/core/dump.py +++ b/lib/core/dump.py @@ -18,6 +18,7 @@ from lib.core.common import getUnicode from lib.core.common import normalizeUnicode from lib.core.common import openFile from lib.core.common import restoreDumpMarkedChars +from lib.core.common import safeCSValue from lib.core.data import conf from lib.core.data import kb from lib.core.data import logger @@ -392,9 +393,9 @@ class Dump: if not conf.replicate: if not conf.multipleTargets and field == fields: - dataToDumpFile(dumpFP, "%s" % column) + dataToDumpFile(dumpFP, "%s" % safeCSValue(column)) elif not conf.multipleTargets: - dataToDumpFile(dumpFP, "%s," % column) + dataToDumpFile(dumpFP, "%s," % safeCSValue(column)) field += 1 @@ -432,9 +433,9 @@ class Dump: if not conf.replicate: if not conf.multipleTargets and field == fields: - dataToDumpFile(dumpFP, "\"%s\"" % value) + dataToDumpFile(dumpFP, "%s" % safeCSValue(value)) elif not conf.multipleTargets: - dataToDumpFile(dumpFP, "\"%s\"," % value) + dataToDumpFile(dumpFP, "%s," % safeCSValue(value)) field += 1 diff --git a/lib/core/threads.py b/lib/core/threads.py index e131f75db..4151ef259 100644 --- a/lib/core/threads.py +++ b/lib/core/threads.py @@ -11,11 +11,15 @@ import difflib import threading import time +from thread import error as threadError + from lib.core.data import kb from lib.core.data import logger from lib.core.datatype import advancedDict from lib.core.enums import PAYLOAD +from lib.core.exception import sqlmapConnectionException from lib.core.exception import sqlmapThreadException +from lib.core.exception import sqlmapValueException from lib.core.settings import MAX_NUMBER_OF_THREADS from lib.core.settings import PYVERSION @@ -68,7 +72,7 @@ def exceptionHandledFunction(threadFunction): print logger.error("thread %s: %s" % (threading.currentThread().getName(), errMsg)) -def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardException=True, threadChoice=False): +def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardException=True, threadChoice=False, startThreadMsg=True): threads = [] kb.multiThreadMode = True @@ -92,8 +96,9 @@ def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardExceptio logger.warn(warnMsg) if numThreads > 1: - infoMsg = "starting %d threads" % numThreads - logger.info(infoMsg) + if startThreadMsg: + infoMsg = "starting %d threads" % numThreads + logger.info(infoMsg) else: threadFunction() return @@ -108,7 +113,13 @@ def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardExceptio else: thread.setDaemon(True) - thread.start() + try: + thread.start() + except threadError, errMsg: + errMsg = "error occured while starting new thread ('%s')" % errMsg + logger.critical(errMsg) + break + threads.append(thread) # And wait for them to all finish @@ -122,11 +133,9 @@ def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardExceptio time.sleep(1) except KeyboardInterrupt: + print kb.threadContinue = False kb.threadException = True - - print '\r', - logger.info("waiting for threads to finish (Ctrl+C was pressed)") try: @@ -139,6 +148,20 @@ def runThreads(numThreads, threadFunction, cleanupFunction=None, forwardExceptio if forwardException: raise + except (sqlmapConnectionException, sqlmapValueException), errMsg: + print + kb.threadException = True + logger.error("thread %s: %s" % (threading.currentThread().getName(), errMsg)) + + except: + from lib.core.common import unhandledExceptionMessage + + print + kb.threadException = True + errMsg = unhandledExceptionMessage() + logger.error("thread %s: %s" % (threading.currentThread().getName(), errMsg)) + traceback.print_exc() + finally: kb.multiThreadMode = False kb.bruteMode = False diff --git a/lib/techniques/blind/inference.py b/lib/techniques/blind/inference.py index 07b568726..f11855cd4 100644 --- a/lib/techniques/blind/inference.py +++ b/lib/techniques/blind/inference.py @@ -44,6 +44,8 @@ from lib.core.settings import INFERENCE_GREATER_CHAR from lib.core.settings import INFERENCE_EQUALS_CHAR from lib.core.settings import INFERENCE_NOT_EQUALS_CHAR from lib.core.settings import PYVERSION +from lib.core.threads import getCurrentThreadData +from lib.core.threads import runThreads from lib.core.unescaper import unescaper from lib.request.connect import Connect as Request @@ -301,26 +303,31 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None # Go multi-threading (--threads > 1) if conf.threads > 1 and isinstance(length, int) and length > 1: - value = [ None ] * length - index = [ firstChar ] # As list for python nested function scoping - idxlock = threading.Lock() - iolock = threading.Lock() - valuelock = threading.Lock() - kb.threadContinue = True + value = [] + threadData = getCurrentThreadData() + + threadData.shared.value = [ None ] * length + threadData.shared.index = [ firstChar ] # As list for python nested function scoping + + lockNames = ('iolock', 'idxlock', 'valuelock') + for lock in lockNames: + kb.locks[lock] = threading.Lock() + + try: + def blindThread(): + threadData = getCurrentThreadData() - def downloadThread(): - try: while kb.threadContinue: - idxlock.acquire() + kb.locks.idxlock.acquire() - if index[0] >= length: - idxlock.release() + if threadData.shared.index[0] >= length: + kb.locks.idxlock.release() return - index[0] += 1 - curidx = index[0] - idxlock.release() + threadData.shared.index[0] += 1 + curidx = threadData.shared.index[0] + kb.locks.idxlock.release() if kb.threadContinue: charStart = time.time() @@ -330,14 +337,14 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None else: break - valuelock.acquire() - value[curidx-1] = val - currentValue = list(value) - valuelock.release() + kb.locks.valuelock.acquire() + threadData.shared.value[curidx-1] = val + currentValue = list(threadData.shared.value) + kb.locks.valuelock.release() if kb.threadContinue: if showEta: - etaProgressUpdate(time.time() - charStart, index[0]) + etaProgressUpdate(time.time() - charStart, threadData.shared.index[0]) elif conf.verbose >= 1: startCharIndex = 0 endCharIndex = 0 @@ -370,14 +377,14 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None status = ' %d/%d (%d%s)' % (count, length, round(100.0*count/length), '%') output += status if count != length else " "*len(status) - iolock.acquire() + kb.locks.iolock.acquire() dataToStdout("\r[%s] [INFO] retrieved: %s" % (time.strftime("%X"), filterControlChars(output))) - iolock.release() + kb.locks.iolock.release() if not kb.threadContinue: if int(threading.currentThread().getName()) == numThreads - 1: partialValue = unicode() - for v in value: + for v in threadData.shared.value: if v is None: break elif isinstance(v, basestring): @@ -386,57 +393,14 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None if len(partialValue) > 0: dataToSessionFile(replaceNewlineTabs(partialValue)) - except (sqlmapConnectionException, sqlmapValueException), errMsg: - print - kb.threadException = True - logger.error("thread %s: %s" % (threading.currentThread().getName(), errMsg)) - - except KeyboardInterrupt: - kb.threadException = True - - print - logger.debug("waiting for threads to finish") - - try: - while (threading.activeCount() > 1): - pass - - except KeyboardInterrupt: - raise sqlmapThreadException, "user aborted" - - except: - print - kb.threadException = True - errMsg = unhandledExceptionMessage() - logger.error("thread %s: %s" % (threading.currentThread().getName(), errMsg)) - traceback.print_exc() - - # Start the threads - for numThread in range(numThreads): - thread = threading.Thread(target=downloadThread, name=str(numThread)) - - if PYVERSION >= "2.6": - thread.daemon = True - else: - thread.setDaemon(True) - - thread.start() - threads.append(thread) - - # And wait for them to all finish - try: - alive = True - while alive: - alive = False - for thread in threads: - if thread.isAlive(): - alive = True - time.sleep(1) + runThreads(numThreads, blindThread, startThreadMsg=False) except KeyboardInterrupt: - kb.threadContinue = False raise + finally: + value = threadData.shared.value + infoMsg = None # If we have got one single character not correctly fetched it diff --git a/lib/techniques/error/use.py b/lib/techniques/error/use.py index ce7be2cec..47828f474 100644 --- a/lib/techniques/error/use.py +++ b/lib/techniques/error/use.py @@ -318,22 +318,22 @@ def errorUse(expression, expected=None, resumeValue=True, dump=False): stopLimit = 1 + threadData = getCurrentThreadData() + threadData.shared.limits = range(startLimit, stopLimit) + numThreads = min(conf.threads, len(threadData.shared.limits)) + threadData.shared.outputs = [] + + if stopLimit > TURN_OFF_RESUME_INFO_LIMIT: + kb.suppressResumeInfo = True + debugMsg = "suppressing possible resume console info because of " + debugMsg += "large number of rows. It might take too long" + logger.debug(debugMsg) + + lockNames = ('limits', 'outputs') + for lock in lockNames: + kb.locks[lock] = threading.Lock() + try: - threadData = getCurrentThreadData() - threadData.shared.limits = range(startLimit, stopLimit) - numThreads = min(conf.threads, len(threadData.shared.limits)) - threadData.shared.outputs = [] - - if stopLimit > TURN_OFF_RESUME_INFO_LIMIT: - kb.suppressResumeInfo = True - debugMsg = "suppressing possible resume console info because of " - debugMsg += "large number of rows. It might take too long" - logger.debug(debugMsg) - - lockNames = ('limits', 'outputs') - for lock in lockNames: - kb.locks[lock] = threading.Lock() - def errorThread(): threadData = getCurrentThreadData() @@ -366,12 +366,6 @@ def errorUse(expression, expected=None, resumeValue=True, dump=False): warnMsg += "will display partial output" logger.warn(warnMsg) - except sqlmapConnectionException, e: - errMsg = "connection exception detected. sqlmap " - errMsg += "will display partial output" - errMsg += "'%s'" % e - logger.critical(errMsg) - finally: outputs = threadData.shared.outputs kb.suppressResumeInfo = False diff --git a/lib/techniques/union/use.py b/lib/techniques/union/use.py index 745c942ba..4aa1a3a76 100644 --- a/lib/techniques/union/use.py +++ b/lib/techniques/union/use.py @@ -259,22 +259,22 @@ def unionUse(expression, unpack=True, dump=False): stopLimit = 1 + threadData = getCurrentThreadData() + threadData.shared.limits = range(startLimit, stopLimit) + numThreads = min(conf.threads, len(threadData.shared.limits)) + threadData.shared.value = "" + + if stopLimit > TURN_OFF_RESUME_INFO_LIMIT: + kb.suppressResumeInfo = True + debugMsg = "suppressing possible resume console info because of " + debugMsg += "large number of rows. It might take too long" + logger.debug(debugMsg) + + lockNames = ('limits', 'value') + for lock in lockNames: + kb.locks[lock] = threading.Lock() + try: - threadData = getCurrentThreadData() - threadData.shared.limits = range(startLimit, stopLimit) - numThreads = min(conf.threads, len(threadData.shared.limits)) - threadData.shared.value = "" - - if stopLimit > TURN_OFF_RESUME_INFO_LIMIT: - kb.suppressResumeInfo = True - debugMsg = "suppressing possible resume console info because of " - debugMsg += "large number of rows. It might take too long" - logger.debug(debugMsg) - - lockNames = ('limits', 'value') - for lock in lockNames: - kb.locks[lock] = threading.Lock() - def unionThread(): threadData = getCurrentThreadData() @@ -334,12 +334,6 @@ def unionUse(expression, unpack=True, dump=False): warnMsg += "will display partial output" logger.warn(warnMsg) - except sqlmapConnectionException, e: - errMsg = "connection exception detected. sqlmap " - errMsg += "will display partial output" - errMsg += "'%s'" % e - logger.critical(errMsg) - finally: value = threadData.shared.value kb.suppressResumeInfo = False diff --git a/plugins/generic/enumeration.py b/plugins/generic/enumeration.py index d54d98874..0fb2206a3 100644 --- a/plugins/generic/enumeration.py +++ b/plugins/generic/enumeration.py @@ -13,6 +13,7 @@ import time from lib.core.agent import agent from lib.core.common import arrayizeValue from lib.core.common import Backend +from lib.core.common import clearConsoleLine from lib.core.common import dataToStdout from lib.core.common import getRange from lib.core.common import getCompiledRegex @@ -463,7 +464,7 @@ class Enumeration: query += " WHERE " if Backend.isDbms(DBMS.MYSQL) and kb.data.has_information_schema: - query += " OR ".join("%s LIKE '%%%s%%'" % (condition, user) for user in sorted(users)) + query += " OR ".join("%s LIKE '%s'" % (condition, user) for user in sorted(users)) else: query += " OR ".join("%s = '%s'" % (condition, user) for user in sorted(users)) @@ -1161,7 +1162,7 @@ class Enumeration: infoMsg = "fetching columns " if len(colList) > 0: - condQuery = " AND (%s)" % " OR ".join("%s LIKE '%%%s%%'" % (condition, unsafeSQLIdentificatorNaming(col)) for col in sorted(colList)) + condQuery = " AND (%s)" % " OR ".join("%s LIKE '%s'" % (condition, unsafeSQLIdentificatorNaming(col)) for col in sorted(colList)) likeMsg = "like '%s' " % ", ".join(unsafeSQLIdentificatorNaming(col) for col in sorted(colList)) infoMsg += likeMsg else: @@ -1703,30 +1704,36 @@ class Enumeration: plusOne = False indexRange = getRange(count, dump=True, plusOne=plusOne) - for index in indexRange: - for column in colList: - if column not in lengths: - lengths[column] = 0 + try: + for index in indexRange: + for column in colList: + if column not in lengths: + lengths[column] = 0 - if column not in entries: - entries[column] = [] + if column not in entries: + entries[column] = [] - if Backend.getIdentifiedDbms() in ( DBMS.MYSQL, DBMS.PGSQL ): - query = rootQuery.blind.query % (column, conf.db, conf.tbl, index) - elif Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2): - query = rootQuery.blind.query % (column, column, - tbl.upper() if not conf.db else ("%s.%s" % (conf.db.upper(), tbl.upper())), - index) - elif Backend.isDbms(DBMS.SQLITE): - query = rootQuery.blind.query % (column, tbl, index) + if Backend.getIdentifiedDbms() in ( DBMS.MYSQL, DBMS.PGSQL ): + query = rootQuery.blind.query % (column, conf.db, conf.tbl, index) + elif Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2): + query = rootQuery.blind.query % (column, column, + tbl.upper() if not conf.db else ("%s.%s" % (conf.db.upper(), tbl.upper())), + index) + elif Backend.isDbms(DBMS.SQLITE): + query = rootQuery.blind.query % (column, tbl, index) - elif Backend.isDbms(DBMS.FIREBIRD): - query = rootQuery.blind.query % (index, column, tbl) + elif Backend.isDbms(DBMS.FIREBIRD): + query = rootQuery.blind.query % (index, column, tbl) - value = inject.getValue(query, inband=False, error=False, dump=True) + value = inject.getValue(query, inband=False, error=False, dump=True) - lengths[column] = max(lengths[column], len(value) if value else 0) - entries[column].append(value) + lengths[column] = max(lengths[column], len(value) if value else 0) + entries[column].append(value) + + except KeyboardInterrupt: + clearConsoleLine() + warnMsg = "Ctrl+C detected in dumping phase" + logger.warn(warnMsg) for column, columnEntries in entries.items(): length = max(lengths[column], len(column)) @@ -1750,10 +1757,6 @@ class Enumeration: warnMsg += "on database '%s'" % unsafeSQLIdentificatorNaming(conf.db) logger.warn(warnMsg) - except KeyboardInterrupt: - warnMsg = "Ctrl+C detected in dumping phase" - logger.warn(warnMsg) - except sqlmapConnectionException, e: errMsg = "connection exception detected in dumping phase: " errMsg += "'%s'" % e diff --git a/plugins/generic/misc.py b/plugins/generic/misc.py index c8c6edc64..bd118a698 100644 --- a/plugins/generic/misc.py +++ b/plugins/generic/misc.py @@ -156,14 +156,14 @@ class Miscellaneous: def likeOrExact(self, what): message = "do you want sqlmap to consider provided %s(s):\n" % what - message += "[1] as LIKE %s names (default)\n" % what - message += "[2] as exact %s names" % what + message += "[1] as LIKE %s names\n" % what + message += "[2] as exact %s names (default)" % what - choice = readInput(message, default="1") + choice = readInput(message, default="2") if not choice or choice == "1": choice = "1" - condParam = " LIKE '%%%s%%'" + condParam = " LIKE '%%%s%%'" # this doesn't work, neither not sure it ever did elif choice.isdigit() and choice == "2": condParam = "='%s'" else: