mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 15:37:29 +03:00 
			
		
		
		
	Merge branch 'develop' of https://github.com/explosion/spaCy into develop
This commit is contained in:
		
						commit
						e257e66ab9
					
				|  | @ -67,10 +67,7 @@ def evaluate( | ||||||
|     corpus = Corpus(data_path, data_path) |     corpus = Corpus(data_path, data_path) | ||||||
|     nlp = util.load_model(model) |     nlp = util.load_model(model) | ||||||
|     dev_dataset = list(corpus.dev_dataset(nlp, gold_preproc=gold_preproc)) |     dev_dataset = list(corpus.dev_dataset(nlp, gold_preproc=gold_preproc)) | ||||||
|     begin = timer() |  | ||||||
|     scores = nlp.evaluate(dev_dataset, verbose=False) |     scores = nlp.evaluate(dev_dataset, verbose=False) | ||||||
|     end = timer() |  | ||||||
|     nwords = sum(len(ex.predicted) for ex in dev_dataset) |  | ||||||
|     metrics = { |     metrics = { | ||||||
|         "TOK": "token_acc", |         "TOK": "token_acc", | ||||||
|         "TAG": "tag_acc", |         "TAG": "tag_acc", | ||||||
|  | @ -82,16 +79,20 @@ def evaluate( | ||||||
|         "NER P": "ents_p", |         "NER P": "ents_p", | ||||||
|         "NER R": "ents_r", |         "NER R": "ents_r", | ||||||
|         "NER F": "ents_f", |         "NER F": "ents_f", | ||||||
|         "Textcat": "cats_score", |         "TEXTCAT": "cats_score", | ||||||
|         "Sent P": "sents_p", |         "SENT P": "sents_p", | ||||||
|         "Sent R": "sents_r", |         "SENT R": "sents_r", | ||||||
|         "Sent F": "sents_f", |         "SENT F": "sents_f", | ||||||
|  |         "SPEED": "speed", | ||||||
|     } |     } | ||||||
|     results = {} |     results = {} | ||||||
|     for metric, key in metrics.items(): |     for metric, key in metrics.items(): | ||||||
|         if key in scores: |         if key in scores: | ||||||
|             if key == "cats_score": |             if key == "cats_score": | ||||||
|                 metric = metric + " (" + scores.get("cats_score_desc", "unk") + ")" |                 metric = metric + " (" + scores.get("cats_score_desc", "unk") + ")" | ||||||
|  |             if key == "speed": | ||||||
|  |                 results[metric] = f"{scores[key]:.0f}" | ||||||
|  |             else: | ||||||
|                 results[metric] = f"{scores[key]*100:.2f}" |                 results[metric] = f"{scores[key]*100:.2f}" | ||||||
|     data = {re.sub(r"[\s/]", "_", k.lower()): v for k, v in results.items()} |     data = {re.sub(r"[\s/]", "_", k.lower()): v for k, v in results.items()} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1,5 +1,4 @@ | ||||||
| from typing import Optional, Dict, Any, Tuple, Union, Callable, List | from typing import Optional, Dict, Any, Tuple, Union, Callable, List | ||||||
| from timeit import default_timer as timer |  | ||||||
| import srsly | import srsly | ||||||
| import tqdm | import tqdm | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  | @ -248,14 +247,11 @@ def create_evaluation_callback( | ||||||
|         dev_examples = list(dev_examples) |         dev_examples = list(dev_examples) | ||||||
|         n_words = sum(len(ex.predicted) for ex in dev_examples) |         n_words = sum(len(ex.predicted) for ex in dev_examples) | ||||||
|         batch_size = cfg["eval_batch_size"] |         batch_size = cfg["eval_batch_size"] | ||||||
|         start_time = timer() |  | ||||||
|         if optimizer.averages: |         if optimizer.averages: | ||||||
|             with nlp.use_params(optimizer.averages): |             with nlp.use_params(optimizer.averages): | ||||||
|                 scores = nlp.evaluate(dev_examples, batch_size=batch_size) |                 scores = nlp.evaluate(dev_examples, batch_size=batch_size) | ||||||
|         else: |         else: | ||||||
|             scores = nlp.evaluate(dev_examples, batch_size=batch_size) |             scores = nlp.evaluate(dev_examples, batch_size=batch_size) | ||||||
|         end_time = timer() |  | ||||||
|         wps = n_words / (end_time - start_time) |  | ||||||
|         # Calculate a weighted sum based on score_weights for the main score |         # Calculate a weighted sum based on score_weights for the main score | ||||||
|         weights = cfg["score_weights"] |         weights = cfg["score_weights"] | ||||||
|         try: |         try: | ||||||
|  | @ -264,7 +260,6 @@ def create_evaluation_callback( | ||||||
|             keys = list(scores.keys()) |             keys = list(scores.keys()) | ||||||
|             err = Errors.E983.format(dict="score_weights", key=str(e), keys=keys) |             err = Errors.E983.format(dict="score_weights", key=str(e), keys=keys) | ||||||
|             raise KeyError(err) |             raise KeyError(err) | ||||||
|         scores["speed"] = wps |  | ||||||
|         return weighted_score, scores |         return weighted_score, scores | ||||||
| 
 | 
 | ||||||
|     return evaluate |     return evaluate | ||||||
|  | @ -446,7 +441,7 @@ def update_meta( | ||||||
|     training: Union[Dict[str, Any], Config], nlp: Language, info: Dict[str, Any] |     training: Union[Dict[str, Any], Config], nlp: Language, info: Dict[str, Any] | ||||||
| ) -> None: | ) -> None: | ||||||
|     nlp.meta["performance"] = {} |     nlp.meta["performance"] = {} | ||||||
|     for metric in training["scores_weights"]: |     for metric in training["score_weights"]: | ||||||
|         nlp.meta["performance"][metric] = info["other_scores"][metric] |         nlp.meta["performance"][metric] = info["other_scores"][metric] | ||||||
|     for pipe_name in nlp.pipe_names: |     for pipe_name in nlp.pipe_names: | ||||||
|         nlp.meta["performance"][f"{pipe_name}_loss"] = info["losses"][pipe_name] |         nlp.meta["performance"][f"{pipe_name}_loss"] = info["losses"][pipe_name] | ||||||
|  |  | ||||||
|  | @ -432,12 +432,12 @@ class Errors: | ||||||
|             "Current DocBin: {current}\nOther DocBin: {other}") |             "Current DocBin: {current}\nOther DocBin: {other}") | ||||||
|     E169 = ("Can't find module: {module}") |     E169 = ("Can't find module: {module}") | ||||||
|     E170 = ("Cannot apply transition {name}: invalid for the current state.") |     E170 = ("Cannot apply transition {name}: invalid for the current state.") | ||||||
|     E171 = ("Matcher.add received invalid on_match callback argument: expected " |     E171 = ("Matcher.add received invalid 'on_match' callback argument: expected " | ||||||
|             "callable or None, but got: {arg_type}") |             "callable or None, but got: {arg_type}") | ||||||
|     E175 = ("Can't remove rule for unknown match pattern ID: {key}") |     E175 = ("Can't remove rule for unknown match pattern ID: {key}") | ||||||
|     E176 = ("Alias '{alias}' is not defined in the Knowledge Base.") |     E176 = ("Alias '{alias}' is not defined in the Knowledge Base.") | ||||||
|     E177 = ("Ill-formed IOB input detected: {tag}") |     E177 = ("Ill-formed IOB input detected: {tag}") | ||||||
|     E178 = ("Invalid pattern. Expected list of dicts but got: {pat}. Maybe you " |     E178 = ("Each pattern should be a list of dicts, but got: {pat}. Maybe you " | ||||||
|             "accidentally passed a single pattern to Matcher.add instead of a " |             "accidentally passed a single pattern to Matcher.add instead of a " | ||||||
|             "list of patterns? If you only want to add one pattern, make sure " |             "list of patterns? If you only want to add one pattern, make sure " | ||||||
|             "to wrap it in a list. For example: matcher.add('{key}', [pattern])") |             "to wrap it in a list. For example: matcher.add('{key}', [pattern])") | ||||||
|  | @ -483,6 +483,10 @@ class Errors: | ||||||
|     E199 = ("Unable to merge 0-length span at doc[{start}:{end}].") |     E199 = ("Unable to merge 0-length span at doc[{start}:{end}].") | ||||||
| 
 | 
 | ||||||
|     # TODO: fix numbering after merging develop into master |     # TODO: fix numbering after merging develop into master | ||||||
|  |     E947 = ("Matcher.add received invalid 'greedy' argument: expected " | ||||||
|  |             "a string value from {expected} but got: '{arg}'") | ||||||
|  |     E948 = ("Matcher.add received invalid 'patterns' argument: expected " | ||||||
|  |             "a List, but got: {arg_type}") | ||||||
|     E952 = ("The section '{name}' is not a valid section in the provided config.") |     E952 = ("The section '{name}' is not a valid section in the provided config.") | ||||||
|     E953 = ("Mismatched IDs received by the Tok2Vec listener: {id1} vs. {id2}") |     E953 = ("Mismatched IDs received by the Tok2Vec listener: {id1} vs. {id2}") | ||||||
|     E954 = ("The Tok2Vec listener did not receive a valid input.") |     E954 = ("The Tok2Vec listener did not receive a valid input.") | ||||||
|  |  | ||||||
|  | @ -14,6 +14,7 @@ from thinc.api import get_current_ops, Config, require_gpu, Optimizer | ||||||
| import srsly | import srsly | ||||||
| import multiprocessing as mp | import multiprocessing as mp | ||||||
| from itertools import chain, cycle | from itertools import chain, cycle | ||||||
|  | from timeit import default_timer as timer | ||||||
| 
 | 
 | ||||||
| from .tokens.underscore import Underscore | from .tokens.underscore import Underscore | ||||||
| from .vocab import Vocab, create_vocab | from .vocab import Vocab, create_vocab | ||||||
|  | @ -1130,7 +1131,14 @@ class Language: | ||||||
|             kwargs.setdefault("verbose", verbose) |             kwargs.setdefault("verbose", verbose) | ||||||
|             kwargs.setdefault("nlp", self) |             kwargs.setdefault("nlp", self) | ||||||
|             scorer = Scorer(**kwargs) |             scorer = Scorer(**kwargs) | ||||||
|         docs = list(eg.predicted for eg in examples) |         texts = [eg.reference.text for eg in examples] | ||||||
|  |         docs = [eg.predicted for eg in examples] | ||||||
|  |         start_time = timer() | ||||||
|  |         # tokenize the texts only for timing purposes | ||||||
|  |         if not hasattr(self.tokenizer, "pipe"): | ||||||
|  |             _ = [self.tokenizer(text) for text in texts] | ||||||
|  |         else: | ||||||
|  |             _ = list(self.tokenizer.pipe(texts)) | ||||||
|         for name, pipe in self.pipeline: |         for name, pipe in self.pipeline: | ||||||
|             kwargs = component_cfg.get(name, {}) |             kwargs = component_cfg.get(name, {}) | ||||||
|             kwargs.setdefault("batch_size", batch_size) |             kwargs.setdefault("batch_size", batch_size) | ||||||
|  | @ -1138,11 +1146,18 @@ class Language: | ||||||
|                 docs = _pipe(docs, pipe, kwargs) |                 docs = _pipe(docs, pipe, kwargs) | ||||||
|             else: |             else: | ||||||
|                 docs = pipe.pipe(docs, **kwargs) |                 docs = pipe.pipe(docs, **kwargs) | ||||||
|  |         # iterate over the final generator | ||||||
|  |         if len(self.pipeline): | ||||||
|  |             docs = list(docs) | ||||||
|  |         end_time = timer() | ||||||
|         for i, (doc, eg) in enumerate(zip(docs, examples)): |         for i, (doc, eg) in enumerate(zip(docs, examples)): | ||||||
|             if verbose: |             if verbose: | ||||||
|                 print(doc) |                 print(doc) | ||||||
|             eg.predicted = doc |             eg.predicted = doc | ||||||
|         return scorer.score(examples) |         results = scorer.score(examples) | ||||||
|  |         n_words = sum(len(eg.predicted) for eg in examples) | ||||||
|  |         results["speed"] = n_words / (end_time - start_time) | ||||||
|  |         return results | ||||||
| 
 | 
 | ||||||
|     @contextmanager |     @contextmanager | ||||||
|     def use_params(self, params: dict): |     def use_params(self, params: dict): | ||||||
|  |  | ||||||
|  | @ -66,6 +66,7 @@ cdef class Matcher: | ||||||
|     cdef public object validate |     cdef public object validate | ||||||
|     cdef public object _patterns |     cdef public object _patterns | ||||||
|     cdef public object _callbacks |     cdef public object _callbacks | ||||||
|  |     cdef public object _filter | ||||||
|     cdef public object _extensions |     cdef public object _extensions | ||||||
|     cdef public object _extra_predicates |     cdef public object _extra_predicates | ||||||
|     cdef public object _seen_attrs |     cdef public object _seen_attrs | ||||||
|  |  | ||||||
|  | @ -1,6 +1,9 @@ | ||||||
| # cython: infer_types=True, cython: profile=True | # cython: infer_types=True, cython: profile=True | ||||||
|  | from typing import List | ||||||
|  | 
 | ||||||
| from libcpp.vector cimport vector | from libcpp.vector cimport vector | ||||||
| from libc.stdint cimport int32_t | from libc.stdint cimport int32_t | ||||||
|  | from libc.string cimport memset, memcmp | ||||||
| from cymem.cymem cimport Pool | from cymem.cymem cimport Pool | ||||||
| from murmurhash.mrmr cimport hash64 | from murmurhash.mrmr cimport hash64 | ||||||
| 
 | 
 | ||||||
|  | @ -41,6 +44,7 @@ cdef class Matcher: | ||||||
|         self._extra_predicates = [] |         self._extra_predicates = [] | ||||||
|         self._patterns = {} |         self._patterns = {} | ||||||
|         self._callbacks = {} |         self._callbacks = {} | ||||||
|  |         self._filter = {} | ||||||
|         self._extensions = {} |         self._extensions = {} | ||||||
|         self._seen_attrs = set() |         self._seen_attrs = set() | ||||||
|         self.vocab = vocab |         self.vocab = vocab | ||||||
|  | @ -68,7 +72,7 @@ cdef class Matcher: | ||||||
|         """ |         """ | ||||||
|         return self._normalize_key(key) in self._patterns |         return self._normalize_key(key) in self._patterns | ||||||
| 
 | 
 | ||||||
|     def add(self, key, patterns, *_patterns, on_match=None): |     def add(self, key, patterns, *, on_match=None, greedy: str=None): | ||||||
|         """Add a match-rule to the matcher. A match-rule consists of: an ID |         """Add a match-rule to the matcher. A match-rule consists of: an ID | ||||||
|         key, an on_match callback, and one or more patterns. |         key, an on_match callback, and one or more patterns. | ||||||
| 
 | 
 | ||||||
|  | @ -86,11 +90,10 @@ cdef class Matcher: | ||||||
|         '+': Require the pattern to match 1 or more times. |         '+': Require the pattern to match 1 or more times. | ||||||
|         '*': Allow the pattern to zero or more times. |         '*': Allow the pattern to zero or more times. | ||||||
| 
 | 
 | ||||||
|         The + and * operators are usually interpretted "greedily", i.e. longer |         The + and * operators return all possible matches (not just the greedy | ||||||
|         matches are returned where possible. However, if you specify two '+' |         ones). However, the "greedy" argument can filter the final matches | ||||||
|         and '*' patterns in a row and their matches overlap, the first |         by returning a non-overlapping set per key, either taking preference to | ||||||
|         operator will behave non-greedily. This quirk in the semantics makes |         the first greedy match ("FIRST"), or the longest ("LONGEST"). | ||||||
|         the matcher more efficient, by avoiding the need for back-tracking. |  | ||||||
| 
 | 
 | ||||||
|         As of spaCy v2.2.2, Matcher.add supports the future API, which makes |         As of spaCy v2.2.2, Matcher.add supports the future API, which makes | ||||||
|         the patterns the second argument and a list (instead of a variable |         the patterns the second argument and a list (instead of a variable | ||||||
|  | @ -100,16 +103,15 @@ cdef class Matcher: | ||||||
|         key (str): The match ID. |         key (str): The match ID. | ||||||
|         patterns (list): The patterns to add for the given key. |         patterns (list): The patterns to add for the given key. | ||||||
|         on_match (callable): Optional callback executed on match. |         on_match (callable): Optional callback executed on match. | ||||||
|         *_patterns (list): For backwards compatibility: list of patterns to add |         greedy (str): Optional filter: "FIRST" or "LONGEST". | ||||||
|             as variable arguments. Will be ignored if a list of patterns is |  | ||||||
|             provided as the second argument. |  | ||||||
|         """ |         """ | ||||||
|         errors = {} |         errors = {} | ||||||
|         if on_match is not None and not hasattr(on_match, "__call__"): |         if on_match is not None and not hasattr(on_match, "__call__"): | ||||||
|             raise ValueError(Errors.E171.format(arg_type=type(on_match))) |             raise ValueError(Errors.E171.format(arg_type=type(on_match))) | ||||||
|         if patterns is None or hasattr(patterns, "__call__"):  # old API |         if patterns is None or not isinstance(patterns, List):  # old API | ||||||
|             on_match = patterns |             raise ValueError(Errors.E948.format(arg_type=type(patterns))) | ||||||
|             patterns = _patterns |         if greedy is not None and greedy not in ["FIRST", "LONGEST"]: | ||||||
|  |             raise ValueError(Errors.E947.format(expected=["FIRST", "LONGEST"], arg=greedy)) | ||||||
|         for i, pattern in enumerate(patterns): |         for i, pattern in enumerate(patterns): | ||||||
|             if len(pattern) == 0: |             if len(pattern) == 0: | ||||||
|                 raise ValueError(Errors.E012.format(key=key)) |                 raise ValueError(Errors.E012.format(key=key)) | ||||||
|  | @ -132,6 +134,7 @@ cdef class Matcher: | ||||||
|                 raise ValueError(Errors.E154.format()) |                 raise ValueError(Errors.E154.format()) | ||||||
|         self._patterns.setdefault(key, []) |         self._patterns.setdefault(key, []) | ||||||
|         self._callbacks[key] = on_match |         self._callbacks[key] = on_match | ||||||
|  |         self._filter[key] = greedy | ||||||
|         self._patterns[key].extend(patterns) |         self._patterns[key].extend(patterns) | ||||||
| 
 | 
 | ||||||
|     def remove(self, key): |     def remove(self, key): | ||||||
|  | @ -217,6 +220,7 @@ cdef class Matcher: | ||||||
|             length = doclike.end - doclike.start |             length = doclike.end - doclike.start | ||||||
|         else: |         else: | ||||||
|             raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doclike).__name__)) |             raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doclike).__name__)) | ||||||
|  |         cdef Pool tmp_pool = Pool() | ||||||
|         if len(set([LEMMA, POS, TAG]) & self._seen_attrs) > 0 \ |         if len(set([LEMMA, POS, TAG]) & self._seen_attrs) > 0 \ | ||||||
|           and not doc.is_tagged: |           and not doc.is_tagged: | ||||||
|             raise ValueError(Errors.E155.format()) |             raise ValueError(Errors.E155.format()) | ||||||
|  | @ -224,11 +228,42 @@ cdef class Matcher: | ||||||
|             raise ValueError(Errors.E156.format()) |             raise ValueError(Errors.E156.format()) | ||||||
|         matches = find_matches(&self.patterns[0], self.patterns.size(), doclike, length, |         matches = find_matches(&self.patterns[0], self.patterns.size(), doclike, length, | ||||||
|                                 extensions=self._extensions, predicates=self._extra_predicates) |                                 extensions=self._extensions, predicates=self._extra_predicates) | ||||||
|         for i, (key, start, end) in enumerate(matches): |         final_matches = [] | ||||||
|  |         pairs_by_id = {} | ||||||
|  |         # For each key, either add all matches, or only the filtered, non-overlapping ones | ||||||
|  |         for (key, start, end) in matches: | ||||||
|  |             span_filter = self._filter.get(key) | ||||||
|  |             if span_filter is not None: | ||||||
|  |                 pairs = pairs_by_id.get(key, []) | ||||||
|  |                 pairs.append((start,end)) | ||||||
|  |                 pairs_by_id[key] = pairs | ||||||
|  |             else: | ||||||
|  |                 final_matches.append((key, start, end)) | ||||||
|  |         matched = <char*>tmp_pool.alloc(length, sizeof(char)) | ||||||
|  |         empty = <char*>tmp_pool.alloc(length, sizeof(char)) | ||||||
|  |         for key, pairs in pairs_by_id.items(): | ||||||
|  |             memset(matched, 0, length * sizeof(matched[0])) | ||||||
|  |             span_filter = self._filter.get(key) | ||||||
|  |             if span_filter == "FIRST": | ||||||
|  |                 sorted_pairs = sorted(pairs, key=lambda x: (x[0], -x[1]), reverse=False) # sort by start | ||||||
|  |             elif span_filter == "LONGEST": | ||||||
|  |                 sorted_pairs = sorted(pairs, key=lambda x: (x[1]-x[0], -x[0]), reverse=True) # reverse sort by length | ||||||
|  |             else: | ||||||
|  |                 raise ValueError(Errors.E947.format(expected=["FIRST", "LONGEST"], arg=span_filter)) | ||||||
|  |             for (start, end) in sorted_pairs: | ||||||
|  |                 assert 0 <= start < end  # Defend against segfaults | ||||||
|  |                 span_len = end-start | ||||||
|  |                 # If no tokens in the span have matched | ||||||
|  |                 if memcmp(&matched[start], &empty[start], span_len * sizeof(matched[0])) == 0: | ||||||
|  |                     final_matches.append((key, start, end)) | ||||||
|  |                     # Mark tokens that have matched | ||||||
|  |                     memset(&matched[start], 1, span_len * sizeof(matched[0])) | ||||||
|  |         # perform the callbacks on the filtered set of results | ||||||
|  |         for i, (key, start, end) in enumerate(final_matches): | ||||||
|             on_match = self._callbacks.get(key, None) |             on_match = self._callbacks.get(key, None) | ||||||
|             if on_match is not None: |             if on_match is not None: | ||||||
|                 on_match(self, doc, i, matches) |                 on_match(self, doc, i, final_matches) | ||||||
|         return matches |         return final_matches | ||||||
| 
 | 
 | ||||||
