Tidy up code

This commit is contained in:
Adriane Boyd 2021-06-28 11:48:00 +02:00
parent 93572dc12a
commit 5eeb25f043
46 changed files with 276 additions and 136 deletions

View File

@ -4,6 +4,7 @@ import sys
# set library-specific custom warning handling before doing anything else # set library-specific custom warning handling before doing anything else
from .errors import setup_default_warnings from .errors import setup_default_warnings
setup_default_warnings() setup_default_warnings()
# These are imported as part of the API # These are imported as part of the API

View File

@ -139,7 +139,10 @@ def debug_model(
upstream_component = None upstream_component = None
if model.has_ref("tok2vec") and "tok2vec-listener" in model.get_ref("tok2vec").name: if model.has_ref("tok2vec") and "tok2vec-listener" in model.get_ref("tok2vec").name:
upstream_component = nlp.get_pipe("tok2vec") upstream_component = nlp.get_pipe("tok2vec")
if model.has_ref("tok2vec") and "transformer-listener" in model.get_ref("tok2vec").name: if (
model.has_ref("tok2vec")
and "transformer-listener" in model.get_ref("tok2vec").name
):
upstream_component = nlp.get_pipe("transformer") upstream_component = nlp.get_pipe("transformer")
goldY = None goldY = None
for e in range(3): for e in range(3):

View File

@ -127,7 +127,9 @@ def evaluate(
data["ents_per_type"] = scores["ents_per_type"] data["ents_per_type"] = scores["ents_per_type"]
if f"spans_{spans_key}_per_type" in scores: if f"spans_{spans_key}_per_type" in scores:
if scores[f"spans_{spans_key}_per_type"]: if scores[f"spans_{spans_key}_per_type"]:
print_prf_per_type(msg, scores[f"spans_{spans_key}_per_type"], "SPANS", "type") print_prf_per_type(
msg, scores[f"spans_{spans_key}_per_type"], "SPANS", "type"
)
data[f"spans_{spans_key}_per_type"] = scores[f"spans_{spans_key}_per_type"] data[f"spans_{spans_key}_per_type"] = scores[f"spans_{spans_key}_per_type"]
if "cats_f_per_type" in scores: if "cats_f_per_type" in scores:
if scores["cats_f_per_type"]: if scores["cats_f_per_type"]:

View File

@ -120,7 +120,9 @@ def parse_deps(orig_doc: Doc, options: Dict[str, Any] = {}) -> Dict[str, Any]:
doc (Doc): Document do parse. doc (Doc): Document do parse.
RETURNS (dict): Generated dependency parse keyed by words and arcs. RETURNS (dict): Generated dependency parse keyed by words and arcs.
""" """
doc = Doc(orig_doc.vocab).from_bytes(orig_doc.to_bytes(exclude=["user_data", "user_hooks"])) doc = Doc(orig_doc.vocab).from_bytes(
orig_doc.to_bytes(exclude=["user_data", "user_hooks"])
)
if not doc.has_annotation("DEP"): if not doc.has_annotation("DEP"):
warnings.warn(Warnings.W005) warnings.warn(Warnings.W005)
if options.get("collapse_phrases", False): if options.get("collapse_phrases", False):

View File

@ -22,13 +22,13 @@ _num_words = [
"тринадесет", "тринадесет",
"тринайсет", "тринайсет",
"четиринадесет", "четиринадесет",
"четиринайсет" "четиринайсет",
"петнадесет", "петнадесет",
"петнайсет" "петнайсет",
"шестнадесет", "шестнадесет",
"шестнайсет", "шестнайсет",
"седемнадесет", "седемнадесет",
"седемнайсет" "седемнайсет",
"осемнадесет", "осемнадесет",
"осемнайсет", "осемнайсет",
"деветнадесет", "деветнадесет",
@ -36,7 +36,7 @@ _num_words = [
"двадесет", "двадесет",
"двайсет", "двайсет",
"тридесет", "тридесет",
"трийсет" "трийсет",
"четиридесет", "четиридесет",
"четиресет", "четиресет",
"петдесет", "петдесет",

View File

@ -58,7 +58,6 @@ _abbr_dot_exc = [
{ORTH: "стр.", NORM: "страница"}, {ORTH: "стр.", NORM: "страница"},
{ORTH: "ул.", NORM: "улица"}, {ORTH: "ул.", NORM: "улица"},
{ORTH: "чл.", NORM: "член"}, {ORTH: "чл.", NORM: "член"},
] ]
for abbr in _abbr_dot_exc: for abbr in _abbr_dot_exc:

View File

@ -81,16 +81,32 @@ for exc_data in [
# Source: https://kaino.kotus.fi/visk/sisallys.php?p=141 # Source: https://kaino.kotus.fi/visk/sisallys.php?p=141
conj_contraction_bases = [ conj_contraction_bases = [
("ett", "että"), ("jott", "jotta"), ("kosk", "koska"), ("mutt", "mutta"), ("ett", "että"),
("vaikk", "vaikka"), ("ehk", "ehkä"), ("miks", "miksi"), ("siks", "siksi"), ("jott", "jotta"),
("joll", "jos"), ("ell", "jos") ("kosk", "koska"),
("mutt", "mutta"),
("vaikk", "vaikka"),
("ehk", "ehkä"),
("miks", "miksi"),
("siks", "siksi"),
("joll", "jos"),
("ell", "jos"),
] ]
conj_contraction_negations = [ conj_contraction_negations = [
("en", "en"), ("et", "et"), ("ei", "ei"), ("emme", "emme"), ("en", "en"),
("ette", "ette"), ("eivat", "eivät"), ("eivät", "eivät")] ("et", "et"),
("ei", "ei"),
("emme", "emme"),
("ette", "ette"),
("eivat", "eivät"),
("eivät", "eivät"),
]
for (base_lower, base_norm) in conj_contraction_bases: for (base_lower, base_norm) in conj_contraction_bases:
for base in [base_lower, base_lower.title()]: for base in [base_lower, base_lower.title()]:
for (suffix, suffix_norm) in conj_contraction_negations: for (suffix, suffix_norm) in conj_contraction_negations:
_exc[base + suffix] = [{ORTH: base, NORM: base_norm}, {ORTH: suffix, NORM: suffix_norm}] _exc[base + suffix] = [
{ORTH: base, NORM: base_norm},
{ORTH: suffix, NORM: suffix_norm},
]
TOKENIZER_EXCEPTIONS = update_exc(BASE_EXCEPTIONS, _exc) TOKENIZER_EXCEPTIONS = update_exc(BASE_EXCEPTIONS, _exc)

View File

@ -4,12 +4,12 @@ from ...pipeline import Lemmatizer
from ...tokens import Token from ...tokens import Token
class ItalianLemmatizer(Lemmatizer): class ItalianLemmatizer(Lemmatizer):
"""This lemmatizer was adapted from the Polish one (version of April 2021). """This lemmatizer was adapted from the Polish one (version of April 2021).
It implements lookup lemmatization based on the morphological lexicon It implements lookup lemmatization based on the morphological lexicon
morph-it (Baroni and Zanchetta). The table lemma_lookup with non-POS-aware morph-it (Baroni and Zanchetta). The table lemma_lookup with non-POS-aware
entries is used as a backup for words that aren't handled by morph-it.""" entries is used as a backup for words that aren't handled by morph-it."""
@classmethod @classmethod
def get_lookups_config(cls, mode: str) -> Tuple[List[str], List[str]]: def get_lookups_config(cls, mode: str) -> Tuple[List[str], List[str]]:
if mode == "pos_lookup": if mode == "pos_lookup":

View File

@ -25,7 +25,7 @@ for orth in [
"artt.", "artt.",
"att.", "att.",
"avv.", "avv.",
"Avv." "Avv.",
"by-pass", "by-pass",
"c.d.", "c.d.",
"c/c", "c/c",

View File

@ -687,9 +687,11 @@ class Language:
if not isinstance(source, Language): if not isinstance(source, Language):
raise ValueError(Errors.E945.format(name=source_name, source=type(source))) raise ValueError(Errors.E945.format(name=source_name, source=type(source)))
# Check vectors, with faster checks first # Check vectors, with faster checks first
if self.vocab.vectors.shape != source.vocab.vectors.shape or \ if (
self.vocab.vectors.key2row != source.vocab.vectors.key2row or \ self.vocab.vectors.shape != source.vocab.vectors.shape
self.vocab.vectors.to_bytes() != source.vocab.vectors.to_bytes(): or self.vocab.vectors.key2row != source.vocab.vectors.key2row
or self.vocab.vectors.to_bytes() != source.vocab.vectors.to_bytes()
):
warnings.warn(Warnings.W113.format(name=source_name)) warnings.warn(Warnings.W113.format(name=source_name))
if not source_name in source.component_names: if not source_name in source.component_names:
raise KeyError( raise KeyError(
@ -1539,15 +1541,21 @@ class Language:
# Cycle channels not to break the order of docs. # Cycle channels not to break the order of docs.
# The received object is a batch of byte-encoded docs, so flatten them with chain.from_iterable. # The received object is a batch of byte-encoded docs, so flatten them with chain.from_iterable.
byte_tuples = chain.from_iterable(recv.recv() for recv in cycle(bytedocs_recv_ch)) byte_tuples = chain.from_iterable(
recv.recv() for recv in cycle(bytedocs_recv_ch)
)
try: try:
for i, (_, (byte_doc, byte_error)) in enumerate(zip(raw_texts, byte_tuples), 1): for i, (_, (byte_doc, byte_error)) in enumerate(
zip(raw_texts, byte_tuples), 1
):
if byte_doc is not None: if byte_doc is not None:
doc = Doc(self.vocab).from_bytes(byte_doc) doc = Doc(self.vocab).from_bytes(byte_doc)
yield doc yield doc
elif byte_error is not None: elif byte_error is not None:
error = srsly.msgpack_loads(byte_error) error = srsly.msgpack_loads(byte_error)
self.default_error_handler(None, None, None, ValueError(Errors.E871.format(error=error))) self.default_error_handler(
None, None, None, ValueError(Errors.E871.format(error=error))
)
if i % batch_size == 0: if i % batch_size == 0:
# tell `sender` that one batch was consumed. # tell `sender` that one batch was consumed.
sender.step() sender.step()
@ -1707,7 +1715,9 @@ class Language:
if "replace_listeners" in pipe_cfg: if "replace_listeners" in pipe_cfg:
for name, proc in source_nlps[model].pipeline: for name, proc in source_nlps[model].pipeline:
if source_name in getattr(proc, "listening_components", []): if source_name in getattr(proc, "listening_components", []):
source_nlps[model].replace_listeners(name, source_name, pipe_cfg["replace_listeners"]) source_nlps[model].replace_listeners(
name, source_name, pipe_cfg["replace_listeners"]
)
listeners_replaced = True listeners_replaced = True
nlp.add_pipe(source_name, source=source_nlps[model], name=pipe_name) nlp.add_pipe(source_name, source=source_nlps[model], name=pipe_name)
# Delete from cache if listeners were replaced # Delete from cache if listeners were replaced
@ -1727,12 +1737,16 @@ class Language:
for name, proc in nlp.pipeline: for name, proc in nlp.pipeline:
# Remove listeners not in the pipeline # Remove listeners not in the pipeline
listener_names = getattr(proc, "listening_components", []) listener_names = getattr(proc, "listening_components", [])
unused_listener_names = [ll for ll in listener_names if ll not in nlp.pipe_names] unused_listener_names = [
ll for ll in listener_names if ll not in nlp.pipe_names
]
for listener_name in unused_listener_names: for listener_name in unused_listener_names:
for listener in proc.listener_map.get(listener_name, []): for listener in proc.listener_map.get(listener_name, []):
proc.remove_listener(listener, listener_name) proc.remove_listener(listener, listener_name)
for listener in getattr(proc, "listening_components", []): # e.g. tok2vec/transformer for listener in getattr(
proc, "listening_components", []
): # e.g. tok2vec/transformer
# If it's a component sourced from another pipeline, we check if # If it's a component sourced from another pipeline, we check if
# the tok2vec listeners should be replaced with standalone tok2vec # the tok2vec listeners should be replaced with standalone tok2vec
# models (e.g. so component can be frozen without its performance # models (e.g. so component can be frozen without its performance
@ -1827,7 +1841,9 @@ class Language:
new_config = tok2vec_cfg["model"] new_config = tok2vec_cfg["model"]
if "replace_listener_cfg" in tok2vec_model.attrs: if "replace_listener_cfg" in tok2vec_model.attrs:
replace_func = tok2vec_model.attrs["replace_listener_cfg"] replace_func = tok2vec_model.attrs["replace_listener_cfg"]
new_config = replace_func(tok2vec_cfg["model"], pipe_cfg["model"]["tok2vec"]) new_config = replace_func(
tok2vec_cfg["model"], pipe_cfg["model"]["tok2vec"]
)
util.set_dot_to_object(pipe_cfg, listener_path, new_config) util.set_dot_to_object(pipe_cfg, listener_path, new_config)
# Go over the listener layers and replace them # Go over the listener layers and replace them
for listener in pipe_listeners: for listener in pipe_listeners:
@ -1866,7 +1882,10 @@ class Language:
util.to_disk(path, serializers, exclude) util.to_disk(path, serializers, exclude)
def from_disk( def from_disk(
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList(), self,
path: Union[str, Path],
*,
exclude: Iterable[str] = SimpleFrozenList(),
overrides: Dict[str, Any] = SimpleFrozenDict(), overrides: Dict[str, Any] = SimpleFrozenDict(),
) -> "Language": ) -> "Language":
"""Loads state from a directory. Modifies the object in place and """Loads state from a directory. Modifies the object in place and

View File

@ -12,9 +12,7 @@ from .strings import get_string_id
UNSET = object() UNSET = object()
def load_lookups( def load_lookups(lang: str, tables: List[str], strict: bool = True) -> "Lookups":
lang: str, tables: List[str], strict: bool = True
) -> 'Lookups':
"""Load the data from the spacy-lookups-data package for a given language, """Load the data from the spacy-lookups-data package for a given language,
if available. Returns an empty `Lookups` container if there's no data or if the package if available. Returns an empty `Lookups` container if there's no data or if the package
is not installed. is not installed.

View File

@ -309,9 +309,7 @@ class EntityLinker(TrainablePipe):
assert sent_index >= 0 assert sent_index >= 0
# get n_neighbour sentences, clipped to the length of the document # get n_neighbour sentences, clipped to the length of the document
start_sentence = max(0, sent_index - self.n_sents) start_sentence = max(0, sent_index - self.n_sents)
end_sentence = min( end_sentence = min(len(sentences) - 1, sent_index + self.n_sents)
len(sentences) - 1, sent_index + self.n_sents
)
start_token = sentences[start_sentence].start start_token = sentences[start_sentence].start
end_token = sentences[end_sentence].end end_token = sentences[end_sentence].end
sent_doc = doc[start_token:end_token].as_doc() sent_doc = doc[start_token:end_token].as_doc()
@ -337,22 +335,16 @@ class EntityLinker(TrainablePipe):
else: else:
random.shuffle(candidates) random.shuffle(candidates)
# set all prior probabilities to 0 if incl_prior=False # set all prior probabilities to 0 if incl_prior=False
prior_probs = xp.asarray( prior_probs = xp.asarray([c.prior_prob for c in candidates])
[c.prior_prob for c in candidates]
)
if not self.incl_prior: if not self.incl_prior:
prior_probs = xp.asarray( prior_probs = xp.asarray([0.0 for _ in candidates])
[0.0 for _ in candidates]
)
scores = prior_probs scores = prior_probs
# add in similarity from the context # add in similarity from the context
if self.incl_context: if self.incl_context:
entity_encodings = xp.asarray( entity_encodings = xp.asarray(
[c.entity_vector for c in candidates] [c.entity_vector for c in candidates]
) )
entity_norm = xp.linalg.norm( entity_norm = xp.linalg.norm(entity_encodings, axis=1)
entity_encodings, axis=1
)
if len(entity_encodings) != len(prior_probs): if len(entity_encodings) != len(prior_probs):
raise RuntimeError( raise RuntimeError(
Errors.E147.format( Errors.E147.format(
@ -361,14 +353,12 @@ class EntityLinker(TrainablePipe):
) )
) )
# cosine similarity # cosine similarity
sims = xp.dot( sims = xp.dot(entity_encodings, sentence_encoding_t) / (
entity_encodings, sentence_encoding_t sentence_norm * entity_norm
) / (sentence_norm * entity_norm) )
if sims.shape != prior_probs.shape: if sims.shape != prior_probs.shape:
raise ValueError(Errors.E161) raise ValueError(Errors.E161)
scores = ( scores = prior_probs + sims - (prior_probs * sims)
prior_probs + sims - (prior_probs * sims)
)
# TODO: thresholding # TODO: thresholding
best_index = scores.argmax().item() best_index = scores.argmax().item()
best_candidate = candidates[best_index] best_candidate = candidates[best_index]

View File

@ -278,9 +278,7 @@ class EntityRuler(Pipe):
if self == pipe: if self == pipe:
current_index = i current_index = i
break break
subsequent_pipes = [ subsequent_pipes = [pipe for pipe in self.nlp.pipe_names[current_index:]]
pipe for pipe in self.nlp.pipe_names[current_index :]
]
except ValueError: except ValueError:
subsequent_pipes = [] subsequent_pipes = []
with self.nlp.select_pipes(disable=subsequent_pipes): with self.nlp.select_pipes(disable=subsequent_pipes):

View File

@ -61,7 +61,7 @@ def build_ngram_suggester(sizes: List[int]) -> Callable[[List[Doc]], Ragged]:
length = 0 length = 0
for size in sizes: for size in sizes:
if size <= len(doc): if size <= len(doc):
starts_size = starts[:len(doc) - (size - 1)] starts_size = starts[: len(doc) - (size - 1)]
spans.append(ops.xp.hstack((starts_size, starts_size + size))) spans.append(ops.xp.hstack((starts_size, starts_size + size)))
length += spans[-1].shape[0] length += spans[-1].shape[0]
if spans: if spans:
@ -70,7 +70,7 @@ def build_ngram_suggester(sizes: List[int]) -> Callable[[List[Doc]], Ragged]:
if len(spans) > 0: if len(spans) > 0:
output = Ragged(ops.xp.vstack(spans), ops.asarray(lengths, dtype="i")) output = Ragged(ops.xp.vstack(spans), ops.asarray(lengths, dtype="i"))
else: else:
output = Ragged(ops.xp.zeros((0,0)), ops.asarray(lengths, dtype="i")) output = Ragged(ops.xp.zeros((0, 0)), ops.asarray(lengths, dtype="i"))
assert output.dataXd.ndim == 2 assert output.dataXd.ndim == 2
return output return output

View File

@ -299,7 +299,9 @@ class TextCategorizer(TrainablePipe):
self._allow_extra_label() self._allow_extra_label()
self.cfg["labels"].append(label) self.cfg["labels"].append(label)
if self.model and "resize_output" in self.model.attrs: if self.model and "resize_output" in self.model.attrs:
self.model = self.model.attrs["resize_output"](self.model, len(self.cfg["labels"])) self.model = self.model.attrs["resize_output"](
self.model, len(self.cfg["labels"])
)
self.vocab.strings.add(label) self.vocab.strings.add(label)
return 1 return 1

View File

@ -365,7 +365,9 @@ class Scorer:
gold_spans.add(gold_span) gold_spans.add(gold_span)
gold_per_type[span.label_].add(gold_span) gold_per_type[span.label_].add(gold_span)
pred_per_type = {label: set() for label in labels} pred_per_type = {label: set() for label in labels}
for span in example.get_aligned_spans_x2y(getter(pred_doc, attr), allow_overlap): for span in example.get_aligned_spans_x2y(
getter(pred_doc, attr), allow_overlap
):
if labeled: if labeled:
pred_span = (span.label_, span.start, span.end - 1) pred_span = (span.label_, span.start, span.end - 1)
else: else:
@ -392,7 +394,9 @@ class Scorer:
final_scores[f"{attr}_r"] = score.recall final_scores[f"{attr}_r"] = score.recall
final_scores[f"{attr}_f"] = score.fscore final_scores[f"{attr}_f"] = score.fscore
if labeled: if labeled:
final_scores[f"{attr}_per_type"] = {k: v.to_dict() for k, v in score_per_type.items()} final_scores[f"{attr}_per_type"] = {
k: v.to_dict() for k, v in score_per_type.items()
}
return final_scores return final_scores
@staticmethod @staticmethod

View File

@ -1,6 +1,7 @@
import pytest import pytest
from spacy.lang.bg.lex_attrs import like_num from spacy.lang.bg.lex_attrs import like_num
@pytest.mark.parametrize( @pytest.mark.parametrize(
"word,match", "word,match",
[ [

View File

@ -40,20 +40,21 @@ CONTRACTION_TESTS = [
( (
"Päätimme ettemme tule.", "Päätimme ettemme tule.",
["Päätimme", "ett", "emme", "tule", "."], ["Päätimme", "ett", "emme", "tule", "."],
["päätimme", "että", "emme", "tule", "."] ["päätimme", "että", "emme", "tule", "."],
), ),
( (
"Miksei puhuttaisi?", "Miksei puhuttaisi?",
["Miks", "ei", "puhuttaisi", "?"], ["Miks", "ei", "puhuttaisi", "?"],
["miksi", "ei", "puhuttaisi", "?"] ["miksi", "ei", "puhuttaisi", "?"],
), ),
( (
"He tottelivat vaikkeivat halunneet", "He tottelivat vaikkeivat halunneet",
["He", "tottelivat", "vaikk", "eivat", "halunneet"], ["He", "tottelivat", "vaikk", "eivat", "halunneet"],
["he", "tottelivat", "vaikka", "eivät", "halunneet"] ["he", "tottelivat", "vaikka", "eivät", "halunneet"],
), ),
] ]
@pytest.mark.parametrize("text,expected_tokens", ABBREVIATION_TESTS) @pytest.mark.parametrize("text,expected_tokens", ABBREVIATION_TESTS)
def test_fi_tokenizer_abbreviations(fi_tokenizer, text, expected_tokens): def test_fi_tokenizer_abbreviations(fi_tokenizer, text, expected_tokens):
tokens = fi_tokenizer(text) tokens = fi_tokenizer(text)

View File

@ -257,11 +257,21 @@ def test_matcher_with_alignments_nongreedy(en_vocab):
(2, "aaab", "a a a b", [[0, 1, 2, 3]]), (2, "aaab", "a a a b", [[0, 1, 2, 3]]),
(3, "aaab", "a+ b", [[0, 1], [0, 0, 1], [0, 0, 0, 1]]), (3, "aaab", "a+ b", [[0, 1], [0, 0, 1], [0, 0, 0, 1]]),
(4, "aaba", "a+ b a+", [[0, 1, 2], [0, 0, 1, 2]]), (4, "aaba", "a+ b a+", [[0, 1, 2], [0, 0, 1, 2]]),
(5, "aabaa", "a+ b a+", [[0, 1, 2], [0, 0, 1, 2], [0, 0, 1, 2, 2], [0, 1, 2, 2] ]), (
5,
"aabaa",
"a+ b a+",
[[0, 1, 2], [0, 0, 1, 2], [0, 0, 1, 2, 2], [0, 1, 2, 2]],
),
(6, "aaba", "a+ b a*", [[0, 1], [0, 0, 1], [0, 0, 1, 2], [0, 1, 2]]), (6, "aaba", "a+ b a*", [[0, 1], [0, 0, 1], [0, 0, 1, 2], [0, 1, 2]]),
(7, "aaaa", "a*", [[0], [0, 0], [0, 0, 0], [0, 0, 0, 0]]), (7, "aaaa", "a*", [[0], [0, 0], [0, 0, 0], [0, 0, 0, 0]]),
(8, "baab", "b a* b b*", [[0, 1, 1, 2]]), (8, "baab", "b a* b b*", [[0, 1, 1, 2]]),
(9, "aabb", "a* b* a*", [[1], [2], [2, 2], [0, 1], [0, 0, 1], [0, 0, 1, 1], [0, 1, 1], [1, 1]]), (
9,
"aabb",
"a* b* a*",
[[1], [2], [2, 2], [0, 1], [0, 0, 1], [0, 0, 1, 1], [0, 1, 1], [1, 1]],
),
(10, "aaab", "a+ a+ a b", [[0, 1, 2, 3]]), (10, "aaab", "a+ a+ a b", [[0, 1, 2, 3]]),
(11, "aaab", "a+ a+ a+ b", [[0, 1, 2, 3]]), (11, "aaab", "a+ a+ a+ b", [[0, 1, 2, 3]]),
(12, "aaab", "a+ a a b", [[0, 1, 2, 3]]), (12, "aaab", "a+ a a b", [[0, 1, 2, 3]]),

View File

@ -557,7 +557,11 @@ def test_neg_annotation(neg_key):
ner.add_label("PERSON") ner.add_label("PERSON")
ner.add_label("ORG") ner.add_label("ORG")
example = Example.from_dict(neg_doc, {"entities": [(7, 17, "PERSON")]}) example = Example.from_dict(neg_doc, {"entities": [(7, 17, "PERSON")]})
example.reference.spans[neg_key] = [Span(neg_doc, 2, 4, "ORG"), Span(neg_doc, 2, 3, "PERSON"), Span(neg_doc, 1, 4, "PERSON")] example.reference.spans[neg_key] = [
Span(neg_doc, 2, 4, "ORG"),
Span(neg_doc, 2, 3, "PERSON"),
Span(neg_doc, 1, 4, "PERSON"),
]
optimizer = nlp.initialize() optimizer = nlp.initialize()
for i in range(2): for i in range(2):

View File

@ -254,7 +254,9 @@ def test_nel_nsents(nlp):
"""Test that n_sents can be set through the configuration""" """Test that n_sents can be set through the configuration"""
entity_linker = nlp.add_pipe("entity_linker", config={}) entity_linker = nlp.add_pipe("entity_linker", config={})
assert entity_linker.n_sents == 0 assert entity_linker.n_sents == 0
entity_linker = nlp.replace_pipe("entity_linker", "entity_linker", config={"n_sents": 2}) entity_linker = nlp.replace_pipe(
"entity_linker", "entity_linker", config={"n_sents": 2}
)
assert entity_linker.n_sents == 2 assert entity_linker.n_sents == 2
@ -596,7 +598,9 @@ def test_kb_to_bytes():
kb_1.add_entity(entity="Q66", freq=9, entity_vector=[1, 2, 3]) kb_1.add_entity(entity="Q66", freq=9, entity_vector=[1, 2, 3])
kb_1.add_alias(alias="Russ Cochran", entities=["Q2146908"], probabilities=[0.8]) kb_1.add_alias(alias="Russ Cochran", entities=["Q2146908"], probabilities=[0.8])
kb_1.add_alias(alias="Boeing", entities=["Q66"], probabilities=[0.5]) kb_1.add_alias(alias="Boeing", entities=["Q66"], probabilities=[0.5])
kb_1.add_alias(alias="Randomness", entities=["Q66", "Q2146908"], probabilities=[0.1, 0.2]) kb_1.add_alias(
alias="Randomness", entities=["Q66", "Q2146908"], probabilities=[0.1, 0.2]
)
assert kb_1.contains_alias("Russ Cochran") assert kb_1.contains_alias("Russ Cochran")
kb_bytes = kb_1.to_bytes() kb_bytes = kb_1.to_bytes()
kb_2 = KnowledgeBase(nlp.vocab, entity_vector_length=3) kb_2 = KnowledgeBase(nlp.vocab, entity_vector_length=3)
@ -611,8 +615,12 @@ def test_kb_to_bytes():
assert kb_2.contains_alias("Russ Cochran") assert kb_2.contains_alias("Russ Cochran")
assert kb_1.get_size_aliases() == kb_2.get_size_aliases() assert kb_1.get_size_aliases() == kb_2.get_size_aliases()
assert kb_1.get_alias_strings() == kb_2.get_alias_strings() assert kb_1.get_alias_strings() == kb_2.get_alias_strings()
assert len(kb_1.get_alias_candidates("Russ Cochran")) == len(kb_2.get_alias_candidates("Russ Cochran")) assert len(kb_1.get_alias_candidates("Russ Cochran")) == len(
assert len(kb_1.get_alias_candidates("Randomness")) == len(kb_2.get_alias_candidates("Randomness")) kb_2.get_alias_candidates("Russ Cochran")
)
assert len(kb_1.get_alias_candidates("Randomness")) == len(
kb_2.get_alias_candidates("Randomness")
)
def test_nel_to_bytes(): def test_nel_to_bytes():
@ -640,7 +648,9 @@ def test_nel_to_bytes():
kb_2 = nlp_2.get_pipe("entity_linker").kb kb_2 = nlp_2.get_pipe("entity_linker").kb
assert kb_2.contains_alias("Russ Cochran") assert kb_2.contains_alias("Russ Cochran")
assert kb_2.get_vector("Q2146908") == [6, -4, 3] assert kb_2.get_vector("Q2146908") == [6, -4, 3]
assert_almost_equal(kb_2.get_prior_prob(entity="Q2146908", alias="Russ Cochran"), 0.8) assert_almost_equal(
kb_2.get_prior_prob(entity="Q2146908", alias="Russ Cochran"), 0.8
)
def test_scorer_links(): def test_scorer_links():

View File

@ -82,7 +82,9 @@ def util_batch_unbatch_docs_list(
Y_batched = model.predict(in_data) Y_batched = model.predict(in_data)
Y_not_batched = [model.predict([u])[0] for u in in_data] Y_not_batched = [model.predict([u])[0] for u in in_data]
for i in range(len(Y_batched)): for i in range(len(Y_batched)):
assert_almost_equal(OPS.to_numpy(Y_batched[i]), OPS.to_numpy(Y_not_batched[i]), decimal=4) assert_almost_equal(
OPS.to_numpy(Y_batched[i]), OPS.to_numpy(Y_not_batched[i]), decimal=4
)
def util_batch_unbatch_docs_array( def util_batch_unbatch_docs_array(

View File

@ -351,9 +351,21 @@ def test_language_factories_invalid():
([{"a": 0.5, "b": 0.5}, {"b": 1.0}], {"a": 0.0}, {"a": 0.0, "b": 1.0}), ([{"a": 0.5, "b": 0.5}, {"b": 1.0}], {"a": 0.0}, {"a": 0.0, "b": 1.0}),
([{"a": 0.0, "b": 0.0}, {"c": 0.0}], {}, {"a": 0.0, "b": 0.0, "c": 0.0}), ([{"a": 0.0, "b": 0.0}, {"c": 0.0}], {}, {"a": 0.0, "b": 0.0, "c": 0.0}),
([{"a": 0.0, "b": 0.0}, {"c": 1.0}], {}, {"a": 0.0, "b": 0.0, "c": 1.0}), ([{"a": 0.0, "b": 0.0}, {"c": 1.0}], {}, {"a": 0.0, "b": 0.0, "c": 1.0}),
([{"a": 0.0, "b": 0.0}, {"c": 0.0}], {"c": 0.2}, {"a": 0.0, "b": 0.0, "c": 1.0}), (
([{"a": 0.5, "b": 0.5, "c": 1.0, "d": 1.0}], {"a": 0.0, "b": 0.0}, {"a": 0.0, "b": 0.0, "c": 0.5, "d": 0.5}), [{"a": 0.0, "b": 0.0}, {"c": 0.0}],
([{"a": 0.5, "b": 0.5, "c": 1.0, "d": 1.0}], {"a": 0.0, "b": 0.0, "f": 0.0}, {"a": 0.0, "b": 0.0, "c": 0.5, "d": 0.5, "f": 0.0}), {"c": 0.2},
{"a": 0.0, "b": 0.0, "c": 1.0},
),
(
[{"a": 0.5, "b": 0.5, "c": 1.0, "d": 1.0}],
{"a": 0.0, "b": 0.0},
{"a": 0.0, "b": 0.0, "c": 0.5, "d": 0.5},
),
(
[{"a": 0.5, "b": 0.5, "c": 1.0, "d": 1.0}],
{"a": 0.0, "b": 0.0, "f": 0.0},
{"a": 0.0, "b": 0.0, "c": 0.5, "d": 0.5, "f": 0.0},
),
], ],
) )
def test_language_factories_combine_score_weights(weights, override, expected): def test_language_factories_combine_score_weights(weights, override, expected):

View File

@ -446,7 +446,12 @@ def test_update_with_annotates():
for text in texts: for text in texts:
examples.append(Example(nlp.make_doc(text), nlp.make_doc(text))) examples.append(Example(nlp.make_doc(text), nlp.make_doc(text)))
for components_to_annotate in [[], [f"{name}1"], [f"{name}1", f"{name}2"], [f"{name}2", f"{name}1"]]: for components_to_annotate in [
[],
[f"{name}1"],
[f"{name}1", f"{name}2"],
[f"{name}2", f"{name}1"],
]:
for key in results: for key in results:
results[key] = "" results[key] = ""
nlp = English(vocab=nlp.vocab) nlp = English(vocab=nlp.vocab)

View File

@ -79,10 +79,7 @@ def test_ngram_suggester(en_tokenizer):
assert spans.shape[0] == len(spans_set) assert spans.shape[0] == len(spans_set)
offset += ngrams.lengths[i] offset += ngrams.lengths[i]
# the number of spans is correct # the number of spans is correct
assert_equal( assert_equal(ngrams.lengths, [max(0, len(doc) - (size - 1)) for doc in docs])
ngrams.lengths,
[max(0, len(doc) - (size - 1)) for doc in docs]
)
# test 1-3-gram suggestions # test 1-3-gram suggestions
ngram_suggester = registry.misc.get("ngram_suggester.v1")(sizes=[1, 2, 3]) ngram_suggester = registry.misc.get("ngram_suggester.v1")(sizes=[1, 2, 3])

View File

@ -131,7 +131,7 @@ def test_implicit_label(name, get_examples):
nlp.initialize(get_examples=get_examples(nlp)) nlp.initialize(get_examples=get_examples(nlp))
#fmt: off # fmt: off
@pytest.mark.parametrize( @pytest.mark.parametrize(
"name,textcat_config", "name,textcat_config",
[ [
@ -150,7 +150,7 @@ def test_implicit_label(name, get_examples):
("textcat_multilabel", {"@architectures": "spacy.TextCatCNN.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}), ("textcat_multilabel", {"@architectures": "spacy.TextCatCNN.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
], ],
) )
#fmt: on # fmt: on
def test_no_resize(name, textcat_config): def test_no_resize(name, textcat_config):
"""The old textcat architectures weren't resizable""" """The old textcat architectures weren't resizable"""
nlp = Language() nlp = Language()
@ -165,7 +165,7 @@ def test_no_resize(name, textcat_config):
textcat.add_label("NEUTRAL") textcat.add_label("NEUTRAL")
#fmt: off # fmt: off
@pytest.mark.parametrize( @pytest.mark.parametrize(
"name,textcat_config", "name,textcat_config",
[ [
@ -179,7 +179,7 @@ def test_no_resize(name, textcat_config):
("textcat_multilabel", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}), ("textcat_multilabel", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
], ],
) )
#fmt: on # fmt: on
def test_resize(name, textcat_config): def test_resize(name, textcat_config):
"""The new textcat architectures are resizable""" """The new textcat architectures are resizable"""
nlp = Language() nlp = Language()
@ -194,7 +194,7 @@ def test_resize(name, textcat_config):
assert textcat.model.maybe_get_dim("nO") in [3, None] assert textcat.model.maybe_get_dim("nO") in [3, None]
#fmt: off # fmt: off
@pytest.mark.parametrize( @pytest.mark.parametrize(
"name,textcat_config", "name,textcat_config",
[ [
@ -208,7 +208,7 @@ def test_resize(name, textcat_config):
("textcat_multilabel", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}), ("textcat_multilabel", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
], ],
) )
#fmt: on # fmt: on
def test_resize_same_results(name, textcat_config): def test_resize_same_results(name, textcat_config):
# Ensure that the resized textcat classifiers still produce the same results for old labels # Ensure that the resized textcat classifiers still produce the same results for old labels
fix_random_seed(0) fix_random_seed(0)
@ -511,7 +511,9 @@ def test_textcat_threshold():
macro_f = scores["cats_score"] macro_f = scores["cats_score"]
assert scores["cats_f_per_type"]["POSITIVE"]["r"] == 1.0 assert scores["cats_f_per_type"]["POSITIVE"]["r"] == 1.0
scores = nlp.evaluate(train_examples, scorer_cfg={"threshold": 0, "positive_label": "POSITIVE"}) scores = nlp.evaluate(
train_examples, scorer_cfg={"threshold": 0, "positive_label": "POSITIVE"}
)
pos_f = scores["cats_score"] pos_f = scores["cats_score"]
assert scores["cats_f_per_type"]["POSITIVE"]["r"] == 1.0 assert scores["cats_f_per_type"]["POSITIVE"]["r"] == 1.0
assert pos_f > macro_f assert pos_f > macro_f

View File

@ -129,8 +129,14 @@ cfg_string = """
""" """
TRAIN_DATA = [ TRAIN_DATA = [
("I like green eggs", {"tags": ["N", "V", "J", "N"], "cats": {"preference": 1.0, "imperative": 0.0}}), (
("Eat blue ham", {"tags": ["V", "J", "N"], "cats": {"preference": 0.0, "imperative": 1.0}}), "I like green eggs",
{"tags": ["N", "V", "J", "N"], "cats": {"preference": 1.0, "imperative": 0.0}},
),
(
"Eat blue ham",
{"tags": ["V", "J", "N"], "cats": {"preference": 0.0, "imperative": 1.0}},
),
] ]
@ -405,5 +411,5 @@ def test_tok2vec_listeners_textcat():
cats1 = docs[1].cats cats1 = docs[1].cats
assert cats1["preference"] > 0.1 assert cats1["preference"] > 0.1
assert cats1["imperative"] < 0.9 assert cats1["imperative"] < 0.9
assert([t.tag_ for t in docs[0]] == ["V", "J", "N"]) assert [t.tag_ for t in docs[0]] == ["V", "J", "N"]
assert([t.tag_ for t in docs[1]] == ["N", "V", "J", "N"]) assert [t.tag_ for t in docs[1]] == ["N", "V", "J", "N"]

View File

@ -152,7 +152,8 @@ labels = ['label1', 'label2']
@pytest.mark.parametrize( @pytest.mark.parametrize(
"component_name", ["textcat", "textcat_multilabel"], "component_name",
["textcat", "textcat_multilabel"],
) )
def test_issue6908(component_name): def test_issue6908(component_name):
"""Test intializing textcat with labels in a list""" """Test intializing textcat with labels in a list"""

View File

@ -8,8 +8,7 @@ def test_issue7056():
sentence segmentation errors.""" sentence segmentation errors."""
vocab = Vocab() vocab = Vocab()
ae = ArcEager( ae = ArcEager(
vocab.strings, vocab.strings, ArcEager.get_actions(left_labels=["amod"], right_labels=["pobj"])
ArcEager.get_actions(left_labels=["amod"], right_labels=["pobj"])
) )
doc = Doc(vocab, words="Severe pain , after trauma".split()) doc = Doc(vocab, words="Severe pain , after trauma".split())
state = ae.init_batch([doc])[0] state = ae.init_batch([doc])[0]

View File

@ -41,7 +41,7 @@ def test_partial_links():
nlp.add_pipe("sentencizer", first=True) nlp.add_pipe("sentencizer", first=True)
patterns = [ patterns = [
{"label": "PERSON", "pattern": [{"LOWER": "russ"}, {"LOWER": "cochran"}]}, {"label": "PERSON", "pattern": [{"LOWER": "russ"}, {"LOWER": "cochran"}]},
{"label": "ORG", "pattern": [{"LOWER": "ec"}, {"LOWER": "comics"}]} {"label": "ORG", "pattern": [{"LOWER": "ec"}, {"LOWER": "comics"}]},
] ]
ruler = nlp.add_pipe("entity_ruler", before="entity_linker") ruler = nlp.add_pipe("entity_ruler", before="entity_linker")
ruler.add_patterns(patterns) ruler.add_patterns(patterns)

View File

@ -8,7 +8,17 @@ def test_issue7065():
nlp = English() nlp = English()
nlp.add_pipe("sentencizer") nlp.add_pipe("sentencizer")
ruler = nlp.add_pipe("entity_ruler") ruler = nlp.add_pipe("entity_ruler")
patterns = [{"label": "THING", "pattern": [{"LOWER": "symphony"}, {"LOWER": "no"}, {"LOWER": "."}, {"LOWER": "8"}]}] patterns = [
{
"label": "THING",
"pattern": [
{"LOWER": "symphony"},
{"LOWER": "no"},
{"LOWER": "."},
{"LOWER": "8"},
],
}
]
ruler.add_patterns(patterns) ruler.add_patterns(patterns)
doc = nlp(text) doc = nlp(text)
@ -28,11 +38,15 @@ def test_issue7065_b():
text = "Mahler 's Symphony No. 8 was beautiful." text = "Mahler 's Symphony No. 8 was beautiful."
entities = [(0, 6, "PERSON"), (10, 24, "WORK")] entities = [(0, 6, "PERSON"), (10, 24, "WORK")]
links = {(0, 6): {"Q7304": 1.0, "Q270853": 0.0}, links = {
(10, 24): {"Q7304": 0.0, "Q270853": 1.0}} (0, 6): {"Q7304": 1.0, "Q270853": 0.0},
(10, 24): {"Q7304": 0.0, "Q270853": 1.0},
}
sent_starts = [1, -1, 0, 0, 0, 0, 0, 0, 0] sent_starts = [1, -1, 0, 0, 0, 0, 0, 0, 0]
doc = nlp(text) doc = nlp(text)
example = Example.from_dict(doc, {"entities": entities, "links": links, "sent_starts": sent_starts}) example = Example.from_dict(
doc, {"entities": entities, "links": links, "sent_starts": sent_starts}
)
train_examples = [example] train_examples = [example]
def create_kb(vocab): def create_kb(vocab):
@ -65,7 +79,15 @@ def test_issue7065_b():
# Add a custom rule-based component to mimick NER # Add a custom rule-based component to mimick NER
patterns = [ patterns = [
{"label": "PERSON", "pattern": [{"LOWER": "mahler"}]}, {"label": "PERSON", "pattern": [{"LOWER": "mahler"}]},
{"label": "WORK", "pattern": [{"LOWER": "symphony"}, {"LOWER": "no"}, {"LOWER": "."}, {"LOWER": "8"}]} {
"label": "WORK",
"pattern": [
{"LOWER": "symphony"},
{"LOWER": "no"},
{"LOWER": "."},
{"LOWER": "8"},
],
},
] ]
ruler = nlp.add_pipe("entity_ruler", before="entity_linker") ruler = nlp.add_pipe("entity_ruler", before="entity_linker")
ruler.add_patterns(patterns) ruler.add_patterns(patterns)

View File

@ -1,11 +1,22 @@
from spacy.lang.en import English from spacy.lang.en import English
def test_issue8168(): def test_issue8168():
nlp = English() nlp = English()
ruler = nlp.add_pipe("entity_ruler") ruler = nlp.add_pipe("entity_ruler")
patterns = [{"label": "ORG", "pattern": "Apple"}, patterns = [
{"label": "GPE", "pattern": [{"LOWER": "san"}, {"LOWER": "francisco"}], "id": "san-francisco"}, {"label": "ORG", "pattern": "Apple"},
{"label": "GPE", "pattern": [{"LOWER": "san"}, {"LOWER": "fran"}], "id": "san-francisco"}] {
"label": "GPE",
"pattern": [{"LOWER": "san"}, {"LOWER": "francisco"}],
"id": "san-francisco",
},
{
"label": "GPE",
"pattern": [{"LOWER": "san"}, {"LOWER": "fran"}],
"id": "san-francisco",
},
]
ruler.add_patterns(patterns) ruler.add_patterns(patterns)
assert ruler._ent_ids == {8043148519967183733: ('GPE', 'san-francisco')} assert ruler._ent_ids == {8043148519967183733: ("GPE", "san-francisco")}

View File

@ -9,20 +9,13 @@ def test_issue8190():
"nlp": { "nlp": {
"lang": "en", "lang": "en",
}, },
"custom": { "custom": {"key": "value"},
"key": "value"
}
} }
source_nlp = English.from_config(source_cfg) source_nlp = English.from_config(source_cfg)
with make_tempdir() as dir_path: with make_tempdir() as dir_path:
# We need to create a loadable source pipeline # We need to create a loadable source pipeline
source_path = dir_path / "test_model" source_path = dir_path / "test_model"
source_nlp.to_disk(source_path) source_nlp.to_disk(source_path)
nlp = spacy.load(source_path, config={ nlp = spacy.load(source_path, config={"custom": {"key": "updated_value"}})
"custom": {
"key": "updated_value"
}
})
assert nlp.config["custom"]["key"] == "updated_value" assert nlp.config["custom"]["key"] == "updated_value"

View File

@ -4,7 +4,12 @@ import spacy
from spacy.lang.en import English from spacy.lang.en import English
from spacy.lang.de import German from spacy.lang.de import German
from spacy.language import Language, DEFAULT_CONFIG, DEFAULT_CONFIG_PRETRAIN_PATH from spacy.language import Language, DEFAULT_CONFIG, DEFAULT_CONFIG_PRETRAIN_PATH
from spacy.util import registry, load_model_from_config, load_config, load_config_from_str from spacy.util import (
registry,
load_model_from_config,
load_config,
load_config_from_str,
)
from spacy.ml.models import build_Tok2Vec_model, build_tb_parser_model from spacy.ml.models import build_Tok2Vec_model, build_tb_parser_model
from spacy.ml.models import MultiHashEmbed, MaxoutWindowEncoder from spacy.ml.models import MultiHashEmbed, MaxoutWindowEncoder
from spacy.schemas import ConfigSchema, ConfigSchemaPretrain from spacy.schemas import ConfigSchema, ConfigSchemaPretrain
@ -493,4 +498,4 @@ def test_hyphen_in_config():
self.punctuation = punctuation self.punctuation = punctuation
nlp = English.from_config(load_config_from_str(hyphen_config_str)) nlp = English.from_config(load_config_from_str(hyphen_config_str))
assert nlp.get_pipe("my_punctual_component").punctuation == ['?', '-'] assert nlp.get_pipe("my_punctual_component").punctuation == ["?", "-"]

View File

@ -64,7 +64,9 @@ def test_serialize_doc_span_groups(en_vocab):
def test_serialize_doc_bin(): def test_serialize_doc_bin():
doc_bin = DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE", "NORM", "ENT_ID"], store_user_data=True) doc_bin = DocBin(
attrs=["LEMMA", "ENT_IOB", "ENT_TYPE", "NORM", "ENT_ID"], store_user_data=True
)
texts = ["Some text", "Lots of texts...", "..."] texts = ["Some text", "Lots of texts...", "..."]
cats = {"A": 0.5} cats = {"A": 0.5}
nlp = English() nlp = English()

View File

@ -5,7 +5,6 @@ from catalogue import RegistryError
def test_get_architecture(): def test_get_architecture():
@registry.architectures("my_test_function") @registry.architectures("my_test_function")
def create_model(nr_in, nr_out): def create_model(nr_in, nr_out):
return Linear(nr_in, nr_out) return Linear(nr_in, nr_out)

View File

@ -143,7 +143,9 @@ def sample_vectors():
@pytest.fixture @pytest.fixture
def nlp2(nlp, sample_vectors): def nlp2(nlp, sample_vectors):
Language.component("test_language_vector_modification_pipe", func=vector_modification_pipe) Language.component(
"test_language_vector_modification_pipe", func=vector_modification_pipe
)
Language.component("test_language_userdata_pipe", func=userdata_pipe) Language.component("test_language_userdata_pipe", func=userdata_pipe)
Language.component("test_language_ner_pipe", func=ner_pipe) Language.component("test_language_ner_pipe", func=ner_pipe)
add_vecs_to_vocab(nlp.vocab, sample_vectors) add_vecs_to_vocab(nlp.vocab, sample_vectors)

View File

@ -444,7 +444,9 @@ def test_score_spans():
assert f"{key}_per_type" in scores assert f"{key}_per_type" in scores
# Discard labels from the evaluation # Discard labels from the evaluation
scores = Scorer.score_spans([eg], attr=key, getter=span_getter, allow_overlap=True, labeled=False) scores = Scorer.score_spans(
[eg], attr=key, getter=span_getter, allow_overlap=True, labeled=False
)
assert scores[f"{key}_p"] == 1.0 assert scores[f"{key}_p"] == 1.0
assert scores[f"{key}_r"] == 1.0 assert scores[f"{key}_r"] == 1.0
assert f"{key}_per_type" not in scores assert f"{key}_per_type" not in scores
@ -467,4 +469,6 @@ def test_prf_score():
assert (c.precision, c.recall, c.fscore) == approx((0.25, 0.5, 0.33333333)) assert (c.precision, c.recall, c.fscore) == approx((0.25, 0.5, 0.33333333))
a += b a += b
assert (a.precision, a.recall, a.fscore) == approx((c.precision, c.recall, c.fscore)) assert (a.precision, a.recall, a.fscore) == approx(
(c.precision, c.recall, c.fscore)
)

View File

@ -278,7 +278,9 @@ def test_pretraining_training():
filled = filled.interpolate() filled = filled.interpolate()
P = filled["pretraining"] P = filled["pretraining"]
nlp_base = init_nlp(filled) nlp_base = init_nlp(filled)
model_base = nlp_base.get_pipe(P["component"]).model.get_ref(P["layer"]).get_ref("embed") model_base = (
nlp_base.get_pipe(P["component"]).model.get_ref(P["layer"]).get_ref("embed")
)
embed_base = None embed_base = None
for node in model_base.walk(): for node in model_base.walk():
if node.name == "hashembed": if node.name == "hashembed":
@ -331,11 +333,12 @@ def write_sample_training(tmp_dir):
def write_vectors_model(tmp_dir): def write_vectors_model(tmp_dir):
import numpy import numpy
vocab = Vocab() vocab = Vocab()
vector_data = { vector_data = {
"dog": numpy.random.uniform(-1, 1, (300,)), "dog": numpy.random.uniform(-1, 1, (300,)),
"cat": numpy.random.uniform(-1, 1, (300,)), "cat": numpy.random.uniform(-1, 1, (300,)),
"orange": numpy.random.uniform(-1, 1, (300,)) "orange": numpy.random.uniform(-1, 1, (300,)),
} }
for word, vector in vector_data.items(): for word, vector in vector_data.items():
vocab.set_vector(word, vector) vocab.set_vector(word, vector)

View File

@ -434,8 +434,14 @@ def test_aligned_spans_y2x_overlap(en_vocab, en_tokenizer):
gold_doc = nlp.make_doc(text) gold_doc = nlp.make_doc(text)
spans = [] spans = []
prefix = "I flew to " prefix = "I flew to "
spans.append(gold_doc.char_span(len(prefix), len(prefix + "San Francisco"), label="CITY")) spans.append(
spans.append(gold_doc.char_span(len(prefix), len(prefix + "San Francisco Valley"), label="VALLEY")) gold_doc.char_span(len(prefix), len(prefix + "San Francisco"), label="CITY")
)
spans.append(
gold_doc.char_span(
len(prefix), len(prefix + "San Francisco Valley"), label="VALLEY"
)
)
spans_key = "overlap_ents" spans_key = "overlap_ents"
gold_doc.spans[spans_key] = spans gold_doc.spans[spans_key] = spans
example = Example(doc, gold_doc) example = Example(doc, gold_doc)
@ -443,7 +449,9 @@ def test_aligned_spans_y2x_overlap(en_vocab, en_tokenizer):
assert [(ent.start, ent.end) for ent in spans_gold] == [(3, 5), (3, 6)] assert [(ent.start, ent.end) for ent in spans_gold] == [(3, 5), (3, 6)]
# Ensure that 'get_aligned_spans_y2x' has the aligned entities correct # Ensure that 'get_aligned_spans_y2x' has the aligned entities correct
spans_y2x_no_overlap = example.get_aligned_spans_y2x(spans_gold, allow_overlap=False) spans_y2x_no_overlap = example.get_aligned_spans_y2x(
spans_gold, allow_overlap=False
)
assert [(ent.start, ent.end) for ent in spans_y2x_no_overlap] == [(3, 5)] assert [(ent.start, ent.end) for ent in spans_y2x_no_overlap] == [(3, 5)]
spans_y2x_overlap = example.get_aligned_spans_y2x(spans_gold, allow_overlap=True) spans_y2x_overlap = example.get_aligned_spans_y2x(spans_gold, allow_overlap=True)
assert [(ent.start, ent.end) for ent in spans_y2x_overlap] == [(3, 5), (3, 6)] assert [(ent.start, ent.end) for ent in spans_y2x_overlap] == [(3, 5), (3, 6)]

View File

@ -12,6 +12,7 @@ from ..util import add_vecs_to_vocab, get_cosine, make_tempdir
OPS = get_current_ops() OPS = get_current_ops()
@pytest.fixture @pytest.fixture
def strings(): def strings():
return ["apple", "orange"] return ["apple", "orange"]

View File

@ -66,7 +66,11 @@ def configure_minibatch_by_words(
""" """
optionals = {"get_length": get_length} if get_length is not None else {} optionals = {"get_length": get_length} if get_length is not None else {}
return partial( return partial(
minibatch_by_words, size=size, tolerance=tolerance, discard_oversize=discard_oversize, **optionals minibatch_by_words,
size=size,
tolerance=tolerance,
discard_oversize=discard_oversize,
**optionals
) )

View File

@ -70,14 +70,18 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
nlp._link_components() nlp._link_components()
with nlp.select_pipes(disable=[*frozen_components, *resume_components]): with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
if T["max_epochs"] == -1: if T["max_epochs"] == -1:
logger.debug("Due to streamed train corpus, using only first 100 examples for initialization. If necessary, provide all labels in [initialize]. More info: https://spacy.io/api/cli#init_labels") logger.debug(
"Due to streamed train corpus, using only first 100 examples for initialization. If necessary, provide all labels in [initialize]. More info: https://spacy.io/api/cli#init_labels"
)
nlp.initialize(lambda: islice(train_corpus(nlp), 100), sgd=optimizer) nlp.initialize(lambda: islice(train_corpus(nlp), 100), sgd=optimizer)
else: else:
nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer) nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer)
logger.info(f"Initialized pipeline components: {nlp.pipe_names}") logger.info(f"Initialized pipeline components: {nlp.pipe_names}")
# Detect components with listeners that are not frozen consistently # Detect components with listeners that are not frozen consistently
for name, proc in nlp.pipeline: for name, proc in nlp.pipeline:
for listener in getattr(proc, "listening_components", []): # e.g. tok2vec/transformer for listener in getattr(
proc, "listening_components", []
): # e.g. tok2vec/transformer
# Don't warn about components not in the pipeline # Don't warn about components not in the pipeline
if listener not in nlp.pipe_names: if listener not in nlp.pipe_names:
continue continue

