mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	only use a single spans_key like in spancat
This commit is contained in:
		
							parent
							
								
									90af16af76
								
							
						
					
					
						commit
						6f750d0da6
					
				|  | @ -1,4 +1,3 @@ | |||
| from functools import partial | ||||
| from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, cast | ||||
| 
 | ||||
| from thinc.api import Config, Model, Ops, Optimizer, get_current_ops, set_dropout_rate | ||||
|  | @ -41,7 +40,6 @@ depth = 4 | |||
| """ | ||||
| 
 | ||||
| DEFAULT_SPAN_FINDER_MODEL = Config().from_str(span_finder_default_config)["model"] | ||||
| DEFAULT_PREDICTED_KEY = "span_candidates" | ||||
| 
 | ||||
| 
 | ||||
| @Language.factory( | ||||
|  | @ -50,21 +48,15 @@ DEFAULT_PREDICTED_KEY = "span_candidates" | |||
|     default_config={ | ||||
|         "threshold": 0.5, | ||||
|         "model": DEFAULT_SPAN_FINDER_MODEL, | ||||
|         "predicted_key": DEFAULT_PREDICTED_KEY, | ||||
|         "training_key": DEFAULT_SPANS_KEY, | ||||
|         # XXX Doesn't 0 seem bad compared to None instead? | ||||
|         "spans_key": DEFAULT_SPANS_KEY, | ||||
|         "max_length": None, | ||||
|         "min_length": None, | ||||
|         "scorer": { | ||||
|             "@scorers": "spacy.span_finder_scorer.v1", | ||||
|             "predicted_key": DEFAULT_PREDICTED_KEY, | ||||
|             "training_key": DEFAULT_SPANS_KEY, | ||||
|         }, | ||||
|         "scorer": {"@scorers": "spacy.span_finder_scorer.v1"}, | ||||
|     }, | ||||
|     default_score_weights={ | ||||
|         f"span_finder_{DEFAULT_PREDICTED_KEY}_f": 1.0, | ||||
|         f"span_finder_{DEFAULT_PREDICTED_KEY}_p": 0.0, | ||||
|         f"span_finder_{DEFAULT_PREDICTED_KEY}_r": 0.0, | ||||
|         f"span_finder_{DEFAULT_SPANS_KEY}_f": 1.0, | ||||
|         f"span_finder_{DEFAULT_SPANS_KEY}_p": 0.0, | ||||
|         f"span_finder_{DEFAULT_SPANS_KEY}_r": 0.0, | ||||
|     }, | ||||
| ) | ||||
| def make_span_finder( | ||||
|  | @ -75,8 +67,7 @@ def make_span_finder( | |||
|     threshold: float, | ||||
|     max_length: Optional[int], | ||||
|     min_length: Optional[int], | ||||
|     predicted_key: str = DEFAULT_PREDICTED_KEY, | ||||
|     training_key: str = DEFAULT_SPANS_KEY, | ||||
|     spans_key: str, | ||||
| ) -> "SpanFinder": | ||||
|     """Create a SpanFinder component. The component predicts whether a token is | ||||
|     the start or the end of a potential span. | ||||
|  | @ -84,10 +75,9 @@ def make_span_finder( | |||
|     model (Model[List[Doc], Floats2d]): A model instance that | ||||
|         is given a list of documents and predicts a probability for each token. | ||||
|     threshold (float): Minimum probability to consider a prediction positive. | ||||
|     predicted_key (str): Name of the span group the predicted spans are saved | ||||
|         to | ||||
|     training_key (str): Name of the span group the training spans are read | ||||
|         from | ||||
|     spans_key (str): Key of the doc.spans dict to save the spans under. During | ||||
|         initialization and training, the component will look for spans on the | ||||
|         reference document under the same key. | ||||
|     max_length (Optional[int]): Max length of the produced spans, defaults to None meaning unlimited length. | ||||
|     min_length (Optional[int]): Min length of the produced spans, defaults to None meaining shortest span is length 1. | ||||
|     """ | ||||
|  | @ -99,51 +89,26 @@ def make_span_finder( | |||
|         scorer=scorer, | ||||
|         max_length=max_length, | ||||
|         min_length=min_length, | ||||
|         predicted_key=predicted_key, | ||||
|         training_key=training_key, | ||||
|         spans_key=spans_key, | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| @registry.scorers("spacy.span_finder_scorer.v1") | ||||
| def make_span_finder_scorer( | ||||
|     predicted_key: str = DEFAULT_PREDICTED_KEY, | ||||
|     training_key: str = DEFAULT_SPANS_KEY, | ||||
| ): | ||||
|     return partial( | ||||
|         span_finder_score, predicted_key=predicted_key, training_key=training_key | ||||
|     ) | ||||
| def make_span_finder_scorer(): | ||||
|     return span_finder_score | ||||
| 
 | ||||
| 
 | ||||
| def span_finder_score( | ||||
|     examples: Iterable[Example], | ||||
|     *, | ||||
|     predicted_key: str = DEFAULT_PREDICTED_KEY, | ||||
|     training_key: str = DEFAULT_SPANS_KEY, | ||||
|     **kwargs, | ||||
| ) -> Dict[str, Any]: | ||||
| def span_finder_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: | ||||
|     kwargs = dict(kwargs) | ||||
|     print(kwargs) | ||||
|     attr_prefix = "span_finder_" | ||||
|     kwargs.setdefault("attr", f"{attr_prefix}{predicted_key}") | ||||
|     kwargs.setdefault("allow_overlap", True) | ||||
|     key = kwargs["spans_key"] | ||||
|     kwargs.setdefault("attr", f"{attr_prefix}{key}") | ||||
|     kwargs.setdefault( | ||||
|         "getter", lambda doc, key: doc.spans.get(key[len(attr_prefix) :], []) | ||||
|     ) | ||||
|     kwargs.setdefault("labeled", False) | ||||
|     kwargs.setdefault("has_annotation", lambda doc: predicted_key in doc.spans) | ||||
|     # score_spans can only score spans with the same key in both the reference | ||||
|     # and predicted docs, so temporarily copy the reference spans from the | ||||
|     # reference key to the candidates key in the reference docs, restoring the | ||||
|     # original span groups afterwards | ||||
|     orig_span_groups = [] | ||||
|     for eg in examples: | ||||
|         orig_span_groups.append(eg.reference.spans.get(predicted_key)) | ||||
|         if training_key in eg.reference.spans: | ||||
|             eg.reference.spans[predicted_key] = eg.reference.spans[training_key] | ||||
|     scores = Scorer.score_spans(examples, **kwargs) | ||||
|     for orig_span_group, eg in zip(orig_span_groups, examples): | ||||
|         if orig_span_group is not None: | ||||
|             eg.reference.spans[predicted_key] = orig_span_group | ||||
|     return scores | ||||
|     kwargs.setdefault("has_annotation", lambda doc: key in doc.spans) | ||||
|     return Scorer.score_spans(examples, **kwargs) | ||||
| 
 | ||||
| 
 | ||||
| class _MaxInt(int): | ||||
|  | @ -179,13 +144,8 @@ class SpanFinder(TrainablePipe): | |||
|         max_length: Optional[int] = None, | ||||
|         min_length: Optional[int] = None, | ||||
|         # XXX I think this is weird and should be just None like in | ||||
|         scorer: Optional[Callable] = partial( | ||||
|             span_finder_score, | ||||
|             predicted_key=DEFAULT_PREDICTED_KEY, | ||||
|             training_key=DEFAULT_SPANS_KEY, | ||||
|         ), | ||||
|         predicted_key: str = DEFAULT_PREDICTED_KEY, | ||||
|         training_key: str = DEFAULT_SPANS_KEY, | ||||
|         scorer: Optional[Callable] = span_finder_score, | ||||
|         spans_key: str = DEFAULT_SPANS_KEY, | ||||
|     ) -> None: | ||||
|         """Initialize the span boundary detector. | ||||
|         model (thinc.api.Model): The Thinc Model powering the pipeline component. | ||||
|  | @ -194,8 +154,9 @@ class SpanFinder(TrainablePipe): | |||
|         threshold (float): Minimum probability to consider a prediction | ||||
|             positive. | ||||
|         scorer (Optional[Callable]): The scoring method. | ||||
|         predicted_key (str): Name of the span group the candidate spans are saved to | ||||
|         training_key (str): Name of the span group the training spans are read from | ||||
|         spans_key (str): Key of the doc.spans dict to save the spans under. During | ||||
|             initialization and training, the component will look for spans on the | ||||
|             reference document under the same key. | ||||
|         max_length (Optional[int]): Max length of the produced spans, defaults to None meaning unlimited length. | ||||
|         min_length (Optional[int]): Min length of the produced spans, defaults to None meaining shortest span is length 1. | ||||
|         """ | ||||
|  | @ -211,11 +172,11 @@ class SpanFinder(TrainablePipe): | |||
|             ) | ||||
|         self.min_length = min_length | ||||
|         self.max_length = max_length | ||||
|         self.predicted_key = predicted_key | ||||
|         self.training_key = training_key | ||||
|         self.spans_key = spans_key | ||||
|         self.model = model | ||||
|         self.name = name | ||||
|         self.scorer = scorer | ||||
|         self.cfg = {"spans_key": spans_key} | ||||
| 
 | ||||
|     def predict(self, docs: Iterable[Doc]): | ||||
|         """Apply the pipeline's model to a batch of docs, without modifying them. | ||||
|  | @ -232,7 +193,7 @@ class SpanFinder(TrainablePipe): | |||
|         """ | ||||
|         offset = 0 | ||||
|         for i, doc in enumerate(docs): | ||||
|             doc.spans[self.predicted_key] = [] | ||||
|             doc.spans[self.spans_key] = [] | ||||
|             starts = [] | ||||
|             ends = [] | ||||
|             doc_scores = scores[offset : offset + len(doc)] | ||||
|  | @ -249,7 +210,7 @@ class SpanFinder(TrainablePipe): | |||
|                     if span_length > self.max_length: | ||||
|                         break | ||||
|                     elif self.min_length <= span_length: | ||||
|                         doc.spans[self.predicted_key].append(doc[start : end + 1]) | ||||
|                         doc.spans[self.spans_key].append(doc[start : end + 1]) | ||||
| 
 | ||||
|     def update( | ||||
|         self, | ||||
|  | @ -304,8 +265,8 @@ class SpanFinder(TrainablePipe): | |||
|             n_tokens = len(eg.predicted) | ||||
|             truth = ops.xp.zeros((n_tokens, 2), dtype="float32") | ||||
|             mask = ops.xp.ones((n_tokens, 2), dtype="float32") | ||||
|             if self.training_key in eg.reference.spans: | ||||
|                 for span in eg.reference.spans[self.training_key]: | ||||
|             if self.spans_key in eg.reference.spans: | ||||
|                 for span in eg.reference.spans[self.spans_key]: | ||||
|                     ref_start_char, ref_end_char = _char_indices(span) | ||||
|                     pred_span = eg.predicted.char_span( | ||||
|                         ref_start_char, ref_end_char, alignment_mode="expand" | ||||
|  | @ -342,8 +303,8 @@ class SpanFinder(TrainablePipe): | |||
|             start_indices = set() | ||||
|             end_indices = set() | ||||
| 
 | ||||
|             if self.training_key in doc.spans: | ||||
|                 for span in doc.spans[self.training_key]: | ||||
|             if self.spans_key in doc.spans: | ||||
|                 for span in doc.spans[self.spans_key]: | ||||
|                     start_indices.add(span.start) | ||||
|                     end_indices.add(span.end - 1) | ||||
| 
 | ||||
|  |  | |||
|  | @ -4,7 +4,7 @@ from thinc.types import Ragged | |||
| 
 | ||||
| from spacy.language import Language | ||||
| from spacy.lang.en import English | ||||
| from spacy.pipeline.span_finder import DEFAULT_PREDICTED_KEY, span_finder_default_config | ||||
| from spacy.pipeline.span_finder import span_finder_default_config | ||||
| from spacy.tokens import Doc | ||||
| from spacy.training import Example | ||||
| from spacy import util | ||||
|  | @ -12,22 +12,22 @@ from spacy.util import registry | |||
| from spacy.util import fix_random_seed, make_tempdir | ||||
| 
 | ||||
| 
 | ||||
| TRAINING_KEY = "pytest" | ||||
| SPANS_KEY = "pytest" | ||||
| TRAIN_DATA = [ | ||||
|     ("Who is Shaka Khan?", {"spans": {TRAINING_KEY: [(7, 17)]}}), | ||||
|     ("Who is Shaka Khan?", {"spans": {SPANS_KEY: [(7, 17)]}}), | ||||
|     ( | ||||
|         "I like London and Berlin.", | ||||
|         {"spans": {TRAINING_KEY: [(7, 13, "LOC"), (18, 24)]}}, | ||||
|         {"spans": {SPANS_KEY: [(7, 13), (18, 24)]}}, | ||||
|     ), | ||||
| ] | ||||
| 
 | ||||
| TRAIN_DATA_OVERLAPPING = [ | ||||
|     ("Who is Shaka Khan?", {"spans": {TRAINING_KEY: [(7, 17)]}}), | ||||
|     ("Who is Shaka Khan?", {"spans": {SPANS_KEY: [(7, 17)]}}), | ||||
|     ( | ||||
|         "I like London and Berlin", | ||||
|         {"spans": {TRAINING_KEY: [(7, 13), (18, 24), (7, 24)]}}, | ||||
|         {"spans": {SPANS_KEY: [(7, 13), (18, 24), (7, 24)]}}, | ||||
|     ), | ||||
|     ("", {"spans": {TRAINING_KEY: []}}), | ||||
|     ("", {"spans": {SPANS_KEY: []}}), | ||||
| ] | ||||
| 
 | ||||
| 
 | ||||
|  | @ -88,8 +88,8 @@ def test_loss_alignment_example(tokens_predicted, tokens_reference, reference_tr | |||
|         nlp.vocab, words=tokens_reference, spaces=[False] * len(tokens_reference) | ||||
|     ) | ||||
|     example = Example(predicted, reference) | ||||
|     example.reference.spans[TRAINING_KEY] = [example.reference.char_span(5, 9)] | ||||
|     span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY}) | ||||
|     example.reference.spans[SPANS_KEY] = [example.reference.char_span(5, 9)] | ||||
|     span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY}) | ||||
|     nlp.initialize() | ||||
|     ops = span_finder.model.ops | ||||
|     if predicted.text != reference.text: | ||||
|  | @ -107,8 +107,8 @@ def test_span_finder_model(): | |||
|     nlp = Language() | ||||
| 
 | ||||
|     docs = [nlp("This is an example."), nlp("This is the second example.")] | ||||
|     docs[0].spans[TRAINING_KEY] = [docs[0][3:4]] | ||||
|     docs[1].spans[TRAINING_KEY] = [docs[1][3:5]] | ||||
|     docs[0].spans[SPANS_KEY] = [docs[0][3:4]] | ||||
|     docs[1].spans[SPANS_KEY] = [docs[1][3:5]] | ||||
| 
 | ||||
|     total_tokens = 0 | ||||
|     for doc in docs: | ||||
|  | @ -128,15 +128,15 @@ def test_span_finder_component(): | |||
|     nlp = Language() | ||||
| 
 | ||||
|     docs = [nlp("This is an example."), nlp("This is the second example.")] | ||||
|     docs[0].spans[TRAINING_KEY] = [docs[0][3:4]] | ||||
|     docs[1].spans[TRAINING_KEY] = [docs[1][3:5]] | ||||
|     docs[0].spans[SPANS_KEY] = [docs[0][3:4]] | ||||
|     docs[1].spans[SPANS_KEY] = [docs[1][3:5]] | ||||
| 
 | ||||
|     span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY}) | ||||
|     span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY}) | ||||
|     nlp.initialize() | ||||
|     docs = list(span_finder.pipe(docs)) | ||||
| 
 | ||||
|     # TODO: update hard-coded name | ||||
|     assert "span_candidates" in docs[0].spans | ||||
|     assert SPANS_KEY in docs[0].spans | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|  | @ -153,7 +153,7 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count): | |||
|                 config={ | ||||
|                     "max_length": max_length, | ||||
|                     "min_length": min_length, | ||||
|                     "training_key": TRAINING_KEY, | ||||
|                     "spans_key": SPANS_KEY, | ||||
|                 }, | ||||
|             ) | ||||
|         return | ||||
|  | @ -162,7 +162,7 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count): | |||
|         config={ | ||||
|             "max_length": max_length, | ||||
|             "min_length": min_length, | ||||
|             "training_key": TRAINING_KEY, | ||||
|             "spans_key": SPANS_KEY, | ||||
|         }, | ||||
|     ) | ||||
|     nlp.initialize() | ||||
|  | @ -182,8 +182,8 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count): | |||
|     ] | ||||
|     span_finder.set_annotations([doc], scores) | ||||
| 
 | ||||
|     assert doc.spans[DEFAULT_PREDICTED_KEY] | ||||
|     assert len(doc.spans[DEFAULT_PREDICTED_KEY]) == span_count | ||||
|     assert doc.spans[SPANS_KEY] | ||||
|     assert len(doc.spans[SPANS_KEY]) == span_count | ||||
| 
 | ||||
|     # Assert below will fail when max_length is set to 0 | ||||
|     if max_length is None: | ||||
|  | @ -193,40 +193,39 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count): | |||
| 
 | ||||
|     assert all( | ||||
|         min_length <= len(span) <= max_length | ||||
|         for span in doc.spans[DEFAULT_PREDICTED_KEY] | ||||
|         for span in doc.spans[SPANS_KEY] | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| def test_span_finder_suggester(): | ||||
|     nlp = Language() | ||||
|     docs = [nlp("This is an example."), nlp("This is the second example.")] | ||||
|     docs[0].spans[TRAINING_KEY] = [docs[0][3:4]] | ||||
|     docs[1].spans[TRAINING_KEY] = [docs[1][3:5]] | ||||
|     span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY}) | ||||
|     docs[0].spans[SPANS_KEY] = [docs[0][3:4]] | ||||
|     docs[1].spans[SPANS_KEY] = [docs[1][3:5]] | ||||
|     span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY}) | ||||
|     nlp.initialize() | ||||
|     span_finder.set_annotations(docs, span_finder.predict(docs)) | ||||
| 
 | ||||
|     suggester = registry.misc.get("spacy.span_finder_suggester.v1")( | ||||
|         candidates_key="span_candidates" | ||||
|         candidates_key=SPANS_KEY | ||||
|     ) | ||||
| 
 | ||||
|     candidates = suggester(docs) | ||||
| 
 | ||||
|     span_length = 0 | ||||
|     for doc in docs: | ||||
|         span_length += len(doc.spans["span_candidates"]) | ||||
|         span_length += len(doc.spans[SPANS_KEY]) | ||||
| 
 | ||||
|     assert span_length == len(candidates.dataXd) | ||||
|     assert type(candidates) == Ragged | ||||
|     assert len(candidates.dataXd[0]) == 2 | ||||
| 
 | ||||
| 
 | ||||
| # XXX Fails because i think the suggester is not correctly implemented?  | ||||
| def test_overfitting_IO(): | ||||
|     # Simple test to try and quickly overfit the spancat component - ensuring the ML models work correctly | ||||
|     fix_random_seed(0) | ||||
|     nlp = English() | ||||
|     span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY}) | ||||
|     span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY}) | ||||
|     train_examples = make_examples(nlp) | ||||
|     optimizer = nlp.initialize(get_examples=lambda: train_examples) | ||||
|     assert span_finder.model.get_dim("nO") == 2 | ||||
|  | @ -239,30 +238,27 @@ def test_overfitting_IO(): | |||
|     # test the trained model | ||||
|     test_text = "I like London and Berlin" | ||||
|     doc = nlp(test_text) | ||||
|     spans = doc.spans[span_finder.predicted_key] | ||||
|     assert len(spans) == 2 | ||||
|     assert len(spans.attrs["scores"]) == 2 | ||||
|     assert min(spans.attrs["scores"]) > 0.9 | ||||
|     assert set([span.text for span in spans]) == {"London", "Berlin"} | ||||
|     spans = doc.spans[span_finder.spans_key] | ||||
|     assert len(spans) == 3 | ||||
|     assert set([span.text for span in spans]) == {"London", "Berlin", "London and Berlin"} | ||||
| 
 | ||||
|     # Also test the results are still the same after IO | ||||
|     with make_tempdir() as tmp_dir: | ||||
|         nlp.to_disk(tmp_dir) | ||||
|         nlp2 = util.load_model_from_path(tmp_dir) | ||||
|         doc2 = nlp2(test_text) | ||||
|         spans2 = doc2.spans[TRAINING_KEY] | ||||
|         assert len(spans2) == 2 | ||||
|         assert len(spans2.attrs["scores"]) == 2 | ||||
|         assert min(spans2.attrs["scores"]) > 0.9 | ||||
|         assert set([span.text for span in spans2]) == {"London", "Berlin"} | ||||
|         spans2 = doc2.spans[span_finder.spans_key] | ||||
|         assert len(spans2) == 3 | ||||
|         assert set([span.text for span in spans2]) == {"London", "Berlin", "London and Berlin"} | ||||
| 
 | ||||
|     # Test scoring | ||||
|     scores = nlp.evaluate(train_examples) | ||||
|     assert f"spans_{TRAINING_KEY}_f" in scores | ||||
|     assert scores[f"spans_{TRAINING_KEY}_p"] == 1.0 | ||||
|     assert scores[f"spans_{TRAINING_KEY}_r"] == 1.0 | ||||
|     assert scores[f"spans_{TRAINING_KEY}_f"] == 1.0 | ||||
|     sf = nlp.get_pipe("span_finder") | ||||
|     print(sf.spans_key) | ||||
|     assert f"span_finder_{span_finder.spans_key}_f" in scores | ||||
|     # XXX Its not perfect 1.0 F1 because we want it to overgenerate for now. | ||||
|     assert scores[f"span_finder_{span_finder.spans_key}_f"] == 0.4 | ||||
| 
 | ||||
|     # also test that the spancat works for just a single entity in a sentence | ||||
|     doc = nlp("London") | ||||
|     assert len(doc.spans[span_finder.predicted_key]) == 1 | ||||
|     assert len(doc.spans[span_finder.spans_key]) == 1 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user