mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Add as_example to Sentencizer pipe() (#4933)
This commit is contained in:
parent
d2f3a44b42
commit
199d89943e
|
@ -1711,12 +1711,24 @@ class Sentencizer(Pipe):
|
||||||
return example
|
return example
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
def pipe(self, stream, batch_size=128, n_threads=-1):
|
def pipe(self, stream, batch_size=128, n_threads=-1, as_example=False):
|
||||||
for docs in util.minibatch(stream, size=batch_size):
|
for examples in util.minibatch(stream, size=batch_size):
|
||||||
docs = list(docs)
|
docs = [self._get_doc(ex) for ex in examples]
|
||||||
tag_ids = self.predict(docs)
|
predictions = self.predict(docs)
|
||||||
self.set_annotations(docs, tag_ids)
|
if isinstance(predictions, tuple) and len(tuple) == 2:
|
||||||
yield from docs
|
scores, tensors = predictions
|
||||||
|
self.set_annotations(docs, scores, tensors=tensors)
|
||||||
|
else:
|
||||||
|
self.set_annotations(docs, predictions)
|
||||||
|
|
||||||
|
if as_example:
|
||||||
|
annotated_examples = []
|
||||||
|
for ex, doc in zip(examples, docs):
|
||||||
|
ex.doc = doc
|
||||||
|
annotated_examples.append(ex)
|
||||||
|
yield from annotated_examples
|
||||||
|
else:
|
||||||
|
yield from docs
|
||||||
|
|
||||||
def predict(self, docs):
|
def predict(self, docs):
|
||||||
"""Apply the pipeline's model to a batch of docs, without
|
"""Apply the pipeline's model to a batch of docs, without
|
||||||
|
|
|
@ -24,6 +24,12 @@ def test_sentencizer_pipe():
|
||||||
sent_starts = [t.is_sent_start for t in doc]
|
sent_starts = [t.is_sent_start for t in doc]
|
||||||
assert sent_starts == [True, False, True, False, False, False, False]
|
assert sent_starts == [True, False, True, False, False, False, False]
|
||||||
assert len(list(doc.sents)) == 2
|
assert len(list(doc.sents)) == 2
|
||||||
|
for ex in nlp.pipe(texts, as_example=True):
|
||||||
|
doc = ex.doc
|
||||||
|
assert doc.is_sentenced
|
||||||
|
sent_starts = [t.is_sent_start for t in doc]
|
||||||
|
assert sent_starts == [True, False, True, False, False, False, False]
|
||||||
|
assert len(list(doc.sents)) == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user