View File

@ -96,8 +96,7 @@ def train(
stdout.write(msg.info(f"Frozen components: {frozen_components}") + "\n") stdout.write(msg.info(f"Frozen components: {frozen_components}") + "\n")
if annotating_components: if annotating_components:
stdout.write( stdout.write(
msg.info(f"Set annotations on update for: {annotating_components}") msg.info(f"Set annotations on update for: {annotating_components}") + "\n"
+ "\n"
) )
stdout.write(msg.info(f"Initial learn rate: {optimizer.learn_rate}") + "\n") stdout.write(msg.info(f"Initial learn rate: {optimizer.learn_rate}") + "\n")
with nlp.select_pipes(disable=frozen_components): with nlp.select_pipes(disable=frozen_components):

View File

@ -57,13 +57,13 @@ if TYPE_CHECKING:
from .vocab import Vocab # noqa: F401 from .vocab import Vocab # noqa: F401
# fmt: off
OOV_RANK = numpy.iinfo(numpy.uint64).max OOV_RANK = numpy.iinfo(numpy.uint64).max
DEFAULT_OOV_PROB = -20 DEFAULT_OOV_PROB = -20
LEXEME_NORM_LANGS = ["cs", "da", "de", "el", "en", "id", "lb", "mk", "pt", "ru", "sr", "ta", "th"] LEXEME_NORM_LANGS = ["cs", "da", "de", "el", "en", "id", "lb", "mk", "pt", "ru", "sr", "ta", "th"]
# Default order of sections in the config.cfg. Not all sections needs to exist, # Default order of sections in the config.cfg. Not all sections needs to exist,
# and additional sections are added at the end, in alphabetical order. # and additional sections are added at the end, in alphabetical order.
# fmt: off
CONFIG_SECTION_ORDER = ["paths", "variables", "system", "nlp", "components", "corpora", "training", "pretraining", "initialize"] CONFIG_SECTION_ORDER = ["paths", "variables", "system", "nlp", "components", "corpora", "training", "pretraining", "initialize"]
# fmt: on # fmt: on
@ -649,8 +649,7 @@ def get_model_version_range(spacy_version: str) -> str:
def get_model_lower_version(constraint: str) -> Optional[str]: def get_model_lower_version(constraint: str) -> Optional[str]:
"""From a version range like >=1.2.3,<1.3.0 return the lower pin. """From a version range like >=1.2.3,<1.3.0 return the lower pin."""
"""
try: try:
specset = SpecifierSet(constraint) specset = SpecifierSet(constraint)
for spec in specset: for spec in specset: