default models defined in component decorator (#5452)

* move defaults to pipeline and use in component decorator

* black formatting

* relative import
This commit is contained in:
Sofie Van Landeghem 2020-05-19 16:20:03 +02:00 committed by GitHub
parent 0d94737857
commit f00de445dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 64 additions and 61 deletions

View File

@ -184,33 +184,6 @@ class Language(object):
self.max_length = max_length self.max_length = max_length
self._optimizer = None self._optimizer = None
# TODO: de-uglify (incorporating into component decorator didn't work because of circular imports)
from .ml.models.defaults import (
default_tagger_config,
default_parser_config,
default_ner_config,
default_textcat_config,
default_nel_config,
default_morphologizer_config,
default_senter_config,
default_tensorizer_config,
default_tok2vec_config,
default_simple_ner_config
)
self.defaults = {
"tagger": default_tagger_config(),
"parser": default_parser_config(),
"ner": default_ner_config(),
"textcat": default_textcat_config(),
"entity_linker": default_nel_config(),
"morphologizer": default_morphologizer_config(),
"senter": default_senter_config(),
"simple_ner": default_simple_ner_config(),
"tensorizer": default_tensorizer_config(),
"tok2vec": default_tok2vec_config(),
}
@property @property
def path(self): def path(self):
return self._path return self._path
@ -338,7 +311,6 @@ class Language(object):
else: else:
raise KeyError(Errors.E002.format(name=name)) raise KeyError(Errors.E002.format(name=name))
factory = self.factories[name] factory = self.factories[name]
default_config = self.defaults.get(name, None)
# transform the model's config to an actual Model # transform the model's config to an actual Model
factory_cfg = dict(config) factory_cfg = dict(config)
@ -349,11 +321,6 @@ class Language(object):
warnings.warn(Warnings.W099.format(type=type(model_cfg), pipe=name)) warnings.warn(Warnings.W099.format(type=type(model_cfg), pipe=name))
model_cfg = None model_cfg = None
del factory_cfg["model"] del factory_cfg["model"]
if model_cfg is None and default_config is not None:
warnings.warn(Warnings.W098.format(name=name))
model_cfg = default_config["model"]
if model_cfg is None:
warnings.warn(Warnings.W097.format(name=name))
model = None model = None
if model_cfg is not None: if model_cfg is not None:
self.config[name] = {"model": model_cfg} self.config[name] = {"model": model_cfg}
@ -539,7 +506,11 @@ class Language(object):
to_disable = [pipe for pipe in self.pipe_names if pipe not in enable] to_disable = [pipe for pipe in self.pipe_names if pipe not in enable]
# raise an error if the enable and disable keywords are not consistent # raise an error if the enable and disable keywords are not consistent
if disable is not None and disable != to_disable: if disable is not None and disable != to_disable:
raise ValueError(Errors.E992.format(enable=enable, disable=disable, names=self.pipe_names)) raise ValueError(
Errors.E992.format(
enable=enable, disable=disable, names=self.pipe_names
)
)
disable = to_disable disable = to_disable
return DisabledPipes(self, disable) return DisabledPipes(self, disable)
@ -1085,7 +1056,14 @@ class component(object):
# NB: This decorator needs to live here, because it needs to write to # NB: This decorator needs to live here, because it needs to write to
# Language.factories. All other solutions would cause circular import. # Language.factories. All other solutions would cause circular import.
def __init__(self, name=None, assigns=tuple(), requires=tuple(), retokenizes=False): def __init__(
self,
name=None,
assigns=tuple(),
requires=tuple(),
retokenizes=False,
default_model=lambda: None,
):
"""Decorate a pipeline component. """Decorate a pipeline component.
name (unicode): Default component and factory name. name (unicode): Default component and factory name.
@ -1097,6 +1075,7 @@ class component(object):
self.assigns = validate_attrs(assigns) self.assigns = validate_attrs(assigns)
self.requires = validate_attrs(requires) self.requires = validate_attrs(requires)
self.retokenizes = retokenizes self.retokenizes = retokenizes
self.default_model = default_model
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
obj = args[0] obj = args[0]
@ -1109,6 +1088,11 @@ class component(object):
obj.retokenizes = self.retokenizes obj.retokenizes = self.retokenizes
def factory(nlp, model, **cfg): def factory(nlp, model, **cfg):
if model is None:
model = self.default_model()
warnings.warn(Warnings.W098.format(name=self.name))
if model is None:
warnings.warn(Warnings.W097.format(name=self.name))
if hasattr(obj, "from_nlp"): if hasattr(obj, "from_nlp"):
return obj.from_nlp(nlp, model, **cfg) return obj.from_nlp(nlp, model, **cfg)
elif isinstance(obj, type): elif isinstance(obj, type):

View File

@ -1,6 +1,6 @@
from pathlib import Path from pathlib import Path
from .... import util from ... import util
def default_nel_config(): def default_nel_config():

View File

@ -17,9 +17,10 @@ from ..util import link_vectors_to_models, create_default_optimizer
from ..errors import Errors, TempErrors from ..errors import Errors, TempErrors
from .pipes import Tagger, _load_cfg from .pipes import Tagger, _load_cfg
from .. import util from .. import util
from .defaults import default_morphologizer
@component("morphologizer", assigns=["token.morph", "token.pos"]) @component("morphologizer", assigns=["token.morph", "token.pos"], default_model=default_morphologizer)
class Morphologizer(Tagger): class Morphologizer(Tagger):
def __init__(self, vocab, model, **cfg): def __init__(self, vocab, model, **cfg):

View File

@ -2,6 +2,7 @@
import numpy import numpy
import srsly import srsly
import random import random
from thinc.api import CosineDistance, to_categorical, get_array_module from thinc.api import CosineDistance, to_categorical, get_array_module
from thinc.api import set_dropout_rate, SequenceCategoricalCrossentropy from thinc.api import set_dropout_rate, SequenceCategoricalCrossentropy
import warnings import warnings
@ -13,6 +14,8 @@ from ..syntax.arc_eager cimport ArcEager
from ..morphology cimport Morphology from ..morphology cimport Morphology
from ..vocab cimport Vocab from ..vocab cimport Vocab
from .defaults import default_tagger, default_parser, default_ner, default_textcat
from .defaults import default_nel, default_senter, default_tensorizer
from .functions import merge_subtokens from .functions import merge_subtokens
from ..language import Language, component from ..language import Language, component
from ..syntax import nonproj from ..syntax import nonproj
@ -234,7 +237,7 @@ class Pipe(object):
return self return self
@component("tensorizer", assigns=["doc.tensor"]) @component("tensorizer", assigns=["doc.tensor"], default_model=default_tensorizer)
class Tensorizer(Pipe): class Tensorizer(Pipe):
"""Pre-train position-sensitive vectors for tokens.""" """Pre-train position-sensitive vectors for tokens."""
@ -366,7 +369,7 @@ class Tensorizer(Pipe):
return sgd return sgd
@component("tagger", assigns=["token.tag", "token.pos", "token.lemma"]) @component("tagger", assigns=["token.tag", "token.pos", "token.lemma"], default_model=default_tagger)
class Tagger(Pipe): class Tagger(Pipe):
"""Pipeline component for part-of-speech tagging. """Pipeline component for part-of-speech tagging.
@ -636,7 +639,7 @@ class Tagger(Pipe):
return self return self
@component("senter", assigns=["token.is_sent_start"]) @component("senter", assigns=["token.is_sent_start"], default_model=default_senter)
class SentenceRecognizer(Tagger): class SentenceRecognizer(Tagger):
"""Pipeline component for sentence segmentation. """Pipeline component for sentence segmentation.
@ -976,7 +979,7 @@ class ClozeMultitask(Pipe):
losses[self.name] += loss losses[self.name] += loss
@component("textcat", assigns=["doc.cats"]) @component("textcat", assigns=["doc.cats"], default_model=default_textcat)
class TextCategorizer(Pipe): class TextCategorizer(Pipe):
"""Pipeline component for text classification. """Pipeline component for text classification.
@ -1227,7 +1230,8 @@ cdef class EntityRecognizer(Parser):
@component( @component(
"entity_linker", "entity_linker",
requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"], requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"],
assigns=["token.ent_kb_id"] assigns=["token.ent_kb_id"],
default_model=default_nel,
) )
class EntityLinker(Pipe): class EntityLinker(Pipe):
"""Pipeline component for named entity linking. """Pipeline component for named entity linking.
@ -1673,8 +1677,19 @@ class Sentencizer(Pipe):
# Cython classes can't be decorated, so we need to add the factories here # Cython classes can't be decorated, so we need to add the factories here
Language.factories["parser"] = lambda nlp, model, **cfg: DependencyParser.from_nlp(nlp, model, **cfg) Language.factories["parser"] = lambda nlp, model, **cfg: parser_factory(nlp, model, **cfg)
Language.factories["ner"] = lambda nlp, model, **cfg: EntityRecognizer.from_nlp(nlp, model, **cfg) Language.factories["ner"] = lambda nlp, model, **cfg: ner_factory(nlp, model, **cfg)
def parser_factory(nlp, model, **cfg):
if model is None:
model = default_parser()
warnings.warn(Warnings.W098.format(name="parser"))
return DependencyParser.from_nlp(nlp, model, **cfg)
def ner_factory(nlp, model, **cfg):
if model is None:
model = default_ner()
warnings.warn(Warnings.W098.format(name="ner"))
return EntityRecognizer.from_nlp(nlp, model, **cfg)
__all__ = ["Tagger", "DependencyParser", "EntityRecognizer", "Tensorizer", "TextCategorizer", "EntityLinker", "Sentencizer", "SentenceRecognizer"] __all__ = ["Tagger", "DependencyParser", "EntityRecognizer", "Tensorizer", "TextCategorizer", "EntityLinker", "Sentencizer", "SentenceRecognizer"]

