Language.distill: copy both reference and predicted (#12209)

* Language.distill: copy both reference and predicted

In distillation we also modify the teacher docs (e.g. in tok2vec
components), so we need to copy both the reference and predicted doc.

Problem caught by @shadeMe

* Make new `_copy_examples` args kwonly
This commit is contained in:
Daniël de Kok 2023-01-31 13:19:42 +01:00 committed by GitHub
parent fb7f018ded
commit c6cca4c00a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1062,7 +1062,7 @@ class Language:
return losses
validate_distillation_examples(examples, "Language.distill")
examples = _copy_examples(examples)
examples = _copy_examples(examples, copy_x=True, copy_y=True)
if sgd is None:
if self._optimizer is None:
@ -2331,13 +2331,18 @@ class DisabledPipes(list):
self[:] = []
def _copy_examples(examples: Iterable[Example]) -> List[Example]:
def _copy_examples(
examples: Iterable[Example], *, copy_x: bool = True, copy_y: bool = False
) -> List[Example]:
"""Make a copy of a batch of examples, copying the predicted Doc as well.
This is used in contexts where we need to take ownership of the examples
so that they can be mutated, for instance during Language.evaluate and
Language.update.
"""
return [Example(eg.x.copy(), eg.y) for eg in examples]
return [
Example(eg.x.copy() if copy_x else eg.x, eg.y.copy() if copy_y else eg.y)
for eg in examples
]
def _apply_pipes(