make it clear that the span_finder_suggester is more general (not specific to span_finder)

This commit is contained in:
kadarakos 2023-06-02 09:11:53 +00:00
parent 37c4ad5007
commit 752b3066cf
3 changed files with 37 additions and 37 deletions

View File

@ -1,7 +1,7 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, cast
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
from thinc.api import Config, Model, Ops, Optimizer, get_current_ops, set_dropout_rate
from thinc.types import Floats2d, Ints1d, Ragged
from thinc.api import Config, Model, Optimizer, set_dropout_rate
from thinc.types import Floats2d
from spacy.language import Language
from spacy.pipeline.trainable_pipe import TrainablePipe
@ -11,7 +11,7 @@ from spacy.training import Example
from spacy.errors import Errors
from ..util import registry
from .spancat import DEFAULT_SPANS_KEY, Suggester
from .spancat import DEFAULT_SPANS_KEY
span_finder_default_config = """
[model]
@ -310,34 +310,3 @@ class SpanFinder(TrainablePipe):
self.model.initialize(X=docs, Y=Y)
else:
self.model.initialize()
@registry.misc("spacy.span_finder_suggester.v1")
def build_span_finder_suggester(spans_key: str) -> Suggester:
"""Suggest every candidate predicted by the SpanFinder"""
def span_finder_suggester(
docs: Iterable[Doc], *, ops: Optional[Ops] = None
) -> Ragged:
if ops is None:
ops = get_current_ops()
spans = []
lengths = []
for doc in docs:
length = 0
if doc.spans[spans_key]:
for span in doc.spans[spans_key]:
spans.append([span.start, span.end])
length += 1
lengths.append(length)
lengths_array = cast(Ints1d, ops.asarray(lengths, dtype="i"))
if len(spans) > 0:
output = Ragged(ops.asarray(spans, dtype="i"), lengths_array)
else:
output = Ragged(ops.xp.zeros((0, 0), dtype="i"), lengths_array)
return output
return span_finder_suggester

View File

@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union,
import numpy
from thinc.api import Config, Model, Ops, Optimizer, get_current_ops, set_dropout_rate
from thinc.types import Floats2d, Ints2d, Ragged
from thinc.types import Floats2d, Ints1d, Ints2d, Ragged
from ..compat import Protocol, runtime_checkable
from ..errors import Errors
@ -111,6 +111,29 @@ def ngram_suggester(
return output
def preset_spans_suggester(
docs: Iterable[Doc], spans_key: str, *, ops: Optional[Ops] = None
) -> Ragged:
if ops is None:
ops = get_current_ops()
spans = []
lengths = []
for doc in docs:
length = 0
if doc.spans[spans_key]:
for span in doc.spans[spans_key]:
spans.append([span.start, span.end])
length += 1
lengths.append(length)
lengths_array = cast(Ints1d, ops.asarray(lengths, dtype="i"))
if len(spans) > 0:
output = Ragged(ops.asarray(spans, dtype="i"), lengths_array)
else:
output = Ragged(ops.xp.zeros((0, 0), dtype="i"), lengths_array)
return output
@registry.misc("spacy.ngram_suggester.v1")
def build_ngram_suggester(sizes: List[int]) -> Suggester:
"""Suggest all spans of the given lengths. Spans are returned as a ragged
@ -129,6 +152,14 @@ def build_ngram_range_suggester(min_size: int, max_size: int) -> Suggester:
return build_ngram_suggester(sizes)
@registry.misc("spacy.preset_spans_suggester.v1")
def build_preset_spans_suggester(spans_key: str) -> Suggester:
"""Suggest all spans that are already stored in doc.spans[spans_key].
This is useful when an upstream component is used to set the spans
on the Doc such as a SpanRuler or SpanFinder."""
return partial(preset_spans_suggester, spans_key=spans_key)
@Language.factory(
"spancat",
assigns=["doc.spans"],

View File

@ -199,7 +199,7 @@ def test_span_finder_suggester():
docs = [nlp("This is an example."), nlp("This is the second example.")]
docs[0].spans[SPANS_KEY] = [docs[0][3:4]]
docs[1].spans[SPANS_KEY] = [docs[1][0:4], docs[1][3:5]]
suggester = registry.misc.get("spacy.span_finder_suggester.v1")(
suggester = registry.misc.get("spacy.preset_spans_suggester.v1")(
spans_key=SPANS_KEY
)
candidates = suggester(docs)