Tidy up and auto-format

This commit is contained in:
Ines Montani 2021-07-18 15:44:56 +10:00
parent 313f55e560
commit f90482d077
19 changed files with 304 additions and 319 deletions

View File

@ -5,7 +5,7 @@ import sys
# set library-specific custom warning handling before doing anything else # set library-specific custom warning handling before doing anything else
from .errors import setup_default_warnings from .errors import setup_default_warnings
setup_default_warnings() setup_default_warnings() # noqa: E402
# These are imported as part of the API # These are imported as part of the API
from thinc.api import prefer_gpu, require_gpu, require_cpu # noqa: F401 from thinc.api import prefer_gpu, require_gpu, require_cpu # noqa: F401

View File

@ -1447,7 +1447,7 @@ class Language:
) -> Iterator[Tuple[Doc, _AnyContext]]: ) -> Iterator[Tuple[Doc, _AnyContext]]:
... ...
def pipe( def pipe( # noqa: F811
self, self,
texts: Iterable[str], texts: Iterable[str],
*, *,

View File

@ -69,4 +69,4 @@ def test_create_with_heads_and_no_deps(vocab):
words = "I like ginger".split() words = "I like ginger".split()
heads = list(range(len(words))) heads = list(range(len(words)))
with pytest.raises(ValueError): with pytest.raises(ValueError):
doc = Doc(vocab, words=words, heads=heads) Doc(vocab, words=words, heads=heads)

View File

@ -329,8 +329,8 @@ def test_ner_constructor(en_vocab):
} }
cfg = {"model": DEFAULT_NER_MODEL} cfg = {"model": DEFAULT_NER_MODEL}
model = registry.resolve(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
ner_1 = EntityRecognizer(en_vocab, model, **config) EntityRecognizer(en_vocab, model, **config)
ner_2 = EntityRecognizer(en_vocab, model) EntityRecognizer(en_vocab, model)
def test_ner_before_ruler(): def test_ner_before_ruler():

View File

@ -224,8 +224,8 @@ def test_parser_constructor(en_vocab):
} }
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.resolve(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
parser_1 = DependencyParser(en_vocab, model, **config) DependencyParser(en_vocab, model, **config)
parser_2 = DependencyParser(en_vocab, model) DependencyParser(en_vocab, model)
@pytest.mark.parametrize("pipe_name", ["parser", "beam_parser"]) @pytest.mark.parametrize("pipe_name", ["parser", "beam_parser"])

View File

@ -74,7 +74,7 @@ def test_annotates_on_update():
nlp.add_pipe("assert_sents") nlp.add_pipe("assert_sents")
# When the pipeline runs, annotations are set # When the pipeline runs, annotations are set
doc = nlp("This is a sentence.") nlp("This is a sentence.")
examples = [] examples = []
for text in ["a a", "b b", "c c"]: for text in ["a a", "b b", "c c"]:

View File

@ -110,4 +110,4 @@ def test_lemmatizer_serialize(nlp):
assert doc2[0].lemma_ == "cope" assert doc2[0].lemma_ == "cope"
# Make sure that lemmatizer cache can be pickled # Make sure that lemmatizer cache can be pickled
b = pickle.dumps(lemmatizer2) pickle.dumps(lemmatizer2)

View File

@ -52,7 +52,7 @@ def test_cant_add_pipe_first_and_last(nlp):
nlp.add_pipe("new_pipe", first=True, last=True) nlp.add_pipe("new_pipe", first=True, last=True)
@pytest.mark.parametrize("name", ["my_component"]) @pytest.mark.parametrize("name", ["test_get_pipe"])
def test_get_pipe(nlp, name): def test_get_pipe(nlp, name):
with pytest.raises(KeyError): with pytest.raises(KeyError):
nlp.get_pipe(name) nlp.get_pipe(name)
@ -62,7 +62,7 @@ def test_get_pipe(nlp, name):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"name,replacement,invalid_replacement", "name,replacement,invalid_replacement",
[("my_component", "other_pipe", lambda doc: doc)], [("test_replace_pipe", "other_pipe", lambda doc: doc)],
) )
def test_replace_pipe(nlp, name, replacement, invalid_replacement): def test_replace_pipe(nlp, name, replacement, invalid_replacement):
with pytest.raises(ValueError): with pytest.raises(ValueError):
@ -435,8 +435,8 @@ def test_update_with_annotates():
return component return component
c1 = Language.component(f"{name}1", func=make_component(f"{name}1")) Language.component(f"{name}1", func=make_component(f"{name}1"))
c2 = Language.component(f"{name}2", func=make_component(f"{name}2")) Language.component(f"{name}2", func=make_component(f"{name}2"))
components = set([f"{name}1", f"{name}2"]) components = set([f"{name}1", f"{name}2"])