|     def _normalize_key(self, key): |     def _normalize_key(self, key): | ||||||
|         if isinstance(key, basestring): |         if isinstance(key, basestring): | ||||||
|  | @ -239,9 +274,9 @@ cdef class Matcher: | ||||||
| 
 | 
 | ||||||
| def unpickle_matcher(vocab, patterns, callbacks): | def unpickle_matcher(vocab, patterns, callbacks): | ||||||
|     matcher = Matcher(vocab) |     matcher = Matcher(vocab) | ||||||
|     for key, specs in patterns.items(): |     for key, pattern in patterns.items(): | ||||||
|         callback = callbacks.get(key, None) |         callback = callbacks.get(key, None) | ||||||
|         matcher.add(key, callback, *specs) |         matcher.add(key, pattern, on_match=callback) | ||||||
|     return matcher |     return matcher | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -58,7 +58,7 @@ def merge_subtokens(doc: Doc, label: str = "subtok") -> Doc: | ||||||
|     """ |     """ | ||||||
|     # TODO: make stateful component with "label" config |     # TODO: make stateful component with "label" config | ||||||
|     merger = Matcher(doc.vocab) |     merger = Matcher(doc.vocab) | ||||||
|     merger.add("SUBTOK", None, [{"DEP": label, "op": "+"}]) |     merger.add("SUBTOK", [[{"DEP": label, "op": "+"}]]) | ||||||
|     matches = merger(doc) |     matches = merger(doc) | ||||||
|     spans = filter_spans([doc[start : end + 1] for _, start, end in matches]) |     spans = filter_spans([doc[start : end + 1] for _, start, end in matches]) | ||||||
|     with doc.retokenize() as retokenizer: |     with doc.retokenize() as retokenizer: | ||||||
|  |  | ||||||
|  | @ -63,18 +63,11 @@ def test_matcher_len_contains(matcher): | ||||||
|     assert "TEST2" not in matcher |     assert "TEST2" not in matcher | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_matcher_add_new_old_api(en_vocab): | def test_matcher_add_new_api(en_vocab): | ||||||
|     doc = Doc(en_vocab, words=["a", "b"]) |     doc = Doc(en_vocab, words=["a", "b"]) | ||||||
|     patterns = [[{"TEXT": "a"}], [{"TEXT": "a"}, {"TEXT": "b"}]] |     patterns = [[{"TEXT": "a"}], [{"TEXT": "a"}, {"TEXT": "b"}]] | ||||||
|     matcher = Matcher(en_vocab) |     matcher = Matcher(en_vocab) | ||||||
|     matcher.add("OLD_API", None, *patterns) |  | ||||||
|     assert len(matcher(doc)) == 2 |  | ||||||
|     matcher = Matcher(en_vocab) |  | ||||||
|     on_match = Mock() |     on_match = Mock() | ||||||
|     matcher.add("OLD_API_CALLBACK", on_match, *patterns) |  | ||||||
|     assert len(matcher(doc)) == 2 |  | ||||||
|     assert on_match.call_count == 2 |  | ||||||
|     # New API: add(key: str, patterns: List[List[dict]], on_match: Callable) |  | ||||||
|     matcher = Matcher(en_vocab) |     matcher = Matcher(en_vocab) | ||||||
|     matcher.add("NEW_API", patterns) |     matcher.add("NEW_API", patterns) | ||||||
|     assert len(matcher(doc)) == 2 |     assert len(matcher(doc)) == 2 | ||||||
|  | @ -176,7 +169,7 @@ def test_matcher_match_zero_plus(matcher): | ||||||
| 
 | 
 | ||||||
