mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	only use a single spans_key like in spancat
This commit is contained in:
		
							parent
							
								
									90af16af76
								
							
						
					
					
						commit
						6f750d0da6
					
				|  | @ -1,4 +1,3 @@ | ||||||
| from functools import partial |  | ||||||
| from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, cast | from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, cast | ||||||
| 
 | 
 | ||||||
| from thinc.api import Config, Model, Ops, Optimizer, get_current_ops, set_dropout_rate | from thinc.api import Config, Model, Ops, Optimizer, get_current_ops, set_dropout_rate | ||||||
|  | @ -41,7 +40,6 @@ depth = 4 | ||||||
| """ | """ | ||||||
| 
 | 
 | ||||||
| DEFAULT_SPAN_FINDER_MODEL = Config().from_str(span_finder_default_config)["model"] | DEFAULT_SPAN_FINDER_MODEL = Config().from_str(span_finder_default_config)["model"] | ||||||
| DEFAULT_PREDICTED_KEY = "span_candidates" |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @Language.factory( | @Language.factory( | ||||||
|  | @ -50,21 +48,15 @@ DEFAULT_PREDICTED_KEY = "span_candidates" | ||||||
|     default_config={ |     default_config={ | ||||||
|         "threshold": 0.5, |         "threshold": 0.5, | ||||||
|         "model": DEFAULT_SPAN_FINDER_MODEL, |         "model": DEFAULT_SPAN_FINDER_MODEL, | ||||||
|         "predicted_key": DEFAULT_PREDICTED_KEY, |         "spans_key": DEFAULT_SPANS_KEY, | ||||||
|         "training_key": DEFAULT_SPANS_KEY, |  | ||||||
|         # XXX Doesn't 0 seem bad compared to None instead? |  | ||||||
|         "max_length": None, |         "max_length": None, | ||||||
|         "min_length": None, |         "min_length": None, | ||||||
|         "scorer": { |         "scorer": {"@scorers": "spacy.span_finder_scorer.v1"}, | ||||||
|             "@scorers": "spacy.span_finder_scorer.v1", |  | ||||||
|             "predicted_key": DEFAULT_PREDICTED_KEY, |  | ||||||
|             "training_key": DEFAULT_SPANS_KEY, |  | ||||||
|         }, |  | ||||||
|     }, |     }, | ||||||
|     default_score_weights={ |     default_score_weights={ | ||||||
|         f"span_finder_{DEFAULT_PREDICTED_KEY}_f": 1.0, |         f"span_finder_{DEFAULT_SPANS_KEY}_f": 1.0, | ||||||
|         f"span_finder_{DEFAULT_PREDICTED_KEY}_p": 0.0, |         f"span_finder_{DEFAULT_SPANS_KEY}_p": 0.0, | ||||||
|         f"span_finder_{DEFAULT_PREDICTED_KEY}_r": 0.0, |         f"span_finder_{DEFAULT_SPANS_KEY}_r": 0.0, | ||||||
|     }, |     }, | ||||||
| ) | ) | ||||||
| def make_span_finder( | def make_span_finder( | ||||||
|  | @ -75,8 +67,7 @@ def make_span_finder( | ||||||
|     threshold: float, |     threshold: float, | ||||||
|     max_length: Optional[int], |     max_length: Optional[int], | ||||||
|     min_length: Optional[int], |     min_length: Optional[int], | ||||||
|     predicted_key: str = DEFAULT_PREDICTED_KEY, |     spans_key: str, | ||||||
|     training_key: str = DEFAULT_SPANS_KEY, |  | ||||||
| ) -> "SpanFinder": | ) -> "SpanFinder": | ||||||
|     """Create a SpanFinder component. The component predicts whether a token is |     """Create a SpanFinder component. The component predicts whether a token is | ||||||
|     the start or the end of a potential span. |     the start or the end of a potential span. | ||||||
|  | @ -84,10 +75,9 @@ def make_span_finder( | ||||||
|     model (Model[List[Doc], Floats2d]): A model instance that |     model (Model[List[Doc], Floats2d]): A model instance that | ||||||
|         is given a list of documents and predicts a probability for each token. |         is given a list of documents and predicts a probability for each token. | ||||||
|     threshold (float): Minimum probability to consider a prediction positive. |     threshold (float): Minimum probability to consider a prediction positive. | ||||||
|     predicted_key (str): Name of the span group the predicted spans are saved |     spans_key (str): Key of the doc.spans dict to save the spans under. During | ||||||
|         to |         initialization and training, the component will look for spans on the | ||||||
|     training_key (str): Name of the span group the training spans are read |         reference document under the same key. | ||||||
|         from |  | ||||||
|     max_length (Optional[int]): Max length of the produced spans, defaults to None meaning unlimited length. |     max_length (Optional[int]): Max length of the produced spans, defaults to None meaning unlimited length. | ||||||
|     min_length (Optional[int]): Min length of the produced spans, defaults to None meaining shortest span is length 1. |     min_length (Optional[int]): Min length of the produced spans, defaults to None meaining shortest span is length 1. | ||||||
|     """ |     """ | ||||||
|  | @ -99,51 +89,26 @@ def make_span_finder( | ||||||
|         scorer=scorer, |         scorer=scorer, | ||||||
|         max_length=max_length, |         max_length=max_length, | ||||||
|         min_length=min_length, |         min_length=min_length, | ||||||
|         predicted_key=predicted_key, |         spans_key=spans_key, | ||||||
|         training_key=training_key, |  | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @registry.scorers("spacy.span_finder_scorer.v1") | @registry.scorers("spacy.span_finder_scorer.v1") | ||||||
| def make_span_finder_scorer( | def make_span_finder_scorer(): | ||||||
|     predicted_key: str = DEFAULT_PREDICTED_KEY, |     return span_finder_score | ||||||
|     training_key: str = DEFAULT_SPANS_KEY, |  | ||||||
| ): |  | ||||||
|     return partial( |  | ||||||
|         span_finder_score, predicted_key=predicted_key, training_key=training_key |  | ||||||
|     ) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def span_finder_score( | def span_finder_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: | ||||||
|     examples: Iterable[Example], |  | ||||||
|     *, |  | ||||||
|     predicted_key: str = DEFAULT_PREDICTED_KEY, |  | ||||||
|     training_key: str = DEFAULT_SPANS_KEY, |  | ||||||
|     **kwargs, |  | ||||||
| ) -> Dict[str, Any]: |  | ||||||
|     kwargs = dict(kwargs) |     kwargs = dict(kwargs) | ||||||
|  |     print(kwargs) | ||||||
|     attr_prefix = "span_finder_" |     attr_prefix = "span_finder_" | ||||||
|     kwargs.setdefault("attr", f"{attr_prefix}{predicted_key}") |     key = kwargs["spans_key"] | ||||||
|     kwargs.setdefault("allow_overlap", True) |     kwargs.setdefault("attr", f"{attr_prefix}{key}") | ||||||
|     kwargs.setdefault( |     kwargs.setdefault( | ||||||
|         "getter", lambda doc, key: doc.spans.get(key[len(attr_prefix) :], []) |         "getter", lambda doc, key: doc.spans.get(key[len(attr_prefix) :], []) | ||||||
|     ) |     ) | ||||||
|     kwargs.setdefault("labeled", False) |     kwargs.setdefault("has_annotation", lambda doc: key in doc.spans) | ||||||
|     kwargs.setdefault("has_annotation", lambda doc: predicted_key in doc.spans) |     return Scorer.score_spans(examples, **kwargs) | ||||||
|     # score_spans can only score spans with the same key in both the reference |  | ||||||
|     # and predicted docs, so temporarily copy the reference spans from the |  | ||||||
|     # reference key to the candidates key in the reference docs, restoring the |  | ||||||
|     # original span groups afterwards |  | ||||||
|     orig_span_groups = [] |  | ||||||
|     for eg in examples: |  | ||||||
|         orig_span_groups.append(eg.reference.spans.get(predicted_key)) |  | ||||||
|         if training_key in eg.reference.spans: |  | ||||||
|             eg.reference.spans[predicted_key] = eg.reference.spans[training_key] |  | ||||||
|     scores = Scorer.score_spans(examples, **kwargs) |  | ||||||
|     for orig_span_group, eg in zip(orig_span_groups, examples): |  | ||||||
|         if orig_span_group is not None: |  | ||||||
|             eg.reference.spans[predicted_key] = orig_span_group |  | ||||||
|     return scores |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class _MaxInt(int): | class _MaxInt(int): | ||||||
|  | @ -179,13 +144,8 @@ class SpanFinder(TrainablePipe): | ||||||
|         max_length: Optional[int] = None, |         max_length: Optional[int] = None, | ||||||
|         min_length: Optional[int] = None, |         min_length: Optional[int] = None, | ||||||
|         # XXX I think this is weird and should be just None like in |         # XXX I think this is weird and should be just None like in | ||||||
|         scorer: Optional[Callable] = partial( |         scorer: Optional[Callable] = span_finder_score, | ||||||
|             span_finder_score, |         spans_key: str = DEFAULT_SPANS_KEY, | ||||||
|             predicted_key=DEFAULT_PREDICTED_KEY, |  | ||||||
|             training_key=DEFAULT_SPANS_KEY, |  | ||||||
|         ), |  | ||||||
|         predicted_key: str = DEFAULT_PREDICTED_KEY, |  | ||||||
|         training_key: str = DEFAULT_SPANS_KEY, |  | ||||||
|     ) -> None: |     ) -> None: | ||||||
|         """Initialize the span boundary detector. |         """Initialize the span boundary detector. | ||||||
|         model (thinc.api.Model): The Thinc Model powering the pipeline component. |         model (thinc.api.Model): The Thinc Model powering the pipeline component. | ||||||
|  | @ -194,8 +154,9 @@ class SpanFinder(TrainablePipe): | ||||||
|         threshold (float): Minimum probability to consider a prediction |         threshold (float): Minimum probability to consider a prediction | ||||||
|             positive. |             positive. | ||||||
|         scorer (Optional[Callable]): The scoring method. |         scorer (Optional[Callable]): The scoring method. | ||||||
|         predicted_key (str): Name of the span group the candidate spans are saved to |         spans_key (str): Key of the doc.spans dict to save the spans under. During | ||||||
|         training_key (str): Name of the span group the training spans are read from |             initialization and training, the component will look for spans on the | ||||||
|  |             reference document under the same key. | ||||||
|         max_length (Optional[int]): Max length of the produced spans, defaults to None meaning unlimited length. |         max_length (Optional[int]): Max length of the produced spans, defaults to None meaning unlimited length. | ||||||
|         min_length (Optional[int]): Min length of the produced spans, defaults to None meaining shortest span is length 1. |         min_length (Optional[int]): Min length of the produced spans, defaults to None meaining shortest span is length 1. | ||||||
|         """ |         """ | ||||||
|  | @ -211,11 +172,11 @@ class SpanFinder(TrainablePipe): | ||||||
|             ) |             ) | ||||||
|         self.min_length = min_length |         self.min_length = min_length | ||||||
|         self.max_length = max_length |         self.max_length = max_length | ||||||
|         self.predicted_key = predicted_key |         self.spans_key = spans_key | ||||||
|         self.training_key = training_key |  | ||||||
|         self.model = model |         self.model = model | ||||||
|         self.name = name |         self.name = name | ||||||
|         self.scorer = scorer |         self.scorer = scorer | ||||||
|  |         self.cfg = {"spans_key": spans_key} | ||||||
| 
 | 
 | ||||||
|     def predict(self, docs: Iterable[Doc]): |     def predict(self, docs: Iterable[Doc]): | ||||||
|         """Apply the pipeline's model to a batch of docs, without modifying them. |         """Apply the pipeline's model to a batch of docs, without modifying them. | ||||||
|  | @ -232,7 +193,7 @@ class SpanFinder(TrainablePipe): | ||||||
|         """ |         """ | ||||||
|         offset = 0 |         offset = 0 | ||||||
|         for i, doc in enumerate(docs): |         for i, doc in enumerate(docs): | ||||||
|             doc.spans[self.predicted_key] = [] |             doc.spans[self.spans_key] = [] | ||||||
|             starts = [] |             starts = [] | ||||||
|             ends = [] |             ends = [] | ||||||
|             doc_scores = scores[offset : offset + len(doc)] |             doc_scores = scores[offset : offset + len(doc)] | ||||||
|  | @ -249,7 +210,7 @@ class SpanFinder(TrainablePipe): | ||||||
|                     if span_length > self.max_length: |                     if span_length > self.max_length: | ||||||
|                         break |                         break | ||||||
|                     elif self.min_length <= span_length: |                     elif self.min_length <= span_length: | ||||||
|                         doc.spans[self.predicted_key].append(doc[start : end + 1]) |                         doc.spans[self.spans_key].append(doc[start : end + 1]) | ||||||
| 
 | 
 | ||||||
|     def update( |     def update( | ||||||
|         self, |         self, | ||||||
|  | @ -304,8 +265,8 @@ class SpanFinder(TrainablePipe): | ||||||
|             n_tokens = len(eg.predicted) |             n_tokens = len(eg.predicted) | ||||||
|             truth = ops.xp.zeros((n_tokens, 2), dtype="float32") |             truth = ops.xp.zeros((n_tokens, 2), dtype="float32") | ||||||
|             mask = ops.xp.ones((n_tokens, 2), dtype="float32") |             mask = ops.xp.ones((n_tokens, 2), dtype="float32") | ||||||
|             if self.training_key in eg.reference.spans: |             if self.spans_key in eg.reference.spans: | ||||||
|                 for span in eg.reference.spans[self.training_key]: |                 for span in eg.reference.spans[self.spans_key]: | ||||||
|                     ref_start_char, ref_end_char = _char_indices(span) |                     ref_start_char, ref_end_char = _char_indices(span) | ||||||
|                     pred_span = eg.predicted.char_span( |                     pred_span = eg.predicted.char_span( | ||||||
|                         ref_start_char, ref_end_char, alignment_mode="expand" |                         ref_start_char, ref_end_char, alignment_mode="expand" | ||||||
|  | @ -342,8 +303,8 @@ class SpanFinder(TrainablePipe): | ||||||
|             start_indices = set() |             start_indices = set() | ||||||
|             end_indices = set() |             end_indices = set() | ||||||
| 
 | 
 | ||||||
|             if self.training_key in doc.spans: |             if self.spans_key in doc.spans: | ||||||
|                 for span in doc.spans[self.training_key]: |                 for span in doc.spans[self.spans_key]: | ||||||
|                     start_indices.add(span.start) |                     start_indices.add(span.start) | ||||||
|                     end_indices.add(span.end - 1) |                     end_indices.add(span.end - 1) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -4,7 +4,7 @@ from thinc.types import Ragged | ||||||
| 
 | 
 | ||||||
| from spacy.language import Language | from spacy.language import Language | ||||||
| from spacy.lang.en import English | from spacy.lang.en import English | ||||||
| from spacy.pipeline.span_finder import DEFAULT_PREDICTED_KEY, span_finder_default_config | from spacy.pipeline.span_finder import span_finder_default_config | ||||||
| from spacy.tokens import Doc | from spacy.tokens import Doc | ||||||
| from spacy.training import Example | from spacy.training import Example | ||||||
| from spacy import util | from spacy import util | ||||||
|  | @ -12,22 +12,22 @@ from spacy.util import registry | ||||||
| from spacy.util import fix_random_seed, make_tempdir | from spacy.util import fix_random_seed, make_tempdir | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| TRAINING_KEY = "pytest" | SPANS_KEY = "pytest" | ||||||
| TRAIN_DATA = [ | TRAIN_DATA = [ | ||||||
|     ("Who is Shaka Khan?", {"spans": {TRAINING_KEY: [(7, 17)]}}), |     ("Who is Shaka Khan?", {"spans": {SPANS_KEY: [(7, 17)]}}), | ||||||
|     ( |     ( | ||||||
|         "I like London and Berlin.", |         "I like London and Berlin.", | ||||||
|         {"spans": {TRAINING_KEY: [(7, 13, "LOC"), (18, 24)]}}, |         {"spans": {SPANS_KEY: [(7, 13), (18, 24)]}}, | ||||||
|     ), |     ), | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
| TRAIN_DATA_OVERLAPPING = [ | TRAIN_DATA_OVERLAPPING = [ | ||||||
|     ("Who is Shaka Khan?", {"spans": {TRAINING_KEY: [(7, 17)]}}), |     ("Who is Shaka Khan?", {"spans": {SPANS_KEY: [(7, 17)]}}), | ||||||
|     ( |     ( | ||||||
|         "I like London and Berlin", |         "I like London and Berlin", | ||||||
|         {"spans": {TRAINING_KEY: [(7, 13), (18, 24), (7, 24)]}}, |         {"spans": {SPANS_KEY: [(7, 13), (18, 24), (7, 24)]}}, | ||||||
|     ), |     ), | ||||||
|     ("", {"spans": {TRAINING_KEY: []}}), |     ("", {"spans": {SPANS_KEY: []}}), | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -88,8 +88,8 @@ def test_loss_alignment_example(tokens_predicted, tokens_reference, reference_tr | ||||||
|         nlp.vocab, words=tokens_reference, spaces=[False] * len(tokens_reference) |         nlp.vocab, words=tokens_reference, spaces=[False] * len(tokens_reference) | ||||||
|     ) |     ) | ||||||
|     example = Example(predicted, reference) |     example = Example(predicted, reference) | ||||||
|     example.reference.spans[TRAINING_KEY] = [example.reference.char_span(5, 9)] |     example.reference.spans[SPANS_KEY] = [example.reference.char_span(5, 9)] | ||||||
|     span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY}) |     span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY}) | ||||||
|     nlp.initialize() |     nlp.initialize() | ||||||
|     ops = span_finder.model.ops |     ops = span_finder.model.ops | ||||||
|     if predicted.text != reference.text: |     if predicted.text != reference.text: | ||||||
|  | @ -107,8 +107,8 @@ def test_span_finder_model(): | ||||||
|     nlp = Language() |     nlp = Language() | ||||||
| 
 | 
 | ||||||
|     docs = [nlp("This is an example."), nlp("This is the second example.")] |     docs = [nlp("This is an example."), nlp("This is the second example.")] | ||||||
|     docs[0].spans[TRAINING_KEY] = [docs[0][3:4]] |     docs[0].spans[SPANS_KEY] = [docs[0][3:4]] | ||||||
|     docs[1].spans[TRAINING_KEY] = [docs[1][3:5]] |     docs[1].spans[SPANS_KEY] = [docs[1][3:5]] | ||||||
| 
 | 
 | ||||||
|     total_tokens = 0 |     total_tokens = 0 | ||||||
|     for doc in docs: |     for doc in docs: | ||||||
|  | @ -128,15 +128,15 @@ def test_span_finder_component(): | ||||||
|     nlp = Language() |     nlp = Language() | ||||||
| 
 | 
 | ||||||
|     docs = [nlp("This is an example."), nlp("This is the second example.")] |     docs = [nlp("This is an example."), nlp("This is the second example.")] | ||||||
|     docs[0].spans[TRAINING_KEY] = [docs[0][3:4]] |     docs[0].spans[SPANS_KEY] = [docs[0][3:4]] | ||||||
|     docs[1].spans[TRAINING_KEY] = [docs[1][3:5]] |     docs[1].spans[SPANS_KEY] = [docs[1][3:5]] | ||||||
| 
 | 
 | ||||||
|     span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY}) |     span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY}) | ||||||
|     nlp.initialize() |     nlp.initialize() | ||||||
|     docs = list(span_finder.pipe(docs)) |     docs = list(span_finder.pipe(docs)) | ||||||
| 
 | 
 | ||||||