View File

@ -69,9 +69,12 @@ def test_issue5082():
def test_issue5137(): def test_issue5137():
@Language.factory("my_component") factory_name = "test_issue5137"
pipe_name = "my_component"
@Language.factory(factory_name)
class MyComponent: class MyComponent:
def __init__(self, nlp, name="my_component", categories="all_categories"): def __init__(self, nlp, name=pipe_name, categories="all_categories"):
self.nlp = nlp self.nlp = nlp
self.categories = categories self.categories = categories
self.name = name self.name = name
@ -86,13 +89,13 @@ def test_issue5137():
pass pass
nlp = English() nlp = English()
my_component = nlp.add_pipe("my_component") my_component = nlp.add_pipe(factory_name, name=pipe_name)
assert my_component.categories == "all_categories" assert my_component.categories == "all_categories"
with make_tempdir() as tmpdir: with make_tempdir() as tmpdir:
nlp.to_disk(tmpdir) nlp.to_disk(tmpdir)
overrides = {"components": {"my_component": {"categories": "my_categories"}}} overrides = {"components": {pipe_name: {"categories": "my_categories"}}}
nlp2 = spacy.load(tmpdir, config=overrides) nlp2 = spacy.load(tmpdir, config=overrides)
assert nlp2.get_pipe("my_component").categories == "my_categories" assert nlp2.get_pipe(pipe_name).categories == "my_categories"
def test_issue5141(en_vocab): def test_issue5141(en_vocab):

View File