| def test_matcher_match_one_plus(matcher): | def test_matcher_match_one_plus(matcher): | ||||||
|     control = Matcher(matcher.vocab) |     control = Matcher(matcher.vocab) | ||||||
|     control.add("BasicPhilippe", None, [{"ORTH": "Philippe"}]) |     control.add("BasicPhilippe", [[{"ORTH": "Philippe"}]]) | ||||||
|     doc = Doc(control.vocab, words=["Philippe", "Philippe"]) |     doc = Doc(control.vocab, words=["Philippe", "Philippe"]) | ||||||
|     m = control(doc) |     m = control(doc) | ||||||
|     assert len(m) == 2 |     assert len(m) == 2 | ||||||
|  |  | ||||||
|  | @ -7,18 +7,10 @@ from spacy.tokens import Doc, Span | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| pattern1 = [{"ORTH": "A"}, {"ORTH": "A", "OP": "*"}] | pattern1 = [{"ORTH": "A"}, {"ORTH": "A", "OP": "*"}] | ||||||
| pattern2 = [{"ORTH": "A"}, {"ORTH": "A"}] | pattern2 = [{"ORTH": "A", "OP": "*"}, {"ORTH": "A"}] | ||||||
| pattern3 = [{"ORTH": "A"}, {"ORTH": "A"}] | pattern3 = [{"ORTH": "A"}, {"ORTH": "A"}] | ||||||
| pattern4 = [ | pattern4 = [{"ORTH": "B"}, {"ORTH": "A", "OP": "*"}, {"ORTH": "B"}] | ||||||
|     {"ORTH": "B"}, | pattern5 = [{"ORTH": "B", "OP": "*"}, {"ORTH": "A", "OP": "*"}, {"ORTH": "B"}] | ||||||
|     {"ORTH": "A", "OP": "*"}, |  | ||||||
|     {"ORTH": "B"}, |  | ||||||
| ] |  | ||||||
| pattern5 = [ |  | ||||||
|     {"ORTH": "B", "OP": "*"}, |  | ||||||
|     {"ORTH": "A", "OP": "*"}, |  | ||||||
|     {"ORTH": "B"}, |  | ||||||
| ] |  | ||||||
| 
 | 
 | ||||||
