mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Merge branch 'master' into feat/add-pipe-instance
This commit is contained in:
commit
8a79a71190
|
@ -32,6 +32,7 @@ def init_vectors_cli(
|
|||
name: Optional[str] = Opt(None, "--name", "-n", help="Optional name for the word vectors, e.g. en_core_web_lg.vectors"),
|
||||
verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"),
|
||||
jsonl_loc: Optional[Path] = Opt(None, "--lexemes-jsonl", "-j", help="Location of JSONL-formatted attributes file", hidden=True),
|
||||
attr: str = Opt("ORTH", "--attr", "-a", help="Optional token attribute to use for vectors, e.g. LOWER or NORM"),
|
||||
# fmt: on
|
||||
):
|
||||
"""Convert word vectors for use with spaCy. Will export an nlp object that
|
||||
|
@ -50,6 +51,7 @@ def init_vectors_cli(
|
|||
prune=prune,
|
||||
name=name,
|
||||
mode=mode,
|
||||
attr=attr,
|
||||
)
|
||||
msg.good(f"Successfully converted {len(nlp.vocab.vectors)} vectors")
|
||||
nlp.to_disk(output_dir)
|
||||
|
|
|
@ -216,7 +216,10 @@ class Warnings(metaclass=ErrorsWithCodes):
|
|||
W123 = ("Argument `enable` with value {enable} does not contain all values specified in the config option "
|
||||
"`enabled` ({enabled}). Be aware that this might affect other components in your pipeline.")
|
||||
W124 = ("{host}:{port} is already in use, using the nearest available port {serve_port} as an alternative.")
|
||||
W125 = (
|
||||
W125 = ("The StaticVectors key_attr is no longer used. To set a custom "
|
||||
"key attribute for vectors, configure it through Vectors(attr=) or "
|
||||
"'spacy init vectors --attr'")
|
||||
W126 = (
|
||||
"Pipe instance '{name}' is being added with a vocab "
|
||||
"instance that will not match other components. This is "
|
||||
"usually an error."
|
||||
|
|
|
@ -742,6 +742,11 @@ class Language:
|
|||
)
|
||||
)
|
||||
pipe = source.get_pipe(source_name)
|
||||
# There is no actual solution here. Either the component has the right
|
||||
# name for the source pipeline or the component has the right name for
|
||||
# the current pipeline. This prioritizes the current pipeline.
|
||||
if hasattr(pipe, "name"):
|
||||
pipe.name = name
|
||||
# Make sure the source config is interpolated so we don't end up with
|
||||
# orphaned variables in our final config
|
||||
source_config = source.config.interpolate()
|
||||
|
@ -822,6 +827,7 @@ class Language:
|
|||
self._pipe_meta[name] = self.get_factory_meta(factory_name)
|
||||
pipe_index = self._get_pipe_index(before, after, first, last)
|
||||
self._components.insert(pipe_index, (name, pipe_component))
|
||||
self._link_components()
|
||||
return pipe_component
|
||||
|
||||
def add_pipe_instance(
|
||||
|
@ -1006,6 +1012,7 @@ class Language:
|
|||
if old_name in self._config["initialize"]["components"]:
|
||||
init_cfg = self._config["initialize"]["components"].pop(old_name)
|
||||
self._config["initialize"]["components"][new_name] = init_cfg
|
||||
self._link_components()
|
||||
|
||||
def remove_pipe(self, name: str) -> Tuple[str, PipeCallable]:
|
||||
"""Remove a component from the pipeline.
|
||||
|
@ -1029,6 +1036,7 @@ class Language:
|
|||
# Make sure the name is also removed from the set of disabled components
|
||||
if name in self.disabled:
|
||||
self._disabled.remove(name)
|
||||
self._link_components()
|
||||
return removed
|
||||
|
||||
def disable_pipe(self, name: str) -> None:
|
||||
|
@ -1757,8 +1765,16 @@ class Language:
|
|||
# The problem is we need to do it during deserialization...And the
|
||||
# components don't receive the pipeline then. So this does have to be
|
||||
# here :(
|
||||
# First, fix up all the internal component names in case they have
|
||||
# gotten out of sync due to sourcing components from different
|
||||
# pipelines, since find_listeners uses proc2.name for the listener
|
||||
# map.
|
||||
for name, proc in self.pipeline:
|
||||
if hasattr(proc, "name"):
|
||||
proc.name = name
|
||||
for i, (name1, proc1) in enumerate(self.pipeline):
|
||||
if isinstance(proc1, ty.ListenedToComponent):
|
||||
proc1.listener_map = {}
|
||||
for name2, proc2 in self.pipeline[i + 1 :]:
|
||||
proc1.find_listeners(proc2)
|
||||
|
||||
|
@ -1913,6 +1929,7 @@ class Language:
|
|||
raw_config=raw_config,
|
||||
)
|
||||
else:
|
||||
assert "source" in pipe_cfg
|
||||
# We need the sourced components to reference the same
|
||||
# vocab without modifying the current vocab state **AND**
|
||||
# we still want to load the source model vectors to perform
|
||||
|
@ -1932,6 +1949,10 @@ class Language:
|
|||
source_name = pipe_cfg.get("component", pipe_name)
|
||||
listeners_replaced = False
|
||||
if "replace_listeners" in pipe_cfg:
|
||||
# Make sure that the listened-to component has the
|
||||
# state of the source pipeline listener map so that the
|
||||
# replace_listeners method below works as intended.
|
||||
source_nlps[model]._link_components()
|
||||
for name, proc in source_nlps[model].pipeline:
|
||||
if source_name in getattr(proc, "listening_components", []):
|
||||
source_nlps[model].replace_listeners(
|
||||
|
@ -1943,6 +1964,8 @@ class Language:
|
|||
nlp.add_pipe(
|
||||
source_name, source=source_nlps[model], name=pipe_name
|
||||
)
|
||||
# At this point after nlp.add_pipe, the listener map
|
||||
# corresponds to the new pipeline.
|
||||
if model not in source_nlp_vectors_hashes:
|
||||
source_nlp_vectors_hashes[model] = hash(
|
||||
source_nlps[model].vocab.vectors.to_bytes(
|
||||
|
@ -1997,27 +2020,6 @@ class Language:
|
|||
raise ValueError(
|
||||
Errors.E942.format(name="pipeline_creation", value=type(nlp))
|
||||
)
|
||||
# Detect components with listeners that are not frozen consistently
|
||||
for name, proc in nlp.pipeline:
|
||||
if isinstance(proc, ty.ListenedToComponent):
|
||||
# Remove listeners not in the pipeline
|
||||
listener_names = proc.listening_components
|
||||
unused_listener_names = [
|
||||
ll for ll in listener_names if ll not in nlp.pipe_names
|
||||
]
|
||||
for listener_name in unused_listener_names:
|
||||
for listener in proc.listener_map.get(listener_name, []):
|
||||
proc.remove_listener(listener, listener_name)
|
||||
|
||||
for listener_name in proc.listening_components:
|
||||
# e.g. tok2vec/transformer
|
||||
# If it's a component sourced from another pipeline, we check if
|
||||
# the tok2vec listeners should be replaced with standalone tok2vec
|
||||
# models (e.g. so component can be frozen without its performance
|
||||
# degrading when other components/tok2vec are updated)
|
||||
paths = sourced.get(listener_name, {}).get("replace_listeners", [])
|
||||
if paths:
|
||||
nlp.replace_listeners(name, listener_name, paths)
|
||||
return nlp
|
||||
|
||||
def replace_listeners(
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import warnings
|
||||
from typing import Callable, List, Optional, Sequence, Tuple, cast
|
||||
|
||||
from thinc.api import Model, Ops, registry
|
||||
|
@ -5,7 +6,8 @@ from thinc.initializers import glorot_uniform_init
|
|||
from thinc.types import Floats1d, Floats2d, Ints1d, Ragged
|
||||
from thinc.util import partial
|
||||
|
||||
from ..errors import Errors
|
||||
from ..attrs import ORTH
|
||||
from ..errors import Errors, Warnings
|
||||
from ..tokens import Doc
|
||||
from ..vectors import Mode
|
||||
from ..vocab import Vocab
|
||||
|
@ -24,6 +26,8 @@ def StaticVectors(
|
|||
linear projection to control the dimensionality. If a dropout rate is
|
||||
specified, the dropout is applied per dimension over the whole batch.
|
||||
"""
|
||||
if key_attr != "ORTH":
|
||||
warnings.warn(Warnings.W125, DeprecationWarning)
|
||||
return Model(
|
||||
"static_vectors",
|
||||
forward,
|
||||
|
@ -40,9 +44,9 @@ def forward(
|
|||
token_count = sum(len(doc) for doc in docs)
|
||||
if not token_count:
|
||||
return _handle_empty(model.ops, model.get_dim("nO"))
|
||||
key_attr: int = model.attrs["key_attr"]
|
||||
keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs])
|
||||
vocab: Vocab = docs[0].vocab
|
||||
key_attr: int = getattr(vocab.vectors, "attr", ORTH)
|
||||
keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs])
|
||||
W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
|
||||
if vocab.vectors.mode == Mode.default:
|
||||
V = model.ops.asarray(vocab.vectors.data)
|
||||
|
|
|
@ -53,9 +53,9 @@ DEFAULT_SPAN_FINDER_MODEL = Config().from_str(span_finder_default_config)["model
|
|||
"scorer": {"@scorers": "spacy.span_finder_scorer.v1"},
|
||||
},
|
||||
default_score_weights={
|
||||
f"span_finder_{DEFAULT_SPANS_KEY}_f": 1.0,
|
||||
f"span_finder_{DEFAULT_SPANS_KEY}_p": 0.0,
|
||||
f"span_finder_{DEFAULT_SPANS_KEY}_r": 0.0,
|
||||
f"spans_{DEFAULT_SPANS_KEY}_f": 1.0,
|
||||
f"spans_{DEFAULT_SPANS_KEY}_p": 0.0,
|
||||
f"spans_{DEFAULT_SPANS_KEY}_r": 0.0,
|
||||
},
|
||||
)
|
||||
def make_span_finder(
|
||||
|
@ -104,7 +104,7 @@ def make_span_finder_scorer():
|
|||
|
||||
def span_finder_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
||||
kwargs = dict(kwargs)
|
||||
attr_prefix = "span_finder_"
|
||||
attr_prefix = "spans_"
|
||||
key = kwargs["spans_key"]
|
||||
kwargs.setdefault("attr", f"{attr_prefix}{key}")
|
||||
kwargs.setdefault(
|
||||
|
|
|
@ -230,10 +230,10 @@ def test_overfitting_IO():
|
|||
|
||||
# Test scoring
|
||||
scores = nlp.evaluate(train_examples)
|
||||
assert f"span_finder_{SPANS_KEY}_f" in scores
|
||||
assert f"spans_{SPANS_KEY}_f" in scores
|
||||
# It's not perfect 1.0 F1 because it's designed to overgenerate for now.
|
||||
assert scores[f"span_finder_{SPANS_KEY}_p"] == 0.75
|
||||
assert scores[f"span_finder_{SPANS_KEY}_r"] == 1.0
|
||||
assert scores[f"spans_{SPANS_KEY}_p"] == 0.75
|
||||
assert scores[f"spans_{SPANS_KEY}_r"] == 1.0
|
||||
|
||||
# also test that the spancat works for just a single entity in a sentence
|
||||
doc = nlp("London")
|
||||
|
|
|
@ -192,8 +192,7 @@ def test_tok2vec_listener(with_vectors):
|
|||
for tag in t[1]["tags"]:
|
||||
tagger.add_label(tag)
|
||||
|
||||
# Check that the Tok2Vec component finds it listeners
|
||||
assert tok2vec.listeners == []
|
||||
# Check that the Tok2Vec component finds its listeners
|
||||
optimizer = nlp.initialize(lambda: train_examples)
|
||||
assert tok2vec.listeners == [tagger_tok2vec]
|
||||
|
||||
|
@ -221,7 +220,6 @@ def test_tok2vec_listener_callback():
|
|||
assert nlp.pipe_names == ["tok2vec", "tagger"]
|
||||
tagger = nlp.get_pipe("tagger")
|
||||
tok2vec = nlp.get_pipe("tok2vec")
|
||||
nlp._link_components()
|
||||
docs = [nlp.make_doc("A random sentence")]
|
||||
tok2vec.model.initialize(X=docs)
|
||||
gold_array = [[1.0 for tag in ["V", "Z"]] for word in docs]
|
||||
|
@ -430,29 +428,46 @@ def test_replace_listeners_from_config():
|
|||
nlp.to_disk(dir_path)
|
||||
base_model = str(dir_path)
|
||||
new_config = {
|
||||
"nlp": {"lang": "en", "pipeline": ["tok2vec", "tagger", "ner"]},
|
||||
"nlp": {
|
||||
"lang": "en",
|
||||
"pipeline": ["tok2vec", "tagger2", "ner3", "tagger4"],
|
||||
},
|
||||
"components": {
|
||||
"tok2vec": {"source": base_model},
|
||||
"tagger": {
|
||||
"tagger2": {
|
||||
"source": base_model,
|
||||
"component": "tagger",
|
||||
"replace_listeners": ["model.tok2vec"],
|
||||
},
|
||||
"ner": {"source": base_model},
|
||||
"ner3": {
|
||||
"source": base_model,
|
||||
"component": "ner",
|
||||
},
|
||||
"tagger4": {
|
||||
"source": base_model,
|
||||
"component": "tagger",
|
||||
},
|
||||
},
|
||||
}
|
||||
new_nlp = util.load_model_from_config(new_config, auto_fill=True)
|
||||
new_nlp.initialize(lambda: examples)
|
||||
tok2vec = new_nlp.get_pipe("tok2vec")
|
||||
tagger = new_nlp.get_pipe("tagger")
|
||||
ner = new_nlp.get_pipe("ner")
|
||||
assert tok2vec.listening_components == ["ner"]
|
||||
tagger = new_nlp.get_pipe("tagger2")
|
||||
ner = new_nlp.get_pipe("ner3")
|
||||
assert "ner" not in new_nlp.pipe_names
|
||||
assert "tagger" not in new_nlp.pipe_names
|
||||
assert tok2vec.listening_components == ["ner3", "tagger4"]
|
||||
assert any(isinstance(node, Tok2VecListener) for node in ner.model.walk())
|
||||
assert not any(isinstance(node, Tok2VecListener) for node in tagger.model.walk())
|
||||
t2v_cfg = new_nlp.config["components"]["tok2vec"]["model"]
|
||||
assert t2v_cfg["@architectures"] == "spacy.Tok2Vec.v2"
|
||||
assert new_nlp.config["components"]["tagger"]["model"]["tok2vec"] == t2v_cfg
|
||||
assert new_nlp.config["components"]["tagger2"]["model"]["tok2vec"] == t2v_cfg
|
||||
assert (
|
||||
new_nlp.config["components"]["ner"]["model"]["tok2vec"]["@architectures"]
|
||||
new_nlp.config["components"]["ner3"]["model"]["tok2vec"]["@architectures"]
|
||||
== "spacy.Tok2VecListener.v1"
|
||||
)
|
||||
assert (
|
||||
new_nlp.config["components"]["tagger4"]["model"]["tok2vec"]["@architectures"]
|
||||
== "spacy.Tok2VecListener.v1"
|
||||
)
|
||||
|
||||
|
@ -544,3 +559,57 @@ def test_tok2vec_listeners_textcat():
|
|||
assert cats1["imperative"] < 0.9
|
||||
assert [t.tag_ for t in docs[0]] == ["V", "J", "N"]
|
||||
assert [t.tag_ for t in docs[1]] == ["N", "V", "J", "N"]
|
||||
|
||||
|
||||
def test_tok2vec_listener_source_link_name():
|
||||
"""The component's internal name and the tok2vec listener map correspond
|
||||
to the most recently modified pipeline.
|
||||
"""
|
||||
orig_config = Config().from_str(cfg_string_multi)
|
||||
nlp1 = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
|
||||
assert nlp1.get_pipe("tok2vec").listening_components == ["tagger", "ner"]
|
||||
|
||||
nlp2 = English()
|
||||
nlp2.add_pipe("tok2vec", source=nlp1)
|
||||
nlp2.add_pipe("tagger", name="tagger2", source=nlp1)
|
||||
|
||||
# there is no way to have the component have the right name for both
|
||||
# pipelines, right now the most recently modified pipeline is prioritized
|
||||
assert nlp1.get_pipe("tagger").name == nlp2.get_pipe("tagger2").name == "tagger2"
|
||||
|
||||
# there is no way to have the tok2vec have the right listener map for both
|
||||
# pipelines, right now the most recently modified pipeline is prioritized
|
||||
assert nlp2.get_pipe("tok2vec").listening_components == ["tagger2"]
|
||||
nlp2.add_pipe("ner", name="ner3", source=nlp1)
|
||||
assert nlp2.get_pipe("tok2vec").listening_components == ["tagger2", "ner3"]
|
||||
nlp2.remove_pipe("ner3")
|
||||
assert nlp2.get_pipe("tok2vec").listening_components == ["tagger2"]
|
||||
nlp2.remove_pipe("tagger2")
|
||||
assert nlp2.get_pipe("tok2vec").listening_components == []
|
||||
|
||||
# at this point the tok2vec component corresponds to nlp2
|
||||
assert nlp1.get_pipe("tok2vec").listening_components == []
|
||||
|
||||
# modifying the nlp1 pipeline syncs the tok2vec listener map back to nlp1
|
||||
nlp1.add_pipe("sentencizer")
|
||||
assert nlp1.get_pipe("tok2vec").listening_components == ["tagger", "ner"]
|
||||
|
||||
# modifying nlp2 syncs it back to nlp2
|
||||
nlp2.add_pipe("sentencizer")
|
||||
assert nlp1.get_pipe("tok2vec").listening_components == []
|
||||
|
||||
|
||||
def test_tok2vec_listener_source_replace_listeners():
|
||||
orig_config = Config().from_str(cfg_string_multi)
|
||||
nlp1 = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
|
||||
assert nlp1.get_pipe("tok2vec").listening_components == ["tagger", "ner"]
|
||||
nlp1.replace_listeners("tok2vec", "tagger", ["model.tok2vec"])
|
||||
assert nlp1.get_pipe("tok2vec").listening_components == ["ner"]
|
||||
|
||||
nlp2 = English()
|
||||
nlp2.add_pipe("tok2vec", source=nlp1)
|
||||
assert nlp2.get_pipe("tok2vec").listening_components == []
|
||||
nlp2.add_pipe("tagger", source=nlp1)
|
||||
assert nlp2.get_pipe("tok2vec").listening_components == []
|
||||
nlp2.add_pipe("ner", name="ner2", source=nlp1)
|
||||
assert nlp2.get_pipe("tok2vec").listening_components == ["ner2"]
|
||||
|
|
|
@ -13,6 +13,7 @@ from spacy.ml.models import (
|
|||
build_Tok2Vec_model,
|
||||
)
|
||||
from spacy.schemas import ConfigSchema, ConfigSchemaPretrain
|
||||
from spacy.training import Example
|
||||
from spacy.util import (
|
||||
load_config,
|
||||
load_config_from_str,
|
||||
|
@ -422,6 +423,55 @@ def test_config_overrides():
|
|||
assert nlp.pipe_names == ["tok2vec", "tagger"]
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore:\\[W036")
|
||||
def test_config_overrides_registered_functions():
|
||||
nlp = spacy.blank("en")
|
||||
nlp.add_pipe("attribute_ruler")
|
||||
with make_tempdir() as d:
|
||||
nlp.to_disk(d)
|
||||
nlp_re1 = spacy.load(
|
||||
d,
|
||||
config={
|
||||
"components": {
|
||||
"attribute_ruler": {
|
||||
"scorer": {"@scorers": "spacy.tagger_scorer.v1"}
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
assert (
|
||||
nlp_re1.config["components"]["attribute_ruler"]["scorer"]["@scorers"]
|
||||
== "spacy.tagger_scorer.v1"
|
||||
)
|
||||
|
||||
@registry.misc("test_some_other_key")
|
||||
def misc_some_other_key():
|
||||
return "some_other_key"
|
||||
|
||||
nlp_re2 = spacy.load(
|
||||
d,
|
||||
config={
|
||||
"components": {
|
||||
"attribute_ruler": {
|
||||
"scorer": {
|
||||
"@scorers": "spacy.overlapping_labeled_spans_scorer.v1",
|
||||
"spans_key": {"@misc": "test_some_other_key"},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
assert nlp_re2.config["components"]["attribute_ruler"]["scorer"][
|
||||
"spans_key"
|
||||
] == {"@misc": "test_some_other_key"}
|
||||
# run dummy evaluation (will return None scores) in order to test that
|
||||
# the spans_key value in the nested override is working as intended in
|
||||
# the config
|
||||
example = Example.from_dict(nlp_re2.make_doc("a b c"), {})
|
||||
scores = nlp_re2.evaluate([example])
|
||||
assert "spans_some_other_key_f" in scores
|
||||
|
||||
|
||||
def test_config_interpolation():
|
||||
config = Config().from_str(nlp_config_string, interpolate=False)
|
||||
assert config["corpora"]["train"]["path"] == "${paths.train}"
|
||||
|
|
|
@ -252,6 +252,10 @@ def test_minor_version(a1, a2, b1, b2, is_match):
|
|||
{"training.batch_size": 128, "training.optimizer.learn_rate": 0.01},
|
||||
{"training": {"batch_size": 128, "optimizer": {"learn_rate": 0.01}}},
|
||||
),
|
||||
(
|
||||
{"attribute_ruler.scorer.@scorers": "spacy.tagger_scorer.v1"},
|
||||
{"attribute_ruler": {"scorer": {"@scorers": "spacy.tagger_scorer.v1"}}},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_dot_to_dict(dot_notation, expected):
|
||||
|
@ -260,6 +264,29 @@ def test_dot_to_dict(dot_notation, expected):
|
|||
assert util.dict_to_dot(result) == dot_notation
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dot_notation,expected",
|
||||
[
|
||||
(
|
||||
{"token.pos": True, "token._.xyz": True},
|
||||
{"token": {"pos": True, "_": {"xyz": True}}},
|
||||
),
|
||||
(
|
||||
{"training.batch_size": 128, "training.optimizer.learn_rate": 0.01},
|
||||
{"training": {"batch_size": 128, "optimizer": {"learn_rate": 0.01}}},
|
||||
),
|
||||
(
|
||||
{"attribute_ruler.scorer": {"@scorers": "spacy.tagger_scorer.v1"}},
|
||||
{"attribute_ruler": {"scorer": {"@scorers": "spacy.tagger_scorer.v1"}}},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_dot_to_dict_overrides(dot_notation, expected):
|
||||
result = util.dot_to_dict(dot_notation)
|
||||
assert result == expected
|
||||
assert util.dict_to_dot(result, for_overrides=True) == dot_notation
|
||||
|
||||
|
||||
def test_set_dot_to_object():
|
||||
config = {"foo": {"bar": 1, "baz": {"x": "y"}}, "test": {"a": {"b": "c"}}}
|
||||
with pytest.raises(KeyError):
|
||||
|
|
|
@ -402,6 +402,7 @@ def test_vectors_serialize():
|
|||
row_r = v_r.add("D", vector=OPS.asarray([10, 20, 30, 40], dtype="f"))
|
||||
assert row == row_r
|
||||
assert_equal(OPS.to_numpy(v.data), OPS.to_numpy(v_r.data))
|
||||
assert v.attr == v_r.attr
|
||||
|
||||
|
||||
def test_vector_is_oov():
|
||||
|
@ -646,3 +647,32 @@ def test_equality():
|
|||
vectors1.resize((5, 9))
|
||||
vectors2.resize((5, 9))
|
||||
assert vectors1 == vectors2
|
||||
|
||||
|
||||
def test_vectors_attr():
|
||||
data = numpy.asarray([[0, 0, 0], [1, 2, 3], [9, 8, 7]], dtype="f")
|
||||
# default ORTH
|
||||
nlp = English()
|
||||
nlp.vocab.vectors = Vectors(data=data, keys=["A", "B", "C"])
|
||||
assert nlp.vocab.strings["A"] in nlp.vocab.vectors.key2row
|
||||
assert nlp.vocab.strings["a"] not in nlp.vocab.vectors.key2row
|
||||
assert nlp.vocab["A"].has_vector is True
|
||||
assert nlp.vocab["a"].has_vector is False
|
||||
assert nlp("A")[0].has_vector is True
|
||||
assert nlp("a")[0].has_vector is False
|
||||
|
||||
# custom LOWER
|
||||
nlp = English()
|
||||
nlp.vocab.vectors = Vectors(data=data, keys=["a", "b", "c"], attr="LOWER")
|
||||
assert nlp.vocab.strings["A"] not in nlp.vocab.vectors.key2row
|
||||
assert nlp.vocab.strings["a"] in nlp.vocab.vectors.key2row
|
||||
assert nlp.vocab["A"].has_vector is True
|
||||
assert nlp.vocab["a"].has_vector is True
|
||||
assert nlp("A")[0].has_vector is True
|
||||
assert nlp("a")[0].has_vector is True
|
||||
# add a new vectors entry
|
||||
assert nlp.vocab["D"].has_vector is False
|
||||
assert nlp.vocab["d"].has_vector is False
|
||||
nlp.vocab.set_vector("D", numpy.asarray([4, 5, 6]))
|
||||
assert nlp.vocab["D"].has_vector is True
|
||||
assert nlp.vocab["d"].has_vector is True
|
||||
|
|
|
@ -35,6 +35,7 @@ from ..attrs cimport (
|
|||
LENGTH,
|
||||
MORPH,
|
||||
NORM,
|
||||
ORTH,
|
||||
POS,
|
||||
SENT_START,
|
||||
SPACY,
|
||||
|
@ -613,13 +614,26 @@ cdef class Doc:
|
|||
"""
|
||||
if "similarity" in self.user_hooks:
|
||||
return self.user_hooks["similarity"](self, other)
|
||||
if isinstance(other, (Lexeme, Token)) and self.length == 1:
|
||||
if self.c[0].lex.orth == other.orth:
|
||||
attr = getattr(self.vocab.vectors, "attr", ORTH)
|
||||
cdef Token this_token
|
||||
cdef Token other_token
|
||||
cdef Lexeme other_lex
|
||||
if len(self) == 1 and isinstance(other, Token):
|
||||
this_token = self[0]
|
||||
other_token = other
|
||||
if Token.get_struct_attr(this_token.c, attr) == Token.get_struct_attr(other_token.c, attr):
|
||||
return 1.0
|
||||
elif isinstance(other, (Span, Doc)) and len(self) == len(other):
|
||||
elif len(self) == 1 and isinstance(other, Lexeme):
|
||||
this_token = self[0]
|
||||
other_lex = other
|
||||
if Token.get_struct_attr(this_token.c, attr) == Lexeme.get_struct_attr(other_lex.c, attr):
|
||||
return 1.0
|
||||
elif isinstance(other, (Doc, Span)) and len(self) == len(other):
|
||||
similar = True
|
||||
for i in range(self.length):
|
||||
if self[i].orth != other[i].orth:
|
||||
for i in range(len(self)):
|
||||
this_token = self[i]
|
||||
other_token = other[i]
|
||||
if Token.get_struct_attr(this_token.c, attr) != Token.get_struct_attr(other_token.c, attr):
|
||||
similar = False
|
||||
break
|
||||
if similar:
|
||||
|
|
|
@ -8,13 +8,14 @@ import numpy
|
|||
from thinc.api import get_array_module
|
||||
|
||||
from ..attrs cimport *
|
||||
from ..attrs cimport attr_id_t
|
||||
from ..attrs cimport ORTH, attr_id_t
|
||||
from ..lexeme cimport Lexeme
|
||||
from ..parts_of_speech cimport univ_pos_t
|
||||
from ..structs cimport LexemeC, TokenC
|
||||
from ..symbols cimport dep
|
||||
from ..typedefs cimport attr_t, flags_t, hash_t
|
||||
from .doc cimport _get_lca_matrix, get_token_attr, token_by_end, token_by_start
|
||||
from .token cimport Token
|
||||
|
||||
from ..errors import Errors, Warnings
|
||||
from ..util import normalize_slice
|
||||
|
@ -341,13 +342,26 @@ cdef class Span:
|
|||
"""
|
||||
if "similarity" in self.doc.user_span_hooks:
|
||||
return self.doc.user_span_hooks["similarity"](self, other)
|
||||
if len(self) == 1 and hasattr(other, "orth"):
|
||||
if self[0].orth == other.orth:
|
||||
attr = getattr(self.doc.vocab.vectors, "attr", ORTH)
|
||||
cdef Token this_token
|
||||
cdef Token other_token
|
||||
cdef Lexeme other_lex
|
||||
if len(self) == 1 and isinstance(other, Token):
|
||||
this_token = self[0]
|
||||
other_token = other
|
||||
if Token.get_struct_attr(this_token.c, attr) == Token.get_struct_attr(other_token.c, attr):
|
||||
return 1.0
|
||||
elif len(self) == 1 and isinstance(other, Lexeme):
|
||||
this_token = self[0]
|
||||
other_lex = other
|
||||
if Token.get_struct_attr(this_token.c, attr) == Lexeme.get_struct_attr(other_lex.c, attr):
|
||||
return 1.0
|
||||
elif isinstance(other, (Doc, Span)) and len(self) == len(other):
|
||||
similar = True
|
||||
for i in range(len(self)):
|
||||
if self[i].orth != getattr(other[i], "orth", None):
|
||||
this_token = self[i]
|
||||
other_token = other[i]
|
||||
if Token.get_struct_attr(this_token.c, attr) != Token.get_struct_attr(other_token.c, attr):
|
||||
similar = False
|
||||
break
|
||||
if similar:
|
||||
|
|
|
@ -28,6 +28,7 @@ from ..attrs cimport (
|
|||
LIKE_EMAIL,
|
||||
LIKE_NUM,
|
||||
LIKE_URL,
|
||||
ORTH,
|
||||
)
|
||||
from ..lexeme cimport Lexeme
|
||||
from ..symbols cimport conj
|
||||
|
@ -214,11 +215,17 @@ cdef class Token:
|
|||
"""
|
||||
if "similarity" in self.doc.user_token_hooks:
|
||||
return self.doc.user_token_hooks["similarity"](self, other)
|
||||
if hasattr(other, "__len__") and len(other) == 1 and hasattr(other, "__getitem__"):
|
||||
if self.c.lex.orth == getattr(other[0], "orth", None):
|
||||
attr = getattr(self.doc.vocab.vectors, "attr", ORTH)
|
||||
cdef Token this_token = self
|
||||
cdef Token other_token
|
||||
cdef Lexeme other_lex
|
||||
if isinstance(other, Token):
|
||||
other_token = other
|
||||
if Token.get_struct_attr(this_token.c, attr) == Token.get_struct_attr(other_token.c, attr):
|
||||
return 1.0
|
||||
elif hasattr(other, "orth"):
|
||||
if self.c.lex.orth == other.orth:
|
||||
elif isinstance(other, Lexeme):
|
||||
other_lex = other
|
||||
if Token.get_struct_attr(this_token.c, attr) == Lexeme.get_struct_attr(other_lex.c, attr):
|
||||
return 1.0
|
||||
if self.vocab.vectors.n_keys == 0:
|
||||
warnings.warn(Warnings.W007.format(obj="Token"))
|
||||
|
@ -415,7 +422,7 @@ cdef class Token:
|
|||
return self.doc.user_token_hooks["has_vector"](self)
|
||||
if self.vocab.vectors.size == 0 and self.doc.tensor.size != 0:
|
||||
return True
|
||||
return self.vocab.has_vector(self.c.lex.orth)
|
||||
return self.vocab.has_vector(Token.get_struct_attr(self.c, self.vocab.vectors.attr))
|
||||
|
||||
@property
|
||||
def vector(self):
|
||||
|
@ -431,7 +438,7 @@ cdef class Token:
|
|||
if self.vocab.vectors.size == 0 and self.doc.tensor.size != 0:
|
||||
return self.doc.tensor[self.i]
|
||||
else:
|
||||
return self.vocab.get_vector(self.c.lex.orth)
|
||||
return self.vocab.get_vector(Token.get_struct_attr(self.c, self.vocab.vectors.attr))
|
||||
|
||||
@property
|
||||
def vector_norm(self):
|
||||
|
|
|
@ -76,7 +76,8 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
|||
with nlp.select_pipes(enable=resume_components):
|
||||
logger.info("Resuming training for: %s", resume_components)
|
||||
nlp.resume_training(sgd=optimizer)
|
||||
# Make sure that listeners are defined before initializing further
|
||||
# Make sure that internal component names are synced and listeners are
|
||||
# defined before initializing further
|
||||
nlp._link_components()
|
||||
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
|
||||
if T["max_epochs"] == -1:
|
||||
|
@ -215,9 +216,14 @@ def convert_vectors(
|
|||
prune: int,
|
||||
name: Optional[str] = None,
|
||||
mode: str = VectorsMode.default,
|
||||
attr: str = "ORTH",
|
||||
) -> None:
|
||||
vectors_loc = ensure_path(vectors_loc)
|
||||
if vectors_loc and vectors_loc.parts[-1].endswith(".npz"):
|
||||
if attr != "ORTH":
|
||||
raise ValueError(
|
||||
"ORTH is the only attribute supported for vectors in .npz format."
|
||||
)
|
||||
nlp.vocab.vectors = Vectors(
|
||||
strings=nlp.vocab.strings, data=numpy.load(vectors_loc.open("rb"))
|
||||
)
|
||||
|
@ -245,11 +251,15 @@ def convert_vectors(
|
|||
nlp.vocab.vectors = Vectors(
|
||||
strings=nlp.vocab.strings,
|
||||
data=vectors_data,
|
||||
attr=attr,
|
||||
**floret_settings,
|
||||
)
|
||||
else:
|
||||
nlp.vocab.vectors = Vectors(
|
||||
strings=nlp.vocab.strings, data=vectors_data, keys=vector_keys
|
||||
strings=nlp.vocab.strings,
|
||||
data=vectors_data,
|
||||
keys=vector_keys,
|
||||
attr=attr,
|
||||
)
|
||||
nlp.vocab.deduplicate_vectors()
|
||||
if name is None:
|
||||
|
|
|
@ -547,7 +547,7 @@ def load_model_from_path(
|
|||
if not meta:
|
||||
meta = get_model_meta(model_path)
|
||||
config_path = model_path / "config.cfg"
|
||||
overrides = dict_to_dot(config)
|
||||
overrides = dict_to_dot(config, for_overrides=True)
|
||||
config = load_config(config_path, overrides=overrides)
|
||||
nlp = load_model_from_config(
|
||||
config,
|
||||
|
@ -1525,14 +1525,19 @@ def dot_to_dict(values: Dict[str, Any]) -> Dict[str, dict]:
|
|||
return result
|
||||
|
||||
|
||||
def dict_to_dot(obj: Dict[str, dict]) -> Dict[str, Any]:
|
||||
def dict_to_dot(obj: Dict[str, dict], *, for_overrides: bool = False) -> Dict[str, Any]:
|
||||
"""Convert dot notation to a dict. For example: {"token": {"pos": True,
|
||||
"_": {"xyz": True }}} becomes {"token.pos": True, "token._.xyz": True}.
|
||||
|
||||
values (Dict[str, dict]): The dict to convert.
|
||||
obj (Dict[str, dict]): The dict to convert.
|
||||
for_overrides (bool): Whether to enable special handling for registered
|
||||
functions in overrides.
|
||||
RETURNS (Dict[str, Any]): The key/value pairs.
|
||||
"""
|
||||
return {".".join(key): value for key, value in walk_dict(obj)}
|
||||
return {
|
||||
".".join(key): value
|
||||
for key, value in walk_dict(obj, for_overrides=for_overrides)
|
||||
}
|
||||
|
||||
|
||||
def dot_to_object(config: Config, section: str):
|
||||
|
@ -1574,13 +1579,20 @@ def set_dot_to_object(config: Config, section: str, value: Any) -> None:
|
|||
|
||||
|
||||
def walk_dict(
|
||||
node: Dict[str, Any], parent: List[str] = []
|
||||
node: Dict[str, Any], parent: List[str] = [], *, for_overrides: bool = False
|
||||
) -> Iterator[Tuple[List[str], Any]]:
|
||||
"""Walk a dict and yield the path and values of the leaves."""
|
||||
"""Walk a dict and yield the path and values of the leaves.
|
||||
|
||||
for_overrides (bool): Whether to treat registered functions that start with
|
||||
@ as final values rather than dicts to traverse.
|
||||
"""
|
||||
for key, value in node.items():
|
||||
key_parent = [*parent, key]
|
||||
if isinstance(value, dict):
|
||||
yield from walk_dict(value, key_parent)
|
||||
if isinstance(value, dict) and (
|
||||
not for_overrides
|
||||
or not any(value_key.startswith("@") for value_key in value)
|
||||
):
|
||||
yield from walk_dict(value, key_parent, for_overrides=for_overrides)
|
||||
else:
|
||||
yield (key_parent, value)
|
||||
|
||||
|
|
|
@ -15,9 +15,11 @@ from thinc.api import Ops, get_array_module, get_current_ops
|
|||
from thinc.backends import get_array_ops
|
||||
from thinc.types import Floats2d
|
||||
|
||||
from .attrs cimport ORTH, attr_id_t
|
||||
from .strings cimport StringStore
|
||||
|
||||
from . import util
|
||||
from .attrs import IDS
|
||||
from .errors import Errors, Warnings
|
||||
from .strings import get_string_id
|
||||
|
||||
|
@ -64,8 +66,9 @@ cdef class Vectors:
|
|||
cdef readonly uint32_t hash_seed
|
||||
cdef readonly unicode bow
|
||||
cdef readonly unicode eow
|
||||
cdef readonly attr_id_t attr
|
||||
|
||||
def __init__(self, *, strings=None, shape=None, data=None, keys=None, name=None, mode=Mode.default, minn=0, maxn=0, hash_count=1, hash_seed=0, bow="<", eow=">"):
|
||||
def __init__(self, *, strings=None, shape=None, data=None, keys=None, name=None, mode=Mode.default, minn=0, maxn=0, hash_count=1, hash_seed=0, bow="<", eow=">", attr="ORTH"):
|
||||
"""Create a new vector store.
|
||||
|
||||
strings (StringStore): The string store.
|
||||
|
@ -80,6 +83,8 @@ cdef class Vectors:
|
|||
hash_seed (int): The floret hash seed (default: 0).
|
||||
bow (str): The floret BOW string (default: "<").
|
||||
eow (str): The floret EOW string (default: ">").
|
||||
attr (Union[int, str]): The token attribute for the vector keys
|
||||
(default: "ORTH").
|
||||
|
||||
DOCS: https://spacy.io/api/vectors#init
|
||||
"""
|
||||
|
@ -103,6 +108,14 @@ cdef class Vectors:
|
|||
self.hash_seed = hash_seed
|
||||
self.bow = bow
|
||||
self.eow = eow
|
||||
if isinstance(attr, (int, long)):
|
||||
self.attr = attr
|
||||
else:
|
||||
attr = attr.upper()
|
||||
if attr == "TEXT":
|
||||
attr = "ORTH"
|
||||
self.attr = IDS.get(attr, ORTH)
|
||||
|
||||
if self.mode == Mode.default:
|
||||
if data is None:
|
||||
if shape is None:
|
||||
|
@ -546,6 +559,7 @@ cdef class Vectors:
|
|||
"hash_seed": self.hash_seed,
|
||||
"bow": self.bow,
|
||||
"eow": self.eow,
|
||||
"attr": self.attr,
|
||||
}
|
||||
|
||||
def _set_cfg(self, cfg):
|
||||
|
@ -556,6 +570,7 @@ cdef class Vectors:
|
|||
self.hash_seed = cfg.get("hash_seed", 0)
|
||||
self.bow = cfg.get("bow", "<")
|
||||
self.eow = cfg.get("eow", ">")
|
||||
self.attr = cfg.get("attr", ORTH)
|
||||
|
||||
def to_disk(self, path, *, exclude=tuple()):
|
||||
"""Save the current state to a directory.
|
||||
|
|
|
@ -365,8 +365,13 @@ cdef class Vocab:
|
|||
self[orth]
|
||||
# Make prob negative so it sorts by rank ascending
|
||||
# (key2row contains the rank)
|
||||
priority = [(-lex.prob, self.vectors.key2row[lex.orth], lex.orth)
|
||||
for lex in self if lex.orth in self.vectors.key2row]
|
||||
priority = []
|
||||
cdef Lexeme lex
|
||||
cdef attr_t value
|
||||
for lex in self:
|
||||
value = Lexeme.get_struct_attr(lex.c, self.vectors.attr)
|
||||
if value in self.vectors.key2row:
|
||||
priority.append((-lex.prob, self.vectors.key2row[value], value))
|
||||
priority.sort()
|
||||
indices = xp.asarray([i for (prob, i, key) in priority], dtype="uint64")
|
||||
keys = xp.asarray([key for (prob, i, key) in priority], dtype="uint64")
|
||||
|
@ -399,8 +404,10 @@ cdef class Vocab:
|
|||
"""
|
||||
if isinstance(orth, str):
|
||||
orth = self.strings.add(orth)
|
||||
if self.has_vector(orth):
|
||||
return self.vectors[orth]
|
||||
cdef Lexeme lex = self[orth]
|
||||
key = Lexeme.get_struct_attr(lex.c, self.vectors.attr)
|
||||
if self.has_vector(key):
|
||||
return self.vectors[key]
|
||||
xp = get_array_module(self.vectors.data)
|
||||
vectors = xp.zeros((self.vectors_length,), dtype="f")
|
||||
return vectors
|
||||
|
@ -416,15 +423,16 @@ cdef class Vocab:
|
|||
"""
|
||||
if isinstance(orth, str):
|
||||
orth = self.strings.add(orth)
|
||||
if self.vectors.is_full and orth not in self.vectors:
|
||||
cdef Lexeme lex = self[orth]
|
||||
key = Lexeme.get_struct_attr(lex.c, self.vectors.attr)
|
||||
if self.vectors.is_full and key not in self.vectors:
|
||||
new_rows = max(100, int(self.vectors.shape[0]*1.3))
|
||||
if self.vectors.shape[1] == 0:
|
||||
width = vector.size
|
||||
else:
|
||||
width = self.vectors.shape[1]
|
||||
self.vectors.resize((new_rows, width))
|
||||
lex = self[orth] # Add word to vocab if necessary
|
||||
row = self.vectors.add(orth, vector=vector)
|
||||
row = self.vectors.add(key, vector=vector)
|
||||
if row >= 0:
|
||||
lex.rank = row
|
||||
|
||||
|
@ -439,7 +447,9 @@ cdef class Vocab:
|
|||
"""
|
||||
if isinstance(orth, str):
|
||||
orth = self.strings.add(orth)
|
||||
return orth in self.vectors
|
||||
cdef Lexeme lex = self[orth]
|
||||
key = Lexeme.get_struct_attr(lex.c, self.vectors.attr)
|
||||
return key in self.vectors
|
||||
|
||||
property lookups:
|
||||
def __get__(self):
|
||||
|
|
|
@ -303,7 +303,7 @@ mapped to a zero vector. See the documentation on
|
|||
| `nM` | The width of the static vectors. ~~Optional[int]~~ |
|
||||
| `dropout` | Optional dropout rate. If set, it's applied per dimension over the whole batch. Defaults to `None`. ~~Optional[float]~~ |
|
||||
| `init_W` | The [initialization function](https://thinc.ai/docs/api-initializers). Defaults to [`glorot_uniform_init`](https://thinc.ai/docs/api-initializers#glorot_uniform_init). ~~Callable[[Ops, Tuple[int, ...]]], FloatsXd]~~ |
|
||||
| `key_attr` | Defaults to `"ORTH"`. ~~str~~ |
|
||||
| `key_attr` | This setting is ignored in spaCy v3.6+. To set a custom key attribute for vectors, configure it through [`Vectors`](/api/vectors) or [`spacy init vectors`](/api/cli#init-vectors). Defaults to `"ORTH"`. ~~str~~ |
|
||||
| **CREATES** | The model using the architecture. ~~Model[List[Doc], Ragged]~~ |
|
||||
|
||||
### spacy.FeatureExtractor.v1 {id="FeatureExtractor"}
|
||||
|
|
|
@ -211,7 +211,8 @@ $ python -m spacy init vectors [lang] [vectors_loc] [output_dir] [--prune] [--tr
|
|||
| `output_dir` | Pipeline output directory. Will be created if it doesn't exist. ~~Path (positional)~~ |
|
||||
| `--truncate`, `-t` | Number of vectors to truncate to when reading in vectors file. Defaults to `0` for no truncation. ~~int (option)~~ |
|
||||
| `--prune`, `-p` | Number of vectors to prune the vocabulary to. Defaults to `-1` for no pruning. ~~int (option)~~ |
|
||||
| `--mode`, `-m` | Vectors mode: `default` or [`floret`](https://github.com/explosion/floret). Defaults to `default`. ~~Optional[str] \(option)~~ |
|
||||
| `--mode`, `-m` | Vectors mode: `default` or [`floret`](https://github.com/explosion/floret). Defaults to `default`. ~~str \(option)~~ |
|
||||
| `--attr`, `-a` | Token attribute to use for vectors, e.g. `LOWER` or `NORM`) Defaults to `ORTH`. ~~str \(option)~~ |
|
||||
| `--name`, `-n` | Name to assign to the word vectors in the `meta.json`, e.g. `en_core_web_md.vectors`. ~~Optional[str] \(option)~~ |
|
||||
| `--verbose`, `-V` | Print additional information and explanations. ~~bool (flag)~~ |
|
||||
| `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ |
|
||||
|
|
|
@ -60,6 +60,7 @@ modified later.
|
|||
| `hash_seed` <Tag variant="new">3.2</Tag> | The floret hash seed (default: `0`). ~~int~~ |
|
||||
| `bow` <Tag variant="new">3.2</Tag> | The floret BOW string (default: `"<"`). ~~str~~ |
|
||||
| `eow` <Tag variant="new">3.2</Tag> | The floret EOW string (default: `">"`). ~~str~~ |
|
||||
| `attr` <Tag variant="new">3.6</Tag> | The token attribute for the vector keys (default: `"ORTH"`). ~~Union[int, str]~~ |
|
||||
|
||||
## Vectors.\_\_getitem\_\_ {id="getitem",tag="method"}
|
||||
|
||||
|
@ -453,8 +454,9 @@ Load state from a binary string.
|
|||
|
||||
## Attributes {id="attributes"}
|
||||
|
||||
| Name | Description |
|
||||
| --------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `data` | Stored vectors data. `numpy` is used for CPU vectors, `cupy` for GPU vectors. ~~Union[numpy.ndarray[ndim=1, dtype=float32], cupy.ndarray[ndim=1, dtype=float32]]~~ |
|
||||
| `key2row` | Dictionary mapping word hashes to rows in the `Vectors.data` table. ~~Dict[int, int]~~ |
|
||||
| `keys` | Array keeping the keys in order, such that `keys[vectors.key2row[key]] == key`. ~~Union[numpy.ndarray[ndim=1, dtype=float32], cupy.ndarray[ndim=1, dtype=float32]]~~ |
|
||||
| Name | Description |
|
||||
| ----------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `data` | Stored vectors data. `numpy` is used for CPU vectors, `cupy` for GPU vectors. ~~Union[numpy.ndarray[ndim=1, dtype=float32], cupy.ndarray[ndim=1, dtype=float32]]~~ |
|
||||
| `key2row` | Dictionary mapping word hashes to rows in the `Vectors.data` table. ~~Dict[int, int]~~ |
|
||||
| `keys` | Array keeping the keys in order, such that `keys[vectors.key2row[key]] == key`. ~~Union[numpy.ndarray[ndim=1, dtype=float32], cupy.ndarray[ndim=1, dtype=float32]]~~ |
|
||||
| `attr` <Tag variant="new">3.6</Tag> | The token attribute for the vector keys. ~~int~~ |
|
||||
|
|
|
@ -11,7 +11,6 @@ menu:
|
|||
- ['Custom Functions', 'custom-functions']
|
||||
- ['Initialization', 'initialization']
|
||||
- ['Data Utilities', 'data']
|
||||
- ['Parallel Training', 'parallel-training']
|
||||
- ['Internal API', 'api']
|
||||
---
|
||||
|
||||
|
@ -1565,77 +1564,6 @@ token-based annotations like the dependency parse or entity labels, you'll need
|
|||
to take care to adjust the `Example` object so its annotations match and remain
|
||||
valid.
|
||||
|
||||
## Parallel & distributed training with Ray {id="parallel-training"}
|
||||
|
||||
> #### Installation
|
||||
>
|
||||
> ```bash
|
||||
> $ pip install -U %%SPACY_PKG_NAME[ray]%%SPACY_PKG_FLAGS
|
||||
> # Check that the CLI is registered
|
||||
> $ python -m spacy ray --help
|
||||
> ```
|
||||
|
||||
[Ray](https://ray.io/) is a fast and simple framework for building and running
|
||||
**distributed applications**. You can use Ray to train spaCy on one or more
|
||||
remote machines, potentially speeding up your training process. Parallel
|
||||
training won't always be faster though – it depends on your batch size, models,
|
||||
and hardware.
|
||||
|
||||
<Infobox variant="warning">
|
||||
|
||||
To use Ray with spaCy, you need the
|
||||
[`spacy-ray`](https://github.com/explosion/spacy-ray) package installed.
|
||||
Installing the package will automatically add the `ray` command to the spaCy
|
||||
CLI.
|
||||
|
||||
</Infobox>
|
||||
|
||||
The [`spacy ray train`](/api/cli#ray-train) command follows the same API as
|
||||
[`spacy train`](/api/cli#train), with a few extra options to configure the Ray
|
||||
setup. You can optionally set the `--address` option to point to your Ray
|
||||
cluster. If it's not set, Ray will run locally.
|
||||
|
||||
```bash
|
||||
python -m spacy ray train config.cfg --n-workers 2
|
||||
```
|
||||
|
||||
<Project id="integrations/ray">
|
||||
|
||||
Get started with parallel training using our project template. It trains a
|
||||
simple model on a Universal Dependencies Treebank and lets you parallelize the
|
||||
training with Ray.
|
||||
|
||||
</Project>
|
||||
|
||||
### How parallel training works {id="parallel-training-details"}
|
||||
|
||||
Each worker receives a shard of the **data** and builds a copy of the **model
|
||||
and optimizer** from the [`config.cfg`](#config). It also has a communication
|
||||
channel to **pass gradients and parameters** to the other workers. Additionally,
|
||||
each worker is given ownership of a subset of the parameter arrays. Every
|
||||
parameter array is owned by exactly one worker, and the workers are given a
|
||||
mapping so they know which worker owns which parameter.
|
||||
|
||||
![Illustration of setup](/images/spacy-ray.svg)
|
||||
|
||||
As training proceeds, every worker will be computing gradients for **all** of
|
||||
the model parameters. When they compute gradients for parameters they don't own,
|
||||
they'll **send them to the worker** that does own that parameter, along with a
|
||||
version identifier so that the owner can decide whether to discard the gradient.
|
||||
Workers use the gradients they receive and the ones they compute locally to
|
||||
update the parameters they own, and then broadcast the updated array and a new
|
||||
version ID to the other workers.
|
||||
|
||||
This training procedure is **asynchronous** and **non-blocking**. Workers always
|
||||
push their gradient increments and parameter updates, they do not have to pull
|
||||
them and block on the result, so the transfers can happen in the background,
|
||||
overlapped with the actual training work. The workers also do not have to stop
|
||||
and wait for each other ("synchronize") at the start of each batch. This is very
|
||||
useful for spaCy, because spaCy is often trained on long documents, which means
|
||||
**batches can vary in size** significantly. Uneven workloads make synchronous
|
||||
gradient descent inefficient, because if one batch is slow, all of the other
|
||||
workers are stuck waiting for it to complete before they can continue.
|
||||
|
||||
## Internal training API {id="api"}
|
||||
|
||||
<Infobox variant="danger">
|
||||
|
|
|
@ -4372,7 +4372,7 @@
|
|||
"code_example": [
|
||||
"import spacy",
|
||||
"",
|
||||
"nlp = spacy.load(\"en_core_web_sm\", disable=[\"ner\"])",
|
||||
"nlp = spacy.load(\"en_core_web_sm\", exclude=[\"ner\"])",
|
||||
"nlp.add_pipe(\"span_marker\", config={\"model\": \"tomaarsen/span-marker-roberta-large-ontonotes5\"})",
|
||||
"",
|
||||
"text = \"\"\"Cleopatra VII, also known as Cleopatra the Great, was the last active ruler of the \\",
|
||||
|
|
|
@ -13,6 +13,8 @@ import 'prismjs/components/prism-json.min.js'
|
|||
import 'prismjs/components/prism-markdown.min.js'
|
||||
import 'prismjs/components/prism-python.min.js'
|
||||
import 'prismjs/components/prism-yaml.min.js'
|
||||
import 'prismjs/components/prism-docker.min.js'
|
||||
import 'prismjs/components/prism-r.min.js'
|
||||
|
||||
import { isString } from './util'
|
||||
import Link, { OptionalLink } from './link'
|
||||
|
@ -172,7 +174,7 @@ const convertLine = ({ line, prompt, lang }) => {
|
|||
return handlePromot({ lineFlat, prompt })
|
||||
}
|
||||
|
||||
return lang === 'none' || !lineFlat ? (
|
||||
return lang === 'none' || !lineFlat || !(lang in Prism.languages) ? (
|
||||
lineFlat
|
||||
) : (
|
||||
<span
|
||||
|
|
Loading…
Reference in New Issue
Block a user