@ -0,0 +1,281 @@
from spacy.cli.evaluate import print_textcats_auc_per_cat, print_prf_per_type
from spacy.lang.en import English
from spacy.training import Example
from spacy.tokens.doc import Doc
from spacy.vocab import Vocab
from spacy.kb import KnowledgeBase
from spacy.pipeline._parser_internals.arc_eager import ArcEager
from spacy.util import load_config_from_str, load_config
from spacy.cli.init_config import fill_config
from thinc.api import Config
from wasabi import msg
from ..util import make_tempdir
def test_issue7019():
scores = {"LABEL_A": 0.39829102, "LABEL_B": 0.938298329382, "LABEL_C": None}
print_textcats_auc_per_cat(msg, scores)
scores = {
"LABEL_A": {"p": 0.3420302, "r": 0.3929020, "f": 0.49823928932},
"LABEL_B": {"p": None, "r": None, "f": None},
}
print_prf_per_type(msg, scores, name="foo", type="bar")
CONFIG_7029 = """
[nlp]
lang = "en"
pipeline = ["tok2vec", "tagger"]
[components]
[components.tok2vec]
factory = "tok2vec"
[components.tok2vec.model]
@architectures = "spacy.Tok2Vec.v1"
[components.tok2vec.model.embed]
@architectures = "spacy.MultiHashEmbed.v1"
width = ${components.tok2vec.model.encode:width}
attrs = ["NORM","PREFIX","SUFFIX","SHAPE"]
rows = [5000,2500,2500,2500]
include_static_vectors = false
[components.tok2vec.model.encode]
@architectures = "spacy.MaxoutWindowEncoder.v1"
width = 96
depth = 4
window_size = 1
maxout_pieces = 3
[components.tagger]
factory = "tagger"
[components.tagger.model]
@architectures = "spacy.Tagger.v1"
nO = null
[components.tagger.model.tok2vec]
@architectures = "spacy.Tok2VecListener.v1"
width = ${components.tok2vec.model.encode:width}
upstream = "*"
"""
def test_issue7029():
"""Test that an empty document doesn't mess up an entire batch."""
TRAIN_DATA = [
("I like green eggs", {"tags": ["N", "V", "J", "N"]}),
("Eat blue ham", {"tags": ["V", "J", "N"]}),
]
nlp = English.from_config(load_config_from_str(CONFIG_7029))
train_examples = []
for t in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
optimizer = nlp.initialize(get_examples=lambda: train_examples)
for i in range(50):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
texts = ["first", "second", "third", "fourth", "and", "then", "some", ""]
docs1 = list(nlp.pipe(texts, batch_size=1))
docs2 = list(nlp.pipe(texts, batch_size=4))
assert [doc[0].tag_ for doc in docs1[:-1]] == [doc[0].tag_ for doc in docs2[:-1]]
def test_issue7055():
"""Test that fill-config doesn't turn sourced components into factories."""
source_cfg = {
"nlp": {"lang": "en", "pipeline": ["tok2vec", "tagger"]},
"components": {
"tok2vec": {"factory": "tok2vec"},
"tagger": {"factory": "tagger"},
},
}
source_nlp = English.from_config(source_cfg)
with make_tempdir() as dir_path:
# We need to create a loadable source pipeline
source_path = dir_path / "test_model"
source_nlp.to_disk(source_path)
base_cfg = {
"nlp": {"lang": "en", "pipeline": ["tok2vec", "tagger", "ner"]},
"components": {
"tok2vec": {"source": str(source_path)},
"tagger": {"source": str(source_path)},
"ner": {"factory": "ner"},
},
}
base_cfg = Config(base_cfg)
base_path = dir_path / "base.cfg"
base_cfg.to_disk(base_path)
output_path = dir_path / "config.cfg"
fill_config(output_path, base_path, silent=True)
filled_cfg = load_config(output_path)
assert filled_cfg["components"]["tok2vec"]["source"] == str(source_path)
assert filled_cfg["components"]["tagger"]["source"] == str(source_path)
assert filled_cfg["components"]["ner"]["factory"] == "ner"
assert "model" in filled_cfg["components"]["ner"]
def test_issue7056():
"""Test that the Unshift transition works properly, and doesn't cause
sentence segmentation errors."""
vocab = Vocab()
ae = ArcEager(
vocab.strings, ArcEager.get_actions(left_labels=["amod"], right_labels=["pobj"])
)
doc = Doc(vocab, words="Severe pain , after trauma".split())
state = ae.init_batch([doc])[0]
ae.apply_transition(state, "S")
ae.apply_transition(state, "L-amod")
ae.apply_transition(state, "S")
ae.apply_transition(state, "S")
ae.apply_transition(state, "S")
ae.apply_transition(state, "R-pobj")
ae.apply_transition(state, "D")
ae.apply_transition(state, "D")
ae.apply_transition(state, "D")
assert not state.eol()
def test_partial_links():
# Test that having some entities on the doc without gold links, doesn't crash
TRAIN_DATA = [
(
"Russ Cochran his reprints include EC Comics.",
{
"links": {(0, 12): {"Q2146908": 1.0}},
"entities": [(0, 12, "PERSON")],
"sent_starts": [1, -1, 0, 0, 0, 0, 0, 0],
},
)
]
nlp = English()
vector_length = 3
train_examples = []
for text, annotation in TRAIN_DATA:
doc = nlp(text)
train_examples.append(Example.from_dict(doc, annotation))
def create_kb(vocab):
# create artificial KB
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
mykb.add_alias("Russ Cochran", ["Q2146908"], [0.9])
return mykb
# Create and train the Entity Linker
entity_linker = nlp.add_pipe("entity_linker", last=True)
entity_linker.set_kb(create_kb)
optimizer = nlp.initialize(get_examples=lambda: train_examples)
for i in range(2):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
# adding additional components that are required for the entity_linker
nlp.add_pipe("sentencizer", first=True)
patterns = [
{"label": "PERSON", "pattern": [{"LOWER": "russ"}, {"LOWER": "cochran"}]},
{"label": "ORG", "pattern": [{"LOWER": "ec"}, {"LOWER": "comics"}]},
]
ruler = nlp.add_pipe("entity_ruler", before="entity_linker")
ruler.add_patterns(patterns)
# this will run the pipeline on the examples and shouldn't crash
results = nlp.evaluate(train_examples)
assert "PERSON" in results["ents_per_type"]
assert "PERSON" in results["nel_f_per_type"]
assert "ORG" in results["ents_per_type"]
assert "ORG" not in results["nel_f_per_type"]
def test_issue7065():
text = "Kathleen Battle sang in Mahler 's Symphony No. 8 at the Cincinnati Symphony Orchestra 's May Festival."
nlp = English()
nlp.add_pipe("sentencizer")
ruler = nlp.add_pipe("entity_ruler")
patterns = [
{
"label": "THING",
"pattern": [
{"LOWER": "symphony"},
{"LOWER": "no"},
{"LOWER": "."},
{"LOWER": "8"},
],
}
]
ruler.add_patterns(patterns)
doc = nlp(text)
sentences = [s for s in doc.sents]
assert len(sentences) == 2
sent0 = sentences[0]
ent = doc.ents[0]
assert ent.start < sent0.end < ent.end
assert sentences.index(ent.sent) == 0
def test_issue7065_b():
# Test that the NEL doesn't crash when an entity crosses a sentence boundary
nlp = English()
vector_length = 3
nlp.add_pipe("sentencizer")
text = "Mahler 's Symphony No. 8 was beautiful."
entities = [(0, 6, "PERSON"), (10, 24, "WORK")]
links = {
(0, 6): {"Q7304": 1.0, "Q270853": 0.0},
(10, 24): {"Q7304": 0.0, "Q270853": 1.0},
}
sent_starts = [1, -1, 0, 0, 0, 0, 0, 0, 0]
doc = nlp(text)
example = Example.from_dict(
doc, {"entities": entities, "links": links, "sent_starts": sent_starts}
)
train_examples = [example]
def create_kb(vocab):
# create artificial KB
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
mykb.add_entity(entity="Q270853", freq=12, entity_vector=[9, 1, -7])
mykb.add_alias(
alias="No. 8",
entities=["Q270853"],
probabilities=[1.0],
)
mykb.add_entity(entity="Q7304", freq=12, entity_vector=[6, -4, 3])
mykb.add_alias(
alias="Mahler",
entities=["Q7304"],
probabilities=[1.0],
)
return mykb
# Create the Entity Linker component and add it to the pipeline
entity_linker = nlp.add_pipe("entity_linker", last=True)
entity_linker.set_kb(create_kb)
# train the NEL pipe
optimizer = nlp.initialize(get_examples=lambda: train_examples)
for i in range(2):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
# Add a custom rule-based component to mimick NER
patterns = [
{"label": "PERSON", "pattern": [{"LOWER": "mahler"}]},
{
"label": "WORK",
"pattern": [
{"LOWER": "symphony"},
{"LOWER": "no"},
{"LOWER": "."},
{"LOWER": "8"},
],
},
]
ruler = nlp.add_pipe("entity_ruler", before="entity_linker")
ruler.add_patterns(patterns)
# test the trained model - this should not throw E148
doc = nlp(text)
assert doc

View File

@ -1,12 +0,0 @@
from spacy.cli.evaluate import print_textcats_auc_per_cat, print_prf_per_type
from wasabi import msg
def test_issue7019():
scores = {"LABEL_A": 0.39829102, "LABEL_B": 0.938298329382, "LABEL_C": None}
print_textcats_auc_per_cat(msg, scores)
scores = {
"LABEL_A": {"p": 0.3420302, "r": 0.3929020, "f": 0.49823928932},
"LABEL_B": {"p": None, "r": None, "f": None},
}
print_prf_per_type(msg, scores, name="foo", type="bar")

View File

@ -1,66 +0,0 @@
from spacy.lang.en import English
from spacy.training import Example
from spacy.util import load_config_from_str
CONFIG = """
[nlp]
lang = "en"
pipeline = ["tok2vec", "tagger"]
[components]
[components.tok2vec]
factory = "tok2vec"
[components.tok2vec.model]
@architectures = "spacy.Tok2Vec.v1"
[components.tok2vec.model.embed]
@architectures = "spacy.MultiHashEmbed.v1"
width = ${components.tok2vec.model.encode:width}
attrs = ["NORM","PREFIX","SUFFIX","SHAPE"]
rows = [5000,2500,2500,2500]
include_static_vectors = false
[components.tok2vec.model.encode]
@architectures = "spacy.MaxoutWindowEncoder.v1"
width = 96
depth = 4
window_size = 1
maxout_pieces = 3
[components.tagger]
factory = "tagger"
[components.tagger.model]
@architectures = "spacy.Tagger.v1"
nO = null
[components.tagger.model.tok2vec]
@architectures = "spacy.Tok2VecListener.v1"
width = ${components.tok2vec.model.encode:width}
upstream = "*"
"""
TRAIN_DATA = [
("I like green eggs", {"tags": ["N", "V", "J", "N"]}),
("Eat blue ham", {"tags": ["V", "J", "N"]}),
]
def test_issue7029():
"""Test that an empty document doesn't mess up an entire batch."""
nlp = English.from_config(load_config_from_str(CONFIG))
train_examples = []
for t in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
optimizer = nlp.initialize(get_examples=lambda: train_examples)
for i in range(50):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
texts = ["first", "second", "third", "fourth", "and", "then", "some", ""]
docs1 = list(nlp.pipe(texts, batch_size=1))
docs2 = list(nlp.pipe(texts, batch_size=4))
assert [doc[0].tag_ for doc in docs1[:-1]] == [doc[0].tag_ for doc in docs2[:-1]]

