Restructure Sentencizer to follow Pipe API (#4721)

* Restructure Sentencizer to follow Pipe API

Restructure Sentencizer to follow Pipe API so that it can be scored with
`nlp.evaluate()`.

* Add Sentencizer pipe() test
This commit is contained in:
adrianeboyd 2019-11-27 16:33:34 +01:00 committed by Matthew Honnibal
parent 16cb19e960
commit 48ea2e8d0f
2 changed files with 63 additions and 13 deletions

View File

@ -1464,21 +1464,59 @@ class Sentencizer(object):
DOCS: https://spacy.io/api/sentencizer#call DOCS: https://spacy.io/api/sentencizer#call
""" """
start = 0 tags = self.predict([doc])
seen_period = False self.set_annotations([doc], tags)
for i, token in enumerate(doc):
is_in_punct_chars = token.text in self.punct_chars
token.is_sent_start = i == 0
if seen_period and not token.is_punct and not is_in_punct_chars:
doc[start].is_sent_start = True
start = token.i
seen_period = False
elif is_in_punct_chars:
seen_period = True
if start < len(doc):
doc[start].is_sent_start = True
return doc return doc
def pipe(self, stream, batch_size=128, n_threads=-1):
for docs in util.minibatch(stream, size=batch_size):
docs = list(docs)
tag_ids = self.predict(docs)
self.set_annotations(docs, tag_ids)
yield from docs
def predict(self, docs):
"""Apply the pipeline's model to a batch of docs, without
modifying them.
"""
if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs.
guesses = [[] for doc in docs]
return guesses
guesses = []
for doc in docs:
start = 0
seen_period = False
doc_guesses = [False] * len(doc)
doc_guesses[0] = True
for i, token in enumerate(doc):
is_in_punct_chars = token.text in self.punct_chars
if seen_period and not token.is_punct and not is_in_punct_chars:
doc_guesses[start] = True
start = token.i
seen_period = False
elif is_in_punct_chars:
seen_period = True
if start < len(doc):
doc_guesses[start] = True
guesses.append(doc_guesses)
return guesses
def set_annotations(self, docs, batch_tag_ids, tensors=None):
if isinstance(docs, Doc):
docs = [docs]
cdef Doc doc
cdef int idx = 0
for i, doc in enumerate(docs):
doc_tag_ids = batch_tag_ids[i]
for j, tag_id in enumerate(doc_tag_ids):
# Don't clobber existing sentence boundaries
if doc.c[j].sent_start == 0:
if tag_id:
doc.c[j].sent_start = 1
else:
doc.c[j].sent_start = -1
def to_bytes(self, **kwargs): def to_bytes(self, **kwargs):
"""Serialize the sentencizer to a bytestring. """Serialize the sentencizer to a bytestring.

View File

@ -5,6 +5,7 @@ import pytest
import spacy import spacy
from spacy.pipeline import Sentencizer from spacy.pipeline import Sentencizer
from spacy.tokens import Doc from spacy.tokens import Doc
from spacy.lang.en import English
def test_sentencizer(en_vocab): def test_sentencizer(en_vocab):
@ -17,6 +18,17 @@ def test_sentencizer(en_vocab):
assert len(list(doc.sents)) == 2 assert len(list(doc.sents)) == 2
def test_sentencizer_pipe():
texts = ["Hello! This is a test.", "Hi! This is a test."]
nlp = English()
nlp.add_pipe(nlp.create_pipe("sentencizer"))
for doc in nlp.pipe(texts):
assert doc.is_sentenced
sent_starts = [t.is_sent_start for t in doc]
assert sent_starts == [True, False, True, False, False, False, False]
assert len(list(doc.sents)) == 2
@pytest.mark.parametrize( @pytest.mark.parametrize(
"words,sent_starts,n_sents", "words,sent_starts,n_sents",
[ [