View File

@ -2,6 +2,8 @@ from typing import List
from thinc.types import Floats2d from thinc.types import Floats2d
from thinc.api import SequenceCategoricalCrossentropy, set_dropout_rate from thinc.api import SequenceCategoricalCrossentropy, set_dropout_rate
from thinc.util import to_numpy from thinc.util import to_numpy
from .defaults import default_simple_ner
from ..gold import Example, spans_from_biluo_tags, iob_to_biluo, biluo_to_iob from ..gold import Example, spans_from_biluo_tags, iob_to_biluo, biluo_to_iob
from ..tokens import Doc from ..tokens import Doc
from ..language import component from ..language import component
@ -9,7 +11,7 @@ from ..util import link_vectors_to_models
from .pipes import Pipe from .pipes import Pipe
@component("simple_ner", assigns=["doc.ents"]) @component("simple_ner", assigns=["doc.ents"], default_model=default_simple_ner)
class SimpleNER(Pipe): class SimpleNER(Pipe):
"""Named entity recognition with a tagging model. The model should include """Named entity recognition with a tagging model. The model should include
validity constraints to ensure that only valid tag sequences are returned.""" validity constraints to ensure that only valid tag sequences are returned."""

View File

@ -6,9 +6,10 @@ from ..tokens import Doc
from ..vocab import Vocab from ..vocab import Vocab
from ..language import component from ..language import component
from ..util import link_vectors_to_models, minibatch, eg2doc from ..util import link_vectors_to_models, minibatch, eg2doc
from .defaults import default_tok2vec
@component("tok2vec", assigns=["doc.tensor"]) @component("tok2vec", assigns=["doc.tensor"], default_model=default_tok2vec)
class Tok2Vec(Pipe): class Tok2Vec(Pipe):
@classmethod @classmethod
def from_nlp(cls, nlp, model, **cfg): def from_nlp(cls, nlp, model, **cfg):