View File

@ -1,40 +0,0 @@
from spacy.cli.init_config import fill_config
from spacy.util import load_config
from spacy.lang.en import English
from thinc.api import Config
from ..util import make_tempdir
def test_issue7055():
"""Test that fill-config doesn't turn sourced components into factories."""
source_cfg = {
"nlp": {"lang": "en", "pipeline": ["tok2vec", "tagger"]},
"components": {
"tok2vec": {"factory": "tok2vec"},
"tagger": {"factory": "tagger"},
},
}
source_nlp = English.from_config(source_cfg)
with make_tempdir() as dir_path:
# We need to create a loadable source pipeline
source_path = dir_path / "test_model"
source_nlp.to_disk(source_path)
base_cfg = {
"nlp": {"lang": "en", "pipeline": ["tok2vec", "tagger", "ner"]},
"components": {
"tok2vec": {"source": str(source_path)},
"tagger": {"source": str(source_path)},
"ner": {"factory": "ner"},
},
}
base_cfg = Config(base_cfg)
base_path = dir_path / "base.cfg"
base_cfg.to_disk(base_path)
output_path = dir_path / "config.cfg"
fill_config(output_path, base_path, silent=True)
filled_cfg = load_config(output_path)
assert filled_cfg["components"]["tok2vec"]["source"] == str(source_path)
assert filled_cfg["components"]["tagger"]["source"] == str(source_path)
assert filled_cfg["components"]["ner"]["factory"] == "ner"
assert "model" in filled_cfg["components"]["ner"]

View File

@ -1,24 +0,0 @@
from spacy.tokens.doc import Doc
from spacy.vocab import Vocab
from spacy.pipeline._parser_internals.arc_eager import ArcEager
def test_issue7056():
"""Test that the Unshift transition works properly, and doesn't cause
sentence segmentation errors."""
vocab = Vocab()
ae = ArcEager(
vocab.strings, ArcEager.get_actions(left_labels=["amod"], right_labels=["pobj"])
)
doc = Doc(vocab, words="Severe pain , after trauma".split())
state = ae.init_batch([doc])[0]
ae.apply_transition(state, "S")
ae.apply_transition(state, "L-amod")
ae.apply_transition(state, "S")
ae.apply_transition(state, "S")
ae.apply_transition(state, "S")
ae.apply_transition(state, "R-pobj")
ae.apply_transition(state, "D")
ae.apply_transition(state, "D")
ae.apply_transition(state, "D")
assert not state.eol()

View File

@ -1,54 +0,0 @@
from spacy.kb import KnowledgeBase
from spacy.training import Example
from spacy.lang.en import English
# fmt: off
TRAIN_DATA = [
("Russ Cochran his reprints include EC Comics.",
{"links": {(0, 12): {"Q2146908": 1.0}},
"entities": [(0, 12, "PERSON")],
"sent_starts": [1, -1, 0, 0, 0, 0, 0, 0]})
]
# fmt: on
def test_partial_links():
# Test that having some entities on the doc without gold links, doesn't crash
nlp = English()
vector_length = 3
train_examples = []
for text, annotation in TRAIN_DATA:
doc = nlp(text)
train_examples.append(Example.from_dict(doc, annotation))
def create_kb(vocab):
# create artificial KB
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
mykb.add_alias("Russ Cochran", ["Q2146908"], [0.9])
return mykb
# Create and train the Entity Linker
entity_linker = nlp.add_pipe("entity_linker", last=True)
entity_linker.set_kb(create_kb)
optimizer = nlp.initialize(get_examples=lambda: train_examples)
for i in range(2):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
# adding additional components that are required for the entity_linker
nlp.add_pipe("sentencizer", first=True)
patterns = [
{"label": "PERSON", "pattern": [{"LOWER": "russ"}, {"LOWER": "cochran"}]},
{"label": "ORG", "pattern": [{"LOWER": "ec"}, {"LOWER": "comics"}]},
]
ruler = nlp.add_pipe("entity_ruler", before="entity_linker")
ruler.add_patterns(patterns)
# this will run the pipeline on the examples and shouldn't crash
results = nlp.evaluate(train_examples)
assert "PERSON" in results["ents_per_type"]
assert "PERSON" in results["nel_f_per_type"]
assert "ORG" in results["ents_per_type"]
assert "ORG" not in results["nel_f_per_type"]

View File

@ -1,97 +0,0 @@
from spacy.kb import KnowledgeBase
from spacy.lang.en import English
from spacy.training import Example
def test_issue7065():
text = "Kathleen Battle sang in Mahler 's Symphony No. 8 at the Cincinnati Symphony Orchestra 's May Festival."
nlp = English()
nlp.add_pipe("sentencizer")
ruler = nlp.add_pipe("entity_ruler")
patterns = [
{
"label": "THING",
"pattern": [
{"LOWER": "symphony"},
{"LOWER": "no"},
{"LOWER": "."},
{"LOWER": "8"},
],
}
]
ruler.add_patterns(patterns)
doc = nlp(text)
sentences = [s for s in doc.sents]
assert len(sentences) == 2
sent0 = sentences[0]
ent = doc.ents[0]
assert ent.start < sent0.end < ent.end
assert sentences.index(ent.sent) == 0
def test_issue7065_b():
# Test that the NEL doesn't crash when an entity crosses a sentence boundary
nlp = English()
vector_length = 3
nlp.add_pipe("sentencizer")
text = "Mahler 's Symphony No. 8 was beautiful."
entities = [(0, 6, "PERSON"), (10, 24, "WORK")]
links = {
(0, 6): {"Q7304": 1.0, "Q270853": 0.0},
(10, 24): {"Q7304": 0.0, "Q270853": 1.0},
}
sent_starts = [1, -1, 0, 0, 0, 0, 0, 0, 0]
doc = nlp(text)
example = Example.from_dict(
doc, {"entities": entities, "links": links, "sent_starts": sent_starts}
)
train_examples = [example]
def create_kb(vocab):
# create artificial KB
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
mykb.add_entity(entity="Q270853", freq=12, entity_vector=[9, 1, -7])
mykb.add_alias(
alias="No. 8",
entities=["Q270853"],
probabilities=[1.0],
)
mykb.add_entity(entity="Q7304", freq=12, entity_vector=[6, -4, 3])
mykb.add_alias(
alias="Mahler",
entities=["Q7304"],
probabilities=[1.0],
)
return mykb
# Create the Entity Linker component and add it to the pipeline
entity_linker = nlp.add_pipe("entity_linker", last=True)
entity_linker.set_kb(create_kb)
# train the NEL pipe
optimizer = nlp.initialize(get_examples=lambda: train_examples)
for i in range(2):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
# Add a custom rule-based component to mimick NER
patterns = [
{"label": "PERSON", "pattern": [{"LOWER": "mahler"}]},
{
"label": "WORK",
"pattern": [
{"LOWER": "symphony"},
{"LOWER": "no"},
{"LOWER": "."},
{"LOWER": "8"},
],
},
]
ruler = nlp.add_pipe("entity_ruler", before="entity_linker")
ruler.add_patterns(patterns)
# test the trained model - this should not throw E148
doc = nlp(text)
assert doc

View File

@ -60,12 +60,6 @@ def taggers(en_vocab):
@pytest.mark.parametrize("Parser", test_parsers) @pytest.mark.parametrize("Parser", test_parsers)
def test_serialize_parser_roundtrip_bytes(en_vocab, Parser): def test_serialize_parser_roundtrip_bytes(en_vocab, Parser):
config = {
"update_with_oracle_cut_size": 100,
"beam_width": 1,
"beam_update_prob": 1.0,
"beam_density": 0.0,
}
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.resolve(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
parser = Parser(en_vocab, model) parser = Parser(en_vocab, model)

View File

@ -440,7 +440,7 @@ def test_init_config(lang, pipeline, optimize, pretraining):
assert isinstance(config, Config) assert isinstance(config, Config)
if pretraining: if pretraining:
config["paths"]["raw_text"] = "my_data.jsonl" config["paths"]["raw_text"] = "my_data.jsonl"
nlp = load_model_from_config(config, auto_fill=True) load_model_from_config(config, auto_fill=True)
def test_model_recommendations(): def test_model_recommendations():

View File

@ -211,7 +211,7 @@ def test_empty_docs(model_func, kwargs):
def test_init_extract_spans(): def test_init_extract_spans():
model = extract_spans().initialize() extract_spans().initialize()
def test_extract_spans_span_indices(): def test_extract_spans_span_indices():