This commit is contained in:
Ines Montani 2020-09-29 12:14:08 +02:00
parent 9c8b2524fe
commit 42f0e4c946
5 changed files with 8 additions and 10 deletions

View File

@ -8,7 +8,7 @@ from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
import warnings import warnings
from thinc.api import Model, get_current_ops, Config, require_gpu, Optimizer from thinc.api import Model, get_current_ops, Config, Optimizer
import srsly import srsly
import multiprocessing as mp import multiprocessing as mp
from itertools import chain, cycle from itertools import chain, cycle
@ -1153,10 +1153,9 @@ class Language:
get_examples: Optional[Callable[[], Iterable[Example]]] = None, get_examples: Optional[Callable[[], Iterable[Example]]] = None,
*, *,
sgd: Optional[Optimizer] = None, sgd: Optional[Optimizer] = None,
device: int = -1,
) -> Optimizer: ) -> Optimizer:
warnings.warn(Warnings.W089, DeprecationWarning) warnings.warn(Warnings.W089, DeprecationWarning)
return self.initialize(get_examples, sgd=sgd, device=device) return self.initialize(get_examples, sgd=sgd)
def initialize( def initialize(
self, self,
@ -1220,7 +1219,6 @@ class Language:
proc.initialize, p_settings, section="components", name=name proc.initialize, p_settings, section="components", name=name
) )
proc.initialize( proc.initialize(
get_examples, pipeline=self.pipeline
get_examples, get_examples,
pipeline=self.pipeline, pipeline=self.pipeline,
**p_settings, **p_settings,

View File

@ -132,7 +132,7 @@ cdef class DependencyParser(Parser):
labeller.model.set_dim("nO", len(self.labels)) labeller.model.set_dim("nO", len(self.labels))
if labeller.model.has_ref("output_layer"): if labeller.model.has_ref("output_layer"):
labeller.model.get_ref("output_layer").set_dim("nO", len(self.labels)) labeller.model.get_ref("output_layer").set_dim("nO", len(self.labels))
labeller.initialize(get_examples, pipeline=pipeline, sgd=sgd) labeller.initialize(get_examples, pipeline=pipeline)
@property @property
def labels(self): def labels(self):

View File

@ -58,7 +58,7 @@ class Sentencizer(Pipe):
else: else:
self.punct_chars = set(self.default_punct_chars) self.punct_chars = set(self.default_punct_chars)
def initialize(self, get_examples, pipeline=None, sgd=None): def initialize(self, get_examples, pipeline=None):
pass pass
def __call__(self, doc): def __call__(self, doc):

View File

@ -107,7 +107,7 @@ def validate_init_settings(
*, *,
section: Optional[str] = None, section: Optional[str] = None,
name: str = "", name: str = "",
exclude: Iterable[str] = ("get_examples", "nlp", "pipeline", "sgd"), exclude: Iterable[str] = ("get_examples", "nlp", "pipeline"),
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Validate initialization settings against the expected arguments in """Validate initialization settings against the expected arguments in
the method signature. Will parse values if possible (e.g. int to string) the method signature. Will parse values if possible (e.g. int to string)

View File

@ -55,7 +55,7 @@ def init_nlp(config: Config, *, use_gpu: int = -1, silent: bool = True) -> Langu
msg.info(f"Resuming training for: {resume_components}") msg.info(f"Resuming training for: {resume_components}")
nlp.resume_training(sgd=optimizer) nlp.resume_training(sgd=optimizer)
with nlp.select_pipes(disable=[*frozen_components, *resume_components]): with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer, settings=I) nlp.initialize(lambda: train_corpus(nlp), settings=I)
msg.good("Initialized pipeline components") msg.good("Initialized pipeline components")
# Verify the config after calling 'initialize' to ensure labels # Verify the config after calling 'initialize' to ensure labels
# are properly initialized # are properly initialized