View File

@ -3,7 +3,7 @@ from spacy.tokens import Span
import pytest import pytest
from ..util import get_doc from ..util import get_doc
from ...ml.models.defaults import default_ner from spacy.pipeline.defaults import default_ner
def test_doc_add_entities_set_ents_iob(en_vocab): def test_doc_add_entities_set_ents_iob(en_vocab):

View File

@ -4,7 +4,7 @@ from spacy.attrs import NORM
from spacy.gold import GoldParse from spacy.gold import GoldParse
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.ml.models.defaults import default_parser, default_ner from spacy.pipeline.defaults import default_parser, default_ner
from spacy.tokens import Doc from spacy.tokens import Doc
from spacy.pipeline import DependencyParser, EntityRecognizer from spacy.pipeline import DependencyParser, EntityRecognizer
from spacy.util import fix_random_seed from spacy.util import fix_random_seed

View File

@ -1,7 +1,7 @@
import pytest import pytest
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.ml.models.defaults import default_parser from spacy.pipeline.defaults import default_parser
from spacy.pipeline import DependencyParser from spacy.pipeline import DependencyParser
from spacy.tokens import Doc from spacy.tokens import Doc
from spacy.gold import GoldParse from spacy.gold import GoldParse

View File

@ -2,7 +2,7 @@ import pytest
from spacy import util from spacy import util
from spacy.lang.en import English from spacy.lang.en import English
from spacy.ml.models.defaults import default_ner from spacy.pipeline.defaults import default_ner
from spacy.pipeline import EntityRecognizer, EntityRuler from spacy.pipeline import EntityRecognizer, EntityRuler
from spacy.vocab import Vocab from spacy.vocab import Vocab

View File

