mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-13 05:07:03 +03:00
pretrain architectures (#6451)
* define new architectures for the pretraining objective * add loss function as attr of the omdel * cleanup * cleanup * shorten name * fix typo * remove unused error
This commit is contained in:
parent
29b058ebdc
commit
f98a04434a
|
@ -17,7 +17,9 @@ tolerance = 0.2
|
||||||
get_length = null
|
get_length = null
|
||||||
|
|
||||||
[pretraining.objective]
|
[pretraining.objective]
|
||||||
type = "characters"
|
@architectures = "spacy.PretrainCharacters.v1"
|
||||||
|
maxout_pieces = 3
|
||||||
|
hidden_size = 300
|
||||||
n_characters = 4
|
n_characters = 4
|
||||||
|
|
||||||
[pretraining.optimizer]
|
[pretraining.optimizer]
|
||||||
|
|
|
@ -484,8 +484,8 @@ class Errors:
|
||||||
"has been applied.")
|
"has been applied.")
|
||||||
E905 = ("Cannot initialize StaticVectors layer: nM dimension unset. This "
|
E905 = ("Cannot initialize StaticVectors layer: nM dimension unset. This "
|
||||||
"dimension refers to the width of the vectors table.")
|
"dimension refers to the width of the vectors table.")
|
||||||
E906 = ("Unexpected `loss` value in pretraining objective: {loss_type}")
|
E906 = ("Unexpected `loss` value in pretraining objective: '{found}'. Supported values "
|
||||||
E907 = ("Unexpected `objective_type` value in pretraining objective: {objective_type}")
|
"are: {supported}")
|
||||||
E908 = ("Can't set `spaces` without `words` in `Doc.__init__`.")
|
E908 = ("Can't set `spaces` without `words` in `Doc.__init__`.")
|
||||||
E909 = ("Expected {name} in parser internals. This is likely a bug in spaCy.")
|
E909 = ("Expected {name} in parser internals. This is likely a bug in spaCy.")
|
||||||
E910 = ("Encountered NaN value when computing loss for component '{name}'.")
|
E910 = ("Encountered NaN value when computing loss for component '{name}'.")
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from .entity_linker import * # noqa
|
from .entity_linker import * # noqa
|
||||||
|
from .multi_task import * # noqa
|
||||||
from .parser import * # noqa
|
from .parser import * # noqa
|
||||||
from .tagger import * # noqa
|
from .tagger import * # noqa
|
||||||
from .textcat import * # noqa
|
from .textcat import * # noqa
|
||||||
|
|
|
@ -1,7 +1,14 @@
|
||||||
from typing import Optional, Iterable, Tuple, List, TYPE_CHECKING
|
from typing import Optional, Iterable, Tuple, List, Callable, TYPE_CHECKING
|
||||||
import numpy
|
|
||||||
from thinc.api import chain, Maxout, LayerNorm, Softmax, Linear, zero_init, Model
|
from thinc.api import chain, Maxout, LayerNorm, Softmax, Linear, zero_init, Model
|
||||||
from thinc.api import MultiSoftmax, list2array
|
from thinc.api import MultiSoftmax, list2array
|
||||||
|
from thinc.api import to_categorical, CosineDistance, L2Distance
|
||||||
|
|
||||||
|
from ...util import registry
|
||||||
|
from ...errors import Errors
|
||||||
|
from ...attrs import ID
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
# This lets us add type hints for mypy etc. without causing circular imports
|
# This lets us add type hints for mypy etc. without causing circular imports
|
||||||
|
@ -9,6 +16,74 @@ if TYPE_CHECKING:
|
||||||
from ...tokens import Doc # noqa: F401
|
from ...tokens import Doc # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
@registry.architectures.register("spacy.PretrainVectors.v1")
|
||||||
|
def create_pretrain_vectors(
|
||||||
|
maxout_pieces: int, hidden_size: int, loss: str
|
||||||
|
) -> Callable[["Vocab", Model], Model]:
|
||||||
|
def create_vectors_objective(vocab: "Vocab", tok2vec: Model) -> Model:
|
||||||
|
model = build_cloze_multi_task_model(
|
||||||
|
vocab, tok2vec, hidden_size=hidden_size, maxout_pieces=maxout_pieces
|
||||||
|
)
|
||||||
|
model.attrs["loss"] = create_vectors_loss()
|
||||||
|
return model
|
||||||
|
|
||||||
|
def create_vectors_loss() -> Callable:
|
||||||
|
if loss == "cosine":
|
||||||
|
distance = CosineDistance(normalize=True, ignore_zeros=True)
|
||||||
|
return partial(get_vectors_loss, distance=distance)
|
||||||
|
elif loss == "L2":
|
||||||
|
distance = L2Distance(normalize=True)
|
||||||
|
return partial(get_vectors_loss, distance=distance)
|
||||||
|
else:
|
||||||
|
raise ValueError(Errors.E906.format(found=loss, supported="'cosine', 'L2'"))
|
||||||
|
|
||||||
|
return create_vectors_objective
|
||||||
|
|
||||||
|
|
||||||
|
@registry.architectures.register("spacy.PretrainCharacters.v1")
|
||||||
|
def create_pretrain_characters(
|
||||||
|
maxout_pieces: int, hidden_size: int, n_characters: int
|
||||||
|
) -> Callable[["Vocab", Model], Model]:
|
||||||
|
def create_characters_objective(vocab: "Vocab", tok2vec: Model) -> Model:
|
||||||
|
model = build_cloze_characters_multi_task_model(
|
||||||
|
vocab,
|
||||||
|
tok2vec,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
maxout_pieces=maxout_pieces,
|
||||||
|
nr_char=n_characters,
|
||||||
|
)
|
||||||
|
model.attrs["loss"] = partial(get_characters_loss, nr_char=n_characters)
|
||||||
|
return model
|
||||||
|
|
||||||
|
return create_characters_objective
|
||||||
|
|
||||||
|
|
||||||
|
def get_vectors_loss(ops, docs, prediction, distance):
|
||||||
|
"""Compute a loss based on a distance between the documents' vectors and
|
||||||
|
the prediction.
|
||||||
|
"""
|
||||||
|
# The simplest way to implement this would be to vstack the
|
||||||
|
# token.vector values, but that's a bit inefficient, especially on GPU.
|
||||||
|
# Instead we fetch the index into the vectors table for each of our tokens,
|
||||||
|
# and look them up all at once. This prevents data copying.
|
||||||
|
ids = ops.flatten([doc.to_array(ID).ravel() for doc in docs])
|
||||||
|
target = docs[0].vocab.vectors.data[ids]
|
||||||
|
d_target, loss = distance(prediction, target)
|
||||||
|
return loss, d_target
|
||||||
|
|
||||||
|
|
||||||
|
def get_characters_loss(ops, docs, prediction, nr_char):
|
||||||
|
"""Compute a loss based on a number of characters predicted from the docs."""
|
||||||
|
target_ids = numpy.vstack([doc.to_utf8_array(nr_char=nr_char) for doc in docs])
|
||||||
|
target_ids = target_ids.reshape((-1,))
|
||||||
|
target = ops.asarray(to_categorical(target_ids, n_classes=256), dtype="f")
|
||||||
|
target = target.reshape((-1, 256 * nr_char))
|
||||||
|
diff = prediction - target
|
||||||
|
loss = (diff ** 2).sum()
|
||||||
|
d_target = diff / float(prediction.shape[0])
|
||||||
|
return loss, d_target
|
||||||
|
|
||||||
|
|
||||||
def build_multi_task_model(
|
def build_multi_task_model(
|
||||||
tok2vec: Model,
|
tok2vec: Model,
|
||||||
maxout_pieces: int,
|
maxout_pieces: int,
|
||||||
|
@ -33,23 +108,19 @@ def build_multi_task_model(
|
||||||
|
|
||||||
|
|
||||||
def build_cloze_multi_task_model(
|
def build_cloze_multi_task_model(
|
||||||
vocab: "Vocab",
|
vocab: "Vocab", tok2vec: Model, maxout_pieces: int, hidden_size: int
|
||||||
tok2vec: Model,
|
|
||||||
maxout_pieces: int,
|
|
||||||
hidden_size: int,
|
|
||||||
nO: Optional[int] = None,
|
|
||||||
) -> Model:
|
) -> Model:
|
||||||
# nO = vocab.vectors.data.shape[1]
|
nO = vocab.vectors.data.shape[1]
|
||||||
output_layer = chain(
|
output_layer = chain(
|
||||||
list2array(),
|
list2array(),
|
||||||
Maxout(
|
Maxout(
|
||||||
nO=nO,
|
nO=hidden_size,
|
||||||
nI=tok2vec.get_dim("nO"),
|
nI=tok2vec.get_dim("nO"),
|
||||||
nP=maxout_pieces,
|
nP=maxout_pieces,
|
||||||
normalize=True,
|
normalize=True,
|
||||||
dropout=0.0,
|
dropout=0.0,
|
||||||
),
|
),
|
||||||
Linear(nO=nO, nI=nO, init_W=zero_init),
|
Linear(nO=nO, nI=hidden_size, init_W=zero_init),
|
||||||
)
|
)
|
||||||
model = chain(tok2vec, output_layer)
|
model = chain(tok2vec, output_layer)
|
||||||
model = build_masked_language_model(vocab, model)
|
model = build_masked_language_model(vocab, model)
|
||||||
|
|
|
@ -351,9 +351,7 @@ class ConfigSchemaPretrain(BaseModel):
|
||||||
batcher: Batcher = Field(..., title="Batcher for the training data")
|
batcher: Batcher = Field(..., title="Batcher for the training data")
|
||||||
component: str = Field(..., title="Component to find the layer to pretrain")
|
component: str = Field(..., title="Component to find the layer to pretrain")
|
||||||
layer: str = Field(..., title="Layer to pretrain. Whole model if empty.")
|
layer: str = Field(..., title="Layer to pretrain. Whole model if empty.")
|
||||||
|
objective: Callable[["Vocab", "Model"], "Model"] = Field(..., title="A function that creates the pretraining objective.")
|
||||||
# TODO: use a more detailed schema for this?
|
|
||||||
objective: Dict[str, Any] = Field(..., title="Pretraining objective")
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
|
|
@ -3,15 +3,15 @@ from thinc.api import Config, ConfigValidationError
|
||||||
import spacy
|
import spacy
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.lang.de import German
|
from spacy.lang.de import German
|
||||||
from spacy.language import Language, DEFAULT_CONFIG
|
from spacy.language import Language, DEFAULT_CONFIG, DEFAULT_CONFIG_PRETRAIN_PATH
|
||||||
from spacy.util import registry, load_model_from_config
|
from spacy.util import registry, load_model_from_config, load_config
|
||||||
from spacy.ml.models import build_Tok2Vec_model, build_tb_parser_model
|
from spacy.ml.models import build_Tok2Vec_model, build_tb_parser_model
|
||||||
from spacy.ml.models import MultiHashEmbed, MaxoutWindowEncoder
|
from spacy.ml.models import MultiHashEmbed, MaxoutWindowEncoder
|
||||||
from spacy.schemas import ConfigSchema
|
from spacy.schemas import ConfigSchema, ConfigSchemaPretrain
|
||||||
|
|
||||||
|
|
||||||
from ..util import make_tempdir
|
from ..util import make_tempdir
|
||||||
|
|
||||||
|
|
||||||
nlp_config_string = """
|
nlp_config_string = """
|
||||||
[paths]
|
[paths]
|
||||||
train = null
|
train = null
|
||||||
|
@ -63,6 +63,59 @@ factory = "tagger"
|
||||||
width = ${components.tok2vec.model.width}
|
width = ${components.tok2vec.model.width}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
pretrain_config_string = """
|
||||||
|
[paths]
|
||||||
|
train = null
|
||||||
|
dev = null
|
||||||
|
|
||||||
|
[corpora]
|
||||||
|
|
||||||
|
[corpora.train]
|
||||||
|
@readers = "spacy.Corpus.v1"
|
||||||
|
path = ${paths.train}
|
||||||
|
|
||||||
|
[corpora.dev]
|
||||||
|
@readers = "spacy.Corpus.v1"
|
||||||
|
path = ${paths.dev}
|
||||||
|
|
||||||
|
[training]
|
||||||
|
|
||||||
|
[training.batcher]
|
||||||
|
@batchers = "spacy.batch_by_words.v1"
|
||||||
|
size = 666
|
||||||
|
|
||||||
|
[nlp]
|
||||||
|
lang = "en"
|
||||||
|
pipeline = ["tok2vec", "tagger"]
|
||||||
|
|
||||||
|
[components]
|
||||||
|
|
||||||
|
[components.tok2vec]
|
||||||
|
factory = "tok2vec"
|
||||||
|
|
||||||
|
[components.tok2vec.model]
|
||||||
|
@architectures = "spacy.HashEmbedCNN.v1"
|
||||||
|
pretrained_vectors = null
|
||||||
|
width = 342
|
||||||
|
depth = 4
|
||||||
|
window_size = 1
|
||||||
|
embed_size = 2000
|
||||||
|
maxout_pieces = 3
|
||||||
|
subword_features = true
|
||||||
|
|
||||||
|
[components.tagger]
|
||||||
|
factory = "tagger"
|
||||||
|
|
||||||
|
[components.tagger.model]
|
||||||
|
@architectures = "spacy.Tagger.v1"
|
||||||
|
|
||||||
|
[components.tagger.model.tok2vec]
|
||||||
|
@architectures = "spacy.Tok2VecListener.v1"
|
||||||
|
width = ${components.tok2vec.model.width}
|
||||||
|
|
||||||
|
[pretraining]
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
parser_config_string = """
|
parser_config_string = """
|
||||||
[model]
|
[model]
|
||||||
|
@ -126,6 +179,14 @@ def test_create_nlp_from_config():
|
||||||
load_model_from_config(Config(bad_cfg), auto_fill=True)
|
load_model_from_config(Config(bad_cfg), auto_fill=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_nlp_from_pretraining_config():
|
||||||
|
"""Test that the default pretraining config validates properly"""
|
||||||
|
config = Config().from_str(pretrain_config_string)
|
||||||
|
pretrain_config = load_config(DEFAULT_CONFIG_PRETRAIN_PATH)
|
||||||
|
filled = config.merge(pretrain_config)
|
||||||
|
resolved = registry.resolve(filled["pretraining"], schema=ConfigSchemaPretrain)
|
||||||
|
|
||||||
|
|
||||||
def test_create_nlp_from_config_multiple_instances():
|
def test_create_nlp_from_config_multiple_instances():
|
||||||
"""Test that the nlp object is created correctly for a config with multiple
|
"""Test that the nlp object is created correctly for a config with multiple
|
||||||
instances of the same component."""
|
instances of the same component."""
|
||||||
|
|
|
@ -1,22 +1,16 @@
|
||||||
from typing import Optional, Callable, Iterable, Union, List
|
from typing import Optional, Callable, Iterable, Union, List
|
||||||
from thinc.api import Config, fix_random_seed, set_gpu_allocator, Model, Optimizer
|
from thinc.api import Config, fix_random_seed, set_gpu_allocator, Model, Optimizer
|
||||||
from thinc.api import set_dropout_rate, to_categorical, CosineDistance, L2Distance
|
from thinc.api import set_dropout_rate
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from functools import partial
|
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
import srsly
|
import srsly
|
||||||
import numpy
|
|
||||||
import time
|
import time
|
||||||
import re
|
import re
|
||||||
from wasabi import Printer
|
from wasabi import Printer
|
||||||
|
|
||||||
from .example import Example
|
from .example import Example
|
||||||
from ..tokens import Doc
|
from ..tokens import Doc
|
||||||
from ..attrs import ID
|
|
||||||
from ..ml.models.multi_task import build_cloze_multi_task_model
|
|
||||||
from ..ml.models.multi_task import build_cloze_characters_multi_task_model
|
|
||||||
from ..schemas import ConfigSchemaTraining, ConfigSchemaPretrain
|
from ..schemas import ConfigSchemaTraining, ConfigSchemaPretrain
|
||||||
from ..errors import Errors
|
|
||||||
from ..util import registry, load_model_from_config, dot_to_object
|
from ..util import registry, load_model_from_config, dot_to_object
|
||||||
|
|
||||||
|
|
||||||
|
@ -49,6 +43,7 @@ def pretrain(
|
||||||
else:
|
else:
|
||||||
# Without '--resume-path' the '--epoch-resume' argument is ignored
|
# Without '--resume-path' the '--epoch-resume' argument is ignored
|
||||||
epoch_resume = 0
|
epoch_resume = 0
|
||||||
|
objective = model.attrs["loss"]
|
||||||
# TODO: move this to logger function?
|
# TODO: move this to logger function?
|
||||||
tracker = ProgressTracker(frequency=10000)
|
tracker = ProgressTracker(frequency=10000)
|
||||||
msg.divider(f"Pre-training tok2vec layer - starting at epoch {epoch_resume}")
|
msg.divider(f"Pre-training tok2vec layer - starting at epoch {epoch_resume}")
|
||||||
|
@ -69,7 +64,6 @@ def pretrain(
|
||||||
with (output_dir / "log.jsonl").open("a") as file_:
|
with (output_dir / "log.jsonl").open("a") as file_:
|
||||||
file_.write(srsly.json_dumps(log) + "\n")
|
file_.write(srsly.json_dumps(log) + "\n")
|
||||||
|
|
||||||
objective = create_objective(P["objective"])
|
|
||||||
# TODO: I think we probably want this to look more like the
|
# TODO: I think we probably want this to look more like the
|
||||||
# 'create_train_batches' function?
|
# 'create_train_batches' function?
|
||||||
for epoch in range(epoch_resume, P["max_epochs"]):
|
for epoch in range(epoch_resume, P["max_epochs"]):
|
||||||
|
@ -132,58 +126,6 @@ def make_update(
|
||||||
return float(loss)
|
return float(loss)
|
||||||
|
|
||||||
|
|
||||||
def create_objective(config: Config):
|
|
||||||
"""Create the objective for pretraining.
|
|
||||||
|
|
||||||
We'd like to replace this with a registry function but it's tricky because
|
|
||||||
we're also making a model choice based on this. For now we hard-code support
|
|
||||||
for two types (characters, vectors). For characters you can specify
|
|
||||||
n_characters, for vectors you can specify the loss.
|
|
||||||
|
|
||||||
Bleh.
|
|
||||||
"""
|
|
||||||
objective_type = config["type"]
|
|
||||||
if objective_type == "characters":
|
|
||||||
return partial(get_characters_loss, nr_char=config["n_characters"])
|
|
||||||
elif objective_type == "vectors":
|
|
||||||
if config["loss"] == "cosine":
|
|
||||||
distance = CosineDistance(normalize=True, ignore_zeros=True)
|
|
||||||
return partial(get_vectors_loss, distance=distance)
|
|
||||||
elif config["loss"] == "L2":
|
|
||||||
distance = L2Distance(normalize=True, ignore_zeros=True)
|
|
||||||
return partial(get_vectors_loss, distance=distance)
|
|
||||||
else:
|
|
||||||
raise ValueError(Errors.E906.format(loss_type=config["loss"]))
|
|
||||||
else:
|
|
||||||
raise ValueError(Errors.E907.format(objective_type=objective_type))
|
|
||||||
|
|
||||||
|
|
||||||
def get_vectors_loss(ops, docs, prediction, distance):
|
|
||||||
"""Compute a loss based on a distance between the documents' vectors and
|
|
||||||
the prediction.
|
|
||||||
"""
|
|
||||||
# The simplest way to implement this would be to vstack the
|
|
||||||
# token.vector values, but that's a bit inefficient, especially on GPU.
|
|
||||||
# Instead we fetch the index into the vectors table for each of our tokens,
|
|
||||||
# and look them up all at once. This prevents data copying.
|
|
||||||
ids = ops.flatten([doc.to_array(ID).ravel() for doc in docs])
|
|
||||||
target = docs[0].vocab.vectors.data[ids]
|
|
||||||
d_target, loss = distance(prediction, target)
|
|
||||||
return loss, d_target
|
|
||||||
|
|
||||||
|
|
||||||
def get_characters_loss(ops, docs, prediction, nr_char):
|
|
||||||
"""Compute a loss based on a number of characters predicted from the docs."""
|
|
||||||
target_ids = numpy.vstack([doc.to_utf8_array(nr_char=nr_char) for doc in docs])
|
|
||||||
target_ids = target_ids.reshape((-1,))
|
|
||||||
target = ops.asarray(to_categorical(target_ids, n_classes=256), dtype="f")
|
|
||||||
target = target.reshape((-1, 256 * nr_char))
|
|
||||||
diff = prediction - target
|
|
||||||
loss = (diff ** 2).sum()
|
|
||||||
d_target = diff / float(prediction.shape[0])
|
|
||||||
return loss, d_target
|
|
||||||
|
|
||||||
|
|
||||||
def create_pretraining_model(nlp, pretrain_config):
|
def create_pretraining_model(nlp, pretrain_config):
|
||||||
"""Define a network for the pretraining. We simply add an output layer onto
|
"""Define a network for the pretraining. We simply add an output layer onto
|
||||||
the tok2vec input model. The tok2vec input model needs to be a model that
|
the tok2vec input model. The tok2vec input model needs to be a model that
|
||||||
|
@ -192,27 +134,15 @@ def create_pretraining_model(nlp, pretrain_config):
|
||||||
The actual tok2vec layer is stored as a reference, and only this bit will be
|
The actual tok2vec layer is stored as a reference, and only this bit will be
|
||||||
serialized to file and read back in when calling the 'train' command.
|
serialized to file and read back in when calling the 'train' command.
|
||||||
"""
|
"""
|
||||||
|
nlp.initialize()
|
||||||
component = nlp.get_pipe(pretrain_config["component"])
|
component = nlp.get_pipe(pretrain_config["component"])
|
||||||
if pretrain_config.get("layer"):
|
if pretrain_config.get("layer"):
|
||||||
tok2vec = component.model.get_ref(pretrain_config["layer"])
|
tok2vec = component.model.get_ref(pretrain_config["layer"])
|
||||||
else:
|
else:
|
||||||
tok2vec = component.model
|
tok2vec = component.model
|
||||||
|
|
||||||
# TODO
|
create_function = pretrain_config["objective"]
|
||||||
maxout_pieces = 3
|
model = create_function(nlp.vocab, tok2vec)
|
||||||
hidden_size = 300
|
|
||||||
if pretrain_config["objective"]["type"] == "vectors":
|
|
||||||
model = build_cloze_multi_task_model(
|
|
||||||
nlp.vocab, tok2vec, hidden_size=hidden_size, maxout_pieces=maxout_pieces
|
|
||||||
)
|
|
||||||
elif pretrain_config["objective"]["type"] == "characters":
|
|
||||||
model = build_cloze_characters_multi_task_model(
|
|
||||||
nlp.vocab,
|
|
||||||
tok2vec,
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
maxout_pieces=maxout_pieces,
|
|
||||||
nr_char=pretrain_config["objective"]["n_characters"],
|
|
||||||
)
|
|
||||||
model.initialize(X=[nlp.make_doc("Give it a doc to infer shapes")])
|
model.initialize(X=[nlp.make_doc("Give it a doc to infer shapes")])
|
||||||
set_dropout_rate(model, pretrain_config["dropout"])
|
set_dropout_rate(model, pretrain_config["dropout"])
|
||||||
return model
|
return model
|
||||||
|
|
Loading…
Reference in New Issue
Block a user