mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
Restructure Sentencizer to follow Pipe API (#4721)
* Restructure Sentencizer to follow Pipe API Restructure Sentencizer to follow Pipe API so that it can be scored with `nlp.evaluate()`. * Add Sentencizer pipe() test
This commit is contained in:
parent
16cb19e960
commit
48ea2e8d0f
|
@ -1464,21 +1464,59 @@ class Sentencizer(object):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/sentencizer#call
|
DOCS: https://spacy.io/api/sentencizer#call
|
||||||
"""
|
"""
|
||||||
start = 0
|
tags = self.predict([doc])
|
||||||
seen_period = False
|
self.set_annotations([doc], tags)
|
||||||
for i, token in enumerate(doc):
|
|
||||||
is_in_punct_chars = token.text in self.punct_chars
|
|
||||||
token.is_sent_start = i == 0
|
|
||||||
if seen_period and not token.is_punct and not is_in_punct_chars:
|
|
||||||
doc[start].is_sent_start = True
|
|
||||||
start = token.i
|
|
||||||
seen_period = False
|
|
||||||
elif is_in_punct_chars:
|
|
||||||
seen_period = True
|
|
||||||
if start < len(doc):
|
|
||||||
doc[start].is_sent_start = True
|
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
|
def pipe(self, stream, batch_size=128, n_threads=-1):
|
||||||
|
for docs in util.minibatch(stream, size=batch_size):
|
||||||
|
docs = list(docs)
|
||||||
|
tag_ids = self.predict(docs)
|
||||||
|
self.set_annotations(docs, tag_ids)
|
||||||
|
yield from docs
|
||||||
|
|
||||||
|
def predict(self, docs):
|
||||||
|
"""Apply the pipeline's model to a batch of docs, without
|
||||||
|
modifying them.
|
||||||
|
"""
|
||||||
|
if not any(len(doc) for doc in docs):
|
||||||
|
# Handle cases where there are no tokens in any docs.
|
||||||
|
guesses = [[] for doc in docs]
|
||||||
|
return guesses
|
||||||
|
guesses = []
|
||||||
|
for doc in docs:
|
||||||
|
start = 0
|
||||||
|
seen_period = False
|
||||||
|
doc_guesses = [False] * len(doc)
|
||||||
|
doc_guesses[0] = True
|
||||||
|
for i, token in enumerate(doc):
|
||||||
|
is_in_punct_chars = token.text in self.punct_chars
|
||||||
|
if seen_period and not token.is_punct and not is_in_punct_chars:
|
||||||
|
doc_guesses[start] = True
|
||||||
|
start = token.i
|
||||||
|
seen_period = False
|
||||||
|
elif is_in_punct_chars:
|
||||||
|
seen_period = True
|
||||||
|
if start < len(doc):
|
||||||
|
doc_guesses[start] = True
|
||||||
|
guesses.append(doc_guesses)
|
||||||
|
return guesses
|
||||||
|
|
||||||
|
def set_annotations(self, docs, batch_tag_ids, tensors=None):
|
||||||
|
if isinstance(docs, Doc):
|
||||||
|
docs = [docs]
|
||||||
|
cdef Doc doc
|
||||||
|
cdef int idx = 0
|
||||||
|
for i, doc in enumerate(docs):
|
||||||
|
doc_tag_ids = batch_tag_ids[i]
|
||||||
|
for j, tag_id in enumerate(doc_tag_ids):
|
||||||
|
# Don't clobber existing sentence boundaries
|
||||||
|
if doc.c[j].sent_start == 0:
|
||||||
|
if tag_id:
|
||||||
|
doc.c[j].sent_start = 1
|
||||||
|
else:
|
||||||
|
doc.c[j].sent_start = -1
|
||||||
|
|
||||||
def to_bytes(self, **kwargs):
|
def to_bytes(self, **kwargs):
|
||||||
"""Serialize the sentencizer to a bytestring.
|
"""Serialize the sentencizer to a bytestring.
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import pytest
|
||||||
import spacy
|
import spacy
|
||||||
from spacy.pipeline import Sentencizer
|
from spacy.pipeline import Sentencizer
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
|
from spacy.lang.en import English
|
||||||
|
|
||||||
|
|
||||||
def test_sentencizer(en_vocab):
|
def test_sentencizer(en_vocab):
|
||||||
|
@ -17,6 +18,17 @@ def test_sentencizer(en_vocab):
|
||||||
assert len(list(doc.sents)) == 2
|
assert len(list(doc.sents)) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_sentencizer_pipe():
|
||||||
|
texts = ["Hello! This is a test.", "Hi! This is a test."]
|
||||||
|
nlp = English()
|
||||||
|
nlp.add_pipe(nlp.create_pipe("sentencizer"))
|
||||||
|
for doc in nlp.pipe(texts):
|
||||||
|
assert doc.is_sentenced
|
||||||
|
sent_starts = [t.is_sent_start for t in doc]
|
||||||
|
assert sent_starts == [True, False, True, False, False, False, False]
|
||||||
|
assert len(list(doc.sents)) == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"words,sent_starts,n_sents",
|
"words,sent_starts,n_sents",
|
||||||
[
|
[
|
||||||
|
|
Loading…
Reference in New Issue
Block a user