mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Save span candidates produced by spancat suggesters (#10413)
* Add save_candidates attribute * Change spancat api * Add unit test * reimplement method to produce a list of doc * Add method to docs * Add new version tag * Add intended use to docstring * prettier formatting
This commit is contained in:
parent
b68bf43f5b
commit
2eef47dd26
|
@ -272,6 +272,24 @@ class SpanCategorizer(TrainablePipe):
|
|||
scores = self.model.predict((docs, indices)) # type: ignore
|
||||
return indices, scores
|
||||
|
||||
def set_candidates(
|
||||
self, docs: Iterable[Doc], *, candidates_key: str = "candidates"
|
||||
) -> None:
|
||||
"""Use the spancat suggester to add a list of span candidates to a list of docs.
|
||||
This method is intended to be used for debugging purposes.
|
||||
|
||||
docs (Iterable[Doc]): The documents to modify.
|
||||
candidates_key (str): Key of the Doc.spans dict to save the candidate spans under.
|
||||
|
||||
DOCS: https://spacy.io/api/spancategorizer#set_candidates
|
||||
"""
|
||||
suggester_output = self.suggester(docs, ops=self.model.ops)
|
||||
|
||||
for candidates, doc in zip(suggester_output, docs): # type: ignore
|
||||
doc.spans[candidates_key] = []
|
||||
for index in candidates.dataXd:
|
||||
doc.spans[candidates_key].append(doc[index[0] : index[1]])
|
||||
|
||||
def set_annotations(self, docs: Iterable[Doc], indices_scores) -> None:
|
||||
"""Modify a batch of Doc objects, using pre-computed scores.
|
||||
|
||||
|
|
|
@ -397,3 +397,25 @@ def test_zero_suggestions():
|
|||
assert set(spancat.labels) == {"LOC", "PERSON"}
|
||||
|
||||
nlp.update(train_examples, sgd=optimizer)
|
||||
|
||||
|
||||
def test_set_candidates():
|
||||
nlp = Language()
|
||||
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
|
||||
train_examples = make_examples(nlp)
|
||||
nlp.initialize(get_examples=lambda: train_examples)
|
||||
texts = [
|
||||
"Just a sentence.",
|
||||
"I like London and Berlin",
|
||||
"I like Berlin",
|
||||
"I eat ham.",
|
||||
]
|
||||
|
||||
docs = [nlp(text) for text in texts]
|
||||
spancat.set_candidates(docs)
|
||||
|
||||
assert len(docs) == len(texts)
|
||||
assert type(docs[0].spans["candidates"]) == SpanGroup
|
||||
assert len(docs[0].spans["candidates"]) == 9
|
||||
assert docs[0].spans["candidates"][0].text == "Just"
|
||||
assert docs[0].spans["candidates"][4].text == "Just a"
|
||||
|
|
|
@ -239,6 +239,24 @@ Delegates to [`predict`](/api/spancategorizer#predict) and
|
|||
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
|
||||
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
|
||||
|
||||
## SpanCategorizer.set_candidates {#set_candidates tag="method", new="3.3"}
|
||||
|
||||
Use the suggester to add a list of [`Span`](/api/span) candidates to a list of
|
||||
[`Doc`](/api/doc) objects. This method is intended to be used for debugging
|
||||
purposes.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> spancat = nlp.add_pipe("spancat")
|
||||
> spancat.set_candidates(docs, "candidates")
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ---------------- | -------------------------------------------------------------------- |
|
||||
| `docs` | The documents to modify. ~~Iterable[Doc]~~ |
|
||||
| `candidates_key` | Key of the Doc.spans dict to save the candidate spans under. ~~str~~ |
|
||||
|
||||
## SpanCategorizer.get_loss {#get_loss tag="method"}
|
||||
|
||||
Find the loss and gradient of loss for the batch of documents and their
|
||||
|
|
Loading…
Reference in New Issue
Block a user