Update for an Issue #189 (code refactoring of ProgressBar so it could be ready for usage in non-inference cases out of box)

This commit is contained in:
stamparm 2013-05-09 15:52:18 +02:00
parent fc57b7565d
commit 9fe5a8832f
2 changed files with 28 additions and 20 deletions

View File

@ -8,6 +8,7 @@ See the file 'doc/COPYING' for copying permission
from lib.core.common import getUnicode from lib.core.common import getUnicode
from lib.core.common import dataToStdout from lib.core.common import dataToStdout
from lib.core.data import conf from lib.core.data import conf
from lib.core.data import kb
class ProgressBar(object): class ProgressBar(object):
""" """
@ -22,6 +23,7 @@ class ProgressBar(object):
self._span = self._max - self._min self._span = self._max - self._min
self._width = totalWidth if totalWidth else conf.progressWidth self._width = totalWidth if totalWidth else conf.progressWidth
self._amount = 0 self._amount = 0
self._times = []
self.update() self.update()
def _convertSeconds(self, value): def _convertSeconds(self, value):
@ -50,7 +52,7 @@ class ProgressBar(object):
percentDone = int(percentDone) percentDone = int(percentDone)
# Figure out how many hash bars the percentage should be # Figure out how many hash bars the percentage should be
allFull = self._width - 2 allFull = self._width - len("100%% [] %s/%s ETA 00:00" % (self._max, self._max))
numHashes = (percentDone / 100.0) * allFull numHashes = (percentDone / 100.0) * allFull
numHashes = int(round(numHashes)) numHashes = int(round(numHashes))
@ -67,6 +69,22 @@ class ProgressBar(object):
percentString = getUnicode(percentDone) + "%" percentString = getUnicode(percentDone) + "%"
self._progBar = "%s %s" % (percentString, self._progBar) self._progBar = "%s %s" % (percentString, self._progBar)
def progress(self, deltaTime, newAmount, threads=1):
"""
This method saves item delta time and shows updated progress bar with calculated eta
"""
if len(self._times) <= ((self._max * 3) / 100) or newAmount == self._max:
eta = 0
else:
midTime = sum(self._times) / len(self._times)
midTimeWithLatest = (midTime + deltaTime) / 2
eta = midTimeWithLatest * (self._max - newAmount) / threads
self._times.append(deltaTime)
self.update(newAmount)
self.draw(eta)
def draw(self, eta=0): def draw(self, eta=0):
""" """
This method draws the progress bar if it has changed This method draws the progress bar if it has changed
@ -78,8 +96,11 @@ class ProgressBar(object):
if eta and self._amount < self._max: if eta and self._amount < self._max:
dataToStdout("\r%s %d/%d ETA %s" % (self._progBar, self._amount, self._max, self._convertSeconds(int(eta)))) dataToStdout("\r%s %d/%d ETA %s" % (self._progBar, self._amount, self._max, self._convertSeconds(int(eta))))
else: else:
blank = " " * (80 - len("\r%s %d/%d" % (self._progBar, self._amount, self._max))) dataToStdout("\r%s\r" % (" " * (self._width - 1)))
dataToStdout("\r%s %d/%d%s" % (self._progBar, self._amount, self._max, blank)) if self._amount < self._max:
dataToStdout("%s %d/%d" % (self._progBar, self._amount, self._max))
else:
kb.prependFlag = False
def __str__(self): def __str__(self):
""" """

View File

@ -140,7 +140,6 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
if showEta: if showEta:
progress = ProgressBar(maxValue=length) progress = ProgressBar(maxValue=length)
progressTime = []
if timeBasedCompare and conf.threads > 1: if timeBasedCompare and conf.threads > 1:
warnMsg = "multi-threading is considered unsafe in time-based data retrieval. Going to switch it off automatically" warnMsg = "multi-threading is considered unsafe in time-based data retrieval. Going to switch it off automatically"
@ -354,18 +353,6 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
return None return None
def etaProgressUpdate(charTime, index):
if len(progressTime) <= ((length * 3) / 100):
eta = 0
else:
midTime = sum(progressTime) / len(progressTime)
midTimeWithLatest = (midTime + charTime) / 2
eta = midTimeWithLatest * (length - index) / conf.threads
progressTime.append(charTime)
progress.update(index)
progress.draw(eta)
# Go multi-threading (--threads > 1) # Go multi-threading (--threads > 1)
if conf.threads > 1 and isinstance(length, int) and length > 1: if conf.threads > 1 and isinstance(length, int) and length > 1:
threadData = getCurrentThreadData() threadData = getCurrentThreadData()
@ -404,7 +391,7 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
if kb.threadContinue: if kb.threadContinue:
if showEta: if showEta:
etaProgressUpdate(time.time() - charStart, threadData.shared.index[0]) progress.progress(time.time() - charStart, threadData.shared.index[0], numThreads)
elif conf.verbose >= 1: elif conf.verbose >= 1:
startCharIndex = 0 startCharIndex = 0
endCharIndex = 0 endCharIndex = 0
@ -496,7 +483,7 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
# Did we have luck? # Did we have luck?
if result: if result:
if showEta: if showEta:
etaProgressUpdate(time.time() - charStart, len(commonValue)) progress.progress(time.time() - charStart, len(commonValue))
elif conf.verbose in (1, 2) or hasattr(conf, "api"): elif conf.verbose in (1, 2) or hasattr(conf, "api"):
dataToStdout(filterControlChars(commonValue[index - 1:])) dataToStdout(filterControlChars(commonValue[index - 1:]))
@ -546,7 +533,7 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
partialValue += val partialValue += val
if showEta: if showEta:
etaProgressUpdate(time.time() - charStart, index) progress.progress(time.time() - charStart, index)
elif conf.verbose in (1, 2) or hasattr(conf, "api"): elif conf.verbose in (1, 2) or hasattr(conf, "api"):
dataToStdout(filterControlChars(val)) dataToStdout(filterControlChars(val))
@ -578,7 +565,7 @@ def bisection(payload, expression, length=None, charsetType=None, firstChar=None
infoMsg = "\r[%s] [INFO] retrieved: %s %s\n" % (time.strftime("%X"), filterControlChars(finalValue), " " * retrievedLength) infoMsg = "\r[%s] [INFO] retrieved: %s %s\n" % (time.strftime("%X"), filterControlChars(finalValue), " " * retrievedLength)
dataToStdout(infoMsg) dataToStdout(infoMsg)
else: else:
if conf.verbose in (1, 2) or showEta and not hasattr(conf, "api"): if conf.verbose in (1, 2) and not showEta and not hasattr(conf, "api"):
dataToStdout("\n") dataToStdout("\n")
if (conf.verbose in (1, 2) and showEta) or conf.verbose >= 3: if (conf.verbose in (1, 2) and showEta) or conf.verbose >= 3: