Save span candidates produced by spancat suggesters (#10413)

* Add save_candidates attribute

* Change spancat api

* Add unit test

* reimplement method to produce a list of doc

* Add method to docs

* Add new version tag

* Add intended use to docstring

* prettier formatting
This commit is contained in:
Edward 2022-03-14 16:46:58 +01:00 committed by GitHub
parent b68bf43f5b
commit 2eef47dd26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 58 additions and 0 deletions

View File

@ -272,6 +272,24 @@ class SpanCategorizer(TrainablePipe):
scores = self.model.predict((docs, indices)) # type: ignore
return indices, scores
def set_candidates(
self, docs: Iterable[Doc], *, candidates_key: str = "candidates"
) -> None:
"""Use the spancat suggester to add a list of span candidates to a list of docs.
This method is intended to be used for debugging purposes.
docs (Iterable[Doc]): The documents to modify.
candidates_key (str): Key of the Doc.spans dict to save the candidate spans under.
DOCS: https://spacy.io/api/spancategorizer#set_candidates
"""
suggester_output = self.suggester(docs, ops=self.model.ops)
for candidates, doc in zip(suggester_output, docs): # type: ignore
doc.spans[candidates_key] = []
for index in candidates.dataXd:
doc.spans[candidates_key].append(doc[index[0] : index[1]])
def set_annotations(self, docs: Iterable[Doc], indices_scores) -> None:
"""Modify a batch of Doc objects, using pre-computed scores.

View File

@ -397,3 +397,25 @@ def test_zero_suggestions():
assert set(spancat.labels) == {"LOC", "PERSON"}
nlp.update(train_examples, sgd=optimizer)
def test_set_candidates():
nlp = Language()
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
train_examples = make_examples(nlp)
nlp.initialize(get_examples=lambda: train_examples)
texts = [
"Just a sentence.",
"I like London and Berlin",
"I like Berlin",
"I eat ham.",
]
docs = [nlp(text) for text in texts]
spancat.set_candidates(docs)
assert len(docs) == len(texts)
assert type(docs[0].spans["candidates"]) == SpanGroup
assert len(docs[0].spans["candidates"]) == 9
assert docs[0].spans["candidates"][0].text == "Just"
assert docs[0].spans["candidates"][4].text == "Just a"

View File

@ -239,6 +239,24 @@ Delegates to [`predict`](/api/spancategorizer#predict) and
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
## SpanCategorizer.set_candidates {#set_candidates tag="method", new="3.3"}
Use the suggester to add a list of [`Span`](/api/span) candidates to a list of
[`Doc`](/api/doc) objects. This method is intended to be used for debugging
purposes.
> #### Example
>
> ```python
> spancat = nlp.add_pipe("spancat")
> spancat.set_candidates(docs, "candidates")
> ```
| Name | Description |
| ---------------- | -------------------------------------------------------------------- |
| `docs` | The documents to modify. ~~Iterable[Doc]~~ |
| `candidates_key` | Key of the Doc.spans dict to save the candidate spans under. ~~str~~ |
## SpanCategorizer.get_loss {#get_loss tag="method"}
Find the loss and gradient of loss for the batch of documents and their