Update for an Issue #342 and #372

This commit is contained in:
Miroslav Stampar 2013-01-31 10:01:52 +01:00
parent 9b4eaa9272
commit 2420a4b626
2 changed files with 10 additions and 8 deletions

View File

@ -888,23 +888,17 @@ class Agent(object):
lengthQuery = queries[Backend.getIdentifiedDbms()].length.query lengthQuery = queries[Backend.getIdentifiedDbms()].length.query
select = re.search("\ASELECT\s+", expression, re.I) select = re.search("\ASELECT\s+", expression, re.I)
selectTopExpr = re.search("\ASELECT\s+TOP\s+[\d]+\s+(.+?)\s+FROM", expression, re.I) selectTopExpr = re.search("\ASELECT\s+TOP\s+[\d]+\s+(.+?)\s+FROM", expression, re.I)
selectDistinctExpr = re.search("\ASELECT\s+DISTINCT\((.+?)\)\s+FROM", expression, re.I)
selectFromExpr = re.search("\ASELECT\s+(.+?)\s+FROM", expression, re.I) selectFromExpr = re.search("\ASELECT\s+(.+?)\s+FROM", expression, re.I)
selectExpr = re.search("\ASELECT\s+(.+)$", expression, re.I) selectExpr = re.search("\ASELECT\s+(.+)$", expression, re.I)
_, _, _, _, _, _, fieldsStr, _ = self.getFields(expression) _, _, _, _, _, _, fieldsStr, _ = self.getFields(expression)
if any((selectTopExpr, selectDistinctExpr, selectFromExpr, selectExpr)): if any((selectTopExpr, selectFromExpr, selectExpr)):
query = fieldsStr query = fieldsStr
else: else:
query = expression query = expression
if selectDistinctExpr: if select:
lengthExpr = "SELECT %s FROM (%s)" % (lengthQuery % query, expression)
if Backend.getIdentifiedDbms() in (DBMS.MYSQL, DBMS.PGSQL):
lengthExpr += " AS %s" % randomStr(lowercase=True)
elif select:
lengthExpr = expression.replace(query, lengthQuery % query, 1) lengthExpr = expression.replace(query, lengthQuery % query, 1)
else: else:
lengthExpr = lengthQuery % expression lengthExpr = lengthQuery % expression

View File

@ -25,6 +25,7 @@ from lib.core.common import isTechniqueAvailable
from lib.core.common import parseUnionPage from lib.core.common import parseUnionPage
from lib.core.common import popValue from lib.core.common import popValue
from lib.core.common import pushValue from lib.core.common import pushValue
from lib.core.common import randomStr
from lib.core.common import readInput from lib.core.common import readInput
from lib.core.common import singleTimeWarnMessage from lib.core.common import singleTimeWarnMessage
from lib.core.data import conf from lib.core.data import conf
@ -76,6 +77,13 @@ def _goInference(payload, expression, charsetType=None, firstChar=None, lastChar
if not (timeBasedCompare and kb.dnsTest): if not (timeBasedCompare and kb.dnsTest):
if (conf.eta or conf.threads > 1) and Backend.getIdentifiedDbms() and not re.search("(COUNT|LTRIM)\(", expression, re.I) and not timeBasedCompare: if (conf.eta or conf.threads > 1) and Backend.getIdentifiedDbms() and not re.search("(COUNT|LTRIM)\(", expression, re.I) and not timeBasedCompare:
if field and re.search("\ASELECT\s+DISTINCT\((.+?)\)\s+FROM", expression, re.I):
expression = "SELECT %s FROM (%s)" % (field, expression)
if Backend.getIdentifiedDbms() in (DBMS.MYSQL, DBMS.PGSQL):
expression += " AS %s" % randomStr(lowercase=True)
if field and conf.hexConvert: if field and conf.hexConvert:
nulledCastedField = agent.nullAndCastField(field) nulledCastedField = agent.nullAndCastField(field)
injExpression = expression.replace(field, nulledCastedField, 1) injExpression = expression.replace(field, nulledCastedField, 1)