Fix pipeline intialize

This commit is contained in:
Paul O'Leary McCann 2021-05-18 19:56:27 +09:00
parent 0620820857
commit 2486b8ad4d

View File

@ -2,6 +2,7 @@ from typing import Iterable, Tuple, Optional, Dict, Callable, Any, List
from thinc.types import Floats2d, Ints2d from thinc.types import Floats2d, Ints2d
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
from thinc.api import set_dropout_rate
from itertools import islice from itertools import islice
from .trainable_pipe import TrainablePipe from .trainable_pipe import TrainablePipe
@ -341,10 +342,15 @@ class CoreferenceResolver(TrainablePipe):
DOCS: https://spacy.io/api/coref#initialize (TODO) DOCS: https://spacy.io/api/coref#initialize (TODO)
""" """
validate_get_examples(get_examples, "CoreferenceResolver.initialize") validate_get_examples(get_examples, "CoreferenceResolver.initialize")
subbatch = list(islice(get_examples(), 10))
doc_sample = [eg.reference for eg in subbatch] X = []
assert len(doc_sample) > 0, Errors.E923.format(name=self.name) Y = []
self.model.initialize(X=doc_sample) for ex in islice(get_examples(), 10):
X.append(ex.predicted)
Y.append(ex.reference)
assert len(X) > 0, Errors.E923.format(name=self.name)
self.model.initialize(X=X, Y=Y)
def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]: def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
"""Score a batch of examples. """Score a batch of examples.