|     # TODO: update hard-coded name |     # TODO: update hard-coded name | ||||||
|     assert "span_candidates" in docs[0].spans |     assert SPANS_KEY in docs[0].spans | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||||
|  | @ -153,7 +153,7 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count): | ||||||
|                 config={ |                 config={ | ||||||
|                     "max_length": max_length, |                     "max_length": max_length, | ||||||
|                     "min_length": min_length, |                     "min_length": min_length, | ||||||
|                     "training_key": TRAINING_KEY, |                     "spans_key": SPANS_KEY, | ||||||
|                 }, |                 }, | ||||||
|             ) |             ) | ||||||
|         return |         return | ||||||
|  | @ -162,7 +162,7 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count): | ||||||
|         config={ |         config={ | ||||||
|             "max_length": max_length, |             "max_length": max_length, | ||||||
|             "min_length": min_length, |             "min_length": min_length, | ||||||
|             "training_key": TRAINING_KEY, |             "spans_key": SPANS_KEY, | ||||||
|         }, |         }, | ||||||
|     ) |     ) | ||||||
|     nlp.initialize() |     nlp.initialize() | ||||||
|  | @ -182,8 +182,8 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count): | ||||||
|     ] |     ] | ||||||
|     span_finder.set_annotations([doc], scores) |     span_finder.set_annotations([doc], scores) | ||||||
| 
 | 
 | ||||||
|     assert doc.spans[DEFAULT_PREDICTED_KEY] |     assert doc.spans[SPANS_KEY] | ||||||
|     assert len(doc.spans[DEFAULT_PREDICTED_KEY]) == span_count |     assert len(doc.spans[SPANS_KEY]) == span_count | ||||||
| 
 | 
 | ||||||
|     # Assert below will fail when max_length is set to 0 |     # Assert below will fail when max_length is set to 0 | ||||||
|     if max_length is None: |     if max_length is None: | ||||||
|  | @ -193,40 +193,39 @@ def test_set_annotations_span_lengths(min_length, max_length, span_count): | ||||||
| 
 | 
 | ||||||
|     assert all( |     assert all( | ||||||
|         min_length <= len(span) <= max_length |         min_length <= len(span) <= max_length | ||||||
|         for span in doc.spans[DEFAULT_PREDICTED_KEY] |         for span in doc.spans[SPANS_KEY] | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_span_finder_suggester(): | def test_span_finder_suggester(): | ||||||
|     nlp = Language() |     nlp = Language() | ||||||
|     docs = [nlp("This is an example."), nlp("This is the second example.")] |     docs = [nlp("This is an example."), nlp("This is the second example.")] | ||||||
|     docs[0].spans[TRAINING_KEY] = [docs[0][3:4]] |     docs[0].spans[SPANS_KEY] = [docs[0][3:4]] | ||||||
|     docs[1].spans[TRAINING_KEY] = [docs[1][3:5]] |     docs[1].spans[SPANS_KEY] = [docs[1][3:5]] | ||||||
|     span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY}) |     span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY}) | ||||||
|     nlp.initialize() |     nlp.initialize() | ||||||
|     span_finder.set_annotations(docs, span_finder.predict(docs)) |     span_finder.set_annotations(docs, span_finder.predict(docs)) | ||||||
| 
 | 
 | ||||||
|     suggester = registry.misc.get("spacy.span_finder_suggester.v1")( |     suggester = registry.misc.get("spacy.span_finder_suggester.v1")( | ||||||
|         candidates_key="span_candidates" |         candidates_key=SPANS_KEY | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|     candidates = suggester(docs) |     candidates = suggester(docs) | ||||||
| 
 | 
 | ||||||
|     span_length = 0 |     span_length = 0 | ||||||
|     for doc in docs: |     for doc in docs: | ||||||
|         span_length += len(doc.spans["span_candidates"]) |         span_length += len(doc.spans[SPANS_KEY]) | ||||||
| 
 | 
 | ||||||
|     assert span_length == len(candidates.dataXd) |     assert span_length == len(candidates.dataXd) | ||||||
|     assert type(candidates) == Ragged |     assert type(candidates) == Ragged | ||||||
|     assert len(candidates.dataXd[0]) == 2 |     assert len(candidates.dataXd[0]) == 2 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| # XXX Fails because i think the suggester is not correctly implemented?  |  | ||||||
| def test_overfitting_IO(): | def test_overfitting_IO(): | ||||||
|     # Simple test to try and quickly overfit the spancat component - ensuring the ML models work correctly |     # Simple test to try and quickly overfit the spancat component - ensuring the ML models work correctly | ||||||
|     fix_random_seed(0) |     fix_random_seed(0) | ||||||
|     nlp = English() |     nlp = English() | ||||||
|     span_finder = nlp.add_pipe("span_finder", config={"training_key": TRAINING_KEY}) |     span_finder = nlp.add_pipe("span_finder", config={"spans_key": SPANS_KEY}) | ||||||
|     train_examples = make_examples(nlp) |     train_examples = make_examples(nlp) | ||||||
|     optimizer = nlp.initialize(get_examples=lambda: train_examples) |     optimizer = nlp.initialize(get_examples=lambda: train_examples) | ||||||
|     assert span_finder.model.get_dim("nO") == 2 |     assert span_finder.model.get_dim("nO") == 2 | ||||||
|  | @ -239,30 +238,27 @@ def test_overfitting_IO(): | ||||||
|     # test the trained model |     # test the trained model | ||||||
|     test_text = "I like London and Berlin" |     test_text = "I like London and Berlin" | ||||||
|     doc = nlp(test_text) |     doc = nlp(test_text) | ||||||
|     spans = doc.spans[span_finder.predicted_key] |     spans = doc.spans[span_finder.spans_key] | ||||||
|     assert len(spans) == 2 |     assert len(spans) == 3 | ||||||
|     assert len(spans.attrs["scores"]) == 2 |     assert set([span.text for span in spans]) == {"London", "Berlin", "London and Berlin"} | ||||||
|     assert min(spans.attrs["scores"]) > 0.9 |  | ||||||
|     assert set([span.text for span in spans]) == {"London", "Berlin"} |  | ||||||
| 
 | 
 | ||||||
|     # Also test the results are still the same after IO |     # Also test the results are still the same after IO | ||||||
|     with make_tempdir() as tmp_dir: |     with make_tempdir() as tmp_dir: | ||||||
|         nlp.to_disk(tmp_dir) |         nlp.to_disk(tmp_dir) | ||||||
|         nlp2 = util.load_model_from_path(tmp_dir) |         nlp2 = util.load_model_from_path(tmp_dir) | ||||||
|         doc2 = nlp2(test_text) |         doc2 = nlp2(test_text) | ||||||
|         spans2 = doc2.spans[TRAINING_KEY] |         spans2 = doc2.spans[span_finder.spans_key] | ||||||
|         assert len(spans2) == 2 |         assert len(spans2) == 3 | ||||||
|         assert len(spans2.attrs["scores"]) == 2 |         assert set([span.text for span in spans2]) == {"London", "Berlin", "London and Berlin"} | ||||||
|         assert min(spans2.attrs["scores"]) > 0.9 |  | ||||||
|         assert set([span.text for span in spans2]) == {"London", "Berlin"} |  | ||||||
| 
 | 
 | ||||||
|     # Test scoring |     # Test scoring | ||||||
|     scores = nlp.evaluate(train_examples) |     scores = nlp.evaluate(train_examples) | ||||||
|     assert f"spans_{TRAINING_KEY}_f" in scores |     sf = nlp.get_pipe("span_finder") | ||||||
|     assert scores[f"spans_{TRAINING_KEY}_p"] == 1.0 |     print(sf.spans_key) | ||||||
|     assert scores[f"spans_{TRAINING_KEY}_r"] == 1.0 |     assert f"span_finder_{span_finder.spans_key}_f" in scores | ||||||
|     assert scores[f"spans_{TRAINING_KEY}_f"] == 1.0 |     # XXX Its not perfect 1.0 F1 because we want it to overgenerate for now. | ||||||
|  |     assert scores[f"span_finder_{span_finder.spans_key}_f"] == 0.4 | ||||||
| 
 | 
 | ||||||
|     # also test that the spancat works for just a single entity in a sentence |     # also test that the spancat works for just a single entity in a sentence | ||||||
|     doc = nlp("London") |     doc = nlp("London") | ||||||
|     assert len(doc.spans[span_finder.predicted_key]) == 1 |     assert len(doc.spans[span_finder.spans_key]) == 1 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user