| re_pattern1 = "AA*" | re_pattern1 = "AA*" | ||||||
| re_pattern2 = "A*A" | re_pattern2 = "A*A" | ||||||
|  | @ -26,10 +18,16 @@ re_pattern3 = "AA" | ||||||
| re_pattern4 = "BA*B" | re_pattern4 = "BA*B" | ||||||
| re_pattern5 = "B*A*B" | re_pattern5 = "B*A*B" | ||||||
| 
 | 
 | ||||||
|  | longest1 = "A A A A A" | ||||||
|  | longest2 = "A A A A A" | ||||||
|  | longest3 = "A A" | ||||||
|  | longest4 = "B A A A A A B"      # "FIRST" would be "B B" | ||||||
|  | longest5 = "B B A A A A A B" | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.fixture | @pytest.fixture | ||||||
| def text(): | def text(): | ||||||
|     return "(ABBAAAAAB)." |     return "(BBAAAAAB)." | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.fixture | @pytest.fixture | ||||||
|  | @ -41,25 +39,63 @@ def doc(en_tokenizer, text): | ||||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||||
|     "pattern,re_pattern", |     "pattern,re_pattern", | ||||||
|     [ |     [ | ||||||
|         pytest.param(pattern1, re_pattern1, marks=pytest.mark.xfail()), |         (pattern1, re_pattern1), | ||||||
|         pytest.param(pattern2, re_pattern2, marks=pytest.mark.xfail()), |         (pattern2, re_pattern2), | ||||||
|         pytest.param(pattern3, re_pattern3, marks=pytest.mark.xfail()), |         (pattern3, re_pattern3), | ||||||
|         (pattern4, re_pattern4), |         (pattern4, re_pattern4), | ||||||
|         pytest.param(pattern5, re_pattern5, marks=pytest.mark.xfail()), |         (pattern5, re_pattern5), | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
| def test_greedy_matching(doc, text, pattern, re_pattern): | def test_greedy_matching_first(doc, text, pattern, re_pattern): | ||||||
|     """Test that the greedy matching behavior of the * op is consistant with |     """Test that the greedy matching behavior "FIRST" is consistent with | ||||||
|     other re implementations.""" |     other re implementations.""" | ||||||
|     matcher = Matcher(doc.vocab) |     matcher = Matcher(doc.vocab) | ||||||
|     matcher.add(re_pattern, [pattern]) |     matcher.add(re_pattern, [pattern], greedy="FIRST") | ||||||
|     matches = matcher(doc) |     matches = matcher(doc) | ||||||
|     re_matches = [m.span() for m in re.finditer(re_pattern, text)] |     re_matches = [m.span() for m in re.finditer(re_pattern, text)] | ||||||
|     for match, re_match in zip(matches, re_matches): |     for (key, m_s, m_e), (re_s, re_e) in zip(matches, re_matches): | ||||||
|         assert match[1:] == re_match |         # matching the string, not the exact position | ||||||
|  |         assert doc[m_s:m_e].text == doc[re_s:re_e].text | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.mark.parametrize( | ||||||
|  |     "pattern,longest", | ||||||
|  |     [ | ||||||
|  |         (pattern1, longest1), | ||||||
|  |         (pattern2, longest2), | ||||||
|  |         (pattern3, longest3), | ||||||
|  |         (pattern4, longest4), | ||||||
|  |         (pattern5, longest5), | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  | def test_greedy_matching_longest(doc, text, pattern, longest): | ||||||
|  |     """Test the "LONGEST" greedy matching behavior""" | ||||||
|  |     matcher = Matcher(doc.vocab) | ||||||
|  |     matcher.add("RULE", [pattern], greedy="LONGEST") | ||||||
|  |     matches = matcher(doc) | ||||||
|  |     for (key, s, e) in matches: | ||||||
|  |         assert doc[s:e].text == longest | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_greedy_matching_longest_first(en_tokenizer): | ||||||
|  |     """Test that "LONGEST" matching prefers the first of two equally long matches""" | ||||||
|  |     doc = en_tokenizer(" ".join("CCC")) | ||||||
|  |     matcher = Matcher(doc.vocab) | ||||||
|  |     pattern = [{"ORTH": "C"}, {"ORTH": "C"}] | ||||||
|  |     matcher.add("RULE", [pattern], greedy="LONGEST") | ||||||
|  |     matches = matcher(doc) | ||||||
|  |     # out of 0-2 and 1-3, the first should be picked | ||||||
|  |     assert len(matches) == 1 | ||||||
|  |     assert matches[0][1] == 0 | ||||||
|  |     assert matches[0][2] == 2 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_invalid_greediness(doc, text): | ||||||
|  |     matcher = Matcher(doc.vocab) | ||||||
|  |     with pytest.raises(ValueError): | ||||||
|  |         matcher.add("RULE", [pattern1], greedy="GREEDY") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.xfail |  | ||||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||||
|     "pattern,re_pattern", |     "pattern,re_pattern", | ||||||
|     [ |     [ | ||||||
|  | @ -74,7 +110,7 @@ def test_match_consuming(doc, text, pattern, re_pattern): | ||||||
|     """Test that matcher.__call__ consumes tokens on a match similar to |     """Test that matcher.__call__ consumes tokens on a match similar to | ||||||
|     re.findall.""" |     re.findall.""" | ||||||
|     matcher = Matcher(doc.vocab) |     matcher = Matcher(doc.vocab) | ||||||
|     matcher.add(re_pattern, [pattern]) |     matcher.add(re_pattern, [pattern], greedy="FIRST") | ||||||
|     matches = matcher(doc) |     matches = matcher(doc) | ||||||
|     re_matches = [m.span() for m in re.finditer(re_pattern, text)] |     re_matches = [m.span() for m in re.finditer(re_pattern, text)] | ||||||
|     assert len(matches) == len(re_matches) |     assert len(matches) == len(re_matches) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user