mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-21 17:41:59 +03:00
make it clear that the span_finder_suggester is more general (not specific to span_finder)
This commit is contained in:
parent
37c4ad5007
commit
752b3066cf
|
@ -1,7 +1,7 @@
|
|||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, cast
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from thinc.api import Config, Model, Ops, Optimizer, get_current_ops, set_dropout_rate
|
||||
from thinc.types import Floats2d, Ints1d, Ragged
|
||||
from thinc.api import Config, Model, Optimizer, set_dropout_rate
|
||||
from thinc.types import Floats2d
|
||||
|
||||
from spacy.language import Language
|
||||
from spacy.pipeline.trainable_pipe import TrainablePipe
|
||||
|
@ -11,7 +11,7 @@ from spacy.training import Example
|
|||
from spacy.errors import Errors
|
||||
|
||||
from ..util import registry
|
||||
from .spancat import DEFAULT_SPANS_KEY, Suggester
|
||||
from .spancat import DEFAULT_SPANS_KEY
|
||||
|
||||
span_finder_default_config = """
|
||||
[model]
|
||||
|
@ -310,34 +310,3 @@ class SpanFinder(TrainablePipe):
|
|||
self.model.initialize(X=docs, Y=Y)
|
||||
else:
|
||||
self.model.initialize()
|
||||
|
||||
|
||||
@registry.misc("spacy.span_finder_suggester.v1")
|
||||
def build_span_finder_suggester(spans_key: str) -> Suggester:
|
||||
"""Suggest every candidate predicted by the SpanFinder"""
|
||||
|
||||
def span_finder_suggester(
|
||||
docs: Iterable[Doc], *, ops: Optional[Ops] = None
|
||||
) -> Ragged:
|
||||
if ops is None:
|
||||
ops = get_current_ops()
|
||||
spans = []
|
||||
lengths = []
|
||||
for doc in docs:
|
||||
length = 0
|
||||
if doc.spans[spans_key]:
|
||||
for span in doc.spans[spans_key]:
|
||||
spans.append([span.start, span.end])
|
||||
length += 1
|
||||
|
||||
lengths.append(length)
|
||||
|
||||
lengths_array = cast(Ints1d, ops.asarray(lengths, dtype="i"))
|
||||
if len(spans) > 0:
|
||||
output = Ragged(ops.asarray(spans, dtype="i"), lengths_array)
|
||||
else:
|
||||
output = Ragged(ops.xp.zeros((0, 0), dtype="i"), lengths_array)
|
||||
|
||||
return output
|
||||
|
||||
return span_finder_suggester
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union,
|
|||
|
||||
import numpy
|
||||
from thinc.api import Config, Model, Ops, Optimizer, get_current_ops, set_dropout_rate
|
||||
from thinc.types import Floats2d, Ints2d, Ragged
|
||||
from thinc.types import Floats2d, Ints1d, Ints2d, Ragged
|
||||
|
||||
from ..compat import Protocol, runtime_checkable
|
||||
from ..errors import Errors
|
||||
|
@ -111,6 +111,29 @@ def ngram_suggester(
|
|||
return output
|
||||
|
||||
|
||||
def preset_spans_suggester(
|
||||
docs: Iterable[Doc], spans_key: str, *, ops: Optional[Ops] = None
|
||||
) -> Ragged:
|
||||
if ops is None:
|
||||
ops = get_current_ops()
|
||||
spans = []
|
||||
lengths = []
|
||||
for doc in docs:
|
||||
length = 0
|
||||
if doc.spans[spans_key]:
|
||||
for span in doc.spans[spans_key]:
|
||||
spans.append([span.start, span.end])
|
||||
length += 1
|
||||
|
||||
lengths.append(length)
|
||||
lengths_array = cast(Ints1d, ops.asarray(lengths, dtype="i"))
|
||||
if len(spans) > 0:
|
||||
output = Ragged(ops.asarray(spans, dtype="i"), lengths_array)
|
||||
else:
|
||||
output = Ragged(ops.xp.zeros((0, 0), dtype="i"), lengths_array)
|
||||
return output
|
||||
|
||||
|
||||
@registry.misc("spacy.ngram_suggester.v1")
|
||||
def build_ngram_suggester(sizes: List[int]) -> Suggester:
|
||||
"""Suggest all spans of the given lengths. Spans are returned as a ragged
|
||||
|
@ -129,6 +152,14 @@ def build_ngram_range_suggester(min_size: int, max_size: int) -> Suggester:
|
|||
return build_ngram_suggester(sizes)
|
||||
|
||||
|
||||
@registry.misc("spacy.preset_spans_suggester.v1")
|
||||
def build_preset_spans_suggester(spans_key: str) -> Suggester:
|
||||
"""Suggest all spans that are already stored in doc.spans[spans_key].
|
||||
This is useful when an upstream component is used to set the spans
|
||||
on the Doc such as a SpanRuler or SpanFinder."""
|
||||
return partial(preset_spans_suggester, spans_key=spans_key)
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"spancat",
|
||||
assigns=["doc.spans"],
|
||||
|
|
|
@ -199,7 +199,7 @@ def test_span_finder_suggester():
|
|||
docs = [nlp("This is an example."), nlp("This is the second example.")]
|
||||
docs[0].spans[SPANS_KEY] = [docs[0][3:4]]
|
||||
docs[1].spans[SPANS_KEY] = [docs[1][0:4], docs[1][3:5]]
|
||||
suggester = registry.misc.get("spacy.span_finder_suggester.v1")(
|
||||
suggester = registry.misc.get("spacy.preset_spans_suggester.v1")(
|
||||
spans_key=SPANS_KEY
|
||||
)
|
||||
candidates = suggester(docs)
|
||||
|
|
Loading…
Reference in New Issue
Block a user