mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-11 17:56:30 +03:00
Restructure Sentencizer to follow Pipe API (#4721)
* Restructure Sentencizer to follow Pipe API Restructure Sentencizer to follow Pipe API so that it can be scored with `nlp.evaluate()`. * Add Sentencizer pipe() test
This commit is contained in:
parent
16cb19e960
commit
48ea2e8d0f
|
@ -1464,21 +1464,59 @@ class Sentencizer(object):
|
|||
|
||||
DOCS: https://spacy.io/api/sentencizer#call
|
||||
"""
|
||||
start = 0
|
||||
seen_period = False
|
||||
for i, token in enumerate(doc):
|
||||
is_in_punct_chars = token.text in self.punct_chars
|
||||
token.is_sent_start = i == 0
|
||||
if seen_period and not token.is_punct and not is_in_punct_chars:
|
||||
doc[start].is_sent_start = True
|
||||
start = token.i
|
||||
seen_period = False
|
||||
elif is_in_punct_chars:
|
||||
seen_period = True
|
||||
if start < len(doc):
|
||||
doc[start].is_sent_start = True
|
||||
tags = self.predict([doc])
|
||||
self.set_annotations([doc], tags)
|
||||
return doc
|
||||
|
||||
def pipe(self, stream, batch_size=128, n_threads=-1):
|
||||
for docs in util.minibatch(stream, size=batch_size):
|
||||
docs = list(docs)
|
||||
tag_ids = self.predict(docs)
|
||||
self.set_annotations(docs, tag_ids)
|
||||
yield from docs
|
||||
|
||||
def predict(self, docs):
|
||||
"""Apply the pipeline's model to a batch of docs, without
|
||||
modifying them.
|
||||
"""
|
||||
if not any(len(doc) for doc in docs):
|
||||
# Handle cases where there are no tokens in any docs.
|
||||
guesses = [[] for doc in docs]
|
||||
return guesses
|
||||
guesses = []
|
||||
for doc in docs:
|
||||
start = 0
|
||||
seen_period = False
|
||||
doc_guesses = [False] * len(doc)
|
||||
doc_guesses[0] = True
|
||||
for i, token in enumerate(doc):
|
||||
is_in_punct_chars = token.text in self.punct_chars
|
||||
if seen_period and not token.is_punct and not is_in_punct_chars:
|
||||
doc_guesses[start] = True
|
||||
start = token.i
|
||||
seen_period = False
|
||||
elif is_in_punct_chars:
|
||||
seen_period = True
|
||||
if start < len(doc):
|
||||
doc_guesses[start] = True
|
||||
guesses.append(doc_guesses)
|
||||
return guesses
|
||||
|
||||
def set_annotations(self, docs, batch_tag_ids, tensors=None):
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
cdef Doc doc
|
||||
cdef int idx = 0
|
||||
for i, doc in enumerate(docs):
|
||||
doc_tag_ids = batch_tag_ids[i]
|
||||
for j, tag_id in enumerate(doc_tag_ids):
|
||||
# Don't clobber existing sentence boundaries
|
||||
if doc.c[j].sent_start == 0:
|
||||
if tag_id:
|
||||
doc.c[j].sent_start = 1
|
||||
else:
|
||||
doc.c[j].sent_start = -1
|
||||
|
||||
def to_bytes(self, **kwargs):
|
||||
"""Serialize the sentencizer to a bytestring.
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import pytest
|
|||
import spacy
|
||||
from spacy.pipeline import Sentencizer
|
||||
from spacy.tokens import Doc
|
||||
from spacy.lang.en import English
|
||||
|
||||
|
||||
def test_sentencizer(en_vocab):
|
||||
|
@ -17,6 +18,17 @@ def test_sentencizer(en_vocab):
|
|||
assert len(list(doc.sents)) == 2
|
||||
|
||||
|
||||
def test_sentencizer_pipe():
|
||||
texts = ["Hello! This is a test.", "Hi! This is a test."]
|
||||
nlp = English()
|
||||
nlp.add_pipe(nlp.create_pipe("sentencizer"))
|
||||
for doc in nlp.pipe(texts):
|
||||
assert doc.is_sentenced
|
||||
sent_starts = [t.is_sent_start for t in doc]
|
||||
assert sent_starts == [True, False, True, False, False, False, False]
|
||||
assert len(list(doc.sents)) == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"words,sent_starts,n_sents",
|
||||
[
|
||||
|
|
Loading…
Reference in New Issue
Block a user