mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Merge branch 'develop' of https://github.com/explosion/spaCy into develop
This commit is contained in:
commit
e257e66ab9
|
@ -67,10 +67,7 @@ def evaluate(
|
||||||
corpus = Corpus(data_path, data_path)
|
corpus = Corpus(data_path, data_path)
|
||||||
nlp = util.load_model(model)
|
nlp = util.load_model(model)
|
||||||
dev_dataset = list(corpus.dev_dataset(nlp, gold_preproc=gold_preproc))
|
dev_dataset = list(corpus.dev_dataset(nlp, gold_preproc=gold_preproc))
|
||||||
begin = timer()
|
|
||||||
scores = nlp.evaluate(dev_dataset, verbose=False)
|
scores = nlp.evaluate(dev_dataset, verbose=False)
|
||||||
end = timer()
|
|
||||||
nwords = sum(len(ex.predicted) for ex in dev_dataset)
|
|
||||||
metrics = {
|
metrics = {
|
||||||
"TOK": "token_acc",
|
"TOK": "token_acc",
|
||||||
"TAG": "tag_acc",
|
"TAG": "tag_acc",
|
||||||
|
@ -82,16 +79,20 @@ def evaluate(
|
||||||
"NER P": "ents_p",
|
"NER P": "ents_p",
|
||||||
"NER R": "ents_r",
|
"NER R": "ents_r",
|
||||||
"NER F": "ents_f",
|
"NER F": "ents_f",
|
||||||
"Textcat": "cats_score",
|
"TEXTCAT": "cats_score",
|
||||||
"Sent P": "sents_p",
|
"SENT P": "sents_p",
|
||||||
"Sent R": "sents_r",
|
"SENT R": "sents_r",
|
||||||
"Sent F": "sents_f",
|
"SENT F": "sents_f",
|
||||||
|
"SPEED": "speed",
|
||||||
}
|
}
|
||||||
results = {}
|
results = {}
|
||||||
for metric, key in metrics.items():
|
for metric, key in metrics.items():
|
||||||
if key in scores:
|
if key in scores:
|
||||||
if key == "cats_score":
|
if key == "cats_score":
|
||||||
metric = metric + " (" + scores.get("cats_score_desc", "unk") + ")"
|
metric = metric + " (" + scores.get("cats_score_desc", "unk") + ")"
|
||||||
|
if key == "speed":
|
||||||
|
results[metric] = f"{scores[key]:.0f}"
|
||||||
|
else:
|
||||||
results[metric] = f"{scores[key]*100:.2f}"
|
results[metric] = f"{scores[key]*100:.2f}"
|
||||||
data = {re.sub(r"[\s/]", "_", k.lower()): v for k, v in results.items()}
|
data = {re.sub(r"[\s/]", "_", k.lower()): v for k, v in results.items()}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from typing import Optional, Dict, Any, Tuple, Union, Callable, List
|
from typing import Optional, Dict, Any, Tuple, Union, Callable, List
|
||||||
from timeit import default_timer as timer
|
|
||||||
import srsly
|
import srsly
|
||||||
import tqdm
|
import tqdm
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -248,14 +247,11 @@ def create_evaluation_callback(
|
||||||
dev_examples = list(dev_examples)
|
dev_examples = list(dev_examples)
|
||||||
n_words = sum(len(ex.predicted) for ex in dev_examples)
|
n_words = sum(len(ex.predicted) for ex in dev_examples)
|
||||||
batch_size = cfg["eval_batch_size"]
|
batch_size = cfg["eval_batch_size"]
|
||||||
start_time = timer()
|
|
||||||
if optimizer.averages:
|
if optimizer.averages:
|
||||||
with nlp.use_params(optimizer.averages):
|
with nlp.use_params(optimizer.averages):
|
||||||
scores = nlp.evaluate(dev_examples, batch_size=batch_size)
|
scores = nlp.evaluate(dev_examples, batch_size=batch_size)
|
||||||
else:
|
else:
|
||||||
scores = nlp.evaluate(dev_examples, batch_size=batch_size)
|
scores = nlp.evaluate(dev_examples, batch_size=batch_size)
|
||||||
end_time = timer()
|
|
||||||
wps = n_words / (end_time - start_time)
|
|
||||||
# Calculate a weighted sum based on score_weights for the main score
|
# Calculate a weighted sum based on score_weights for the main score
|
||||||
weights = cfg["score_weights"]
|
weights = cfg["score_weights"]
|
||||||
try:
|
try:
|
||||||
|
@ -264,7 +260,6 @@ def create_evaluation_callback(
|
||||||
keys = list(scores.keys())
|
keys = list(scores.keys())
|
||||||
err = Errors.E983.format(dict="score_weights", key=str(e), keys=keys)
|
err = Errors.E983.format(dict="score_weights", key=str(e), keys=keys)
|
||||||
raise KeyError(err)
|
raise KeyError(err)
|
||||||
scores["speed"] = wps
|
|
||||||
return weighted_score, scores
|
return weighted_score, scores
|
||||||
|
|
||||||
return evaluate
|
return evaluate
|
||||||
|
@ -446,7 +441,7 @@ def update_meta(
|
||||||
training: Union[Dict[str, Any], Config], nlp: Language, info: Dict[str, Any]
|
training: Union[Dict[str, Any], Config], nlp: Language, info: Dict[str, Any]
|
||||||
) -> None:
|
) -> None:
|
||||||
nlp.meta["performance"] = {}
|
nlp.meta["performance"] = {}
|
||||||
for metric in training["scores_weights"]:
|
for metric in training["score_weights"]:
|
||||||
nlp.meta["performance"][metric] = info["other_scores"][metric]
|
nlp.meta["performance"][metric] = info["other_scores"][metric]
|
||||||
for pipe_name in nlp.pipe_names:
|
for pipe_name in nlp.pipe_names:
|
||||||
nlp.meta["performance"][f"{pipe_name}_loss"] = info["losses"][pipe_name]
|
nlp.meta["performance"][f"{pipe_name}_loss"] = info["losses"][pipe_name]
|
||||||
|
|
|
@ -432,12 +432,12 @@ class Errors:
|
||||||
"Current DocBin: {current}\nOther DocBin: {other}")
|
"Current DocBin: {current}\nOther DocBin: {other}")
|
||||||
E169 = ("Can't find module: {module}")
|
E169 = ("Can't find module: {module}")
|
||||||
E170 = ("Cannot apply transition {name}: invalid for the current state.")
|
E170 = ("Cannot apply transition {name}: invalid for the current state.")
|
||||||
E171 = ("Matcher.add received invalid on_match callback argument: expected "
|
E171 = ("Matcher.add received invalid 'on_match' callback argument: expected "
|
||||||
"callable or None, but got: {arg_type}")
|
"callable or None, but got: {arg_type}")
|
||||||
E175 = ("Can't remove rule for unknown match pattern ID: {key}")
|
E175 = ("Can't remove rule for unknown match pattern ID: {key}")
|
||||||
E176 = ("Alias '{alias}' is not defined in the Knowledge Base.")
|
E176 = ("Alias '{alias}' is not defined in the Knowledge Base.")
|
||||||
E177 = ("Ill-formed IOB input detected: {tag}")
|
E177 = ("Ill-formed IOB input detected: {tag}")
|
||||||
E178 = ("Invalid pattern. Expected list of dicts but got: {pat}. Maybe you "
|
E178 = ("Each pattern should be a list of dicts, but got: {pat}. Maybe you "
|
||||||
"accidentally passed a single pattern to Matcher.add instead of a "
|
"accidentally passed a single pattern to Matcher.add instead of a "
|
||||||
"list of patterns? If you only want to add one pattern, make sure "
|
"list of patterns? If you only want to add one pattern, make sure "
|
||||||
"to wrap it in a list. For example: matcher.add('{key}', [pattern])")
|
"to wrap it in a list. For example: matcher.add('{key}', [pattern])")
|
||||||
|
@ -483,6 +483,10 @@ class Errors:
|
||||||
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
|
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
|
||||||
|
|
||||||
# TODO: fix numbering after merging develop into master
|
# TODO: fix numbering after merging develop into master
|
||||||
|
E947 = ("Matcher.add received invalid 'greedy' argument: expected "
|
||||||
|
"a string value from {expected} but got: '{arg}'")
|
||||||
|
E948 = ("Matcher.add received invalid 'patterns' argument: expected "
|
||||||
|
"a List, but got: {arg_type}")
|
||||||
E952 = ("The section '{name}' is not a valid section in the provided config.")
|
E952 = ("The section '{name}' is not a valid section in the provided config.")
|
||||||
E953 = ("Mismatched IDs received by the Tok2Vec listener: {id1} vs. {id2}")
|
E953 = ("Mismatched IDs received by the Tok2Vec listener: {id1} vs. {id2}")
|
||||||
E954 = ("The Tok2Vec listener did not receive a valid input.")
|
E954 = ("The Tok2Vec listener did not receive a valid input.")
|
||||||
|
|
|
@ -14,6 +14,7 @@ from thinc.api import get_current_ops, Config, require_gpu, Optimizer
|
||||||
import srsly
|
import srsly
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
from itertools import chain, cycle
|
from itertools import chain, cycle
|
||||||
|
from timeit import default_timer as timer
|
||||||
|
|
||||||
from .tokens.underscore import Underscore
|
from .tokens.underscore import Underscore
|
||||||
from .vocab import Vocab, create_vocab
|
from .vocab import Vocab, create_vocab
|
||||||
|
@ -1130,7 +1131,14 @@ class Language:
|
||||||
kwargs.setdefault("verbose", verbose)
|
kwargs.setdefault("verbose", verbose)
|
||||||
kwargs.setdefault("nlp", self)
|
kwargs.setdefault("nlp", self)
|
||||||
scorer = Scorer(**kwargs)
|
scorer = Scorer(**kwargs)
|
||||||
docs = list(eg.predicted for eg in examples)
|
texts = [eg.reference.text for eg in examples]
|
||||||
|
docs = [eg.predicted for eg in examples]
|
||||||
|
start_time = timer()
|
||||||
|
# tokenize the texts only for timing purposes
|
||||||
|
if not hasattr(self.tokenizer, "pipe"):
|
||||||
|
_ = [self.tokenizer(text) for text in texts]
|
||||||
|
else:
|
||||||
|
_ = list(self.tokenizer.pipe(texts))
|
||||||
for name, pipe in self.pipeline:
|
for name, pipe in self.pipeline:
|
||||||
kwargs = component_cfg.get(name, {})
|
kwargs = component_cfg.get(name, {})
|
||||||
kwargs.setdefault("batch_size", batch_size)
|
kwargs.setdefault("batch_size", batch_size)
|
||||||
|
@ -1138,11 +1146,18 @@ class Language:
|
||||||
docs = _pipe(docs, pipe, kwargs)
|
docs = _pipe(docs, pipe, kwargs)
|
||||||
else:
|
else:
|
||||||
docs = pipe.pipe(docs, **kwargs)
|
docs = pipe.pipe(docs, **kwargs)
|
||||||
|
# iterate over the final generator
|
||||||
|
if len(self.pipeline):
|
||||||
|
docs = list(docs)
|
||||||
|
end_time = timer()
|
||||||
for i, (doc, eg) in enumerate(zip(docs, examples)):
|
for i, (doc, eg) in enumerate(zip(docs, examples)):
|
||||||
if verbose:
|
if verbose:
|
||||||
print(doc)
|
print(doc)
|
||||||
eg.predicted = doc
|
eg.predicted = doc
|
||||||
return scorer.score(examples)
|
results = scorer.score(examples)
|
||||||
|
n_words = sum(len(eg.predicted) for eg in examples)
|
||||||
|
results["speed"] = n_words / (end_time - start_time)
|
||||||
|
return results
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def use_params(self, params: dict):
|
def use_params(self, params: dict):
|
||||||
|
|
|
@ -66,6 +66,7 @@ cdef class Matcher:
|
||||||
cdef public object validate
|
cdef public object validate
|
||||||
cdef public object _patterns
|
cdef public object _patterns
|
||||||
cdef public object _callbacks
|
cdef public object _callbacks
|
||||||
|
cdef public object _filter
|
||||||
cdef public object _extensions
|
cdef public object _extensions
|
||||||
cdef public object _extra_predicates
|
cdef public object _extra_predicates
|
||||||
cdef public object _seen_attrs
|
cdef public object _seen_attrs
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
# cython: infer_types=True, cython: profile=True
|
# cython: infer_types=True, cython: profile=True
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from libcpp.vector cimport vector
|
from libcpp.vector cimport vector
|
||||||
from libc.stdint cimport int32_t
|
from libc.stdint cimport int32_t
|
||||||
|
from libc.string cimport memset, memcmp
|
||||||
from cymem.cymem cimport Pool
|
from cymem.cymem cimport Pool
|
||||||
from murmurhash.mrmr cimport hash64
|
from murmurhash.mrmr cimport hash64
|
||||||
|
|
||||||
|
@ -41,6 +44,7 @@ cdef class Matcher:
|
||||||
self._extra_predicates = []
|
self._extra_predicates = []
|
||||||
self._patterns = {}
|
self._patterns = {}
|
||||||
self._callbacks = {}
|
self._callbacks = {}
|
||||||
|
self._filter = {}
|
||||||
self._extensions = {}
|
self._extensions = {}
|
||||||
self._seen_attrs = set()
|
self._seen_attrs = set()
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
|
@ -68,7 +72,7 @@ cdef class Matcher:
|
||||||
"""
|
"""
|
||||||
return self._normalize_key(key) in self._patterns
|
return self._normalize_key(key) in self._patterns
|
||||||
|
|
||||||
def add(self, key, patterns, *_patterns, on_match=None):
|
def add(self, key, patterns, *, on_match=None, greedy: str=None):
|
||||||
"""Add a match-rule to the matcher. A match-rule consists of: an ID
|
"""Add a match-rule to the matcher. A match-rule consists of: an ID
|
||||||
key, an on_match callback, and one or more patterns.
|
key, an on_match callback, and one or more patterns.
|
||||||
|
|
||||||
|
@ -86,11 +90,10 @@ cdef class Matcher:
|
||||||
'+': Require the pattern to match 1 or more times.
|
'+': Require the pattern to match 1 or more times.
|
||||||
'*': Allow the pattern to zero or more times.
|
'*': Allow the pattern to zero or more times.
|
||||||
|
|
||||||
The + and * operators are usually interpretted "greedily", i.e. longer
|
The + and * operators return all possible matches (not just the greedy
|
||||||
matches are returned where possible. However, if you specify two '+'
|
ones). However, the "greedy" argument can filter the final matches
|
||||||
and '*' patterns in a row and their matches overlap, the first
|
by returning a non-overlapping set per key, either taking preference to
|
||||||
operator will behave non-greedily. This quirk in the semantics makes
|
the first greedy match ("FIRST"), or the longest ("LONGEST").
|
||||||
the matcher more efficient, by avoiding the need for back-tracking.
|
|
||||||
|
|
||||||
As of spaCy v2.2.2, Matcher.add supports the future API, which makes
|
As of spaCy v2.2.2, Matcher.add supports the future API, which makes
|
||||||
the patterns the second argument and a list (instead of a variable
|
the patterns the second argument and a list (instead of a variable
|
||||||
|
@ -100,16 +103,15 @@ cdef class Matcher:
|
||||||
key (str): The match ID.
|
key (str): The match ID.
|
||||||
patterns (list): The patterns to add for the given key.
|
patterns (list): The patterns to add for the given key.
|
||||||
on_match (callable): Optional callback executed on match.
|
on_match (callable): Optional callback executed on match.
|
||||||
*_patterns (list): For backwards compatibility: list of patterns to add
|
greedy (str): Optional filter: "FIRST" or "LONGEST".
|
||||||
as variable arguments. Will be ignored if a list of patterns is
|
|
||||||
provided as the second argument.
|
|
||||||
"""
|
"""
|
||||||
errors = {}
|
errors = {}
|
||||||
if on_match is not None and not hasattr(on_match, "__call__"):
|
if on_match is not None and not hasattr(on_match, "__call__"):
|
||||||
raise ValueError(Errors.E171.format(arg_type=type(on_match)))
|
raise ValueError(Errors.E171.format(arg_type=type(on_match)))
|
||||||
if patterns is None or hasattr(patterns, "__call__"): # old API
|
if patterns is None or not isinstance(patterns, List): # old API
|
||||||
on_match = patterns
|
raise ValueError(Errors.E948.format(arg_type=type(patterns)))
|
||||||
patterns = _patterns
|
if greedy is not None and greedy not in ["FIRST", "LONGEST"]:
|
||||||
|
raise ValueError(Errors.E947.format(expected=["FIRST", "LONGEST"], arg=greedy))
|
||||||
for i, pattern in enumerate(patterns):
|
for i, pattern in enumerate(patterns):
|
||||||
if len(pattern) == 0:
|
if len(pattern) == 0:
|
||||||
raise ValueError(Errors.E012.format(key=key))
|
raise ValueError(Errors.E012.format(key=key))
|
||||||
|
@ -132,6 +134,7 @@ cdef class Matcher:
|
||||||
raise ValueError(Errors.E154.format())
|
raise ValueError(Errors.E154.format())
|
||||||
self._patterns.setdefault(key, [])
|
self._patterns.setdefault(key, [])
|
||||||
self._callbacks[key] = on_match
|
self._callbacks[key] = on_match
|
||||||
|
self._filter[key] = greedy
|
||||||
self._patterns[key].extend(patterns)
|
self._patterns[key].extend(patterns)
|
||||||
|
|
||||||
def remove(self, key):
|
def remove(self, key):
|
||||||
|
@ -217,6 +220,7 @@ cdef class Matcher:
|
||||||
length = doclike.end - doclike.start
|
length = doclike.end - doclike.start
|
||||||
else:
|
else:
|
||||||
raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doclike).__name__))
|
raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doclike).__name__))
|
||||||
|
cdef Pool tmp_pool = Pool()
|
||||||
if len(set([LEMMA, POS, TAG]) & self._seen_attrs) > 0 \
|
if len(set([LEMMA, POS, TAG]) & self._seen_attrs) > 0 \
|
||||||
and not doc.is_tagged:
|
and not doc.is_tagged:
|
||||||
raise ValueError(Errors.E155.format())
|
raise ValueError(Errors.E155.format())
|
||||||
|
@ -224,11 +228,42 @@ cdef class Matcher:
|
||||||
raise ValueError(Errors.E156.format())
|
raise ValueError(Errors.E156.format())
|
||||||
matches = find_matches(&self.patterns[0], self.patterns.size(), doclike, length,
|
matches = find_matches(&self.patterns[0], self.patterns.size(), doclike, length,
|
||||||
extensions=self._extensions, predicates=self._extra_predicates)
|
extensions=self._extensions, predicates=self._extra_predicates)
|
||||||
for i, (key, start, end) in enumerate(matches):
|
final_matches = []
|
||||||
|
pairs_by_id = {}
|
||||||
|
# For each key, either add all matches, or only the filtered, non-overlapping ones
|
||||||
|
for (key, start, end) in matches:
|
||||||
|
span_filter = self._filter.get(key)
|
||||||
|
if span_filter is not None:
|
||||||
|
pairs = pairs_by_id.get(key, [])
|
||||||
|
pairs.append((start,end))
|
||||||
|
pairs_by_id[key] = pairs
|
||||||
|
else:
|
||||||
|
final_matches.append((key, start, end))
|
||||||
|
matched = <char*>tmp_pool.alloc(length, sizeof(char))
|
||||||
|
empty = <char*>tmp_pool.alloc(length, sizeof(char))
|
||||||
|
for key, pairs in pairs_by_id.items():
|
||||||
|
memset(matched, 0, length * sizeof(matched[0]))
|
||||||
|
span_filter = self._filter.get(key)
|
||||||
|
if span_filter == "FIRST":
|
||||||
|
sorted_pairs = sorted(pairs, key=lambda x: (x[0], -x[1]), reverse=False) # sort by start
|
||||||
|
elif span_filter == "LONGEST":
|
||||||
|
sorted_pairs = sorted(pairs, key=lambda x: (x[1]-x[0], -x[0]), reverse=True) # reverse sort by length
|
||||||
|
else:
|
||||||
|
raise ValueError(Errors.E947.format(expected=["FIRST", "LONGEST"], arg=span_filter))
|
||||||
|
for (start, end) in sorted_pairs:
|
||||||
|
assert 0 <= start < end # Defend against segfaults
|
||||||
|
span_len = end-start
|
||||||
|
# If no tokens in the span have matched
|
||||||
|
if memcmp(&matched[start], &empty[start], span_len * sizeof(matched[0])) == 0:
|
||||||
|
final_matches.append((key, start, end))
|
||||||
|
# Mark tokens that have matched
|
||||||
|
memset(&matched[start], 1, span_len * sizeof(matched[0]))
|
||||||
|
# perform the callbacks on the filtered set of results
|
||||||
|
for i, (key, start, end) in enumerate(final_matches):
|
||||||
on_match = self._callbacks.get(key, None)
|
on_match = self._callbacks.get(key, None)
|
||||||
if on_match is not None:
|
if on_match is not None:
|
||||||
on_match(self, doc, i, matches)
|
on_match(self, doc, i, final_matches)
|
||||||
return matches
|
return final_matches
|
||||||
|
|
||||||
def _normalize_key(self, key):
|
def _normalize_key(self, key):
|
||||||
if isinstance(key, basestring):
|
if isinstance(key, basestring):
|
||||||
|
@ -239,9 +274,9 @@ cdef class Matcher:
|
||||||
|
|
||||||
def unpickle_matcher(vocab, patterns, callbacks):
|
def unpickle_matcher(vocab, patterns, callbacks):
|
||||||
matcher = Matcher(vocab)
|
matcher = Matcher(vocab)
|
||||||
for key, specs in patterns.items():
|
for key, pattern in patterns.items():
|
||||||
callback = callbacks.get(key, None)
|
callback = callbacks.get(key, None)
|
||||||
matcher.add(key, callback, *specs)
|
matcher.add(key, pattern, on_match=callback)
|
||||||
return matcher
|
return matcher
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -58,7 +58,7 @@ def merge_subtokens(doc: Doc, label: str = "subtok") -> Doc:
|
||||||
"""
|
"""
|
||||||
# TODO: make stateful component with "label" config
|
# TODO: make stateful component with "label" config
|
||||||
merger = Matcher(doc.vocab)
|
merger = Matcher(doc.vocab)
|
||||||
merger.add("SUBTOK", None, [{"DEP": label, "op": "+"}])
|
merger.add("SUBTOK", [[{"DEP": label, "op": "+"}]])
|
||||||
matches = merger(doc)
|
matches = merger(doc)
|
||||||
spans = filter_spans([doc[start : end + 1] for _, start, end in matches])
|
spans = filter_spans([doc[start : end + 1] for _, start, end in matches])
|
||||||
with doc.retokenize() as retokenizer:
|
with doc.retokenize() as retokenizer:
|
||||||
|
|
|
@ -63,18 +63,11 @@ def test_matcher_len_contains(matcher):
|
||||||
assert "TEST2" not in matcher
|
assert "TEST2" not in matcher
|
||||||
|
|
||||||
|
|
||||||
def test_matcher_add_new_old_api(en_vocab):
|
def test_matcher_add_new_api(en_vocab):
|
||||||
doc = Doc(en_vocab, words=["a", "b"])
|
doc = Doc(en_vocab, words=["a", "b"])
|
||||||
patterns = [[{"TEXT": "a"}], [{"TEXT": "a"}, {"TEXT": "b"}]]
|
patterns = [[{"TEXT": "a"}], [{"TEXT": "a"}, {"TEXT": "b"}]]
|
||||||
matcher = Matcher(en_vocab)
|
matcher = Matcher(en_vocab)
|
||||||
matcher.add("OLD_API", None, *patterns)
|
|
||||||
assert len(matcher(doc)) == 2
|
|
||||||
matcher = Matcher(en_vocab)
|
|
||||||
on_match = Mock()
|
on_match = Mock()
|
||||||
matcher.add("OLD_API_CALLBACK", on_match, *patterns)
|
|
||||||
assert len(matcher(doc)) == 2
|
|
||||||
assert on_match.call_count == 2
|
|
||||||
# New API: add(key: str, patterns: List[List[dict]], on_match: Callable)
|
|
||||||
matcher = Matcher(en_vocab)
|
matcher = Matcher(en_vocab)
|
||||||
matcher.add("NEW_API", patterns)
|
matcher.add("NEW_API", patterns)
|
||||||
assert len(matcher(doc)) == 2
|
assert len(matcher(doc)) == 2
|
||||||
|
@ -176,7 +169,7 @@ def test_matcher_match_zero_plus(matcher):
|
||||||
|
|
||||||
def test_matcher_match_one_plus(matcher):
|
def test_matcher_match_one_plus(matcher):
|
||||||
control = Matcher(matcher.vocab)
|
control = Matcher(matcher.vocab)
|
||||||
control.add("BasicPhilippe", None, [{"ORTH": "Philippe"}])
|
control.add("BasicPhilippe", [[{"ORTH": "Philippe"}]])
|
||||||
doc = Doc(control.vocab, words=["Philippe", "Philippe"])
|
doc = Doc(control.vocab, words=["Philippe", "Philippe"])
|
||||||
m = control(doc)
|
m = control(doc)
|
||||||
assert len(m) == 2
|
assert len(m) == 2
|
||||||
|
|
|
@ -7,18 +7,10 @@ from spacy.tokens import Doc, Span
|
||||||
|
|
||||||
|
|
||||||
pattern1 = [{"ORTH": "A"}, {"ORTH": "A", "OP": "*"}]
|
pattern1 = [{"ORTH": "A"}, {"ORTH": "A", "OP": "*"}]
|
||||||
pattern2 = [{"ORTH": "A"}, {"ORTH": "A"}]
|
pattern2 = [{"ORTH": "A", "OP": "*"}, {"ORTH": "A"}]
|
||||||
pattern3 = [{"ORTH": "A"}, {"ORTH": "A"}]
|
pattern3 = [{"ORTH": "A"}, {"ORTH": "A"}]
|
||||||
pattern4 = [
|
pattern4 = [{"ORTH": "B"}, {"ORTH": "A", "OP": "*"}, {"ORTH": "B"}]
|
||||||
{"ORTH": "B"},
|
pattern5 = [{"ORTH": "B", "OP": "*"}, {"ORTH": "A", "OP": "*"}, {"ORTH": "B"}]
|
||||||
{"ORTH": "A", "OP": "*"},
|
|
||||||
{"ORTH": "B"},
|
|
||||||
]
|
|
||||||
pattern5 = [
|
|
||||||
{"ORTH": "B", "OP": "*"},
|
|
||||||
{"ORTH": "A", "OP": "*"},
|
|
||||||
{"ORTH": "B"},
|
|
||||||
]
|
|
||||||
|
|
||||||
re_pattern1 = "AA*"
|
re_pattern1 = "AA*"
|
||||||
re_pattern2 = "A*A"
|
re_pattern2 = "A*A"
|
||||||
|
@ -26,10 +18,16 @@ re_pattern3 = "AA"
|
||||||
re_pattern4 = "BA*B"
|
re_pattern4 = "BA*B"
|
||||||
re_pattern5 = "B*A*B"
|
re_pattern5 = "B*A*B"
|
||||||
|
|
||||||
|
longest1 = "A A A A A"
|
||||||
|
longest2 = "A A A A A"
|
||||||
|
longest3 = "A A"
|
||||||
|
longest4 = "B A A A A A B" # "FIRST" would be "B B"
|
||||||
|
longest5 = "B B A A A A A B"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def text():
|
def text():
|
||||||
return "(ABBAAAAAB)."
|
return "(BBAAAAAB)."
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -41,25 +39,63 @@ def doc(en_tokenizer, text):
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"pattern,re_pattern",
|
"pattern,re_pattern",
|
||||||
[
|
[
|
||||||
pytest.param(pattern1, re_pattern1, marks=pytest.mark.xfail()),
|
(pattern1, re_pattern1),
|
||||||
pytest.param(pattern2, re_pattern2, marks=pytest.mark.xfail()),
|
(pattern2, re_pattern2),
|
||||||
pytest.param(pattern3, re_pattern3, marks=pytest.mark.xfail()),
|
(pattern3, re_pattern3),
|
||||||
(pattern4, re_pattern4),
|
(pattern4, re_pattern4),
|
||||||
pytest.param(pattern5, re_pattern5, marks=pytest.mark.xfail()),
|
(pattern5, re_pattern5),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_greedy_matching(doc, text, pattern, re_pattern):
|
def test_greedy_matching_first(doc, text, pattern, re_pattern):
|
||||||
"""Test that the greedy matching behavior of the * op is consistant with
|
"""Test that the greedy matching behavior "FIRST" is consistent with
|
||||||
other re implementations."""
|
other re implementations."""
|
||||||
matcher = Matcher(doc.vocab)
|
matcher = Matcher(doc.vocab)
|
||||||
matcher.add(re_pattern, [pattern])
|
matcher.add(re_pattern, [pattern], greedy="FIRST")
|
||||||
matches = matcher(doc)
|
matches = matcher(doc)
|
||||||
re_matches = [m.span() for m in re.finditer(re_pattern, text)]
|
re_matches = [m.span() for m in re.finditer(re_pattern, text)]
|
||||||
for match, re_match in zip(matches, re_matches):
|
for (key, m_s, m_e), (re_s, re_e) in zip(matches, re_matches):
|
||||||
assert match[1:] == re_match
|
# matching the string, not the exact position
|
||||||
|
assert doc[m_s:m_e].text == doc[re_s:re_e].text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"pattern,longest",
|
||||||
|
[
|
||||||
|
(pattern1, longest1),
|
||||||
|
(pattern2, longest2),
|
||||||
|
(pattern3, longest3),
|
||||||
|
(pattern4, longest4),
|
||||||
|
(pattern5, longest5),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_greedy_matching_longest(doc, text, pattern, longest):
|
||||||
|
"""Test the "LONGEST" greedy matching behavior"""
|
||||||
|
matcher = Matcher(doc.vocab)
|
||||||
|
matcher.add("RULE", [pattern], greedy="LONGEST")
|
||||||
|
matches = matcher(doc)
|
||||||
|
for (key, s, e) in matches:
|
||||||
|
assert doc[s:e].text == longest
|
||||||
|
|
||||||
|
|
||||||
|
def test_greedy_matching_longest_first(en_tokenizer):
|
||||||
|
"""Test that "LONGEST" matching prefers the first of two equally long matches"""
|
||||||
|
doc = en_tokenizer(" ".join("CCC"))
|
||||||
|
matcher = Matcher(doc.vocab)
|
||||||
|
pattern = [{"ORTH": "C"}, {"ORTH": "C"}]
|
||||||
|
matcher.add("RULE", [pattern], greedy="LONGEST")
|
||||||
|
matches = matcher(doc)
|
||||||
|
# out of 0-2 and 1-3, the first should be picked
|
||||||
|
assert len(matches) == 1
|
||||||
|
assert matches[0][1] == 0
|
||||||
|
assert matches[0][2] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_greediness(doc, text):
|
||||||
|
matcher = Matcher(doc.vocab)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
matcher.add("RULE", [pattern1], greedy="GREEDY")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"pattern,re_pattern",
|
"pattern,re_pattern",
|
||||||
[
|
[
|
||||||
|
@ -74,7 +110,7 @@ def test_match_consuming(doc, text, pattern, re_pattern):
|
||||||
"""Test that matcher.__call__ consumes tokens on a match similar to
|
"""Test that matcher.__call__ consumes tokens on a match similar to
|
||||||
re.findall."""
|
re.findall."""
|
||||||
matcher = Matcher(doc.vocab)
|
matcher = Matcher(doc.vocab)
|
||||||
matcher.add(re_pattern, [pattern])
|
matcher.add(re_pattern, [pattern], greedy="FIRST")
|
||||||
matches = matcher(doc)
|
matches = matcher(doc)
|
||||||
re_matches = [m.span() for m in re.finditer(re_pattern, text)]
|
re_matches = [m.span() for m in re.finditer(re_pattern, text)]
|
||||||
assert len(matches) == len(re_matches)
|
assert len(matches) == len(re_matches)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user