spaCy/spacy/pipeline/sentencizer.pyx
Paul O'Leary McCann 89f974d4f5
Cleanup/remove backwards compat overwrite settings (#11888)
* Remove backwards-compatible overwrite from Entity Linker

This also adds a docstring about overwrite, since it wasn't present.

* Fix docstring

* Remove backward compat settings in Morphologizer

This also needed a docstring added.

For this component it's less clear what the right overwrite settings
are.

* Remove backward compat from sentencizer

This was simple

* Remove backward compat from senter

Another simple one

* Remove backward compat setting from tagger

* Add docstrings

* Update spacy/pipeline/morphologizer.pyx

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>

* Update docs

---------

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
2023-02-02 14:13:38 +01:00

181 lines
6.7 KiB
Cython
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# cython: infer_types=True, profile=True, binding=True
from typing import Optional, List, Callable
import srsly
from ..tokens.doc cimport Doc
from .pipe import Pipe
from .senter import senter_score
from ..language import Language
from ..scorer import Scorer
from .. import util
@Language.factory(
"sentencizer",
assigns=["token.is_sent_start", "doc.sents"],
default_config={"punct_chars": None, "overwrite": False, "scorer": {"@scorers": "spacy.senter_scorer.v1"}},
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
)
def make_sentencizer(
nlp: Language,
name: str,
punct_chars: Optional[List[str]],
overwrite: bool,
scorer: Optional[Callable],
):
return Sentencizer(name, punct_chars=punct_chars, overwrite=overwrite, scorer=scorer)
class Sentencizer(Pipe):
"""Segment the Doc into sentences using a rule-based strategy.
DOCS: https://spacy.io/api/sentencizer
"""
default_punct_chars = ['!', '.', '?', '։', '؟', '۔', '܀', '܁', '܂', '߹',
'', '', '', '', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '', '', '', '᱿',
'', '', '', '', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '𐩖', '𐩗', '𑁇', '𑁈', '𑂾', '𑂿', '𑃀',
'𑃁', '𑅁', '𑅂', '𑅃', '𑇅', '𑇆', '𑇍', '𑇞', '𑇟', '𑈸', '𑈹', '𑈻', '𑈼',
'𑊩', '𑑋', '𑑌', '𑗂', '𑗃', '𑗉', '𑗊', '𑗋', '𑗌', '𑗍', '𑗎', '𑗏', '𑗐',
'𑗑', '𑗒', '𑗓', '𑗔', '𑗕', '𑗖', '𑗗', '𑙁', '𑙂', '𑜼', '𑜽', '𑜾', '𑩂',
'𑩃', '𑪛', '𑪜', '𑱁', '𑱂', '𖩮', '𖩯', '𖫵', '𖬷', '𖬸', '𖭄', '𛲟', '𝪈',
'', '']
def __init__(
self,
name="sentencizer",
*,
punct_chars=None,
overwrite=False,
scorer=senter_score,
):
"""Initialize the sentencizer.
punct_chars (list): Punctuation characters to split on. Will be
serialized with the nlp object.
overwrite (bool): Whether to overwrite existing annotations.
scorer (Optional[Callable]): The scoring method. Defaults to
Scorer.score_spans for the attribute "sents".
DOCS: https://spacy.io/api/sentencizer#init
"""
self.name = name
if punct_chars:
self.punct_chars = set(punct_chars)
else:
self.punct_chars = set(self.default_punct_chars)
self.overwrite = overwrite
self.scorer = scorer
def __call__(self, doc):
"""Apply the sentencizer to a Doc and set Token.is_sent_start.
doc (Doc): The document to process.
RETURNS (Doc): The processed Doc.
DOCS: https://spacy.io/api/sentencizer#call
"""
error_handler = self.get_error_handler()
try:
tags = self.predict([doc])
self.set_annotations([doc], tags)
return doc
except Exception as e:
error_handler(self.name, self, [doc], e)
def predict(self, docs):
"""Apply the pipe to a batch of docs, without modifying them.
docs (Iterable[Doc]): The documents to predict.
RETURNS: The predictions for each document.
"""
if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs.
guesses = [[] for doc in docs]
return guesses
guesses = []
for doc in docs:
doc_guesses = [False] * len(doc)
if len(doc) > 0:
start = 0
seen_period = False
doc_guesses[0] = True
for i, token in enumerate(doc):
is_in_punct_chars = token.text in self.punct_chars
if seen_period and not token.is_punct and not is_in_punct_chars:
doc_guesses[start] = True
start = token.i
seen_period = False
elif is_in_punct_chars:
seen_period = True
if start < len(doc):
doc_guesses[start] = True
guesses.append(doc_guesses)
return guesses
def set_annotations(self, docs, batch_tag_ids):
"""Modify a batch of documents, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify.
scores: The tag IDs produced by Sentencizer.predict.
"""
if isinstance(docs, Doc):
docs = [docs]
cdef Doc doc
cdef int idx = 0
for i, doc in enumerate(docs):
doc_tag_ids = batch_tag_ids[i]
for j, tag_id in enumerate(doc_tag_ids):
if doc.c[j].sent_start == 0 or self.overwrite:
if tag_id:
doc.c[j].sent_start = 1
else:
doc.c[j].sent_start = -1
def to_bytes(self, *, exclude=tuple()):
"""Serialize the sentencizer to a bytestring.
RETURNS (bytes): The serialized object.
DOCS: https://spacy.io/api/sentencizer#to_bytes
"""
return srsly.msgpack_dumps({"punct_chars": list(self.punct_chars), "overwrite": self.overwrite})
def from_bytes(self, bytes_data, *, exclude=tuple()):
"""Load the sentencizer from a bytestring.
bytes_data (bytes): The data to load.
returns (Sentencizer): The loaded object.
DOCS: https://spacy.io/api/sentencizer#from_bytes
"""
cfg = srsly.msgpack_loads(bytes_data)
self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars))
self.overwrite = cfg.get("overwrite", self.overwrite)
return self
def to_disk(self, path, *, exclude=tuple()):
"""Serialize the sentencizer to disk.
DOCS: https://spacy.io/api/sentencizer#to_disk
"""
path = util.ensure_path(path)
path = path.with_suffix(".json")
srsly.write_json(path, {"punct_chars": list(self.punct_chars), "overwrite": self.overwrite})
def from_disk(self, path, *, exclude=tuple()):
"""Load the sentencizer from disk.
DOCS: https://spacy.io/api/sentencizer#from_disk
"""
path = util.ensure_path(path)
path = path.with_suffix(".json")
cfg = srsly.read_json(path)
self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars))
self.overwrite = cfg.get("overwrite", self.overwrite)
return self