Language.distill: copy both reference and predicted

In distillation we also modify the teacher docs (e.g. in tok2vec
components), so we need to copy both the reference and predicted doc.

Problem caught by @shadeMe
This commit is contained in:
Daniël de Kok 2023-01-31 12:17:44 +01:00
parent 1b5aba9e22
commit 4f9101ee56

View File

@ -1059,7 +1059,7 @@ class Language:
return losses
validate_distillation_examples(examples, "Language.distill")
examples = _copy_examples(examples)
examples = _copy_examples(examples, copy_x=True, copy_y=True)
if sgd is None:
if self._optimizer is None:
@ -2328,13 +2328,18 @@ class DisabledPipes(list):
self[:] = []
def _copy_examples(examples: Iterable[Example]) -> List[Example]:
def _copy_examples(
examples: Iterable[Example], copy_x: bool = True, copy_y: bool = False
) -> List[Example]:
"""Make a copy of a batch of examples, copying the predicted Doc as well.
This is used in contexts where we need to take ownership of the examples
so that they can be mutated, for instance during Language.evaluate and
Language.update.
"""
return [Example(eg.x.copy(), eg.y) for eg in examples]
return [
Example(eg.x.copy() if copy_x else eg.x, eg.y.copy() if copy_y else eg.y)
for eg in examples
]
def _apply_pipes(