mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Pipe API (#6034)
* ensure Language passes on valid examples for initialization * fix tagger model initialization * check for valid get_examples across components * assume labels were added before begin_training * fix senter initialization * fix morphologizer initialization * use methods to check arguments * test textcat init, requires thinc>=8.0.0a31 * fix tok2vec init * fix entity linker init * use islice * fix simple NER * cleanup debug model * fix assert statements * fix tests * throw error when adding a label if the output layer can't be resized anymore * fix test * add failing test for simple_ner * UX improvements * morphologizer UX * assume begin_training gets a representative set and processes the labels * remove assumptions for output of untrained NER model * restore test for original purpose
This commit is contained in:
parent
4b82882767
commit
60f22e1800
|
@ -6,7 +6,7 @@ requires = [
|
||||||
"cymem>=2.0.2,<2.1.0",
|
"cymem>=2.0.2,<2.1.0",
|
||||||
"preshed>=3.0.2,<3.1.0",
|
"preshed>=3.0.2,<3.1.0",
|
||||||
"murmurhash>=0.28.0,<1.1.0",
|
"murmurhash>=0.28.0,<1.1.0",
|
||||||
"thinc>=8.0.0a30,<8.0.0a40",
|
"thinc>=8.0.0a31,<8.0.0a40",
|
||||||
"blis>=0.4.0,<0.5.0",
|
"blis>=0.4.0,<0.5.0",
|
||||||
"pytokenizations",
|
"pytokenizations",
|
||||||
"pathy"
|
"pathy"
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# Our libraries
|
# Our libraries
|
||||||
cymem>=2.0.2,<2.1.0
|
cymem>=2.0.2,<2.1.0
|
||||||
preshed>=3.0.2,<3.1.0
|
preshed>=3.0.2,<3.1.0
|
||||||
thinc>=8.0.0a30,<8.0.0a40
|
thinc>=8.0.0a31,<8.0.0a40
|
||||||
blis>=0.4.0,<0.5.0
|
blis>=0.4.0,<0.5.0
|
||||||
ml_datasets>=0.1.1
|
ml_datasets>=0.1.1
|
||||||
murmurhash>=0.28.0,<1.1.0
|
murmurhash>=0.28.0,<1.1.0
|
||||||
|
|
|
@ -34,13 +34,13 @@ setup_requires =
|
||||||
cymem>=2.0.2,<2.1.0
|
cymem>=2.0.2,<2.1.0
|
||||||
preshed>=3.0.2,<3.1.0
|
preshed>=3.0.2,<3.1.0
|
||||||
murmurhash>=0.28.0,<1.1.0
|
murmurhash>=0.28.0,<1.1.0
|
||||||
thinc>=8.0.0a30,<8.0.0a40
|
thinc>=8.0.0a31,<8.0.0a40
|
||||||
install_requires =
|
install_requires =
|
||||||
# Our libraries
|
# Our libraries
|
||||||
murmurhash>=0.28.0,<1.1.0
|
murmurhash>=0.28.0,<1.1.0
|
||||||
cymem>=2.0.2,<2.1.0
|
cymem>=2.0.2,<2.1.0
|
||||||
preshed>=3.0.2,<3.1.0
|
preshed>=3.0.2,<3.1.0
|
||||||
thinc>=8.0.0a30,<8.0.0a40
|
thinc>=8.0.0a31,<8.0.0a40
|
||||||
blis>=0.4.0,<0.5.0
|
blis>=0.4.0,<0.5.0
|
||||||
wasabi>=0.8.0,<1.1.0
|
wasabi>=0.8.0,<1.1.0
|
||||||
srsly>=2.1.0,<3.0.0
|
srsly>=2.1.0,<3.0.0
|
||||||
|
|
|
@ -84,11 +84,11 @@ def debug_model(model: Model, *, print_settings: Optional[Dict[str, Any]] = None
|
||||||
_print_model(model, print_settings)
|
_print_model(model, print_settings)
|
||||||
|
|
||||||
# STEP 1: Initializing the model and printing again
|
# STEP 1: Initializing the model and printing again
|
||||||
|
X = _get_docs()
|
||||||
Y = _get_output(model.ops.xp)
|
Y = _get_output(model.ops.xp)
|
||||||
_set_output_dim(nO=Y.shape[-1], model=model)
|
|
||||||
# The output vector might differ from the official type of the output layer
|
# The output vector might differ from the official type of the output layer
|
||||||
with data_validation(False):
|
with data_validation(False):
|
||||||
model.initialize(X=_get_docs(), Y=Y)
|
model.initialize(X=X, Y=Y)
|
||||||
if print_settings.get("print_after_init"):
|
if print_settings.get("print_after_init"):
|
||||||
msg.divider(f"STEP 1 - after initialization")
|
msg.divider(f"STEP 1 - after initialization")
|
||||||
_print_model(model, print_settings)
|
_print_model(model, print_settings)
|
||||||
|
@ -135,15 +135,6 @@ def _get_output(xp):
|
||||||
return xp.asarray([i + 10 for i, _ in enumerate(_get_docs())], dtype="float32")
|
return xp.asarray([i + 10 for i, _ in enumerate(_get_docs())], dtype="float32")
|
||||||
|
|
||||||
|
|
||||||
def _set_output_dim(model, nO):
|
|
||||||
# the dim inference doesn't always work 100%, we need this hack like we have it in pipe.pyx
|
|
||||||
if model.has_dim("nO") is None:
|
|
||||||
model.set_dim("nO", nO)
|
|
||||||
if model.has_ref("output_layer"):
|
|
||||||
if model.get_ref("output_layer").has_dim("nO") is None:
|
|
||||||
model.get_ref("output_layer").set_dim("nO", nO)
|
|
||||||
|
|
||||||
|
|
||||||
def _print_model(model, print_settings):
|
def _print_model(model, print_settings):
|
||||||
layers = print_settings.get("layers", "")
|
layers = print_settings.get("layers", "")
|
||||||
parameters = print_settings.get("parameters", False)
|
parameters = print_settings.get("parameters", False)
|
||||||
|
|
|
@ -247,8 +247,8 @@ class Errors:
|
||||||
"Query string: {string}\nOrth cached: {orth}\nOrth ID: {orth_id}")
|
"Query string: {string}\nOrth cached: {orth}\nOrth ID: {orth_id}")
|
||||||
E065 = ("Only one of the vector table's width and shape can be specified. "
|
E065 = ("Only one of the vector table's width and shape can be specified. "
|
||||||
"Got width {width} and shape {shape}.")
|
"Got width {width} and shape {shape}.")
|
||||||
E067 = ("Invalid BILUO tag sequence: Got a tag starting with 'I' (inside "
|
E067 = ("Invalid BILUO tag sequence: Got a tag starting with {start} "
|
||||||
"an entity) without a preceding 'B' (beginning of an entity). "
|
"without a preceding 'B' (beginning of an entity). "
|
||||||
"Tag sequence:\n{tags}")
|
"Tag sequence:\n{tags}")
|
||||||
E068 = ("Invalid BILUO tag: '{tag}'.")
|
E068 = ("Invalid BILUO tag: '{tag}'.")
|
||||||
E071 = ("Error creating lexeme: specified orth ID ({orth}) does not "
|
E071 = ("Error creating lexeme: specified orth ID ({orth}) does not "
|
||||||
|
@ -320,10 +320,6 @@ class Errors:
|
||||||
"So instead of pickling the span, pickle the Doc it belongs to or "
|
"So instead of pickling the span, pickle the Doc it belongs to or "
|
||||||
"use Span.as_doc to convert the span to a standalone Doc object.")
|
"use Span.as_doc to convert the span to a standalone Doc object.")
|
||||||
E115 = ("All subtokens must have associated heads.")
|
E115 = ("All subtokens must have associated heads.")
|
||||||
E116 = ("Cannot currently add labels to pretrained text classifier. Add "
|
|
||||||
"labels before training begins. This functionality was available "
|
|
||||||
"in previous versions, but had significant bugs that led to poor "
|
|
||||||
"performance.")
|
|
||||||
E117 = ("The newly split tokens must match the text of the original token. "
|
E117 = ("The newly split tokens must match the text of the original token. "
|
||||||
"New orths: {new}. Old text: {old}.")
|
"New orths: {new}. Old text: {old}.")
|
||||||
E118 = ("The custom extension attribute '{attr}' is not registered on the "
|
E118 = ("The custom extension attribute '{attr}' is not registered on the "
|
||||||
|
@ -378,8 +374,9 @@ class Errors:
|
||||||
"should be of equal length.")
|
"should be of equal length.")
|
||||||
E141 = ("Entity vectors should be of length {required} instead of the "
|
E141 = ("Entity vectors should be of length {required} instead of the "
|
||||||
"provided {found}.")
|
"provided {found}.")
|
||||||
E143 = ("Labels for component '{name}' not initialized. Did you forget to "
|
E143 = ("Labels for component '{name}' not initialized. This can be fixed "
|
||||||
"call add_label()?")
|
"by calling add_label, or by providing a representative batch of "
|
||||||
|
"examples to the component's begin_training method.")
|
||||||
E145 = ("Error reading `{param}` from input file.")
|
E145 = ("Error reading `{param}` from input file.")
|
||||||
E146 = ("Could not access `{path}`.")
|
E146 = ("Could not access `{path}`.")
|
||||||
E147 = ("Unexpected error in the {method} functionality of the "
|
E147 = ("Unexpected error in the {method} functionality of the "
|
||||||
|
@ -483,6 +480,16 @@ class Errors:
|
||||||
E201 = ("Span index out of range.")
|
E201 = ("Span index out of range.")
|
||||||
|
|
||||||
# TODO: fix numbering after merging develop into master
|
# TODO: fix numbering after merging develop into master
|
||||||
|
E921 = ("The method 'set_output' can only be called on components that have "
|
||||||
|
"a Model with a 'resize_output' attribute. Otherwise, the output "
|
||||||
|
"layer can not be dynamically changed.")
|
||||||
|
E922 = ("Component '{name}' has been initialized with an output dimension of "
|
||||||
|
"{nO} - cannot add any more labels.")
|
||||||
|
E923 = ("It looks like there is no proper sample data to initialize the "
|
||||||
|
"Model of component '{name}'. "
|
||||||
|
"This is likely a bug in spaCy, so feel free to open an issue.")
|
||||||
|
E924 = ("The '{name}' component does not seem to be initialized properly. "
|
||||||
|
"This is likely a bug in spaCy, so feel free to open an issue.")
|
||||||
E925 = ("Invalid color values for displaCy visualizer: expected dictionary "
|
E925 = ("Invalid color values for displaCy visualizer: expected dictionary "
|
||||||
"mapping label names to colors but got: {obj}")
|
"mapping label names to colors but got: {obj}")
|
||||||
E926 = ("It looks like you're trying to modify nlp.{attr} directly. This "
|
E926 = ("It looks like you're trying to modify nlp.{attr} directly. This "
|
||||||
|
|
|
@ -195,13 +195,15 @@ def tags_to_entities(tags):
|
||||||
continue
|
continue
|
||||||
elif tag.startswith("I"):
|
elif tag.startswith("I"):
|
||||||
if start is None:
|
if start is None:
|
||||||
raise ValueError(Errors.E067.format(tags=tags[: i + 1]))
|
raise ValueError(Errors.E067.format(start="I", tags=tags[: i + 1]))
|
||||||
continue
|
continue
|
||||||
if tag.startswith("U"):
|
if tag.startswith("U"):
|
||||||
entities.append((tag[2:], i, i))
|
entities.append((tag[2:], i, i))
|
||||||
elif tag.startswith("B"):
|
elif tag.startswith("B"):
|
||||||
start = i
|
start = i
|
||||||
elif tag.startswith("L"):
|
elif tag.startswith("L"):
|
||||||
|
if start is None:
|
||||||
|
raise ValueError(Errors.E067.format(start="L", tags=tags[: i + 1]))
|
||||||
entities.append((tag[2:], start, i))
|
entities.append((tag[2:], start, i))
|
||||||
start = None
|
start = None
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -656,7 +656,7 @@ class Language:
|
||||||
return resolved[factory_name]
|
return resolved[factory_name]
|
||||||
|
|
||||||
def create_pipe_from_source(
|
def create_pipe_from_source(
|
||||||
self, source_name: str, source: "Language", *, name: str,
|
self, source_name: str, source: "Language", *, name: str
|
||||||
) -> Tuple[Callable[[Doc], Doc], str]:
|
) -> Tuple[Callable[[Doc], Doc], str]:
|
||||||
"""Create a pipeline component by copying it from an existing model.
|
"""Create a pipeline component by copying it from an existing model.
|
||||||
|
|
||||||
|
@ -1155,21 +1155,24 @@ class Language:
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/language#begin_training
|
DOCS: https://nightly.spacy.io/api/language#begin_training
|
||||||
"""
|
"""
|
||||||
# TODO: throw warning when get_gold_tuples is provided instead of get_examples
|
|
||||||
if get_examples is None:
|
if get_examples is None:
|
||||||
get_examples = lambda: []
|
util.logger.debug(
|
||||||
else: # Populate vocab
|
"No 'get_examples' callback provided to 'Language.begin_training', creating dummy examples"
|
||||||
if not hasattr(get_examples, "__call__"):
|
)
|
||||||
err = Errors.E930.format(name="Language", obj=type(get_examples))
|
doc = Doc(self.vocab, words=["x", "y", "z"])
|
||||||
|
get_examples = lambda: [Example.from_dict(doc, {})]
|
||||||
|
# Populate vocab
|
||||||
|
if not hasattr(get_examples, "__call__"):
|
||||||
|
err = Errors.E930.format(name="Language", obj=type(get_examples))
|
||||||
|
raise ValueError(err)
|
||||||
|
for example in get_examples():
|
||||||
|
if not isinstance(example, Example):
|
||||||
|
err = Errors.E978.format(
|
||||||
|
name="Language.begin_training", types=type(example)
|
||||||
|
)
|
||||||
raise ValueError(err)
|
raise ValueError(err)
|
||||||
for example in get_examples():
|
for word in [t.text for t in example.reference]:
|
||||||
if not isinstance(example, Example):
|
_ = self.vocab[word] # noqa: F841
|
||||||
err = Errors.E978.format(
|
|
||||||
name="Language.begin_training", types=type(example)
|
|
||||||
)
|
|
||||||
raise ValueError(err)
|
|
||||||
for word in [t.text for t in example.reference]:
|
|
||||||
_ = self.vocab[word] # noqa: F841
|
|
||||||
if device >= 0: # TODO: do we need this here?
|
if device >= 0: # TODO: do we need this here?
|
||||||
require_gpu(device)
|
require_gpu(device)
|
||||||
if self.vocab.vectors.data.shape[1] >= 1:
|
if self.vocab.vectors.data.shape[1] >= 1:
|
||||||
|
@ -1187,7 +1190,7 @@ class Language:
|
||||||
return self._optimizer
|
return self._optimizer
|
||||||
|
|
||||||
def resume_training(
|
def resume_training(
|
||||||
self, *, sgd: Optional[Optimizer] = None, device: int = -1,
|
self, *, sgd: Optional[Optimizer] = None, device: int = -1
|
||||||
) -> Optimizer:
|
) -> Optimizer:
|
||||||
"""Continue training a pretrained model.
|
"""Continue training a pretrained model.
|
||||||
|
|
||||||
|
|
|
@ -62,8 +62,6 @@ def forward(model: Model[Padded, Padded], Xp: Padded, is_train: bool):
|
||||||
def get_num_actions(n_labels: int) -> int:
|
def get_num_actions(n_labels: int) -> int:
|
||||||
# One BEGIN action per label
|
# One BEGIN action per label
|
||||||
# One IN action per label
|
# One IN action per label
|
||||||
# One LAST action per label
|
|
||||||
# One UNIT action per label
|
|
||||||
# One OUT action
|
# One OUT action
|
||||||
return n_labels * 2 + 1
|
return n_labels * 2 + 1
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ def BiluoTagger(
|
||||||
A BILUO tag sequence encodes a sequence of non-overlapping labelled spans
|
A BILUO tag sequence encodes a sequence of non-overlapping labelled spans
|
||||||
into tags assigned to each token. The first token of a span is given the
|
into tags assigned to each token. The first token of a span is given the
|
||||||
tag B-LABEL, the last token of the span is given the tag L-LABEL, and tokens
|
tag B-LABEL, the last token of the span is given the tag L-LABEL, and tokens
|
||||||
within the span are given the tag U-LABEL. Single-token spans are given
|
within the span are given the tag I-LABEL. Single-token spans are given
|
||||||
the tag U-LABEL. All other tokens are assigned the tag O.
|
the tag U-LABEL. All other tokens are assigned the tag O.
|
||||||
|
|
||||||
The BILUO tag scheme generally results in better linear separation between
|
The BILUO tag scheme generally results in better linear separation between
|
||||||
|
@ -86,7 +86,7 @@ def IOBTagger(
|
||||||
|
|
||||||
|
|
||||||
def init(model: Model[List[Doc], List[Floats2d]], X=None, Y=None) -> None:
|
def init(model: Model[List[Doc], List[Floats2d]], X=None, Y=None) -> None:
|
||||||
if model.get_dim("nO") is None and Y:
|
if model.has_dim("nO") is None and Y:
|
||||||
model.set_dim("nO", Y[0].shape[1])
|
model.set_dim("nO", Y[0].shape[1])
|
||||||
nO = model.get_dim("nO")
|
nO = model.get_dim("nO")
|
||||||
biluo = model.get_ref("biluo")
|
biluo = model.get_ref("biluo")
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from itertools import islice
|
||||||
from typing import Optional, Iterable, Callable, Dict, Iterator, Union, List, Tuple
|
from typing import Optional, Iterable, Callable, Dict, Iterator, Union, List, Tuple
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import srsly
|
import srsly
|
||||||
|
@ -128,7 +129,7 @@ class EntityLinker(Pipe):
|
||||||
# how many neightbour sentences to take into account
|
# how many neightbour sentences to take into account
|
||||||
self.n_sents = cfg.get("n_sents", 0)
|
self.n_sents = cfg.get("n_sents", 0)
|
||||||
|
|
||||||
def require_kb(self) -> None:
|
def _require_kb(self) -> None:
|
||||||
# Raise an error if the knowledge base is not initialized.
|
# Raise an error if the knowledge base is not initialized.
|
||||||
if len(self.kb) == 0:
|
if len(self.kb) == 0:
|
||||||
raise ValueError(Errors.E139.format(name=self.name))
|
raise ValueError(Errors.E139.format(name=self.name))
|
||||||
|
@ -140,10 +141,11 @@ class EntityLinker(Pipe):
|
||||||
pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None,
|
pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None,
|
||||||
sgd: Optional[Optimizer] = None,
|
sgd: Optional[Optimizer] = None,
|
||||||
) -> Optimizer:
|
) -> Optimizer:
|
||||||
"""Initialize the pipe for training, using data examples if available.
|
"""Initialize the pipe for training, using a representative set
|
||||||
|
of data examples.
|
||||||
|
|
||||||
get_examples (Callable[[], Iterable[Example]]): Optional function that
|
get_examples (Callable[[], Iterable[Example]]): Function that
|
||||||
returns gold-standard Example objects.
|
returns a representative sample of gold-standard Example objects.
|
||||||
pipeline (List[Tuple[str, Callable]]): Optional list of pipeline
|
pipeline (List[Tuple[str, Callable]]): Optional list of pipeline
|
||||||
components that this component is part of. Corresponds to
|
components that this component is part of. Corresponds to
|
||||||
nlp.pipeline.
|
nlp.pipeline.
|
||||||
|
@ -153,10 +155,19 @@ class EntityLinker(Pipe):
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/entitylinker#begin_training
|
DOCS: https://nightly.spacy.io/api/entitylinker#begin_training
|
||||||
"""
|
"""
|
||||||
self.require_kb()
|
self._ensure_examples(get_examples)
|
||||||
|
self._require_kb()
|
||||||
nO = self.kb.entity_vector_length
|
nO = self.kb.entity_vector_length
|
||||||
self.set_output(nO)
|
doc_sample = []
|
||||||
self.model.initialize()
|
vector_sample = []
|
||||||
|
for example in islice(get_examples(), 10):
|
||||||
|
doc_sample.append(example.x)
|
||||||
|
vector_sample.append(self.model.ops.alloc1f(nO))
|
||||||
|
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
|
assert len(vector_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
|
self.model.initialize(
|
||||||
|
X=doc_sample, Y=self.model.ops.asarray(vector_sample, dtype="float32")
|
||||||
|
)
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
sgd = self.create_optimizer()
|
sgd = self.create_optimizer()
|
||||||
return sgd
|
return sgd
|
||||||
|
@ -184,7 +195,7 @@ class EntityLinker(Pipe):
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/entitylinker#update
|
DOCS: https://nightly.spacy.io/api/entitylinker#update
|
||||||
"""
|
"""
|
||||||
self.require_kb()
|
self._require_kb()
|
||||||
if losses is None:
|
if losses is None:
|
||||||
losses = {}
|
losses = {}
|
||||||
losses.setdefault(self.name, 0.0)
|
losses.setdefault(self.name, 0.0)
|
||||||
|
@ -296,7 +307,7 @@ class EntityLinker(Pipe):
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/entitylinker#predict
|
DOCS: https://nightly.spacy.io/api/entitylinker#predict
|
||||||
"""
|
"""
|
||||||
self.require_kb()
|
self._require_kb()
|
||||||
entity_count = 0
|
entity_count = 0
|
||||||
final_kb_ids = []
|
final_kb_ids = []
|
||||||
if not docs:
|
if not docs:
|
||||||
|
@ -405,7 +416,7 @@ class EntityLinker(Pipe):
|
||||||
token.ent_kb_id_ = kb_id
|
token.ent_kb_id_ = kb_id
|
||||||
|
|
||||||
def to_disk(
|
def to_disk(
|
||||||
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList(),
|
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Serialize the pipe to disk.
|
"""Serialize the pipe to disk.
|
||||||
|
|
||||||
|
@ -422,7 +433,7 @@ class EntityLinker(Pipe):
|
||||||
util.to_disk(path, serialize, exclude)
|
util.to_disk(path, serialize, exclude)
|
||||||
|
|
||||||
def from_disk(
|
def from_disk(
|
||||||
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList(),
|
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
|
||||||
) -> "EntityLinker":
|
) -> "EntityLinker":
|
||||||
"""Load the pipe from disk. Modifies the object in place and returns it.
|
"""Load the pipe from disk. Modifies the object in place and returns it.
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import srsly
|
import srsly
|
||||||
from thinc.api import SequenceCategoricalCrossentropy, Model, Config
|
from thinc.api import SequenceCategoricalCrossentropy, Model, Config
|
||||||
|
from itertools import islice
|
||||||
|
|
||||||
from ..tokens.doc cimport Doc
|
from ..tokens.doc cimport Doc
|
||||||
from ..vocab cimport Vocab
|
from ..vocab cimport Vocab
|
||||||
|
@ -112,6 +113,7 @@ class Morphologizer(Tagger):
|
||||||
raise ValueError(Errors.E187)
|
raise ValueError(Errors.E187)
|
||||||
if label in self.labels:
|
if label in self.labels:
|
||||||
return 0
|
return 0
|
||||||
|
self._allow_extra_label()
|
||||||
# normalize label
|
# normalize label
|
||||||
norm_label = self.vocab.morphology.normalize_features(label)
|
norm_label = self.vocab.morphology.normalize_features(label)
|
||||||
# extract separate POS and morph tags
|
# extract separate POS and morph tags
|
||||||
|
@ -128,10 +130,11 @@ class Morphologizer(Tagger):
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
def begin_training(self, get_examples, *, pipeline=None, sgd=None):
|
def begin_training(self, get_examples, *, pipeline=None, sgd=None):
|
||||||
"""Initialize the pipe for training, using data examples if available.
|
"""Initialize the pipe for training, using a representative set
|
||||||
|
of data examples.
|
||||||
|
|
||||||
get_examples (Callable[[], Iterable[Example]]): Optional function that
|
get_examples (Callable[[], Iterable[Example]]): Function that
|
||||||
returns gold-standard Example objects.
|
returns a representative sample of gold-standard Example objects.
|
||||||
pipeline (List[Tuple[str, Callable]]): Optional list of pipeline
|
pipeline (List[Tuple[str, Callable]]): Optional list of pipeline
|
||||||
components that this component is part of. Corresponds to
|
components that this component is part of. Corresponds to
|
||||||
nlp.pipeline.
|
nlp.pipeline.
|
||||||
|
@ -141,9 +144,8 @@ class Morphologizer(Tagger):
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/morphologizer#begin_training
|
DOCS: https://nightly.spacy.io/api/morphologizer#begin_training
|
||||||
"""
|
"""
|
||||||
if not hasattr(get_examples, "__call__"):
|
self._ensure_examples(get_examples)
|
||||||
err = Errors.E930.format(name="Morphologizer", obj=type(get_examples))
|
# First, fetch all labels from the data
|
||||||
raise ValueError(err)
|
|
||||||
for example in get_examples():
|
for example in get_examples():
|
||||||
for i, token in enumerate(example.reference):
|
for i, token in enumerate(example.reference):
|
||||||
pos = token.pos_
|
pos = token.pos_
|
||||||
|
@ -157,8 +159,25 @@ class Morphologizer(Tagger):
|
||||||
if norm_label not in self.cfg["labels_morph"]:
|
if norm_label not in self.cfg["labels_morph"]:
|
||||||
self.cfg["labels_morph"][norm_label] = morph
|
self.cfg["labels_morph"][norm_label] = morph
|
||||||
self.cfg["labels_pos"][norm_label] = POS_IDS[pos]
|
self.cfg["labels_pos"][norm_label] = POS_IDS[pos]
|
||||||
self.set_output(len(self.labels))
|
if len(self.labels) <= 1:
|
||||||
self.model.initialize()
|
raise ValueError(Errors.E143.format(name=self.name))
|
||||||
|
doc_sample = []
|
||||||
|
label_sample = []
|
||||||
|
for example in islice(get_examples(), 10):
|
||||||
|
gold_array = []
|
||||||
|
for i, token in enumerate(example.reference):
|
||||||
|
pos = token.pos_
|
||||||
|
morph = token.morph_
|
||||||
|
morph_dict = Morphology.feats_to_dict(morph)
|
||||||
|
if pos:
|
||||||
|
morph_dict[self.POS_FEAT] = pos
|
||||||
|
norm_label = self.vocab.strings[self.vocab.morphology.add(morph_dict)]
|
||||||
|
gold_array.append([1.0 if label == norm_label else 0.0 for label in self.labels])
|
||||||
|
doc_sample.append(example.x)
|
||||||
|
label_sample.append(self.model.ops.asarray(gold_array, dtype="float32"))
|
||||||
|
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
|
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
|
self.model.initialize(X=doc_sample, Y=label_sample)
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
sgd = self.create_optimizer()
|
sgd = self.create_optimizer()
|
||||||
return sgd
|
return sgd
|
||||||
|
|
|
@ -90,7 +90,7 @@ class MultitaskObjective(Tagger):
|
||||||
label = self.make_label(token)
|
label = self.make_label(token)
|
||||||
if label is not None and label not in self.labels:
|
if label is not None and label not in self.labels:
|
||||||
self.labels[label] = len(self.labels)
|
self.labels[label] = len(self.labels)
|
||||||
self.model.initialize()
|
self.model.initialize() # TODO: fix initialization by defining X and Y
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
sgd = self.create_optimizer()
|
sgd = self.create_optimizer()
|
||||||
return sgd
|
return sgd
|
||||||
|
@ -178,7 +178,7 @@ class ClozeMultitask(Pipe):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def begin_training(self, get_examples, pipeline=None, sgd=None):
|
def begin_training(self, get_examples, pipeline=None, sgd=None):
|
||||||
self.model.initialize()
|
self.model.initialize() # TODO: fix initialization by defining X and Y
|
||||||
X = self.model.ops.alloc((5, self.model.get_ref("tok2vec").get_dim("nO")))
|
X = self.model.ops.alloc((5, self.model.get_ref("tok2vec").get_dim("nO")))
|
||||||
self.model.output_layer.begin_training(X)
|
self.model.output_layer.begin_training(X)
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
|
|
|
@ -160,6 +160,20 @@ cdef class Pipe:
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError(Errors.E931.format(method="add_label", name=self.name))
|
raise NotImplementedError(Errors.E931.format(method="add_label", name=self.name))
|
||||||
|
|
||||||
|
|
||||||
|
def _require_labels(self) -> None:
|
||||||
|
"""Raise an error if the component's model has no labels defined."""
|
||||||
|
if not self.labels or list(self.labels) == [""]:
|
||||||
|
raise ValueError(Errors.E143.format(name=self.name))
|
||||||
|
|
||||||
|
|
||||||
|
def _allow_extra_label(self) -> None:
|
||||||
|
"""Raise an error if the component can not add any more labels."""
|
||||||
|
if self.model.has_dim("nO") and self.model.get_dim("nO") == len(self.labels):
|
||||||
|
if not self.is_resizable():
|
||||||
|
raise ValueError(Errors.E922.format(name=self.name, nO=self.model.get_dim("nO")))
|
||||||
|
|
||||||
|
|
||||||
def create_optimizer(self):
|
def create_optimizer(self):
|
||||||
"""Create an optimizer for the pipeline component.
|
"""Create an optimizer for the pipeline component.
|
||||||
|
|
||||||
|
@ -171,9 +185,12 @@ cdef class Pipe:
|
||||||
|
|
||||||
def begin_training(self, get_examples, *, pipeline=None, sgd=None):
|
def begin_training(self, get_examples, *, pipeline=None, sgd=None):
|
||||||
"""Initialize the pipe for training, using data examples if available.
|
"""Initialize the pipe for training, using data examples if available.
|
||||||
|
This method needs to be implemented by each Pipe component,
|
||||||
|
ensuring the internal model (if available) is initialized properly
|
||||||
|
using the provided sample of Example objects.
|
||||||
|
|
||||||
get_examples (Callable[[], Iterable[Example]]): Optional function that
|
get_examples (Callable[[], Iterable[Example]]): Function that
|
||||||
returns gold-standard Example objects.
|
returns a representative sample of gold-standard Example objects.
|
||||||
pipeline (List[Tuple[str, Callable]]): Optional list of pipeline
|
pipeline (List[Tuple[str, Callable]]): Optional list of pipeline
|
||||||
components that this component is part of. Corresponds to
|
components that this component is part of. Corresponds to
|
||||||
nlp.pipeline.
|
nlp.pipeline.
|
||||||
|
@ -183,16 +200,24 @@ cdef class Pipe:
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/pipe#begin_training
|
DOCS: https://nightly.spacy.io/api/pipe#begin_training
|
||||||
"""
|
"""
|
||||||
self.model.initialize()
|
raise NotImplementedError(Errors.E931.format(method="add_label", name=self.name))
|
||||||
if sgd is None:
|
|
||||||
sgd = self.create_optimizer()
|
def _ensure_examples(self, get_examples):
|
||||||
return sgd
|
if get_examples is None or not hasattr(get_examples, "__call__"):
|
||||||
|
err = Errors.E930.format(name=self.name, obj=type(get_examples))
|
||||||
|
raise ValueError(err)
|
||||||
|
if not get_examples():
|
||||||
|
err = Errors.E930.format(name=self.name, obj=get_examples())
|
||||||
|
raise ValueError(err)
|
||||||
|
|
||||||
|
def is_resizable(self):
|
||||||
|
return hasattr(self, "model") and "resize_output" in self.model.attrs
|
||||||
|
|
||||||
def set_output(self, nO):
|
def set_output(self, nO):
|
||||||
if self.model.has_dim("nO") is not False:
|
if self.is_resizable():
|
||||||
self.model.set_dim("nO", nO)
|
self.model.attrs["resize_output"](self.model, nO)
|
||||||
if self.model.has_ref("output_layer"):
|
else:
|
||||||
self.model.get_ref("output_layer").set_dim("nO", nO)
|
raise NotImplementedError(Errors.E921)
|
||||||
|
|
||||||
def use_params(self, params):
|
def use_params(self, params):
|
||||||
"""Modify the pipe's model, to use the given parameter values. At the
|
"""Modify the pipe's model, to use the given parameter values. At the
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
# cython: infer_types=True, profile=True, binding=True
|
# cython: infer_types=True, profile=True, binding=True
|
||||||
|
from itertools import islice
|
||||||
|
|
||||||
import srsly
|
import srsly
|
||||||
from thinc.api import Model, SequenceCategoricalCrossentropy, Config
|
from thinc.api import Model, SequenceCategoricalCrossentropy, Config
|
||||||
|
|
||||||
|
@ -124,10 +126,11 @@ class SentenceRecognizer(Tagger):
|
||||||
return float(loss), d_scores
|
return float(loss), d_scores
|
||||||
|
|
||||||
def begin_training(self, get_examples, *, pipeline=None, sgd=None):
|
def begin_training(self, get_examples, *, pipeline=None, sgd=None):
|
||||||
"""Initialize the pipe for training, using data examples if available.
|
"""Initialize the pipe for training, using a representative set
|
||||||
|
of data examples.
|
||||||
|
|
||||||
get_examples (Callable[[], Iterable[Example]]): Optional function that
|
get_examples (Callable[[], Iterable[Example]]): Function that
|
||||||
returns gold-standard Example objects.
|
returns a representative sample of gold-standard Example objects.
|
||||||
pipeline (List[Tuple[str, Callable]]): Optional list of pipeline
|
pipeline (List[Tuple[str, Callable]]): Optional list of pipeline
|
||||||
components that this component is part of. Corresponds to
|
components that this component is part of. Corresponds to
|
||||||
nlp.pipeline.
|
nlp.pipeline.
|
||||||
|
@ -137,8 +140,18 @@ class SentenceRecognizer(Tagger):
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/sentencerecognizer#begin_training
|
DOCS: https://nightly.spacy.io/api/sentencerecognizer#begin_training
|
||||||
"""
|
"""
|
||||||
self.set_output(len(self.labels))
|
self._ensure_examples(get_examples)
|
||||||
self.model.initialize()
|
doc_sample = []
|
||||||
|
label_sample = []
|
||||||
|
assert self.labels, Errors.E924.format(name=self.name)
|
||||||
|
for example in islice(get_examples(), 10):
|
||||||
|
doc_sample.append(example.x)
|
||||||
|
gold_tags = example.get_aligned("SENT_START")
|
||||||
|
gold_array = [[1.0 if tag == gold_tag else 0.0 for tag in self.labels] for gold_tag in gold_tags]
|
||||||
|
label_sample.append(self.model.ops.asarray(gold_array, dtype="float32"))
|
||||||
|
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
|
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
|
self.model.initialize(X=doc_sample, Y=label_sample)
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
sgd = self.create_optimizer()
|
sgd = self.create_optimizer()
|
||||||
return sgd
|
return sgd
|
||||||
|
|
|
@ -3,6 +3,7 @@ from thinc.types import Floats2d
|
||||||
from thinc.api import SequenceCategoricalCrossentropy, set_dropout_rate, Model
|
from thinc.api import SequenceCategoricalCrossentropy, set_dropout_rate, Model
|
||||||
from thinc.api import Optimizer, Config
|
from thinc.api import Optimizer, Config
|
||||||
from thinc.util import to_numpy
|
from thinc.util import to_numpy
|
||||||
|
from itertools import islice
|
||||||
|
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
from ..gold import Example, spans_from_biluo_tags, iob_to_biluo, biluo_to_iob
|
from ..gold import Example, spans_from_biluo_tags, iob_to_biluo, biluo_to_iob
|
||||||
|
@ -168,18 +169,29 @@ class SimpleNER(Pipe):
|
||||||
pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None,
|
pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None,
|
||||||
sgd: Optional[Optimizer] = None,
|
sgd: Optional[Optimizer] = None,
|
||||||
):
|
):
|
||||||
|
self._ensure_examples(get_examples)
|
||||||
all_labels = set()
|
all_labels = set()
|
||||||
if not hasattr(get_examples, "__call__"):
|
|
||||||
err = Errors.E930.format(name="SimpleNER", obj=type(get_examples))
|
|
||||||
raise ValueError(err)
|
|
||||||
for example in get_examples():
|
for example in get_examples():
|
||||||
all_labels.update(_get_labels(example))
|
all_labels.update(_get_labels(example))
|
||||||
for label in sorted(all_labels):
|
for label in sorted(all_labels):
|
||||||
self.add_label(label)
|
if label != "":
|
||||||
labels = self.labels
|
self.add_label(label)
|
||||||
n_actions = self.model.attrs["get_num_actions"](len(labels))
|
doc_sample = []
|
||||||
self.model.set_dim("nO", n_actions)
|
label_sample = []
|
||||||
self.model.initialize()
|
self._require_labels()
|
||||||
|
for example in islice(get_examples(), 10):
|
||||||
|
doc_sample.append(example.x)
|
||||||
|
gold_tags = example.get_aligned_ner()
|
||||||
|
if not self.is_biluo:
|
||||||
|
gold_tags = biluo_to_iob(gold_tags)
|
||||||
|
gold_array = [
|
||||||
|
[1.0 if tag == gold_tag else 0.0 for tag in self.get_tag_names()]
|
||||||
|
for gold_tag in gold_tags
|
||||||
|
]
|
||||||
|
label_sample.append(self.model.ops.asarray(gold_array, dtype="float32"))
|
||||||
|
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
|
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
|
self.model.initialize(X=doc_sample, Y=label_sample)
|
||||||
if pipeline is not None:
|
if pipeline is not None:
|
||||||
self.init_multitask_objectives(get_examples, pipeline, sgd=sgd, **self.cfg)
|
self.init_multitask_objectives(get_examples, pipeline, sgd=sgd, **self.cfg)
|
||||||
self.loss_func = SequenceCategoricalCrossentropy(
|
self.loss_func = SequenceCategoricalCrossentropy(
|
||||||
|
@ -206,6 +218,6 @@ def _has_ner(example: Example) -> bool:
|
||||||
def _get_labels(example: Example) -> Set[str]:
|
def _get_labels(example: Example) -> Set[str]:
|
||||||
labels = set()
|
labels = set()
|
||||||
for ner_tag in example.get_aligned("ENT_TYPE", as_string=True):
|
for ner_tag in example.get_aligned("ENT_TYPE", as_string=True):
|
||||||
if ner_tag != "O" and ner_tag != "-":
|
if ner_tag != "O" and ner_tag != "-" and ner_tag != "":
|
||||||
labels.add(ner_tag)
|
labels.add(ner_tag)
|
||||||
return labels
|
return labels
|
||||||
|
|
|
@ -5,6 +5,7 @@ import srsly
|
||||||
from thinc.api import Model, set_dropout_rate, SequenceCategoricalCrossentropy, Config
|
from thinc.api import Model, set_dropout_rate, SequenceCategoricalCrossentropy, Config
|
||||||
from thinc.types import Floats2d
|
from thinc.types import Floats2d
|
||||||
import warnings
|
import warnings
|
||||||
|
from itertools import islice
|
||||||
|
|
||||||
from ..tokens.doc cimport Doc
|
from ..tokens.doc cimport Doc
|
||||||
from ..morphology cimport Morphology
|
from ..morphology cimport Morphology
|
||||||
|
@ -258,10 +259,11 @@ class Tagger(Pipe):
|
||||||
return float(loss), d_scores
|
return float(loss), d_scores
|
||||||
|
|
||||||
def begin_training(self, get_examples, *, pipeline=None, sgd=None):
|
def begin_training(self, get_examples, *, pipeline=None, sgd=None):
|
||||||
"""Initialize the pipe for training, using data examples if available.
|
"""Initialize the pipe for training, using a representative set
|
||||||
|
of data examples.
|
||||||
|
|
||||||
get_examples (Callable[[], Iterable[Example]]): Optional function that
|
get_examples (Callable[[], Iterable[Example]]): Function that
|
||||||
returns gold-standard Example objects.
|
returns a representative sample of gold-standard Example objects..
|
||||||
pipeline (List[Tuple[str, Callable]]): Optional list of pipeline
|
pipeline (List[Tuple[str, Callable]]): Optional list of pipeline
|
||||||
components that this component is part of. Corresponds to
|
components that this component is part of. Corresponds to
|
||||||
nlp.pipeline.
|
nlp.pipeline.
|
||||||
|
@ -271,32 +273,24 @@ class Tagger(Pipe):
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/tagger#begin_training
|
DOCS: https://nightly.spacy.io/api/tagger#begin_training
|
||||||
"""
|
"""
|
||||||
if not hasattr(get_examples, "__call__"):
|
self._ensure_examples(get_examples)
|
||||||
err = Errors.E930.format(name="Tagger", obj=type(get_examples))
|
|
||||||
raise ValueError(err)
|
|
||||||
tags = set()
|
|
||||||
doc_sample = []
|
doc_sample = []
|
||||||
|
label_sample = []
|
||||||
|
tags = set()
|
||||||
for example in get_examples():
|
for example in get_examples():
|
||||||
for token in example.y:
|
for token in example.y:
|
||||||
tags.add(token.tag_)
|
if token.tag_:
|
||||||
if len(doc_sample) < 10:
|
tags.add(token.tag_)
|
||||||
doc_sample.append(example.x)
|
|
||||||
if not doc_sample:
|
|
||||||
doc_sample.append(Doc(self.vocab, words=["hello"]))
|
|
||||||
for tag in sorted(tags):
|
for tag in sorted(tags):
|
||||||
self.add_label(tag)
|
self.add_label(tag)
|
||||||
if len(self.labels) == 0:
|
for example in islice(get_examples(), 10):
|
||||||
err = Errors.E1006.format(name="Tagger")
|
doc_sample.append(example.x)
|
||||||
raise ValueError(err)
|
gold_tags = example.get_aligned("TAG", as_string=True)
|
||||||
self.set_output(len(self.labels))
|
gold_array = [[1.0 if tag == gold_tag else 0.0 for tag in self.labels] for gold_tag in gold_tags]
|
||||||
if doc_sample:
|
label_sample.append(self.model.ops.asarray(gold_array, dtype="float32"))
|
||||||
label_sample = [
|
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
self.model.ops.alloc2f(len(doc), len(self.labels))
|
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
for doc in doc_sample
|
self.model.initialize(X=doc_sample, Y=label_sample)
|
||||||
]
|
|
||||||
self.model.initialize(X=doc_sample, Y=label_sample)
|
|
||||||
else:
|
|
||||||
self.model.initialize()
|
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
sgd = self.create_optimizer()
|
sgd = self.create_optimizer()
|
||||||
return sgd
|
return sgd
|
||||||
|
@ -313,6 +307,7 @@ class Tagger(Pipe):
|
||||||
raise ValueError(Errors.E187)
|
raise ValueError(Errors.E187)
|
||||||
if label in self.labels:
|
if label in self.labels:
|
||||||
return 0
|
return 0
|
||||||
|
self._allow_extra_label()
|
||||||
self.cfg["labels"].append(label)
|
self.cfg["labels"].append(label)
|
||||||
self.vocab.strings.add(label)
|
self.vocab.strings.add(label)
|
||||||
return 1
|
return 1
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from itertools import islice
|
||||||
from typing import Iterable, Tuple, Optional, Dict, List, Callable, Iterator, Any
|
from typing import Iterable, Tuple, Optional, Dict, List, Callable, Iterator, Any
|
||||||
from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config
|
from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config
|
||||||
from thinc.types import Floats2d
|
from thinc.types import Floats2d
|
||||||
|
@ -128,11 +129,6 @@ class TextCategorizer(Pipe):
|
||||||
"""
|
"""
|
||||||
return tuple(self.cfg.setdefault("labels", []))
|
return tuple(self.cfg.setdefault("labels", []))
|
||||||
|
|
||||||
def require_labels(self) -> None:
|
|
||||||
"""Raise an error if the component's model has no labels defined."""
|
|
||||||
if not self.labels:
|
|
||||||
raise ValueError(Errors.E143.format(name=self.name))
|
|
||||||
|
|
||||||
@labels.setter
|
@labels.setter
|
||||||
def labels(self, value: Iterable[str]) -> None:
|
def labels(self, value: Iterable[str]) -> None:
|
||||||
self.cfg["labels"] = tuple(value)
|
self.cfg["labels"] = tuple(value)
|
||||||
|
@ -311,17 +307,7 @@ class TextCategorizer(Pipe):
|
||||||
raise ValueError(Errors.E187)
|
raise ValueError(Errors.E187)
|
||||||
if label in self.labels:
|
if label in self.labels:
|
||||||
return 0
|
return 0
|
||||||
if self.model.has_dim("nO"):
|
self._allow_extra_label()
|
||||||
# This functionality was available previously, but was broken.
|
|
||||||
# The problem is that we resize the last layer, but the last layer
|
|
||||||
# is actually just an ensemble. We're not resizing the child layers
|
|
||||||
# - a huge problem.
|
|
||||||
raise ValueError(Errors.E116)
|
|
||||||
# smaller = self.model._layers[-1]
|
|
||||||
# larger = Linear(len(self.labels)+1, smaller.nI)
|
|
||||||
# copy_array(larger.W[:smaller.nO], smaller.W)
|
|
||||||
# copy_array(larger.b[:smaller.nO], smaller.b)
|
|
||||||
# self.model._layers[-1] = larger
|
|
||||||
self.labels = tuple(list(self.labels) + [label])
|
self.labels = tuple(list(self.labels) + [label])
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
@ -332,10 +318,11 @@ class TextCategorizer(Pipe):
|
||||||
pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None,
|
pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None,
|
||||||
sgd: Optional[Optimizer] = None,
|
sgd: Optional[Optimizer] = None,
|
||||||
) -> Optimizer:
|
) -> Optimizer:
|
||||||
"""Initialize the pipe for training, using data examples if available.
|
"""Initialize the pipe for training, using a representative set
|
||||||
|
of data examples.
|
||||||
|
|
||||||
get_examples (Callable[[], Iterable[Example]]): Optional function that
|
get_examples (Callable[[], Iterable[Example]]): Function that
|
||||||
returns gold-standard Example objects.
|
returns a representative sample of gold-standard Example objects.
|
||||||
pipeline (List[Tuple[str, Callable]]): Optional list of pipeline
|
pipeline (List[Tuple[str, Callable]]): Optional list of pipeline
|
||||||
components that this component is part of. Corresponds to
|
components that this component is part of. Corresponds to
|
||||||
nlp.pipeline.
|
nlp.pipeline.
|
||||||
|
@ -345,22 +332,19 @@ class TextCategorizer(Pipe):
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/textcategorizer#begin_training
|
DOCS: https://nightly.spacy.io/api/textcategorizer#begin_training
|
||||||
"""
|
"""
|
||||||
if not hasattr(get_examples, "__call__"):
|
self._ensure_examples(get_examples)
|
||||||
err = Errors.E930.format(name="TextCategorizer", obj=type(get_examples))
|
|
||||||
raise ValueError(err)
|
|
||||||
subbatch = [] # Select a subbatch of examples to initialize the model
|
subbatch = [] # Select a subbatch of examples to initialize the model
|
||||||
for example in get_examples():
|
for example in islice(get_examples(), 10):
|
||||||
if len(subbatch) < 2:
|
if len(subbatch) < 2:
|
||||||
subbatch.append(example)
|
subbatch.append(example)
|
||||||
for cat in example.y.cats:
|
for cat in example.y.cats:
|
||||||
self.add_label(cat)
|
self.add_label(cat)
|
||||||
self.require_labels()
|
doc_sample = [eg.reference for eg in subbatch]
|
||||||
docs = [eg.reference for eg in subbatch]
|
label_sample, _ = self._examples_to_truth(subbatch)
|
||||||
if not docs: # need at least one doc
|
self._require_labels()
|
||||||
docs = [Doc(self.vocab, words=["hello"])]
|
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
truths, _ = self._examples_to_truth(subbatch)
|
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
self.set_output(len(self.labels))
|
self.model.initialize(X=doc_sample, Y=label_sample)
|
||||||
self.model.initialize(X=docs, Y=truths)
|
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
sgd = self.create_optimizer()
|
sgd = self.create_optimizer()
|
||||||
return sgd
|
return sgd
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from typing import Iterator, Sequence, Iterable, Optional, Dict, Callable, List, Tuple
|
from typing import Iterator, Sequence, Iterable, Optional, Dict, Callable, List, Tuple
|
||||||
from thinc.api import Model, set_dropout_rate, Optimizer, Config
|
from thinc.api import Model, set_dropout_rate, Optimizer, Config
|
||||||
|
from itertools import islice
|
||||||
|
|
||||||
from .pipe import Pipe
|
from .pipe import Pipe
|
||||||
from ..gold import Example, validate_examples
|
from ..gold import Example, validate_examples
|
||||||
|
@ -209,10 +210,11 @@ class Tok2Vec(Pipe):
|
||||||
pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None,
|
pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None,
|
||||||
sgd: Optional[Optimizer] = None,
|
sgd: Optional[Optimizer] = None,
|
||||||
):
|
):
|
||||||
"""Initialize the pipe for training, using data examples if available.
|
"""Initialize the pipe for training, using a representative set
|
||||||
|
of data examples.
|
||||||
|
|
||||||
get_examples (Callable[[], Iterable[Example]]): Optional function that
|
get_examples (Callable[[], Iterable[Example]]): Function that
|
||||||
returns gold-standard Example objects.
|
returns a representative sample of gold-standard Example objects.
|
||||||
pipeline (List[Tuple[str, Callable]]): Optional list of pipeline
|
pipeline (List[Tuple[str, Callable]]): Optional list of pipeline
|
||||||
components that this component is part of. Corresponds to
|
components that this component is part of. Corresponds to
|
||||||
nlp.pipeline.
|
nlp.pipeline.
|
||||||
|
@ -222,8 +224,12 @@ class Tok2Vec(Pipe):
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/tok2vec#begin_training
|
DOCS: https://nightly.spacy.io/api/tok2vec#begin_training
|
||||||
"""
|
"""
|
||||||
docs = [Doc(self.vocab, words=["hello"])]
|
self._ensure_examples(get_examples)
|
||||||
self.model.initialize(X=docs)
|
doc_sample = []
|
||||||
|
for example in islice(get_examples(), 10):
|
||||||
|
doc_sample.append(example.x)
|
||||||
|
assert doc_sample, Errors.E923.format(name=self.name)
|
||||||
|
self.model.initialize(X=doc_sample)
|
||||||
|
|
||||||
def add_label(self, label):
|
def add_label(self, label):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -244,7 +244,7 @@ cdef class Parser(Pipe):
|
||||||
int nr_class, int batch_size) nogil:
|
int nr_class, int batch_size) nogil:
|
||||||
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
||||||
with gil:
|
with gil:
|
||||||
assert self.moves.n_moves > 0
|
assert self.moves.n_moves > 0, Errors.E924.format(name=self.name)
|
||||||
is_valid = <int*>calloc(self.moves.n_moves, sizeof(int))
|
is_valid = <int*>calloc(self.moves.n_moves, sizeof(int))
|
||||||
cdef int i, guess
|
cdef int i, guess
|
||||||
cdef Transition action
|
cdef Transition action
|
||||||
|
@ -378,7 +378,7 @@ cdef class Parser(Pipe):
|
||||||
cdef int i
|
cdef int i
|
||||||
|
|
||||||
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
||||||
assert self.moves.n_moves > 0
|
assert self.moves.n_moves > 0, Errors.E924.format(name=self.name)
|
||||||
|
|
||||||
is_valid = <int*>mem.alloc(self.moves.n_moves, sizeof(int))
|
is_valid = <int*>mem.alloc(self.moves.n_moves, sizeof(int))
|
||||||
costs = <float*>mem.alloc(self.moves.n_moves, sizeof(float))
|
costs = <float*>mem.alloc(self.moves.n_moves, sizeof(float))
|
||||||
|
@ -406,9 +406,7 @@ cdef class Parser(Pipe):
|
||||||
self.model.attrs["resize_output"](self.model, nO)
|
self.model.attrs["resize_output"](self.model, nO)
|
||||||
|
|
||||||
def begin_training(self, get_examples, pipeline=None, sgd=None, **kwargs):
|
def begin_training(self, get_examples, pipeline=None, sgd=None, **kwargs):
|
||||||
if not hasattr(get_examples, "__call__"):
|
self._ensure_examples(get_examples)
|
||||||
err = Errors.E930.format(name="DependencyParser/EntityRecognizer", obj=type(get_examples))
|
|
||||||
raise ValueError(err)
|
|
||||||
self.cfg.update(kwargs)
|
self.cfg.update(kwargs)
|
||||||
lexeme_norms = self.vocab.lookups.get_table("lexeme_norm", {})
|
lexeme_norms = self.vocab.lookups.get_table("lexeme_norm", {})
|
||||||
if len(lexeme_norms) == 0 and self.vocab.lang in util.LEXEME_NORM_LANGS:
|
if len(lexeme_norms) == 0 and self.vocab.lang in util.LEXEME_NORM_LANGS:
|
||||||
|
@ -430,9 +428,6 @@ cdef class Parser(Pipe):
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
sgd = self.create_optimizer()
|
sgd = self.create_optimizer()
|
||||||
doc_sample = []
|
doc_sample = []
|
||||||
for example in islice(get_examples(), 10):
|
|
||||||
doc_sample.append(example.predicted)
|
|
||||||
|
|
||||||
if pipeline is not None:
|
if pipeline is not None:
|
||||||
for name, component in pipeline:
|
for name, component in pipeline:
|
||||||
if component is self:
|
if component is self:
|
||||||
|
@ -441,10 +436,11 @@ cdef class Parser(Pipe):
|
||||||
doc_sample = list(component.pipe(doc_sample, batch_size=8))
|
doc_sample = list(component.pipe(doc_sample, batch_size=8))
|
||||||
else:
|
else:
|
||||||
doc_sample = [component(doc) for doc in doc_sample]
|
doc_sample = [component(doc) for doc in doc_sample]
|
||||||
if doc_sample:
|
if not doc_sample:
|
||||||
self.model.initialize(doc_sample)
|
for example in islice(get_examples(), 10):
|
||||||
else:
|
doc_sample.append(example.predicted)
|
||||||
self.model.initialize()
|
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
|
self.model.initialize(doc_sample)
|
||||||
if pipeline is not None:
|
if pipeline is not None:
|
||||||
self.init_multitask_objectives(get_examples, pipeline, sgd=sgd, **self.cfg)
|
self.init_multitask_objectives(get_examples, pipeline, sgd=sgd, **self.cfg)
|
||||||
return sgd
|
return sgd
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
|
from spacy.gold import Example
|
||||||
from spacy.pipeline import EntityRecognizer
|
from spacy.pipeline import EntityRecognizer
|
||||||
from spacy.tokens import Span
|
from spacy.tokens import Span, Doc
|
||||||
from spacy import registry
|
from spacy import registry
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -7,6 +8,12 @@ from ..util import get_doc
|
||||||
from spacy.pipeline.ner import DEFAULT_NER_MODEL
|
from spacy.pipeline.ner import DEFAULT_NER_MODEL
|
||||||
|
|
||||||
|
|
||||||
|
def _ner_example(ner):
|
||||||
|
doc = Doc(ner.vocab, words=["Joe", "loves", "visiting", "London", "during", "the", "weekend"])
|
||||||
|
gold = {"entities": [(0, 3, "PERSON"), (19, 25, "LOC")]}
|
||||||
|
return Example.from_dict(doc, gold)
|
||||||
|
|
||||||
|
|
||||||
def test_doc_add_entities_set_ents_iob(en_vocab):
|
def test_doc_add_entities_set_ents_iob(en_vocab):
|
||||||
text = ["This", "is", "a", "lion"]
|
text = ["This", "is", "a", "lion"]
|
||||||
doc = get_doc(en_vocab, text)
|
doc = get_doc(en_vocab, text)
|
||||||
|
@ -18,10 +25,8 @@ def test_doc_add_entities_set_ents_iob(en_vocab):
|
||||||
cfg = {"model": DEFAULT_NER_MODEL}
|
cfg = {"model": DEFAULT_NER_MODEL}
|
||||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||||
ner = EntityRecognizer(en_vocab, model, **config)
|
ner = EntityRecognizer(en_vocab, model, **config)
|
||||||
ner.begin_training(lambda: [])
|
ner.begin_training(lambda: [_ner_example(ner)])
|
||||||
ner(doc)
|
ner(doc)
|
||||||
assert len(list(doc.ents)) == 0
|
|
||||||
assert [w.ent_iob_ for w in doc] == (["O"] * len(doc))
|
|
||||||
|
|
||||||
doc.ents = [(doc.vocab.strings["ANIMAL"], 3, 4)]
|
doc.ents = [(doc.vocab.strings["ANIMAL"], 3, 4)]
|
||||||
assert [w.ent_iob_ for w in doc] == ["O", "O", "O", "B"]
|
assert [w.ent_iob_ for w in doc] == ["O", "O", "O", "B"]
|
||||||
|
@ -31,6 +36,7 @@ def test_doc_add_entities_set_ents_iob(en_vocab):
|
||||||
|
|
||||||
|
|
||||||
def test_ents_reset(en_vocab):
|
def test_ents_reset(en_vocab):
|
||||||
|
"""Ensure that resetting doc.ents does not change anything"""
|
||||||
text = ["This", "is", "a", "lion"]
|
text = ["This", "is", "a", "lion"]
|
||||||
doc = get_doc(en_vocab, text)
|
doc = get_doc(en_vocab, text)
|
||||||
config = {
|
config = {
|
||||||
|
@ -41,11 +47,11 @@ def test_ents_reset(en_vocab):
|
||||||
cfg = {"model": DEFAULT_NER_MODEL}
|
cfg = {"model": DEFAULT_NER_MODEL}
|
||||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||||
ner = EntityRecognizer(en_vocab, model, **config)
|
ner = EntityRecognizer(en_vocab, model, **config)
|
||||||
ner.begin_training(lambda: [])
|
ner.begin_training(lambda: [_ner_example(ner)])
|
||||||
ner(doc)
|
ner(doc)
|
||||||
assert [t.ent_iob_ for t in doc] == (["O"] * len(doc))
|
orig_iobs = [t.ent_iob_ for t in doc]
|
||||||
doc.ents = list(doc.ents)
|
doc.ents = list(doc.ents)
|
||||||
assert [t.ent_iob_ for t in doc] == (["O"] * len(doc))
|
assert [t.ent_iob_ for t in doc] == orig_iobs
|
||||||
|
|
||||||
|
|
||||||
def test_add_overlapping_entities(en_vocab):
|
def test_add_overlapping_entities(en_vocab):
|
||||||
|
|
|
@ -35,7 +35,7 @@ def test_init_parser(parser):
|
||||||
def _train_parser(parser):
|
def _train_parser(parser):
|
||||||
fix_random_seed(1)
|
fix_random_seed(1)
|
||||||
parser.add_label("left")
|
parser.add_label("left")
|
||||||
parser.begin_training(lambda: [], **parser.cfg)
|
parser.begin_training(lambda: [_parser_example(parser)], **parser.cfg)
|
||||||
sgd = Adam(0.001)
|
sgd = Adam(0.001)
|
||||||
|
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
|
@ -47,16 +47,25 @@ def _train_parser(parser):
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def _parser_example(parser):
|
||||||
|
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
|
||||||
|
gold = {"heads": [1, 1, 3, 3], "deps": ["right", "ROOT", "left", "ROOT"]}
|
||||||
|
return Example.from_dict(doc, gold)
|
||||||
|
|
||||||
|
|
||||||
|
def _ner_example(ner):
|
||||||
|
doc = Doc(ner.vocab, words=["Joe", "loves", "visiting", "London", "during", "the", "weekend"])
|
||||||
|
gold = {"entities": [(0, 3, "PERSON"), (19, 25, "LOC")]}
|
||||||
|
return Example.from_dict(doc, gold)
|
||||||
|
|
||||||
|
|
||||||
def test_add_label(parser):
|
def test_add_label(parser):
|
||||||
parser = _train_parser(parser)
|
parser = _train_parser(parser)
|
||||||
parser.add_label("right")
|
parser.add_label("right")
|
||||||
sgd = Adam(0.001)
|
sgd = Adam(0.001)
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
losses = {}
|
losses = {}
|
||||||
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
|
parser.update([_parser_example(parser)], sgd=sgd, losses=losses)
|
||||||
gold = {"heads": [1, 1, 3, 3], "deps": ["right", "ROOT", "left", "ROOT"]}
|
|
||||||
example = Example.from_dict(doc, gold)
|
|
||||||
parser.update([example], sgd=sgd, losses=losses)
|
|
||||||
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
|
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
|
||||||
doc = parser(doc)
|
doc = parser(doc)
|
||||||
assert doc[0].dep_ == "right"
|
assert doc[0].dep_ == "right"
|
||||||
|
@ -75,7 +84,7 @@ def test_add_label_deserializes_correctly():
|
||||||
ner1.add_label("C")
|
ner1.add_label("C")
|
||||||
ner1.add_label("B")
|
ner1.add_label("B")
|
||||||
ner1.add_label("A")
|
ner1.add_label("A")
|
||||||
ner1.begin_training(lambda: [])
|
ner1.begin_training(lambda: [_ner_example(ner1)])
|
||||||
ner2 = EntityRecognizer(Vocab(), model, **config)
|
ner2 = EntityRecognizer(Vocab(), model, **config)
|
||||||
|
|
||||||
# the second model needs to be resized before we can call from_bytes
|
# the second model needs to be resized before we can call from_bytes
|
||||||
|
|
|
@ -85,7 +85,7 @@ def test_parser_merge_pp(en_tokenizer):
|
||||||
pos = ["DET", "NOUN", "ADP", "DET", "NOUN", "VERB"]
|
pos = ["DET", "NOUN", "ADP", "DET", "NOUN", "VERB"]
|
||||||
tokens = en_tokenizer(text)
|
tokens = en_tokenizer(text)
|
||||||
doc = get_doc(
|
doc = get_doc(
|
||||||
tokens.vocab, words=[t.text for t in tokens], deps=deps, heads=heads, pos=pos,
|
tokens.vocab, words=[t.text for t in tokens], deps=deps, heads=heads, pos=pos
|
||||||
)
|
)
|
||||||
with doc.retokenize() as retokenizer:
|
with doc.retokenize() as retokenizer:
|
||||||
for np in doc.noun_chunks:
|
for np in doc.noun_chunks:
|
||||||
|
|
|
@ -14,6 +14,12 @@ def vocab():
|
||||||
return Vocab(lex_attr_getters={NORM: lambda s: s})
|
return Vocab(lex_attr_getters={NORM: lambda s: s})
|
||||||
|
|
||||||
|
|
||||||
|
def _parser_example(parser):
|
||||||
|
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
|
||||||
|
gold = {"heads": [1, 1, 3, 3], "deps": ["right", "ROOT", "left", "ROOT"]}
|
||||||
|
return Example.from_dict(doc, gold)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def parser(vocab):
|
def parser(vocab):
|
||||||
config = {
|
config = {
|
||||||
|
@ -28,7 +34,7 @@ def parser(vocab):
|
||||||
parser.cfg["hidden_width"] = 32
|
parser.cfg["hidden_width"] = 32
|
||||||
# parser.add_label('right')
|
# parser.add_label('right')
|
||||||
parser.add_label("left")
|
parser.add_label("left")
|
||||||
parser.begin_training(lambda: [], **parser.cfg)
|
parser.begin_training(lambda: [_parser_example(parser)], **parser.cfg)
|
||||||
sgd = Adam(0.001)
|
sgd = Adam(0.001)
|
||||||
|
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
|
|
|
@ -281,11 +281,12 @@ def test_append_invalid_alias(nlp):
|
||||||
|
|
||||||
def test_preserving_links_asdoc(nlp):
|
def test_preserving_links_asdoc(nlp):
|
||||||
"""Test that Span.as_doc preserves the existing entity links"""
|
"""Test that Span.as_doc preserves the existing entity links"""
|
||||||
|
vector_length = 1
|
||||||
|
|
||||||
@registry.misc.register("myLocationsKB.v1")
|
@registry.misc.register("myLocationsKB.v1")
|
||||||
def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
|
def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
|
||||||
def create_kb(vocab):
|
def create_kb(vocab):
|
||||||
mykb = KnowledgeBase(vocab, entity_vector_length=1)
|
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
|
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
|
||||||
mykb.add_entity(entity="Q2", freq=8, entity_vector=[1])
|
mykb.add_entity(entity="Q2", freq=8, entity_vector=[1])
|
||||||
|
@ -305,10 +306,9 @@ def test_preserving_links_asdoc(nlp):
|
||||||
ruler = nlp.add_pipe("entity_ruler")
|
ruler = nlp.add_pipe("entity_ruler")
|
||||||
ruler.add_patterns(patterns)
|
ruler.add_patterns(patterns)
|
||||||
el_config = {"kb_loader": {"@misc": "myLocationsKB.v1"}, "incl_prior": False}
|
el_config = {"kb_loader": {"@misc": "myLocationsKB.v1"}, "incl_prior": False}
|
||||||
el_pipe = nlp.add_pipe("entity_linker", config=el_config, last=True)
|
entity_linker = nlp.add_pipe("entity_linker", config=el_config, last=True)
|
||||||
el_pipe.begin_training(lambda: [])
|
nlp.begin_training()
|
||||||
el_pipe.incl_context = False
|
assert entity_linker.model.get_dim("nO") == vector_length
|
||||||
el_pipe.incl_prior = True
|
|
||||||
|
|
||||||
# test whether the entity links are preserved by the `as_doc()` function
|
# test whether the entity links are preserved by the `as_doc()` function
|
||||||
text = "She lives in Boston. He lives in Denver."
|
text = "She lives in Boston. He lives in Denver."
|
||||||
|
@ -373,6 +373,7 @@ def test_overfitting_IO():
|
||||||
# Simple test to try and quickly overfit the NEL component - ensuring the ML models work correctly
|
# Simple test to try and quickly overfit the NEL component - ensuring the ML models work correctly
|
||||||
nlp = English()
|
nlp = English()
|
||||||
nlp.add_pipe("sentencizer")
|
nlp.add_pipe("sentencizer")
|
||||||
|
vector_length = 3
|
||||||
|
|
||||||
# Add a custom component to recognize "Russ Cochran" as an entity for the example training data
|
# Add a custom component to recognize "Russ Cochran" as an entity for the example training data
|
||||||
patterns = [
|
patterns = [
|
||||||
|
@ -393,7 +394,7 @@ def test_overfitting_IO():
|
||||||
# create artificial KB - assign same prior weight to the two russ cochran's
|
# create artificial KB - assign same prior weight to the two russ cochran's
|
||||||
# Q2146908 (Russ Cochran): American golfer
|
# Q2146908 (Russ Cochran): American golfer
|
||||||
# Q7381115 (Russ Cochran): publisher
|
# Q7381115 (Russ Cochran): publisher
|
||||||
mykb = KnowledgeBase(vocab, entity_vector_length=3)
|
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
|
||||||
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
|
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
|
||||||
mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7])
|
mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7])
|
||||||
mykb.add_alias(
|
mykb.add_alias(
|
||||||
|
@ -406,14 +407,17 @@ def test_overfitting_IO():
|
||||||
return create_kb
|
return create_kb
|
||||||
|
|
||||||
# Create the Entity Linker component and add it to the pipeline
|
# Create the Entity Linker component and add it to the pipeline
|
||||||
nlp.add_pipe(
|
entity_linker = nlp.add_pipe(
|
||||||
"entity_linker",
|
"entity_linker",
|
||||||
config={"kb_loader": {"@misc": "myOverfittingKB.v1"}},
|
config={"kb_loader": {"@misc": "myOverfittingKB.v1"}},
|
||||||
last=True,
|
last=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# train the NEL pipe
|
# train the NEL pipe
|
||||||
optimizer = nlp.begin_training()
|
optimizer = nlp.begin_training(get_examples=lambda: train_examples)
|
||||||
|
assert entity_linker.model.get_dim("nO") == vector_length
|
||||||
|
assert entity_linker.model.get_dim("nO") == entity_linker.kb.entity_vector_length
|
||||||
|
|
||||||
for i in range(50):
|
for i in range(50):
|
||||||
losses = {}
|
losses = {}
|
||||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
|
|
|
@ -25,27 +25,61 @@ TRAIN_DATA = [
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
# test combinations of morph+POS
|
# test combinations of morph+POS
|
||||||
("Eat blue ham", {"morphs": ["Feat=V", "", ""], "pos": ["", "ADJ", ""]},),
|
("Eat blue ham", {"morphs": ["Feat=V", "", ""], "pos": ["", "ADJ", ""]}),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_label():
|
||||||
|
nlp = Language()
|
||||||
|
nlp.add_pipe("morphologizer")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
nlp.begin_training()
|
||||||
|
|
||||||
|
|
||||||
|
def test_implicit_label():
|
||||||
|
nlp = Language()
|
||||||
|
nlp.add_pipe("morphologizer")
|
||||||
|
train_examples = []
|
||||||
|
for t in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||||
|
nlp.begin_training(get_examples=lambda: train_examples)
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_resize():
|
||||||
|
nlp = Language()
|
||||||
|
morphologizer = nlp.add_pipe("morphologizer")
|
||||||
|
morphologizer.add_label("POS" + Morphology.FIELD_SEP + "NOUN")
|
||||||
|
morphologizer.add_label("POS" + Morphology.FIELD_SEP + "VERB")
|
||||||
|
nlp.begin_training()
|
||||||
|
# this throws an error because the morphologizer can't be resized after initialization
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
morphologizer.add_label("POS" + Morphology.FIELD_SEP + "ADJ")
|
||||||
|
|
||||||
|
|
||||||
|
def test_begin_training_examples():
|
||||||
|
nlp = Language()
|
||||||
|
morphologizer = nlp.add_pipe("morphologizer")
|
||||||
|
morphologizer.add_label("POS" + Morphology.FIELD_SEP + "NOUN")
|
||||||
|
train_examples = []
|
||||||
|
for t in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||||
|
# you shouldn't really call this more than once, but for testing it should be fine
|
||||||
|
nlp.begin_training()
|
||||||
|
nlp.begin_training(get_examples=lambda: train_examples)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
nlp.begin_training(get_examples=lambda: None)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
nlp.begin_training(get_examples=train_examples)
|
||||||
|
|
||||||
|
|
||||||
def test_overfitting_IO():
|
def test_overfitting_IO():
|
||||||
# Simple test to try and quickly overfit the morphologizer - ensuring the ML models work correctly
|
# Simple test to try and quickly overfit the morphologizer - ensuring the ML models work correctly
|
||||||
nlp = English()
|
nlp = English()
|
||||||
morphologizer = nlp.add_pipe("morphologizer")
|
nlp.add_pipe("morphologizer")
|
||||||
train_examples = []
|
train_examples = []
|
||||||
for inst in TRAIN_DATA:
|
for inst in TRAIN_DATA:
|
||||||
train_examples.append(Example.from_dict(nlp.make_doc(inst[0]), inst[1]))
|
train_examples.append(Example.from_dict(nlp.make_doc(inst[0]), inst[1]))
|
||||||
for morph, pos in zip(inst[1]["morphs"], inst[1]["pos"]):
|
optimizer = nlp.begin_training(get_examples=lambda: train_examples)
|
||||||
if morph and pos:
|
|
||||||
morphologizer.add_label(
|
|
||||||
morph + Morphology.FEATURE_SEP + "POS" + Morphology.FIELD_SEP + pos
|
|
||||||
)
|
|
||||||
elif pos:
|
|
||||||
morphologizer.add_label("POS" + Morphology.FIELD_SEP + pos)
|
|
||||||
elif morph:
|
|
||||||
morphologizer.add_label(morph)
|
|
||||||
optimizer = nlp.begin_training()
|
|
||||||
|
|
||||||
for i in range(50):
|
for i in range(50):
|
||||||
losses = {}
|
losses = {}
|
||||||
|
@ -55,18 +89,8 @@ def test_overfitting_IO():
|
||||||
# test the trained model
|
# test the trained model
|
||||||
test_text = "I like blue ham"
|
test_text = "I like blue ham"
|
||||||
doc = nlp(test_text)
|
doc = nlp(test_text)
|
||||||
gold_morphs = [
|
gold_morphs = ["Feat=N", "Feat=V", "", ""]
|
||||||
"Feat=N",
|
gold_pos_tags = ["NOUN", "VERB", "ADJ", ""]
|
||||||
"Feat=V",
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
]
|
|
||||||
gold_pos_tags = [
|
|
||||||
"NOUN",
|
|
||||||
"VERB",
|
|
||||||
"ADJ",
|
|
||||||
"",
|
|
||||||
]
|
|
||||||
assert [t.morph_ for t in doc] == gold_morphs
|
assert [t.morph_ for t in doc] == gold_morphs
|
||||||
assert [t.pos_ for t in doc] == gold_pos_tags
|
assert [t.pos_ for t in doc] == gold_pos_tags
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,20 @@ TRAIN_DATA = [
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def test_begin_training_examples():
|
||||||
|
nlp = Language()
|
||||||
|
senter = nlp.add_pipe("senter")
|
||||||
|
train_examples = []
|
||||||
|
for t in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||||
|
# you shouldn't really call this more than once, but for testing it should be fine
|
||||||
|
nlp.begin_training()
|
||||||
|
nlp.begin_training(get_examples=lambda: train_examples)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
nlp.begin_training(get_examples=lambda: None)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
nlp.begin_training(get_examples=train_examples)
|
||||||
|
|
||||||
|
|
||||||
def test_overfitting_IO():
|
def test_overfitting_IO():
|
||||||
# Simple test to try and quickly overfit the senter - ensuring the ML models work correctly
|
# Simple test to try and quickly overfit the senter - ensuring the ML models work correctly
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import pytest
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.gold import Example
|
from spacy.gold import Example
|
||||||
from spacy import util
|
from spacy import util
|
||||||
|
@ -5,11 +6,73 @@ from ..util import make_tempdir
|
||||||
|
|
||||||
|
|
||||||
TRAIN_DATA = [
|
TRAIN_DATA = [
|
||||||
("Who is Shaka Khan?", {"entities": [(7, 17, "PERSON")]}),
|
("Who is Shaka S Khan?", {"entities": [(7, 19, "PERSON")]}),
|
||||||
("I like London and Berlin.", {"entities": [(7, 13, "LOC"), (18, 24, "LOC")]}),
|
("I like London and Berlin.", {"entities": [(7, 13, "LOC"), (18, 24, "LOC")]}),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_label():
|
||||||
|
nlp = English()
|
||||||
|
nlp.add_pipe("simple_ner")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
nlp.begin_training()
|
||||||
|
|
||||||
|
|
||||||
|
def test_implicit_label():
|
||||||
|
nlp = English()
|
||||||
|
ner = nlp.add_pipe("simple_ner")
|
||||||
|
train_examples = []
|
||||||
|
ner.add_label("ORG")
|
||||||
|
for t in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||||
|
nlp.begin_training(get_examples=lambda: train_examples)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Should be fixed")
|
||||||
|
def test_untrained():
|
||||||
|
# This shouldn't crash, but it does when the simple_ner produces an invalid sequence like ['L-PERSON', 'L-ORG']
|
||||||
|
nlp = English()
|
||||||
|
ner = nlp.add_pipe("simple_ner")
|
||||||
|
ner.add_label("PERSON")
|
||||||
|
ner.add_label("LOC")
|
||||||
|
ner.add_label("ORG")
|
||||||
|
nlp.begin_training()
|
||||||
|
nlp("Example sentence")
|
||||||
|
|
||||||
|
|
||||||
|
def test_resize():
|
||||||
|
nlp = English()
|
||||||
|
ner = nlp.add_pipe("simple_ner")
|
||||||
|
ner.add_label("PERSON")
|
||||||
|
ner.add_label("LOC")
|
||||||
|
nlp.begin_training()
|
||||||
|
assert len(ner.labels) == 2
|
||||||
|
ner.add_label("ORG")
|
||||||
|
nlp.begin_training()
|
||||||
|
assert len(ner.labels) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_begin_training_examples():
|
||||||
|
nlp = English()
|
||||||
|
ner = nlp.add_pipe("simple_ner")
|
||||||
|
train_examples = []
|
||||||
|
for text, annotations in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||||
|
for ent in annotations.get("entities"):
|
||||||
|
ner.add_label(ent[2])
|
||||||
|
# you shouldn't really call this more than once, but for testing it should be fine
|
||||||
|
nlp.begin_training()
|
||||||
|
nlp.begin_training(get_examples=lambda: train_examples)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
nlp.begin_training(get_examples=lambda: None)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
nlp.begin_training(get_examples=lambda: train_examples[0])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
nlp.begin_training(get_examples=lambda: [])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
nlp.begin_training(get_examples=train_examples)
|
||||||
|
|
||||||
|
|
||||||
def test_overfitting_IO():
|
def test_overfitting_IO():
|
||||||
# Simple test to try and quickly overfit the SimpleNER component - ensuring the ML models work correctly
|
# Simple test to try and quickly overfit the SimpleNER component - ensuring the ML models work correctly
|
||||||
nlp = English()
|
nlp = English()
|
||||||
|
@ -17,9 +80,7 @@ def test_overfitting_IO():
|
||||||
train_examples = []
|
train_examples = []
|
||||||
for text, annotations in TRAIN_DATA:
|
for text, annotations in TRAIN_DATA:
|
||||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||||
for ent in annotations.get("entities"):
|
optimizer = nlp.begin_training(get_examples=lambda: train_examples)
|
||||||
ner.add_label(ent[2])
|
|
||||||
optimizer = nlp.begin_training()
|
|
||||||
|
|
||||||
for i in range(50):
|
for i in range(50):
|
||||||
losses = {}
|
losses = {}
|
||||||
|
|
|
@ -34,6 +34,56 @@ TRAIN_DATA = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_label():
|
||||||
|
nlp = Language()
|
||||||
|
nlp.add_pipe("tagger")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
nlp.begin_training()
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_resize():
|
||||||
|
nlp = Language()
|
||||||
|
tagger = nlp.add_pipe("tagger")
|
||||||
|
tagger.add_label("N")
|
||||||
|
tagger.add_label("V")
|
||||||
|
assert tagger.labels == ("N", "V")
|
||||||
|
nlp.begin_training()
|
||||||
|
assert tagger.model.get_dim("nO") == 2
|
||||||
|
# this throws an error because the tagger can't be resized after initialization
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tagger.add_label("J")
|
||||||
|
|
||||||
|
|
||||||
|
def test_implicit_label():
|
||||||
|
nlp = Language()
|
||||||
|
nlp.add_pipe("tagger")
|
||||||
|
train_examples = []
|
||||||
|
for t in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||||
|
nlp.begin_training(get_examples=lambda: train_examples)
|
||||||
|
|
||||||
|
|
||||||
|
def test_begin_training_examples():
|
||||||
|
nlp = Language()
|
||||||
|
tagger = nlp.add_pipe("tagger")
|
||||||
|
train_examples = []
|
||||||
|
for tag in TAGS:
|
||||||
|
tagger.add_label(tag)
|
||||||
|
for t in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||||
|
# you shouldn't really call this more than once, but for testing it should be fine
|
||||||
|
nlp.begin_training()
|
||||||
|
nlp.begin_training(get_examples=lambda: train_examples)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
nlp.begin_training(get_examples=lambda: None)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
nlp.begin_training(get_examples=lambda: train_examples[0])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
nlp.begin_training(get_examples=lambda: [])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
nlp.begin_training(get_examples=train_examples)
|
||||||
|
|
||||||
|
|
||||||
def test_overfitting_IO():
|
def test_overfitting_IO():
|
||||||
# Simple test to try and quickly overfit the tagger - ensuring the ML models work correctly
|
# Simple test to try and quickly overfit the tagger - ensuring the ML models work correctly
|
||||||
nlp = English()
|
nlp = English()
|
||||||
|
@ -41,9 +91,8 @@ def test_overfitting_IO():
|
||||||
train_examples = []
|
train_examples = []
|
||||||
for t in TRAIN_DATA:
|
for t in TRAIN_DATA:
|
||||||
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||||
for tag in TAGS:
|
optimizer = nlp.begin_training(get_examples=lambda: train_examples)
|
||||||
tagger.add_label(tag)
|
assert tagger.model.get_dim("nO") == len(TAGS)
|
||||||
optimizer = nlp.begin_training()
|
|
||||||
|
|
||||||
for i in range(50):
|
for i in range(50):
|
||||||
losses = {}
|
losses = {}
|
||||||
|
|
|
@ -80,6 +80,51 @@ def test_label_types():
|
||||||
textcat.add_label(9)
|
textcat.add_label(9)
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_label():
|
||||||
|
nlp = Language()
|
||||||
|
nlp.add_pipe("textcat")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
nlp.begin_training()
|
||||||
|
|
||||||
|
|
||||||
|
def test_implicit_label():
|
||||||
|
nlp = Language()
|
||||||
|
textcat = nlp.add_pipe("textcat")
|
||||||
|
train_examples = []
|
||||||
|
for t in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||||
|
nlp.begin_training(get_examples=lambda: train_examples)
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_resize():
|
||||||
|
nlp = Language()
|
||||||
|
textcat = nlp.add_pipe("textcat")
|
||||||
|
textcat.add_label("POSITIVE")
|
||||||
|
textcat.add_label("NEGATIVE")
|
||||||
|
nlp.begin_training()
|
||||||
|
assert textcat.model.get_dim("nO") == 2
|
||||||
|
# this throws an error because the textcat can't be resized after initialization
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
textcat.add_label("NEUTRAL")
|
||||||
|
|
||||||
|
|
||||||
|
def test_begin_training_examples():
|
||||||
|
nlp = Language()
|
||||||
|
textcat = nlp.add_pipe("textcat")
|
||||||
|
train_examples = []
|
||||||
|
for text, annotations in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||||
|
for label, value in annotations.get("cats").items():
|
||||||
|
textcat.add_label(label)
|
||||||
|
# you shouldn't really call this more than once, but for testing it should be fine
|
||||||
|
nlp.begin_training()
|
||||||
|
nlp.begin_training(get_examples=lambda: train_examples)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
nlp.begin_training(get_examples=lambda: None)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
nlp.begin_training(get_examples=train_examples)
|
||||||
|
|
||||||
|
|
||||||
def test_overfitting_IO():
|
def test_overfitting_IO():
|
||||||
# Simple test to try and quickly overfit the textcat component - ensuring the ML models work correctly
|
# Simple test to try and quickly overfit the textcat component - ensuring the ML models work correctly
|
||||||
fix_random_seed(0)
|
fix_random_seed(0)
|
||||||
|
@ -89,9 +134,8 @@ def test_overfitting_IO():
|
||||||
train_examples = []
|
train_examples = []
|
||||||
for text, annotations in TRAIN_DATA:
|
for text, annotations in TRAIN_DATA:
|
||||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||||
for label, value in annotations.get("cats").items():
|
optimizer = nlp.begin_training(get_examples=lambda: train_examples)
|
||||||
textcat.add_label(label)
|
assert textcat.model.get_dim("nO") == 2
|
||||||
optimizer = nlp.begin_training()
|
|
||||||
|
|
||||||
for i in range(50):
|
for i in range(50):
|
||||||
losses = {}
|
losses = {}
|
||||||
|
|
|
@ -20,7 +20,7 @@ def test_issue2564():
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
tagger = nlp.add_pipe("tagger")
|
tagger = nlp.add_pipe("tagger")
|
||||||
tagger.add_label("A")
|
tagger.add_label("A")
|
||||||
tagger.begin_training(lambda: [])
|
nlp.begin_training()
|
||||||
doc = nlp("hello world")
|
doc = nlp("hello world")
|
||||||
assert doc.is_tagged
|
assert doc.is_tagged
|
||||||
docs = nlp.pipe(["hello", "world"])
|
docs = nlp.pipe(["hello", "world"])
|
||||||
|
|
|
@ -251,6 +251,12 @@ def test_issue3803():
|
||||||
assert [t.like_num for t in doc] == [True, True, True, True, True, True]
|
assert [t.like_num for t in doc] == [True, True, True, True, True, True]
|
||||||
|
|
||||||
|
|
||||||
|
def _parser_example(parser):
|
||||||
|
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
|
||||||
|
gold = {"heads": [1, 1, 3, 3], "deps": ["right", "ROOT", "left", "ROOT"]}
|
||||||
|
return Example.from_dict(doc, gold)
|
||||||
|
|
||||||
|
|
||||||
def test_issue3830_no_subtok():
|
def test_issue3830_no_subtok():
|
||||||
"""Test that the parser doesn't have subtok label if not learn_tokens"""
|
"""Test that the parser doesn't have subtok label if not learn_tokens"""
|
||||||
config = {
|
config = {
|
||||||
|
@ -264,7 +270,7 @@ def test_issue3830_no_subtok():
|
||||||
parser = DependencyParser(Vocab(), model, **config)
|
parser = DependencyParser(Vocab(), model, **config)
|
||||||
parser.add_label("nsubj")
|
parser.add_label("nsubj")
|
||||||
assert "subtok" not in parser.labels
|
assert "subtok" not in parser.labels
|
||||||
parser.begin_training(lambda: [])
|
parser.begin_training(lambda: [_parser_example(parser)])
|
||||||
assert "subtok" not in parser.labels
|
assert "subtok" not in parser.labels
|
||||||
|
|
||||||
|
|
||||||
|
@ -281,7 +287,7 @@ def test_issue3830_with_subtok():
|
||||||
parser = DependencyParser(Vocab(), model, **config)
|
parser = DependencyParser(Vocab(), model, **config)
|
||||||
parser.add_label("nsubj")
|
parser.add_label("nsubj")
|
||||||
assert "subtok" not in parser.labels
|
assert "subtok" not in parser.labels
|
||||||
parser.begin_training(lambda: [])
|
parser.begin_training(lambda: [_parser_example(parser)])
|
||||||
assert "subtok" in parser.labels
|
assert "subtok" in parser.labels
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -64,7 +64,7 @@ def tagger():
|
||||||
# 1. no model leads to error in serialization,
|
# 1. no model leads to error in serialization,
|
||||||
# 2. the affected line is the one for model serialization
|
# 2. the affected line is the one for model serialization
|
||||||
tagger.add_label("A")
|
tagger.add_label("A")
|
||||||
tagger.begin_training(lambda: [], pipeline=nlp.pipeline)
|
nlp.begin_training()
|
||||||
return tagger
|
return tagger
|
||||||
|
|
||||||
|
|
||||||
|
@ -85,7 +85,7 @@ def entity_linker():
|
||||||
# need to add model for two reasons:
|
# need to add model for two reasons:
|
||||||
# 1. no model leads to error in serialization,
|
# 1. no model leads to error in serialization,
|
||||||
# 2. the affected line is the one for model serialization
|
# 2. the affected line is the one for model serialization
|
||||||
entity_linker.begin_training(lambda: [], pipeline=nlp.pipeline)
|
nlp.begin_training()
|
||||||
return entity_linker
|
return entity_linker
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -89,6 +89,7 @@ def test_init_tok2vec():
|
||||||
tok2vec = nlp.add_pipe("tok2vec")
|
tok2vec = nlp.add_pipe("tok2vec")
|
||||||
assert tok2vec.listeners == []
|
assert tok2vec.listeners == []
|
||||||
nlp.begin_training()
|
nlp.begin_training()
|
||||||
|
assert tok2vec.model.get_dim("nO")
|
||||||
|
|
||||||
|
|
||||||
cfg_string = """
|
cfg_string = """
|
||||||
|
|
Loading…
Reference in New Issue
Block a user