diff --git a/spacy/pipeline/span_finder.py b/spacy/pipeline/span_finder.py index ba023c6c0..cc65e2e36 100644 --- a/spacy/pipeline/span_finder.py +++ b/spacy/pipeline/span_finder.py @@ -1,7 +1,7 @@ -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, cast +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple -from thinc.api import Config, Model, Ops, Optimizer, get_current_ops, set_dropout_rate -from thinc.types import Floats2d, Ints1d, Ragged +from thinc.api import Config, Model, Optimizer, set_dropout_rate +from thinc.types import Floats2d from spacy.language import Language from spacy.pipeline.trainable_pipe import TrainablePipe @@ -11,7 +11,7 @@ from spacy.training import Example from spacy.errors import Errors from ..util import registry -from .spancat import DEFAULT_SPANS_KEY, Suggester +from .spancat import DEFAULT_SPANS_KEY span_finder_default_config = """ [model] @@ -310,34 +310,3 @@ class SpanFinder(TrainablePipe): self.model.initialize(X=docs, Y=Y) else: self.model.initialize() - - -@registry.misc("spacy.span_finder_suggester.v1") -def build_span_finder_suggester(spans_key: str) -> Suggester: - """Suggest every candidate predicted by the SpanFinder""" - - def span_finder_suggester( - docs: Iterable[Doc], *, ops: Optional[Ops] = None - ) -> Ragged: - if ops is None: - ops = get_current_ops() - spans = [] - lengths = [] - for doc in docs: - length = 0 - if doc.spans[spans_key]: - for span in doc.spans[spans_key]: - spans.append([span.start, span.end]) - length += 1 - - lengths.append(length) - - lengths_array = cast(Ints1d, ops.asarray(lengths, dtype="i")) - if len(spans) > 0: - output = Ragged(ops.asarray(spans, dtype="i"), lengths_array) - else: - output = Ragged(ops.xp.zeros((0, 0), dtype="i"), lengths_array) - - return output - - return span_finder_suggester diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index 0b252996e..185c0cd5d 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, import numpy from thinc.api import Config, Model, Ops, Optimizer, get_current_ops, set_dropout_rate -from thinc.types import Floats2d, Ints2d, Ragged +from thinc.types import Floats2d, Ints1d, Ints2d, Ragged from ..compat import Protocol, runtime_checkable from ..errors import Errors @@ -111,6 +111,29 @@ def ngram_suggester( return output +def preset_spans_suggester( + docs: Iterable[Doc], spans_key: str, *, ops: Optional[Ops] = None +) -> Ragged: + if ops is None: + ops = get_current_ops() + spans = [] + lengths = [] + for doc in docs: + length = 0 + if doc.spans[spans_key]: + for span in doc.spans[spans_key]: + spans.append([span.start, span.end]) + length += 1 + + lengths.append(length) + lengths_array = cast(Ints1d, ops.asarray(lengths, dtype="i")) + if len(spans) > 0: + output = Ragged(ops.asarray(spans, dtype="i"), lengths_array) + else: + output = Ragged(ops.xp.zeros((0, 0), dtype="i"), lengths_array) + return output + + @registry.misc("spacy.ngram_suggester.v1") def build_ngram_suggester(sizes: List[int]) -> Suggester: """Suggest all spans of the given lengths. Spans are returned as a ragged @@ -129,6 +152,14 @@ def build_ngram_range_suggester(min_size: int, max_size: int) -> Suggester: return build_ngram_suggester(sizes) +@registry.misc("spacy.preset_spans_suggester.v1") +def build_preset_spans_suggester(spans_key: str) -> Suggester: + """Suggest all spans that are already stored in doc.spans[spans_key]. + This is useful when an upstream component is used to set the spans + on the Doc such as a SpanRuler or SpanFinder.""" + return partial(preset_spans_suggester, spans_key=spans_key) + + @Language.factory( "spancat", assigns=["doc.spans"], diff --git a/spacy/tests/pipeline/test_span_finder.py b/spacy/tests/pipeline/test_span_finder.py index ebe1879d4..b035a2aa5 100644 --- a/spacy/tests/pipeline/test_span_finder.py +++ b/spacy/tests/pipeline/test_span_finder.py @@ -199,7 +199,7 @@ def test_span_finder_suggester(): docs = [nlp("This is an example."), nlp("This is the second example.")] docs[0].spans[SPANS_KEY] = [docs[0][3:4]] docs[1].spans[SPANS_KEY] = [docs[1][0:4], docs[1][3:5]] - suggester = registry.misc.get("spacy.span_finder_suggester.v1")( + suggester = registry.misc.get("spacy.preset_spans_suggester.v1")( spans_key=SPANS_KEY ) candidates = suggester(docs)