mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Tidy up and auto-format
This commit is contained in:
parent
89f2b87266
commit
009280fbc5
16
spacy/_ml.py
16
spacy/_ml.py
|
@ -674,14 +674,14 @@ def build_nel_encoder(embed_width, hidden_width, ner_types, **cfg):
|
||||||
with Model.define_operators({">>": chain, "**": clone}):
|
with Model.define_operators({">>": chain, "**": clone}):
|
||||||
# context encoder
|
# context encoder
|
||||||
tok2vec = Tok2Vec(
|
tok2vec = Tok2Vec(
|
||||||
width=hidden_width,
|
width=hidden_width,
|
||||||
embed_size=embed_width,
|
embed_size=embed_width,
|
||||||
pretrained_vectors=pretrained_vectors,
|
pretrained_vectors=pretrained_vectors,
|
||||||
cnn_maxout_pieces=cnn_maxout_pieces,
|
cnn_maxout_pieces=cnn_maxout_pieces,
|
||||||
subword_features=True,
|
subword_features=True,
|
||||||
conv_depth=conv_depth,
|
conv_depth=conv_depth,
|
||||||
bilstm_depth=0,
|
bilstm_depth=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
model = (
|
model = (
|
||||||
tok2vec
|
tok2vec
|
||||||
|
|
|
@ -8,7 +8,7 @@ import sys
|
||||||
import srsly
|
import srsly
|
||||||
from wasabi import Printer, MESSAGES
|
from wasabi import Printer, MESSAGES
|
||||||
|
|
||||||
from ..gold import GoldCorpus, read_json_object
|
from ..gold import GoldCorpus
|
||||||
from ..syntax import nonproj
|
from ..syntax import nonproj
|
||||||
from ..util import load_model, get_lang_class
|
from ..util import load_model, get_lang_class
|
||||||
|
|
||||||
|
@ -95,13 +95,19 @@ def debug_data(
|
||||||
corpus = GoldCorpus(train_path, dev_path)
|
corpus = GoldCorpus(train_path, dev_path)
|
||||||
try:
|
try:
|
||||||
train_docs = list(corpus.train_docs(nlp))
|
train_docs = list(corpus.train_docs(nlp))
|
||||||
train_docs_unpreprocessed = list(corpus.train_docs_without_preprocessing(nlp))
|
train_docs_unpreprocessed = list(
|
||||||
|
corpus.train_docs_without_preprocessing(nlp)
|
||||||
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
loading_train_error_message = "Training data cannot be loaded: {}".format(str(e))
|
loading_train_error_message = "Training data cannot be loaded: {}".format(
|
||||||
|
str(e)
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
dev_docs = list(corpus.dev_docs(nlp))
|
dev_docs = list(corpus.dev_docs(nlp))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
loading_dev_error_message = "Development data cannot be loaded: {}".format(str(e))
|
loading_dev_error_message = "Development data cannot be loaded: {}".format(
|
||||||
|
str(e)
|
||||||
|
)
|
||||||
if loading_train_error_message or loading_dev_error_message:
|
if loading_train_error_message or loading_dev_error_message:
|
||||||
if loading_train_error_message:
|
if loading_train_error_message:
|
||||||
msg.fail(loading_train_error_message)
|
msg.fail(loading_train_error_message)
|
||||||
|
@ -158,11 +164,15 @@ def debug_data(
|
||||||
)
|
)
|
||||||
if gold_train_data["n_misaligned_words"] > 0:
|
if gold_train_data["n_misaligned_words"] > 0:
|
||||||
msg.warn(
|
msg.warn(
|
||||||
"{} misaligned tokens in the training data".format(gold_train_data["n_misaligned_words"])
|
"{} misaligned tokens in the training data".format(
|
||||||
|
gold_train_data["n_misaligned_words"]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if gold_dev_data["n_misaligned_words"] > 0:
|
if gold_dev_data["n_misaligned_words"] > 0:
|
||||||
msg.warn(
|
msg.warn(
|
||||||
"{} misaligned tokens in the dev data".format(gold_dev_data["n_misaligned_words"])
|
"{} misaligned tokens in the dev data".format(
|
||||||
|
gold_dev_data["n_misaligned_words"]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
most_common_words = gold_train_data["words"].most_common(10)
|
most_common_words = gold_train_data["words"].most_common(10)
|
||||||
msg.text(
|
msg.text(
|
||||||
|
@ -184,7 +194,9 @@ def debug_data(
|
||||||
|
|
||||||
if "ner" in pipeline:
|
if "ner" in pipeline:
|
||||||
# Get all unique NER labels present in the data
|
# Get all unique NER labels present in the data
|
||||||
labels = set(label for label in gold_train_data["ner"] if label not in ("O", "-"))
|
labels = set(
|
||||||
|
label for label in gold_train_data["ner"] if label not in ("O", "-")
|
||||||
|
)
|
||||||
label_counts = gold_train_data["ner"]
|
label_counts = gold_train_data["ner"]
|
||||||
model_labels = _get_labels_from_model(nlp, "ner")
|
model_labels = _get_labels_from_model(nlp, "ner")
|
||||||
new_labels = [l for l in labels if l not in model_labels]
|
new_labels = [l for l in labels if l not in model_labels]
|
||||||
|
@ -222,7 +234,9 @@ def debug_data(
|
||||||
)
|
)
|
||||||
|
|
||||||
if gold_train_data["ws_ents"]:
|
if gold_train_data["ws_ents"]:
|
||||||
msg.fail("{} invalid whitespace entity spans".format(gold_train_data["ws_ents"]))
|
msg.fail(
|
||||||
|
"{} invalid whitespace entity spans".format(gold_train_data["ws_ents"])
|
||||||
|
)
|
||||||
has_ws_ents_error = True
|
has_ws_ents_error = True
|
||||||
|
|
||||||
for label in new_labels:
|
for label in new_labels:
|
||||||
|
@ -323,33 +337,36 @@ def debug_data(
|
||||||
"Found {} sentence{} with an average length of {:.1f} words.".format(
|
"Found {} sentence{} with an average length of {:.1f} words.".format(
|
||||||
gold_train_data["n_sents"],
|
gold_train_data["n_sents"],
|
||||||
"s" if len(train_docs) > 1 else "",
|
"s" if len(train_docs) > 1 else "",
|
||||||
gold_train_data["n_words"] / gold_train_data["n_sents"]
|
gold_train_data["n_words"] / gold_train_data["n_sents"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# profile labels
|
# profile labels
|
||||||
labels_train = [label for label in gold_train_data["deps"]]
|
labels_train = [label for label in gold_train_data["deps"]]
|
||||||
labels_train_unpreprocessed = [label for label in gold_train_unpreprocessed_data["deps"]]
|
labels_train_unpreprocessed = [
|
||||||
|
label for label in gold_train_unpreprocessed_data["deps"]
|
||||||
|
]
|
||||||
labels_dev = [label for label in gold_dev_data["deps"]]
|
labels_dev = [label for label in gold_dev_data["deps"]]
|
||||||
|
|
||||||
if gold_train_unpreprocessed_data["n_nonproj"] > 0:
|
if gold_train_unpreprocessed_data["n_nonproj"] > 0:
|
||||||
msg.info(
|
msg.info(
|
||||||
"Found {} nonprojective train sentence{}".format(
|
"Found {} nonprojective train sentence{}".format(
|
||||||
gold_train_unpreprocessed_data["n_nonproj"],
|
gold_train_unpreprocessed_data["n_nonproj"],
|
||||||
"s" if gold_train_unpreprocessed_data["n_nonproj"] > 1 else ""
|
"s" if gold_train_unpreprocessed_data["n_nonproj"] > 1 else "",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if gold_dev_data["n_nonproj"] > 0:
|
if gold_dev_data["n_nonproj"] > 0:
|
||||||
msg.info(
|
msg.info(
|
||||||
"Found {} nonprojective dev sentence{}".format(
|
"Found {} nonprojective dev sentence{}".format(
|
||||||
gold_dev_data["n_nonproj"],
|
gold_dev_data["n_nonproj"],
|
||||||
"s" if gold_dev_data["n_nonproj"] > 1 else ""
|
"s" if gold_dev_data["n_nonproj"] > 1 else "",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
msg.info(
|
msg.info(
|
||||||
"{} {} in train data".format(
|
"{} {} in train data".format(
|
||||||
len(labels_train_unpreprocessed), "label" if len(labels_train) == 1 else "labels"
|
len(labels_train_unpreprocessed),
|
||||||
|
"label" if len(labels_train) == 1 else "labels",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
msg.info(
|
msg.info(
|
||||||
|
@ -373,43 +390,45 @@ def debug_data(
|
||||||
)
|
)
|
||||||
has_low_data_warning = True
|
has_low_data_warning = True
|
||||||
|
|
||||||
|
|
||||||
# rare labels in projectivized train
|
# rare labels in projectivized train
|
||||||
rare_projectivized_labels = []
|
rare_projectivized_labels = []
|
||||||
for label in gold_train_data["deps"]:
|
for label in gold_train_data["deps"]:
|
||||||
if gold_train_data["deps"][label] <= DEP_LABEL_THRESHOLD and "||" in label:
|
if gold_train_data["deps"][label] <= DEP_LABEL_THRESHOLD and "||" in label:
|
||||||
rare_projectivized_labels.append("{}: {}".format(label, str(gold_train_data["deps"][label])))
|
rare_projectivized_labels.append(
|
||||||
|
"{}: {}".format(label, str(gold_train_data["deps"][label]))
|
||||||
|
)
|
||||||
|
|
||||||
if len(rare_projectivized_labels) > 0:
|
if len(rare_projectivized_labels) > 0:
|
||||||
msg.warn(
|
msg.warn(
|
||||||
"Low number of examples for {} label{} in the "
|
"Low number of examples for {} label{} in the "
|
||||||
"projectivized dependency trees used for training. You may "
|
"projectivized dependency trees used for training. You may "
|
||||||
"want to projectivize labels such as punct before "
|
"want to projectivize labels such as punct before "
|
||||||
"training in order to improve parser performance.".format(
|
"training in order to improve parser performance.".format(
|
||||||
len(rare_projectivized_labels),
|
len(rare_projectivized_labels),
|
||||||
"s" if len(rare_projectivized_labels) > 1 else "")
|
"s" if len(rare_projectivized_labels) > 1 else "",
|
||||||
)
|
)
|
||||||
msg.warn(
|
)
|
||||||
"Projectivized labels with low numbers of examples: "
|
msg.warn(
|
||||||
"{}".format("\n".join(rare_projectivized_labels)),
|
"Projectivized labels with low numbers of examples: "
|
||||||
show=verbose
|
"{}".format("\n".join(rare_projectivized_labels)),
|
||||||
)
|
show=verbose,
|
||||||
has_low_data_warning = True
|
)
|
||||||
|
has_low_data_warning = True
|
||||||
|
|
||||||
# labels only in train
|
# labels only in train
|
||||||
if set(labels_train) - set(labels_dev):
|
if set(labels_train) - set(labels_dev):
|
||||||
msg.warn(
|
msg.warn(
|
||||||
"The following labels were found only in the train data: "
|
"The following labels were found only in the train data: "
|
||||||
"{}".format(", ".join(set(labels_train) - set(labels_dev))),
|
"{}".format(", ".join(set(labels_train) - set(labels_dev))),
|
||||||
show=verbose
|
show=verbose,
|
||||||
)
|
)
|
||||||
|
|
||||||
# labels only in dev
|
# labels only in dev
|
||||||
if set(labels_dev) - set(labels_train):
|
if set(labels_dev) - set(labels_train):
|
||||||
msg.warn(
|
msg.warn(
|
||||||
"The following labels were found only in the dev data: " +
|
"The following labels were found only in the dev data: "
|
||||||
", ".join(set(labels_dev) - set(labels_train)),
|
+ ", ".join(set(labels_dev) - set(labels_train)),
|
||||||
show=verbose
|
show=verbose,
|
||||||
)
|
)
|
||||||
|
|
||||||
if has_low_data_warning:
|
if has_low_data_warning:
|
||||||
|
@ -422,8 +441,10 @@ def debug_data(
|
||||||
# multiple root labels
|
# multiple root labels
|
||||||
if len(gold_train_unpreprocessed_data["roots"]) > 1:
|
if len(gold_train_unpreprocessed_data["roots"]) > 1:
|
||||||
msg.warn(
|
msg.warn(
|
||||||
"Multiple root labels ({}) ".format(", ".join(gold_train_unpreprocessed_data["roots"])) +
|
"Multiple root labels ({}) ".format(
|
||||||
"found in training data. spaCy's parser uses a single root "
|
", ".join(gold_train_unpreprocessed_data["roots"])
|
||||||
|
)
|
||||||
|
+ "found in training data. spaCy's parser uses a single root "
|
||||||
"label ROOT so this distinction will not be available."
|
"label ROOT so this distinction will not be available."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -432,14 +453,14 @@ def debug_data(
|
||||||
msg.fail(
|
msg.fail(
|
||||||
"Found {} nonprojective projectivized train sentence{}".format(
|
"Found {} nonprojective projectivized train sentence{}".format(
|
||||||
gold_train_data["n_nonproj"],
|
gold_train_data["n_nonproj"],
|
||||||
"s" if gold_train_data["n_nonproj"] > 1 else ""
|
"s" if gold_train_data["n_nonproj"] > 1 else "",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if gold_train_data["n_cycles"] > 0:
|
if gold_train_data["n_cycles"] > 0:
|
||||||
msg.fail(
|
msg.fail(
|
||||||
"Found {} projectivized train sentence{} with cycles".format(
|
"Found {} projectivized train sentence{} with cycles".format(
|
||||||
gold_train_data["n_cycles"],
|
gold_train_data["n_cycles"],
|
||||||
"s" if gold_train_data["n_cycles"] > 1 else ""
|
"s" if gold_train_data["n_cycles"] > 1 else "",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -114,7 +114,7 @@ def read_attrs_from_deprecated(freqs_loc, clusters_loc):
|
||||||
probs, _ = read_freqs(freqs_loc)
|
probs, _ = read_freqs(freqs_loc)
|
||||||
msg.good("Counted frequencies")
|
msg.good("Counted frequencies")
|
||||||
else:
|
else:
|
||||||
probs, _ = ({}, DEFAULT_OOV_PROB)
|
probs, _ = ({}, DEFAULT_OOV_PROB) # noqa: F841
|
||||||
if clusters_loc:
|
if clusters_loc:
|
||||||
with msg.loading("Reading clusters..."):
|
with msg.loading("Reading clusters..."):
|
||||||
clusters = read_clusters(clusters_loc)
|
clusters = read_clusters(clusters_loc)
|
||||||
|
|
|
@ -429,6 +429,7 @@ class Errors(object):
|
||||||
E155 = ("The `nlp` object should have access to pre-trained word vectors, cf. "
|
E155 = ("The `nlp` object should have access to pre-trained word vectors, cf. "
|
||||||
"https://spacy.io/usage/models#languages.")
|
"https://spacy.io/usage/models#languages.")
|
||||||
|
|
||||||
|
|
||||||
@add_codes
|
@add_codes
|
||||||
class TempErrors(object):
|
class TempErrors(object):
|
||||||
T003 = ("Resizing pre-trained Tagger models is not currently supported.")
|
T003 = ("Resizing pre-trained Tagger models is not currently supported.")
|
||||||
|
|
|
@ -1,10 +1,8 @@
|
||||||
# encoding: utf8
|
# encoding: utf8
|
||||||
from __future__ import unicode_literals, print_function
|
from __future__ import unicode_literals, print_function
|
||||||
|
|
||||||
import re
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
||||||
from .stop_words import STOP_WORDS
|
from .stop_words import STOP_WORDS
|
||||||
from .tag_map import TAG_MAP
|
from .tag_map import TAG_MAP
|
||||||
from ...attrs import LANG
|
from ...attrs import LANG
|
||||||
|
@ -32,7 +30,7 @@ else:
|
||||||
from typing import NamedTuple
|
from typing import NamedTuple
|
||||||
|
|
||||||
class Morpheme(NamedTuple):
|
class Morpheme(NamedTuple):
|
||||||
|
|
||||||
surface = str("")
|
surface = str("")
|
||||||
lemma = str("")
|
lemma = str("")
|
||||||
tag = str("")
|
tag = str("")
|
||||||
|
|
|
@ -8,6 +8,7 @@ from ..tokenizer_exceptions import BASE_EXCEPTIONS
|
||||||
from .stop_words import STOP_WORDS
|
from .stop_words import STOP_WORDS
|
||||||
from .tag_map import TAG_MAP
|
from .tag_map import TAG_MAP
|
||||||
|
|
||||||
|
|
||||||
class ChineseDefaults(Language.Defaults):
|
class ChineseDefaults(Language.Defaults):
|
||||||
lex_attr_getters = dict(Language.Defaults.lex_attr_getters)
|
lex_attr_getters = dict(Language.Defaults.lex_attr_getters)
|
||||||
lex_attr_getters[LANG] = lambda text: "zh"
|
lex_attr_getters[LANG] = lambda text: "zh"
|
||||||
|
@ -45,4 +46,4 @@ class Chinese(Language):
|
||||||
return Doc(self.vocab, words=words, spaces=spaces)
|
return Doc(self.vocab, words=words, spaces=spaces)
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Chinese"]
|
__all__ = ["Chinese"]
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
# coding: utf8
|
# coding: utf8
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
from ...symbols import POS, PUNCT, SYM, ADJ, CONJ, CCONJ, NUM, DET, ADV, ADP, X, VERB
|
from ...symbols import POS, PUNCT, ADJ, CONJ, CCONJ, NUM, DET, ADV, ADP, X, VERB
|
||||||
from ...symbols import NOUN, PROPN, PART, INTJ, SPACE, PRON, AUX
|
from ...symbols import NOUN, PART, INTJ, PRON
|
||||||
|
|
||||||
# The Chinese part-of-speech tagger uses the OntoNotes 5 version of the Penn Treebank tag set.
|
# The Chinese part-of-speech tagger uses the OntoNotes 5 version of the Penn Treebank tag set.
|
||||||
# We also map the tags to the simpler Google Universal POS tag set.
|
# We also map the tags to the simpler Google Universal POS tag set.
|
||||||
|
@ -43,5 +43,5 @@ TAG_MAP = {
|
||||||
"JJ": {POS: ADJ},
|
"JJ": {POS: ADJ},
|
||||||
"P": {POS: ADP},
|
"P": {POS: ADP},
|
||||||
"PN": {POS: PRON},
|
"PN": {POS: PRON},
|
||||||
"PU": {POS: PUNCT}
|
"PU": {POS: PUNCT},
|
||||||
}
|
}
|
||||||
|
|
|
@ -160,14 +160,15 @@ class Scorer(object):
|
||||||
cand_deps.add((gold_i, gold_head, token.dep_.lower()))
|
cand_deps.add((gold_i, gold_head, token.dep_.lower()))
|
||||||
if "-" not in [token[-1] for token in gold.orig_annot]:
|
if "-" not in [token[-1] for token in gold.orig_annot]:
|
||||||
# Find all NER labels in gold and doc
|
# Find all NER labels in gold and doc
|
||||||
ent_labels = set([x[0] for x in gold_ents]
|
ent_labels = set([x[0] for x in gold_ents] + [k.label_ for k in doc.ents])
|
||||||
+ [k.label_ for k in doc.ents])
|
|
||||||
# Set up all labels for per type scoring and prepare gold per type
|
# Set up all labels for per type scoring and prepare gold per type
|
||||||
gold_per_ents = {ent_label: set() for ent_label in ent_labels}
|
gold_per_ents = {ent_label: set() for ent_label in ent_labels}
|
||||||
for ent_label in ent_labels:
|
for ent_label in ent_labels:
|
||||||
if ent_label not in self.ner_per_ents:
|
if ent_label not in self.ner_per_ents:
|
||||||
self.ner_per_ents[ent_label] = PRFScore()
|
self.ner_per_ents[ent_label] = PRFScore()
|
||||||
gold_per_ents[ent_label].update([x for x in gold_ents if x[0] == ent_label])
|
gold_per_ents[ent_label].update(
|
||||||
|
[x for x in gold_ents if x[0] == ent_label]
|
||||||
|
)
|
||||||
# Find all candidate labels, for all and per type
|
# Find all candidate labels, for all and per type
|
||||||
cand_ents = set()
|
cand_ents = set()
|
||||||
cand_per_ents = {ent_label: set() for ent_label in ent_labels}
|
cand_per_ents = {ent_label: set() for ent_label in ent_labels}
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
# coding: utf8
|
# coding: utf8
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import pytest
|
|
||||||
from spacy.matcher import PhraseMatcher
|
from spacy.matcher import PhraseMatcher
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
|
|
||||||
|
|
|
@ -3,12 +3,13 @@ from __future__ import unicode_literals
|
||||||
|
|
||||||
from ..util import get_doc
|
from ..util import get_doc
|
||||||
|
|
||||||
|
|
||||||
def test_issue4104(en_vocab):
|
def test_issue4104(en_vocab):
|
||||||
"""Test that English lookup lemmatization of spun & dry are correct
|
"""Test that English lookup lemmatization of spun & dry are correct
|
||||||
expected mapping = {'dry': 'dry', 'spun': 'spin', 'spun-dry': 'spin-dry'}
|
expected mapping = {'dry': 'dry', 'spun': 'spin', 'spun-dry': 'spin-dry'}
|
||||||
"""
|
"""
|
||||||
text = 'dry spun spun-dry'
|
text = "dry spun spun-dry"
|
||||||
doc = get_doc(en_vocab, [t for t in text.split(" ")])
|
doc = get_doc(en_vocab, [t for t in text.split(" ")])
|
||||||
# using a simple list to preserve order
|
# using a simple list to preserve order
|
||||||
expected = ['dry', 'spin', 'spin-dry']
|
expected = ["dry", "spin", "spin-dry"]
|
||||||
assert [token.lemma_ for token in doc] == expected
|
assert [token.lemma_ for token in doc] == expected
|
||||||
|
|
|
@ -6,6 +6,7 @@ from spacy.gold import spans_from_biluo_tags, GoldParse
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
def test_gold_biluo_U(en_vocab):
|
def test_gold_biluo_U(en_vocab):
|
||||||
words = ["I", "flew", "to", "London", "."]
|
words = ["I", "flew", "to", "London", "."]
|
||||||
spaces = [True, True, True, False, True]
|
spaces = [True, True, True, False, True]
|
||||||
|
@ -32,14 +33,18 @@ def test_gold_biluo_BIL(en_vocab):
|
||||||
tags = biluo_tags_from_offsets(doc, entities)
|
tags = biluo_tags_from_offsets(doc, entities)
|
||||||
assert tags == ["O", "O", "O", "B-LOC", "I-LOC", "L-LOC", "O"]
|
assert tags == ["O", "O", "O", "B-LOC", "I-LOC", "L-LOC", "O"]
|
||||||
|
|
||||||
|
|
||||||
def test_gold_biluo_overlap(en_vocab):
|
def test_gold_biluo_overlap(en_vocab):
|
||||||
words = ["I", "flew", "to", "San", "Francisco", "Valley", "."]
|
words = ["I", "flew", "to", "San", "Francisco", "Valley", "."]
|
||||||
spaces = [True, True, True, True, True, False, True]
|
spaces = [True, True, True, True, True, False, True]
|
||||||
doc = Doc(en_vocab, words=words, spaces=spaces)
|
doc = Doc(en_vocab, words=words, spaces=spaces)
|
||||||
entities = [(len("I flew to "), len("I flew to San Francisco Valley"), "LOC"),
|
entities = [
|
||||||
(len("I flew to "), len("I flew to San Francisco"), "LOC")]
|
(len("I flew to "), len("I flew to San Francisco Valley"), "LOC"),
|
||||||
|
(len("I flew to "), len("I flew to San Francisco"), "LOC"),
|
||||||
|
]
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
tags = biluo_tags_from_offsets(doc, entities)
|
biluo_tags_from_offsets(doc, entities)
|
||||||
|
|
||||||
|
|
||||||
def test_gold_biluo_misalign(en_vocab):
|
def test_gold_biluo_misalign(en_vocab):
|
||||||
words = ["I", "flew", "to", "San", "Francisco", "Valley."]
|
words = ["I", "flew", "to", "San", "Francisco", "Valley."]
|
||||||
|
|
|
@ -7,67 +7,62 @@ from spacy.scorer import Scorer
|
||||||
from .util import get_doc
|
from .util import get_doc
|
||||||
|
|
||||||
test_ner_cardinal = [
|
test_ner_cardinal = [
|
||||||
[
|
["100 - 200", {"entities": [[0, 3, "CARDINAL"], [6, 9, "CARDINAL"]]}]
|
||||||
"100 - 200",
|
|
||||||
{
|
|
||||||
"entities": [
|
|
||||||
[0, 3, "CARDINAL"],
|
|
||||||
[6, 9, "CARDINAL"]
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
]
|
]
|
||||||
|
|
||||||
test_ner_apple = [
|
test_ner_apple = [
|
||||||
[
|
[
|
||||||
"Apple is looking at buying U.K. startup for $1 billion",
|
"Apple is looking at buying U.K. startup for $1 billion",
|
||||||
{
|
{"entities": [(0, 5, "ORG"), (27, 31, "GPE"), (44, 54, "MONEY")]},
|
||||||
"entities": [
|
|
||||||
(0, 5, "ORG"),
|
|
||||||
(27, 31, "GPE"),
|
|
||||||
(44, 54, "MONEY"),
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_ner_per_type(en_vocab):
|
def test_ner_per_type(en_vocab):
|
||||||
# Gold and Doc are identical
|
# Gold and Doc are identical
|
||||||
scorer = Scorer()
|
scorer = Scorer()
|
||||||
for input_, annot in test_ner_cardinal:
|
for input_, annot in test_ner_cardinal:
|
||||||
doc = get_doc(en_vocab, words = input_.split(' '), ents = [[0, 1, 'CARDINAL'], [2, 3, 'CARDINAL']])
|
doc = get_doc(
|
||||||
gold = GoldParse(doc, entities = annot['entities'])
|
en_vocab,
|
||||||
|
words=input_.split(" "),
|
||||||
|
ents=[[0, 1, "CARDINAL"], [2, 3, "CARDINAL"]],
|
||||||
|
)
|
||||||
|
gold = GoldParse(doc, entities=annot["entities"])
|
||||||
scorer.score(doc, gold)
|
scorer.score(doc, gold)
|
||||||
results = scorer.scores
|
results = scorer.scores
|
||||||
|
|
||||||
assert results['ents_p'] == 100
|
assert results["ents_p"] == 100
|
||||||
assert results['ents_f'] == 100
|
assert results["ents_f"] == 100
|
||||||
assert results['ents_r'] == 100
|
assert results["ents_r"] == 100
|
||||||
assert results['ents_per_type']['CARDINAL']['p'] == 100
|
assert results["ents_per_type"]["CARDINAL"]["p"] == 100
|
||||||
assert results['ents_per_type']['CARDINAL']['f'] == 100
|
assert results["ents_per_type"]["CARDINAL"]["f"] == 100
|
||||||
assert results['ents_per_type']['CARDINAL']['r'] == 100
|
assert results["ents_per_type"]["CARDINAL"]["r"] == 100
|
||||||
|
|
||||||
# Doc has one missing and one extra entity
|
# Doc has one missing and one extra entity
|
||||||
# Entity type MONEY is not present in Doc
|
# Entity type MONEY is not present in Doc
|
||||||
scorer = Scorer()
|
scorer = Scorer()
|
||||||
for input_, annot in test_ner_apple:
|
for input_, annot in test_ner_apple:
|
||||||
doc = get_doc(en_vocab, words = input_.split(' '), ents = [[0, 1, 'ORG'], [5, 6, 'GPE'], [6, 7, 'ORG']])
|
doc = get_doc(
|
||||||
gold = GoldParse(doc, entities = annot['entities'])
|
en_vocab,
|
||||||
|
words=input_.split(" "),
|
||||||
|
ents=[[0, 1, "ORG"], [5, 6, "GPE"], [6, 7, "ORG"]],
|
||||||
|
)
|
||||||
|
gold = GoldParse(doc, entities=annot["entities"])
|
||||||
scorer.score(doc, gold)
|
scorer.score(doc, gold)
|
||||||
results = scorer.scores
|
results = scorer.scores
|
||||||
|
|
||||||
assert results['ents_p'] == approx(66.66666)
|
assert results["ents_p"] == approx(66.66666)
|
||||||
assert results['ents_r'] == approx(66.66666)
|
assert results["ents_r"] == approx(66.66666)
|
||||||
assert results['ents_f'] == approx(66.66666)
|
assert results["ents_f"] == approx(66.66666)
|
||||||
assert 'GPE' in results['ents_per_type']
|
assert "GPE" in results["ents_per_type"]
|
||||||
assert 'MONEY' in results['ents_per_type']
|
assert "MONEY" in results["ents_per_type"]
|
||||||
assert 'ORG' in results['ents_per_type']
|
assert "ORG" in results["ents_per_type"]
|
||||||
assert results['ents_per_type']['GPE']['p'] == 100
|
assert results["ents_per_type"]["GPE"]["p"] == 100
|
||||||
assert results['ents_per_type']['GPE']['r'] == 100
|
assert results["ents_per_type"]["GPE"]["r"] == 100
|
||||||
assert results['ents_per_type']['GPE']['f'] == 100
|
assert results["ents_per_type"]["GPE"]["f"] == 100
|
||||||
assert results['ents_per_type']['MONEY']['p'] == 0
|
assert results["ents_per_type"]["MONEY"]["p"] == 0
|
||||||
assert results['ents_per_type']['MONEY']['r'] == 0
|
assert results["ents_per_type"]["MONEY"]["r"] == 0
|
||||||
assert results['ents_per_type']['MONEY']['f'] == 0
|
assert results["ents_per_type"]["MONEY"]["f"] == 0
|
||||||
assert results['ents_per_type']['ORG']['p'] == 50
|
assert results["ents_per_type"]["ORG"]["p"] == 50
|
||||||
assert results['ents_per_type']['ORG']['r'] == 100
|
assert results["ents_per_type"]["ORG"]["r"] == 100
|
||||||
assert results['ents_per_type']['ORG']['f'] == approx(66.66666)
|
assert results["ents_per_type"]["ORG"]["f"] == approx(66.66666)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user