mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-23 15:54:13 +03:00
Ensure training doesn't crash with empty batches (#4360)
* unit test for previously resolved unflatten issue * prevent batch of empty docs to cause problems
This commit is contained in:
parent
52b5912dbf
commit
9d3ce7cba2
|
@ -454,6 +454,10 @@ class Tagger(Pipe):
|
|||
if losses is not None and self.name not in losses:
|
||||
losses[self.name] = 0.
|
||||
|
||||
if not any(len(doc) for doc in docs):
|
||||
# Handle cases where there are no tokens in any docs.
|
||||
return
|
||||
|
||||
tag_scores, bp_tag_scores = self.model.begin_update(docs, drop=drop)
|
||||
loss, d_tag_scores = self.get_loss(docs, golds, tag_scores)
|
||||
bp_tag_scores(d_tag_scores, sgd=sgd)
|
||||
|
@ -467,6 +471,9 @@ class Tagger(Pipe):
|
|||
"""
|
||||
if self._rehearsal_model is None:
|
||||
return
|
||||
if not any(len(doc) for doc in docs):
|
||||
# Handle cases where there are no tokens in any docs.
|
||||
return
|
||||
guesses, backprop = self.model.begin_update(docs, drop=drop)
|
||||
target = self._rehearsal_model(docs)
|
||||
gradient = guesses - target
|
||||
|
@ -968,6 +975,9 @@ class TextCategorizer(Pipe):
|
|||
|
||||
def update(self, docs, golds, state=None, drop=0., sgd=None, losses=None):
|
||||
self.require_model()
|
||||
if not any(len(doc) for doc in docs):
|
||||
# Handle cases where there are no tokens in any docs.
|
||||
return
|
||||
scores, bp_scores = self.model.begin_update(docs, drop=drop)
|
||||
loss, d_scores = self.get_loss(docs, golds, scores)
|
||||
bp_scores(d_scores, sgd=sgd)
|
||||
|
@ -978,6 +988,9 @@ class TextCategorizer(Pipe):
|
|||
def rehearse(self, docs, drop=0., sgd=None, losses=None):
|
||||
if self._rehearsal_model is None:
|
||||
return
|
||||
if not any(len(doc) for doc in docs):
|
||||
# Handle cases where there are no tokens in any docs.
|
||||
return
|
||||
scores, bp_scores = self.model.begin_update(docs, drop=drop)
|
||||
target = self._rehearsal_model(docs)
|
||||
gradient = scores - target
|
||||
|
|
|
@ -318,6 +318,14 @@ def test_issue3449():
|
|||
assert t3[5].text == "I"
|
||||
|
||||
|
||||
def test_issue3456():
|
||||
# this crashed because of a padding error in layer.ops.unflatten in thinc
|
||||
nlp = English()
|
||||
nlp.add_pipe(nlp.create_pipe("tagger"))
|
||||
nlp.begin_training()
|
||||
list(nlp.pipe(['hi', '']))
|
||||
|
||||
|
||||
def test_issue3468():
|
||||
"""Test that sentence boundaries are set correctly so Doc.is_sentenced can
|
||||
be restored after serialization."""
|
||||
|
|
23
spacy/tests/regression/test_issue4348.py
Normal file
23
spacy/tests/regression/test_issue4348.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from spacy.lang.en import English
|
||||
from spacy.util import minibatch, compounding
|
||||
|
||||
|
||||
def test_issue4348():
|
||||
"""Test that training the tagger with empty data, doesn't throw errors"""
|
||||
|
||||
TRAIN_DATA = [("", {"tags": []}), ("", {"tags": []})]
|
||||
|
||||
nlp = English()
|
||||
tagger = nlp.create_pipe("tagger")
|
||||
nlp.add_pipe(tagger)
|
||||
|
||||
optimizer = nlp.begin_training()
|
||||
for i in range(5):
|
||||
losses = {}
|
||||
batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
|
||||
for batch in batches:
|
||||
texts, annotations = zip(*batch)
|
||||
nlp.update(texts, annotations, sgd=optimizer, losses=losses)
|
Loading…
Reference in New Issue
Block a user