mirror of
https://github.com/explosion/spaCy.git
synced 2025-09-18 18:12:45 +03:00
Fix pipeline intialize
This commit is contained in:
parent
0620820857
commit
2486b8ad4d
|
@ -2,6 +2,7 @@ from typing import Iterable, Tuple, Optional, Dict, Callable, Any, List
|
||||||
|
|
||||||
from thinc.types import Floats2d, Ints2d
|
from thinc.types import Floats2d, Ints2d
|
||||||
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
|
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
|
||||||
|
from thinc.api import set_dropout_rate
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
|
|
||||||
from .trainable_pipe import TrainablePipe
|
from .trainable_pipe import TrainablePipe
|
||||||
|
@ -341,10 +342,15 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
DOCS: https://spacy.io/api/coref#initialize (TODO)
|
DOCS: https://spacy.io/api/coref#initialize (TODO)
|
||||||
"""
|
"""
|
||||||
validate_get_examples(get_examples, "CoreferenceResolver.initialize")
|
validate_get_examples(get_examples, "CoreferenceResolver.initialize")
|
||||||
subbatch = list(islice(get_examples(), 10))
|
|
||||||
doc_sample = [eg.reference for eg in subbatch]
|
X = []
|
||||||
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
Y = []
|
||||||
self.model.initialize(X=doc_sample)
|
for ex in islice(get_examples(), 10):
|
||||||
|
X.append(ex.predicted)
|
||||||
|
Y.append(ex.reference)
|
||||||
|
|
||||||
|
assert len(X) > 0, Errors.E923.format(name=self.name)
|
||||||
|
self.model.initialize(X=X, Y=Y)
|
||||||
|
|
||||||
def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
||||||
"""Score a batch of examples.
|
"""Score a batch of examples.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user