Merge branch 'develop' of https://github.com/explosion/spaCy into develop

This commit is contained in:
Ines Montani 2020-07-29 11:36:45 +02:00
commit e257e66ab9
9 changed files with 148 additions and 68 deletions

View File

@ -67,10 +67,7 @@ def evaluate(
corpus = Corpus(data_path, data_path)
nlp = util.load_model(model)
dev_dataset = list(corpus.dev_dataset(nlp, gold_preproc=gold_preproc))
begin = timer()
scores = nlp.evaluate(dev_dataset, verbose=False)
end = timer()
nwords = sum(len(ex.predicted) for ex in dev_dataset)
metrics = {
"TOK": "token_acc",
"TAG": "tag_acc",
@ -82,17 +79,21 @@ def evaluate(
"NER P": "ents_p",
"NER R": "ents_r",
"NER F": "ents_f",
"Textcat": "cats_score",
"Sent P": "sents_p",
"Sent R": "sents_r",
"Sent F": "sents_f",
"TEXTCAT": "cats_score",
"SENT P": "sents_p",
"SENT R": "sents_r",
"SENT F": "sents_f",
"SPEED": "speed",
}
results = {}
for metric, key in metrics.items():
if key in scores:
if key == "cats_score":
metric = metric + " (" + scores.get("cats_score_desc", "unk") + ")"
results[metric] = f"{scores[key]*100:.2f}"
if key == "speed":
results[metric] = f"{scores[key]:.0f}"
else:
results[metric] = f"{scores[key]*100:.2f}"
data = {re.sub(r"[\s/]", "_", k.lower()): v for k, v in results.items()}
msg.table(results, title="Results")

View File

@ -1,5 +1,4 @@
from typing import Optional, Dict, Any, Tuple, Union, Callable, List
from timeit import default_timer as timer
import srsly
import tqdm
from pathlib import Path
@ -248,14 +247,11 @@ def create_evaluation_callback(
dev_examples = list(dev_examples)
n_words = sum(len(ex.predicted) for ex in dev_examples)
batch_size = cfg["eval_batch_size"]
start_time = timer()
if optimizer.averages:
with nlp.use_params(optimizer.averages):
scores = nlp.evaluate(dev_examples, batch_size=batch_size)
else:
scores = nlp.evaluate(dev_examples, batch_size=batch_size)
end_time = timer()
wps = n_words / (end_time - start_time)
# Calculate a weighted sum based on score_weights for the main score
weights = cfg["score_weights"]
try:
@ -264,7 +260,6 @@ def create_evaluation_callback(
keys = list(scores.keys())
err = Errors.E983.format(dict="score_weights", key=str(e), keys=keys)
raise KeyError(err)
scores["speed"] = wps
return weighted_score, scores
return evaluate
@ -446,7 +441,7 @@ def update_meta(
training: Union[Dict[str, Any], Config], nlp: Language, info: Dict[str, Any]
) -> None:
nlp.meta["performance"] = {}
for metric in training["scores_weights"]:
for metric in training["score_weights"]:
nlp.meta["performance"][metric] = info["other_scores"][metric]
for pipe_name in nlp.pipe_names:
nlp.meta["performance"][f"{pipe_name}_loss"] = info["losses"][pipe_name]

View File

@ -432,12 +432,12 @@ class Errors:
"Current DocBin: {current}\nOther DocBin: {other}")
E169 = ("Can't find module: {module}")
E170 = ("Cannot apply transition {name}: invalid for the current state.")
E171 = ("Matcher.add received invalid on_match callback argument: expected "
E171 = ("Matcher.add received invalid 'on_match' callback argument: expected "
"callable or None, but got: {arg_type}")
E175 = ("Can't remove rule for unknown match pattern ID: {key}")
E176 = ("Alias '{alias}' is not defined in the Knowledge Base.")
E177 = ("Ill-formed IOB input detected: {tag}")
E178 = ("Invalid pattern. Expected list of dicts but got: {pat}. Maybe you "
E178 = ("Each pattern should be a list of dicts, but got: {pat}. Maybe you "
"accidentally passed a single pattern to Matcher.add instead of a "
"list of patterns? If you only want to add one pattern, make sure "
"to wrap it in a list. For example: matcher.add('{key}', [pattern])")
@ -483,6 +483,10 @@ class Errors:
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
# TODO: fix numbering after merging develop into master
E947 = ("Matcher.add received invalid 'greedy' argument: expected "
"a string value from {expected} but got: '{arg}'")
E948 = ("Matcher.add received invalid 'patterns' argument: expected "
"a List, but got: {arg_type}")
E952 = ("The section '{name}' is not a valid section in the provided config.")
E953 = ("Mismatched IDs received by the Tok2Vec listener: {id1} vs. {id2}")
E954 = ("The Tok2Vec listener did not receive a valid input.")

View File

@ -14,6 +14,7 @@ from thinc.api import get_current_ops, Config, require_gpu, Optimizer
import srsly
import multiprocessing as mp
from itertools import chain, cycle
from timeit import default_timer as timer
from .tokens.underscore import Underscore
from .vocab import Vocab, create_vocab
@ -1130,7 +1131,14 @@ class Language:
kwargs.setdefault("verbose", verbose)
kwargs.setdefault("nlp", self)
scorer = Scorer(**kwargs)
docs = list(eg.predicted for eg in examples)
texts = [eg.reference.text for eg in examples]
docs = [eg.predicted for eg in examples]
start_time = timer()
# tokenize the texts only for timing purposes
if not hasattr(self.tokenizer, "pipe"):
_ = [self.tokenizer(text) for text in texts]
else:
_ = list(self.tokenizer.pipe(texts))
for name, pipe in self.pipeline:
kwargs = component_cfg.get(name, {})
kwargs.setdefault("batch_size", batch_size)
@ -1138,11 +1146,18 @@ class Language:
docs = _pipe(docs, pipe, kwargs)
else:
docs = pipe.pipe(docs, **kwargs)
# iterate over the final generator
if len(self.pipeline):
docs = list(docs)
end_time = timer()
for i, (doc, eg) in enumerate(zip(docs, examples)):
if verbose:
print(doc)
eg.predicted = doc
return scorer.score(examples)
results = scorer.score(examples)
n_words = sum(len(eg.predicted) for eg in examples)
results["speed"] = n_words / (end_time - start_time)
return results
@contextmanager
def use_params(self, params: dict):

View File

@ -66,6 +66,7 @@ cdef class Matcher:
cdef public object validate
cdef public object _patterns
cdef public object _callbacks
cdef public object _filter
cdef public object _extensions
cdef public object _extra_predicates
cdef public object _seen_attrs

View File

@ -1,6 +1,9 @@
# cython: infer_types=True, cython: profile=True
from typing import List
from libcpp.vector cimport vector
from libc.stdint cimport int32_t
from libc.string cimport memset, memcmp
from cymem.cymem cimport Pool
from murmurhash.mrmr cimport hash64
@ -41,6 +44,7 @@ cdef class Matcher:
self._extra_predicates = []
self._patterns = {}
self._callbacks = {}
self._filter = {}
self._extensions = {}
self._seen_attrs = set()
self.vocab = vocab
@ -68,7 +72,7 @@ cdef class Matcher:
"""
return self._normalize_key(key) in self._patterns
def add(self, key, patterns, *_patterns, on_match=None):
def add(self, key, patterns, *, on_match=None, greedy: str=None):
"""Add a match-rule to the matcher. A match-rule consists of: an ID
key, an on_match callback, and one or more patterns.
@ -86,11 +90,10 @@ cdef class Matcher:
'+': Require the pattern to match 1 or more times.
'*': Allow the pattern to zero or more times.
The + and * operators are usually interpretted "greedily", i.e. longer
matches are returned where possible. However, if you specify two '+'
and '*' patterns in a row and their matches overlap, the first
operator will behave non-greedily. This quirk in the semantics makes
the matcher more efficient, by avoiding the need for back-tracking.
The + and * operators return all possible matches (not just the greedy
ones). However, the "greedy" argument can filter the final matches
by returning a non-overlapping set per key, either taking preference to
the first greedy match ("FIRST"), or the longest ("LONGEST").
As of spaCy v2.2.2, Matcher.add supports the future API, which makes
the patterns the second argument and a list (instead of a variable
@ -100,16 +103,15 @@ cdef class Matcher:
key (str): The match ID.
patterns (list): The patterns to add for the given key.
on_match (callable): Optional callback executed on match.
*_patterns (list): For backwards compatibility: list of patterns to add
as variable arguments. Will be ignored if a list of patterns is
provided as the second argument.
greedy (str): Optional filter: "FIRST" or "LONGEST".
"""
errors = {}
if on_match is not None and not hasattr(on_match, "__call__"):
raise ValueError(Errors.E171.format(arg_type=type(on_match)))
if patterns is None or hasattr(patterns, "__call__"): # old API
on_match = patterns
patterns = _patterns
if patterns is None or not isinstance(patterns, List): # old API
raise ValueError(Errors.E948.format(arg_type=type(patterns)))
if greedy is not None and greedy not in ["FIRST", "LONGEST"]:
raise ValueError(Errors.E947.format(expected=["FIRST", "LONGEST"], arg=greedy))
for i, pattern in enumerate(patterns):
if len(pattern) == 0:
raise ValueError(Errors.E012.format(key=key))
@ -132,6 +134,7 @@ cdef class Matcher:
raise ValueError(Errors.E154.format())
self._patterns.setdefault(key, [])
self._callbacks[key] = on_match
self._filter[key] = greedy
self._patterns[key].extend(patterns)
def remove(self, key):
@ -217,6 +220,7 @@ cdef class Matcher:
length = doclike.end - doclike.start
else:
raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doclike).__name__))
cdef Pool tmp_pool = Pool()
if len(set([LEMMA, POS, TAG]) & self._seen_attrs) > 0 \
and not doc.is_tagged:
raise ValueError(Errors.E155.format())
@ -224,11 +228,42 @@ cdef class Matcher:
raise ValueError(Errors.E156.format())
matches = find_matches(&self.patterns[0], self.patterns.size(), doclike, length,
extensions=self._extensions, predicates=self._extra_predicates)
for i, (key, start, end) in enumerate(matches):
final_matches = []
pairs_by_id = {}
# For each key, either add all matches, or only the filtered, non-overlapping ones
for (key, start, end) in matches:
span_filter = self._filter.get(key)
if span_filter is not None:
pairs = pairs_by_id.get(key, [])
pairs.append((start,end))
pairs_by_id[key] = pairs
else:
final_matches.append((key, start, end))
matched = <char*>tmp_pool.alloc(length, sizeof(char))
empty = <char*>tmp_pool.alloc(length, sizeof(char))
for key, pairs in pairs_by_id.items():
memset(matched, 0, length * sizeof(matched[0]))
span_filter = self._filter.get(key)
if span_filter == "FIRST":
sorted_pairs = sorted(pairs, key=lambda x: (x[0], -x[1]), reverse=False) # sort by start
elif span_filter == "LONGEST":
sorted_pairs = sorted(pairs, key=lambda x: (x[1]-x[0], -x[0]), reverse=True) # reverse sort by length
else:
raise ValueError(Errors.E947.format(expected=["FIRST", "LONGEST"], arg=span_filter))
for (start, end) in sorted_pairs:
assert 0 <= start < end # Defend against segfaults
span_len = end-start
# If no tokens in the span have matched
if memcmp(&matched[start], &empty[start], span_len * sizeof(matched[0])) == 0:
final_matches.append((key, start, end))
# Mark tokens that have matched
memset(&matched[start], 1, span_len * sizeof(matched[0]))
# perform the callbacks on the filtered set of results
for i, (key, start, end) in enumerate(final_matches):
on_match = self._callbacks.get(key, None)
if on_match is not None:
on_match(self, doc, i, matches)
return matches
on_match(self, doc, i, final_matches)
return final_matches
def _normalize_key(self, key):
if isinstance(key, basestring):
@ -239,9 +274,9 @@ cdef class Matcher:
def unpickle_matcher(vocab, patterns, callbacks):
matcher = Matcher(vocab)
for key, specs in patterns.items():
for key, pattern in patterns.items():
callback = callbacks.get(key, None)
matcher.add(key, callback, *specs)
matcher.add(key, pattern, on_match=callback)
return matcher

View File

@ -58,7 +58,7 @@ def merge_subtokens(doc: Doc, label: str = "subtok") -> Doc:
"""
# TODO: make stateful component with "label" config
merger = Matcher(doc.vocab)
merger.add("SUBTOK", None, [{"DEP": label, "op": "+"}])
merger.add("SUBTOK", [[{"DEP": label, "op": "+"}]])
matches = merger(doc)
spans = filter_spans([doc[start : end + 1] for _, start, end in matches])
with doc.retokenize() as retokenizer:

View File

@ -63,18 +63,11 @@ def test_matcher_len_contains(matcher):
assert "TEST2" not in matcher
def test_matcher_add_new_old_api(en_vocab):
def test_matcher_add_new_api(en_vocab):
doc = Doc(en_vocab, words=["a", "b"])
patterns = [[{"TEXT": "a"}], [{"TEXT": "a"}, {"TEXT": "b"}]]
matcher = Matcher(en_vocab)
matcher.add("OLD_API", None, *patterns)
assert len(matcher(doc)) == 2
matcher = Matcher(en_vocab)
on_match = Mock()
matcher.add("OLD_API_CALLBACK", on_match, *patterns)
assert len(matcher(doc)) == 2
assert on_match.call_count == 2
# New API: add(key: str, patterns: List[List[dict]], on_match: Callable)
matcher = Matcher(en_vocab)
matcher.add("NEW_API", patterns)
assert len(matcher(doc)) == 2
@ -176,7 +169,7 @@ def test_matcher_match_zero_plus(matcher):
def test_matcher_match_one_plus(matcher):
control = Matcher(matcher.vocab)
control.add("BasicPhilippe", None, [{"ORTH": "Philippe"}])
control.add("BasicPhilippe", [[{"ORTH": "Philippe"}]])
doc = Doc(control.vocab, words=["Philippe", "Philippe"])
m = control(doc)
assert len(m) == 2

View File

@ -7,18 +7,10 @@ from spacy.tokens import Doc, Span
pattern1 = [{"ORTH": "A"}, {"ORTH": "A", "OP": "*"}]
pattern2 = [{"ORTH": "A"}, {"ORTH": "A"}]
pattern2 = [{"ORTH": "A", "OP": "*"}, {"ORTH": "A"}]
pattern3 = [{"ORTH": "A"}, {"ORTH": "A"}]
pattern4 = [
{"ORTH": "B"},
{"ORTH": "A", "OP": "*"},
{"ORTH": "B"},
]
pattern5 = [
{"ORTH": "B", "OP": "*"},
{"ORTH": "A", "OP": "*"},
{"ORTH": "B"},
]
pattern4 = [{"ORTH": "B"}, {"ORTH": "A", "OP": "*"}, {"ORTH": "B"}]
pattern5 = [{"ORTH": "B", "OP": "*"}, {"ORTH": "A", "OP": "*"}, {"ORTH": "B"}]
re_pattern1 = "AA*"
re_pattern2 = "A*A"
@ -26,10 +18,16 @@ re_pattern3 = "AA"
re_pattern4 = "BA*B"
re_pattern5 = "B*A*B"
longest1 = "A A A A A"
longest2 = "A A A A A"
longest3 = "A A"
longest4 = "B A A A A A B" # "FIRST" would be "B B"
longest5 = "B B A A A A A B"
@pytest.fixture
def text():
return "(ABBAAAAAB)."
return "(BBAAAAAB)."
@pytest.fixture
@ -41,25 +39,63 @@ def doc(en_tokenizer, text):
@pytest.mark.parametrize(
"pattern,re_pattern",
[
pytest.param(pattern1, re_pattern1, marks=pytest.mark.xfail()),
pytest.param(pattern2, re_pattern2, marks=pytest.mark.xfail()),
pytest.param(pattern3, re_pattern3, marks=pytest.mark.xfail()),
(pattern1, re_pattern1),
(pattern2, re_pattern2),
(pattern3, re_pattern3),
(pattern4, re_pattern4),
pytest.param(pattern5, re_pattern5, marks=pytest.mark.xfail()),
(pattern5, re_pattern5),
],
)
def test_greedy_matching(doc, text, pattern, re_pattern):
"""Test that the greedy matching behavior of the * op is consistant with
def test_greedy_matching_first(doc, text, pattern, re_pattern):
"""Test that the greedy matching behavior "FIRST" is consistent with
other re implementations."""
matcher = Matcher(doc.vocab)
matcher.add(re_pattern, [pattern])
matcher.add(re_pattern, [pattern], greedy="FIRST")
matches = matcher(doc)
re_matches = [m.span() for m in re.finditer(re_pattern, text)]
for match, re_match in zip(matches, re_matches):
assert match[1:] == re_match
for (key, m_s, m_e), (re_s, re_e) in zip(matches, re_matches):
# matching the string, not the exact position
assert doc[m_s:m_e].text == doc[re_s:re_e].text
@pytest.mark.parametrize(
"pattern,longest",
[
(pattern1, longest1),
(pattern2, longest2),
(pattern3, longest3),
(pattern4, longest4),
(pattern5, longest5),
],
)
def test_greedy_matching_longest(doc, text, pattern, longest):
"""Test the "LONGEST" greedy matching behavior"""
matcher = Matcher(doc.vocab)
matcher.add("RULE", [pattern], greedy="LONGEST")
matches = matcher(doc)
for (key, s, e) in matches:
assert doc[s:e].text == longest
def test_greedy_matching_longest_first(en_tokenizer):
"""Test that "LONGEST" matching prefers the first of two equally long matches"""
doc = en_tokenizer(" ".join("CCC"))
matcher = Matcher(doc.vocab)
pattern = [{"ORTH": "C"}, {"ORTH": "C"}]
matcher.add("RULE", [pattern], greedy="LONGEST")
matches = matcher(doc)
# out of 0-2 and 1-3, the first should be picked
assert len(matches) == 1
assert matches[0][1] == 0
assert matches[0][2] == 2
def test_invalid_greediness(doc, text):
matcher = Matcher(doc.vocab)
with pytest.raises(ValueError):
matcher.add("RULE", [pattern1], greedy="GREEDY")
@pytest.mark.xfail
@pytest.mark.parametrize(
"pattern,re_pattern",
[
@ -74,7 +110,7 @@ def test_match_consuming(doc, text, pattern, re_pattern):
"""Test that matcher.__call__ consumes tokens on a match similar to
re.findall."""
matcher = Matcher(doc.vocab)
matcher.add(re_pattern, [pattern])
matcher.add(re_pattern, [pattern], greedy="FIRST")
matches = matcher(doc)
re_matches = [m.span() for m in re.finditer(re_pattern, text)]
assert len(matches) == len(re_matches)