@ -1,5 +1,5 @@
import pytest import pytest
from spacy.ml.models.defaults import default_parser, default_tok2vec from spacy.pipeline.defaults import default_parser, default_tok2vec
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.syntax.arc_eager import ArcEager from spacy.syntax.arc_eager import ArcEager
from spacy.syntax.nn_parser import Parser from spacy.syntax.nn_parser import Parser

View File

@ -2,7 +2,7 @@ import pytest
import numpy import numpy
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.language import Language from spacy.language import Language
from spacy.ml.models.defaults import default_parser from spacy.pipeline.defaults import default_parser
from spacy.pipeline import DependencyParser from spacy.pipeline import DependencyParser
from spacy.syntax.arc_eager import ArcEager from spacy.syntax.arc_eager import ArcEager
from spacy.tokens import Doc from spacy.tokens import Doc

View File

@ -4,7 +4,7 @@ from spacy.attrs import NORM
from spacy.gold import GoldParse from spacy.gold import GoldParse
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.ml.models.defaults import default_parser from spacy.pipeline.defaults import default_parser
from spacy.tokens import Doc from spacy.tokens import Doc
from spacy.pipeline import DependencyParser from spacy.pipeline import DependencyParser

View File

@ -11,7 +11,7 @@ from spacy.gold import GoldParse
from spacy.util import fix_random_seed from spacy.util import fix_random_seed
from ..util import make_tempdir from ..util import make_tempdir
from ...ml.models.defaults import default_tok2vec from spacy.pipeline.defaults import default_tok2vec
TRAIN_DATA = [ TRAIN_DATA = [
("I'm so happy.", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}), ("I'm so happy.", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}),

View File

@ -10,7 +10,7 @@ from spacy.lang.lex_attrs import is_stop
from spacy.vectors import Vectors from spacy.vectors import Vectors
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.language import Language from spacy.language import Language
from spacy.ml.models.defaults import default_ner, default_tagger from spacy.pipeline.defaults import default_ner, default_tagger
from spacy.tokens import Doc, Span, Token from spacy.tokens import Doc, Span, Token
from spacy.pipeline import Tagger, EntityRecognizer from spacy.pipeline import Tagger, EntityRecognizer
from spacy.attrs import HEAD, DEP from spacy.attrs import HEAD, DEP

View File

@ -1,7 +1,7 @@
import pytest import pytest
from spacy.lang.en import English from spacy.lang.en import English
from spacy.lang.de import German from spacy.lang.de import German
from spacy.ml.models.defaults import default_ner from spacy.pipeline.defaults import default_ner
from spacy.pipeline import EntityRuler, EntityRecognizer from spacy.pipeline import EntityRuler, EntityRecognizer
from spacy.matcher import Matcher, PhraseMatcher from spacy.matcher import Matcher, PhraseMatcher
from spacy.tokens import Doc from spacy.tokens import Doc

View File

@ -1,7 +1,7 @@
from spacy.pipeline.pipes import DependencyParser from spacy.pipeline.pipes import DependencyParser
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.ml.models.defaults import default_parser from spacy.pipeline.defaults import default_parser
def test_issue3830_no_subtok(): def test_issue3830_no_subtok():

View File

@ -3,7 +3,7 @@ from spacy.pipeline import EntityRecognizer, EntityRuler
from spacy.lang.en import English from spacy.lang.en import English
from spacy.tokens import Span from spacy.tokens import Span
from spacy.util import ensure_path from spacy.util import ensure_path
from spacy.ml.models.defaults import default_ner from spacy.pipeline.defaults import default_ner
from ..util import make_tempdir from ..util import make_tempdir

View File

@ -1,6 +1,6 @@
from collections import defaultdict from collections import defaultdict
from spacy.ml.models.defaults import default_ner from spacy.pipeline.defaults import default_ner
from spacy.pipeline import EntityRecognizer from spacy.pipeline import EntityRecognizer
from spacy.lang.en import English from spacy.lang.en import English

View File

@ -1,8 +1,8 @@
import pytest import pytest
from spacy.pipeline import Tagger, DependencyParser, EntityRecognizer from spacy.pipeline import Tagger, DependencyParser, EntityRecognizer
from spacy.pipeline import Tensorizer, TextCategorizer, SentenceRecognizer from spacy.pipeline import Tensorizer, TextCategorizer, SentenceRecognizer
from spacy.ml.models.defaults import default_parser, default_tensorizer, default_tagger from spacy.pipeline.defaults import default_parser, default_tensorizer, default_tagger
from spacy.ml.models.defaults import default_textcat, default_senter from spacy.pipeline.defaults import default_textcat, default_senter
from ..util import make_tempdir from ..util import make_tempdir