Fix pipeline intialize

This commit is contained in:
Paul O'Leary McCann 2021-05-18 19:56:27 +09:00
parent 0620820857
commit 2486b8ad4d

View File

@ -2,6 +2,7 @@ from typing import Iterable, Tuple, Optional, Dict, Callable, Any, List
from thinc.types import Floats2d, Ints2d
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
from thinc.api import set_dropout_rate
from itertools import islice
from .trainable_pipe import TrainablePipe
@ -341,10 +342,15 @@ class CoreferenceResolver(TrainablePipe):
DOCS: https://spacy.io/api/coref#initialize (TODO)
"""
validate_get_examples(get_examples, "CoreferenceResolver.initialize")
subbatch = list(islice(get_examples(), 10))
doc_sample = [eg.reference for eg in subbatch]
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
self.model.initialize(X=doc_sample)
X = []
Y = []
for ex in islice(get_examples(), 10):
X.append(ex.predicted)
Y.append(ex.reference)
assert len(X) > 0, Errors.E923.format(name=self.name)
self.model.initialize(X=X, Y=Y)
def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
"""Score a batch of examples.