mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-19 20:52:23 +03:00
Fix pipeline intialize
This commit is contained in:
parent
0620820857
commit
2486b8ad4d
|
@ -2,6 +2,7 @@ from typing import Iterable, Tuple, Optional, Dict, Callable, Any, List
|
|||
|
||||
from thinc.types import Floats2d, Ints2d
|
||||
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
|
||||
from thinc.api import set_dropout_rate
|
||||
from itertools import islice
|
||||
|
||||
from .trainable_pipe import TrainablePipe
|
||||
|
@ -341,10 +342,15 @@ class CoreferenceResolver(TrainablePipe):
|
|||
DOCS: https://spacy.io/api/coref#initialize (TODO)
|
||||
"""
|
||||
validate_get_examples(get_examples, "CoreferenceResolver.initialize")
|
||||
subbatch = list(islice(get_examples(), 10))
|
||||
doc_sample = [eg.reference for eg in subbatch]
|
||||
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
||||
self.model.initialize(X=doc_sample)
|
||||
|
||||
X = []
|
||||
Y = []
|
||||
for ex in islice(get_examples(), 10):
|
||||
X.append(ex.predicted)
|
||||
Y.append(ex.reference)
|
||||
|
||||
assert len(X) > 0, Errors.E923.format(name=self.name)
|
||||
self.model.initialize(X=X, Y=Y)
|
||||
|
||||
def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
||||
"""Score a batch of examples.
|
||||
|
|
Loading…
Reference in New Issue
Block a user