mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Tidy up and auto-format
This commit is contained in:
parent
89f2b87266
commit
009280fbc5
16
spacy/_ml.py
16
spacy/_ml.py
|
@ -674,14 +674,14 @@ def build_nel_encoder(embed_width, hidden_width, ner_types, **cfg):
|
|||
with Model.define_operators({">>": chain, "**": clone}):
|
||||
# context encoder
|
||||
tok2vec = Tok2Vec(
|
||||
width=hidden_width,
|
||||
embed_size=embed_width,
|
||||
pretrained_vectors=pretrained_vectors,
|
||||
cnn_maxout_pieces=cnn_maxout_pieces,
|
||||
subword_features=True,
|
||||
conv_depth=conv_depth,
|
||||
bilstm_depth=0,
|
||||
)
|
||||
width=hidden_width,
|
||||
embed_size=embed_width,
|
||||
pretrained_vectors=pretrained_vectors,
|
||||
cnn_maxout_pieces=cnn_maxout_pieces,
|
||||
subword_features=True,
|
||||
conv_depth=conv_depth,
|
||||
bilstm_depth=0,
|
||||
)
|
||||
|
||||
model = (
|
||||
tok2vec
|
||||
|
|
|
@ -8,7 +8,7 @@ import sys
|
|||
import srsly
|
||||
from wasabi import Printer, MESSAGES
|
||||
|
||||
from ..gold import GoldCorpus, read_json_object
|
||||
from ..gold import GoldCorpus
|
||||
from ..syntax import nonproj
|
||||
from ..util import load_model, get_lang_class
|
||||
|
||||
|
@ -95,13 +95,19 @@ def debug_data(
|
|||
corpus = GoldCorpus(train_path, dev_path)
|
||||
try:
|
||||
train_docs = list(corpus.train_docs(nlp))
|
||||
train_docs_unpreprocessed = list(corpus.train_docs_without_preprocessing(nlp))
|
||||
train_docs_unpreprocessed = list(
|
||||
corpus.train_docs_without_preprocessing(nlp)
|
||||
)
|
||||
except ValueError as e:
|
||||
loading_train_error_message = "Training data cannot be loaded: {}".format(str(e))
|
||||
loading_train_error_message = "Training data cannot be loaded: {}".format(
|
||||
str(e)
|
||||
)
|
||||
try:
|
||||
dev_docs = list(corpus.dev_docs(nlp))
|
||||
except ValueError as e:
|
||||
loading_dev_error_message = "Development data cannot be loaded: {}".format(str(e))
|
||||
loading_dev_error_message = "Development data cannot be loaded: {}".format(
|
||||
str(e)
|
||||
)
|
||||
if loading_train_error_message or loading_dev_error_message:
|
||||
if loading_train_error_message:
|
||||
msg.fail(loading_train_error_message)
|
||||
|
@ -158,11 +164,15 @@ def debug_data(
|
|||
)
|
||||
if gold_train_data["n_misaligned_words"] > 0:
|
||||
msg.warn(
|
||||
"{} misaligned tokens in the training data".format(gold_train_data["n_misaligned_words"])
|
||||
"{} misaligned tokens in the training data".format(
|
||||
gold_train_data["n_misaligned_words"]
|
||||
)
|
||||
)
|
||||
if gold_dev_data["n_misaligned_words"] > 0:
|
||||
msg.warn(
|
||||
"{} misaligned tokens in the dev data".format(gold_dev_data["n_misaligned_words"])
|
||||
"{} misaligned tokens in the dev data".format(
|
||||
gold_dev_data["n_misaligned_words"]
|
||||
)
|
||||
)
|
||||
most_common_words = gold_train_data["words"].most_common(10)
|
||||
msg.text(
|
||||
|
@ -184,7 +194,9 @@ def debug_data(
|
|||
|
||||
if "ner" in pipeline:
|
||||
# Get all unique NER labels present in the data
|
||||
labels = set(label for label in gold_train_data["ner"] if label not in ("O", "-"))
|
||||
labels = set(
|
||||
label for label in gold_train_data["ner"] if label not in ("O", "-")
|
||||
)
|
||||
label_counts = gold_train_data["ner"]
|
||||
model_labels = _get_labels_from_model(nlp, "ner")
|
||||
new_labels = [l for l in labels if l not in model_labels]
|
||||
|
@ -222,7 +234,9 @@ def debug_data(
|
|||
)
|
||||
|
||||
if gold_train_data["ws_ents"]:
|
||||
msg.fail("{} invalid whitespace entity spans".format(gold_train_data["ws_ents"]))
|
||||
msg.fail(
|
||||
"{} invalid whitespace entity spans".format(gold_train_data["ws_ents"])
|
||||
)
|
||||
has_ws_ents_error = True
|
||||
|
||||
for label in new_labels:
|
||||
|
@ -323,33 +337,36 @@ def debug_data(
|
|||
"Found {} sentence{} with an average length of {:.1f} words.".format(
|
||||
gold_train_data["n_sents"],
|
||||
"s" if len(train_docs) > 1 else "",
|
||||
gold_train_data["n_words"] / gold_train_data["n_sents"]
|
||||
gold_train_data["n_words"] / gold_train_data["n_sents"],
|
||||
)
|
||||
)
|
||||
|
||||
# profile labels
|
||||
labels_train = [label for label in gold_train_data["deps"]]
|
||||
labels_train_unpreprocessed = [label for label in gold_train_unpreprocessed_data["deps"]]
|
||||
labels_train_unpreprocessed = [
|
||||
label for label in gold_train_unpreprocessed_data["deps"]
|
||||
]
|
||||
labels_dev = [label for label in gold_dev_data["deps"]]
|
||||
|
||||
if gold_train_unpreprocessed_data["n_nonproj"] > 0:
|
||||
msg.info(
|
||||
"Found {} nonprojective train sentence{}".format(
|
||||
gold_train_unpreprocessed_data["n_nonproj"],
|
||||
"s" if gold_train_unpreprocessed_data["n_nonproj"] > 1 else ""
|
||||
"s" if gold_train_unpreprocessed_data["n_nonproj"] > 1 else "",
|
||||
)
|
||||
)
|
||||
if gold_dev_data["n_nonproj"] > 0:
|
||||
msg.info(
|
||||
"Found {} nonprojective dev sentence{}".format(
|
||||
gold_dev_data["n_nonproj"],
|
||||
"s" if gold_dev_data["n_nonproj"] > 1 else ""
|
||||
"s" if gold_dev_data["n_nonproj"] > 1 else "",
|
||||
)
|
||||
)
|
||||
|
||||
msg.info(
|
||||
"{} {} in train data".format(
|
||||
len(labels_train_unpreprocessed), "label" if len(labels_train) == 1 else "labels"
|
||||
len(labels_train_unpreprocessed),
|
||||
"label" if len(labels_train) == 1 else "labels",
|
||||
)
|
||||
)
|
||||
msg.info(
|
||||
|
@ -373,43 +390,45 @@ def debug_data(
|
|||
)
|
||||
has_low_data_warning = True
|
||||
|
||||
|
||||
# rare labels in projectivized train
|
||||
rare_projectivized_labels = []
|
||||
for label in gold_train_data["deps"]:
|
||||
if gold_train_data["deps"][label] <= DEP_LABEL_THRESHOLD and "||" in label:
|
||||
rare_projectivized_labels.append("{}: {}".format(label, str(gold_train_data["deps"][label])))
|
||||
rare_projectivized_labels.append(
|
||||
"{}: {}".format(label, str(gold_train_data["deps"][label]))
|
||||
)
|
||||
|
||||
if len(rare_projectivized_labels) > 0:
|
||||
msg.warn(
|
||||
"Low number of examples for {} label{} in the "
|
||||
"projectivized dependency trees used for training. You may "
|
||||
"want to projectivize labels such as punct before "
|
||||
"training in order to improve parser performance.".format(
|
||||
len(rare_projectivized_labels),
|
||||
"s" if len(rare_projectivized_labels) > 1 else "")
|
||||
msg.warn(
|
||||
"Low number of examples for {} label{} in the "
|
||||
"projectivized dependency trees used for training. You may "
|
||||
"want to projectivize labels such as punct before "
|
||||
"training in order to improve parser performance.".format(
|
||||
len(rare_projectivized_labels),
|
||||
"s" if len(rare_projectivized_labels) > 1 else "",
|
||||
)
|
||||
msg.warn(
|
||||
"Projectivized labels with low numbers of examples: "
|
||||
"{}".format("\n".join(rare_projectivized_labels)),
|
||||
show=verbose
|
||||
)
|
||||
has_low_data_warning = True
|
||||
)
|
||||
msg.warn(
|
||||
"Projectivized labels with low numbers of examples: "
|
||||
"{}".format("\n".join(rare_projectivized_labels)),
|
||||
show=verbose,
|
||||
)
|
||||
has_low_data_warning = True
|
||||
|
||||
# labels only in train
|
||||
if set(labels_train) - set(labels_dev):
|
||||
msg.warn(
|
||||
"The following labels were found only in the train data: "
|
||||
"{}".format(", ".join(set(labels_train) - set(labels_dev))),
|
||||
show=verbose
|
||||
show=verbose,
|
||||
)
|
||||
|
||||
# labels only in dev
|
||||
if set(labels_dev) - set(labels_train):
|
||||
msg.warn(
|
||||
"The following labels were found only in the dev data: " +
|
||||
", ".join(set(labels_dev) - set(labels_train)),
|
||||
show=verbose
|
||||
"The following labels were found only in the dev data: "
|
||||
+ ", ".join(set(labels_dev) - set(labels_train)),
|
||||
show=verbose,
|
||||
)
|
||||
|
||||
if has_low_data_warning:
|
||||
|
@ -422,8 +441,10 @@ def debug_data(
|
|||
# multiple root labels
|
||||
if len(gold_train_unpreprocessed_data["roots"]) > 1:
|
||||
msg.warn(
|
||||
"Multiple root labels ({}) ".format(", ".join(gold_train_unpreprocessed_data["roots"])) +
|
||||
"found in training data. spaCy's parser uses a single root "
|
||||
"Multiple root labels ({}) ".format(
|
||||
", ".join(gold_train_unpreprocessed_data["roots"])
|
||||
)
|
||||
+ "found in training data. spaCy's parser uses a single root "
|
||||
"label ROOT so this distinction will not be available."
|
||||
)
|
||||
|
||||
|
@ -432,14 +453,14 @@ def debug_data(
|
|||
msg.fail(
|
||||
"Found {} nonprojective projectivized train sentence{}".format(
|
||||
gold_train_data["n_nonproj"],
|
||||
"s" if gold_train_data["n_nonproj"] > 1 else ""
|
||||
"s" if gold_train_data["n_nonproj"] > 1 else "",
|
||||
)
|
||||
)
|
||||
if gold_train_data["n_cycles"] > 0:
|
||||
msg.fail(
|
||||
"Found {} projectivized train sentence{} with cycles".format(
|
||||
gold_train_data["n_cycles"],
|
||||
"s" if gold_train_data["n_cycles"] > 1 else ""
|
||||
"s" if gold_train_data["n_cycles"] > 1 else "",
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -114,7 +114,7 @@ def read_attrs_from_deprecated(freqs_loc, clusters_loc):
|
|||
probs, _ = read_freqs(freqs_loc)
|
||||
msg.good("Counted frequencies")
|
||||
else:
|
||||
probs, _ = ({}, DEFAULT_OOV_PROB)
|
||||
probs, _ = ({}, DEFAULT_OOV_PROB) # noqa: F841
|
||||
if clusters_loc:
|
||||
with msg.loading("Reading clusters..."):
|
||||
clusters = read_clusters(clusters_loc)
|
||||
|
|
|
@ -429,6 +429,7 @@ class Errors(object):
|
|||
E155 = ("The `nlp` object should have access to pre-trained word vectors, cf. "
|
||||
"https://spacy.io/usage/models#languages.")
|
||||
|
||||
|
||||
@add_codes
|
||||
class TempErrors(object):
|
||||
T003 = ("Resizing pre-trained Tagger models is not currently supported.")
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
# encoding: utf8
|
||||
from __future__ import unicode_literals, print_function
|
||||
|
||||
import re
|
||||
import sys
|
||||
|
||||
|
||||
from .stop_words import STOP_WORDS
|
||||
from .tag_map import TAG_MAP
|
||||
from ...attrs import LANG
|
||||
|
|
|
@ -8,6 +8,7 @@ from ..tokenizer_exceptions import BASE_EXCEPTIONS
|
|||
from .stop_words import STOP_WORDS
|
||||
from .tag_map import TAG_MAP
|
||||
|
||||
|
||||
class ChineseDefaults(Language.Defaults):
|
||||
lex_attr_getters = dict(Language.Defaults.lex_attr_getters)
|
||||
lex_attr_getters[LANG] = lambda text: "zh"
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from ...symbols import POS, PUNCT, SYM, ADJ, CONJ, CCONJ, NUM, DET, ADV, ADP, X, VERB
|
||||
from ...symbols import NOUN, PROPN, PART, INTJ, SPACE, PRON, AUX
|
||||
from ...symbols import POS, PUNCT, ADJ, CONJ, CCONJ, NUM, DET, ADV, ADP, X, VERB
|
||||
from ...symbols import NOUN, PART, INTJ, PRON
|
||||
|
||||
# The Chinese part-of-speech tagger uses the OntoNotes 5 version of the Penn Treebank tag set.
|
||||
# We also map the tags to the simpler Google Universal POS tag set.
|
||||
|
@ -43,5 +43,5 @@ TAG_MAP = {
|
|||
"JJ": {POS: ADJ},
|
||||
"P": {POS: ADP},
|
||||
"PN": {POS: PRON},
|
||||
"PU": {POS: PUNCT}
|
||||
"PU": {POS: PUNCT},
|
||||
}
|
|
@ -160,14 +160,15 @@ class Scorer(object):
|
|||
cand_deps.add((gold_i, gold_head, token.dep_.lower()))
|
||||
if "-" not in [token[-1] for token in gold.orig_annot]:
|
||||
# Find all NER labels in gold and doc
|
||||
ent_labels = set([x[0] for x in gold_ents]
|
||||
+ [k.label_ for k in doc.ents])
|
||||
ent_labels = set([x[0] for x in gold_ents] + [k.label_ for k in doc.ents])
|
||||
# Set up all labels for per type scoring and prepare gold per type
|
||||
gold_per_ents = {ent_label: set() for ent_label in ent_labels}
|
||||
for ent_label in ent_labels:
|
||||
if ent_label not in self.ner_per_ents:
|
||||
self.ner_per_ents[ent_label] = PRFScore()
|
||||
gold_per_ents[ent_label].update([x for x in gold_ents if x[0] == ent_label])
|
||||
gold_per_ents[ent_label].update(
|
||||
[x for x in gold_ents if x[0] == ent_label]
|
||||
)
|
||||
# Find all candidate labels, for all and per type
|
||||
cand_ents = set()
|
||||
cand_per_ents = {ent_label: set() for ent_label in ent_labels}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
from spacy.matcher import PhraseMatcher
|
||||
from spacy.tokens import Doc
|
||||
|
||||
|
|
|
@ -3,12 +3,13 @@ from __future__ import unicode_literals
|
|||
|
||||
from ..util import get_doc
|
||||
|
||||
|
||||
def test_issue4104(en_vocab):
|
||||
"""Test that English lookup lemmatization of spun & dry are correct
|
||||
expected mapping = {'dry': 'dry', 'spun': 'spin', 'spun-dry': 'spin-dry'}
|
||||
"""
|
||||
text = 'dry spun spun-dry'
|
||||
"""
|
||||
text = "dry spun spun-dry"
|
||||
doc = get_doc(en_vocab, [t for t in text.split(" ")])
|
||||
# using a simple list to preserve order
|
||||
expected = ['dry', 'spin', 'spin-dry']
|
||||
expected = ["dry", "spin", "spin-dry"]
|
||||
assert [token.lemma_ for token in doc] == expected
|
||||
|
|
|
@ -6,6 +6,7 @@ from spacy.gold import spans_from_biluo_tags, GoldParse
|
|||
from spacy.tokens import Doc
|
||||
import pytest
|
||||
|
||||
|
||||
def test_gold_biluo_U(en_vocab):
|
||||
words = ["I", "flew", "to", "London", "."]
|
||||
spaces = [True, True, True, False, True]
|
||||
|
@ -32,14 +33,18 @@ def test_gold_biluo_BIL(en_vocab):
|
|||
tags = biluo_tags_from_offsets(doc, entities)
|
||||
assert tags == ["O", "O", "O", "B-LOC", "I-LOC", "L-LOC", "O"]
|
||||
|
||||
|
||||
def test_gold_biluo_overlap(en_vocab):
|
||||
words = ["I", "flew", "to", "San", "Francisco", "Valley", "."]
|
||||
spaces = [True, True, True, True, True, False, True]
|
||||
doc = Doc(en_vocab, words=words, spaces=spaces)
|
||||
entities = [(len("I flew to "), len("I flew to San Francisco Valley"), "LOC"),
|
||||
(len("I flew to "), len("I flew to San Francisco"), "LOC")]
|
||||
entities = [
|
||||
(len("I flew to "), len("I flew to San Francisco Valley"), "LOC"),
|
||||
(len("I flew to "), len("I flew to San Francisco"), "LOC"),
|
||||
]
|
||||
with pytest.raises(ValueError):
|
||||
tags = biluo_tags_from_offsets(doc, entities)
|
||||
biluo_tags_from_offsets(doc, entities)
|
||||
|
||||
|
||||
def test_gold_biluo_misalign(en_vocab):
|
||||
words = ["I", "flew", "to", "San", "Francisco", "Valley."]
|
||||
|
|
|
@ -7,67 +7,62 @@ from spacy.scorer import Scorer
|
|||
from .util import get_doc
|
||||
|
||||
test_ner_cardinal = [
|
||||
[
|
||||
"100 - 200",
|
||||
{
|
||||
"entities": [
|
||||
[0, 3, "CARDINAL"],
|
||||
[6, 9, "CARDINAL"]
|
||||
]
|
||||
}
|
||||
]
|
||||
["100 - 200", {"entities": [[0, 3, "CARDINAL"], [6, 9, "CARDINAL"]]}]
|
||||
]
|
||||
|
||||
test_ner_apple = [
|
||||
[
|
||||
"Apple is looking at buying U.K. startup for $1 billion",
|
||||
{
|
||||
"entities": [
|
||||
(0, 5, "ORG"),
|
||||
(27, 31, "GPE"),
|
||||
(44, 54, "MONEY"),
|
||||
]
|
||||
}
|
||||
{"entities": [(0, 5, "ORG"), (27, 31, "GPE"), (44, 54, "MONEY")]},
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
def test_ner_per_type(en_vocab):
|
||||
# Gold and Doc are identical
|
||||
scorer = Scorer()
|
||||
for input_, annot in test_ner_cardinal:
|
||||
doc = get_doc(en_vocab, words = input_.split(' '), ents = [[0, 1, 'CARDINAL'], [2, 3, 'CARDINAL']])
|
||||
gold = GoldParse(doc, entities = annot['entities'])
|
||||
doc = get_doc(
|
||||
en_vocab,
|
||||
words=input_.split(" "),
|
||||
ents=[[0, 1, "CARDINAL"], [2, 3, "CARDINAL"]],
|
||||
)
|
||||
gold = GoldParse(doc, entities=annot["entities"])
|
||||
scorer.score(doc, gold)
|
||||
results = scorer.scores
|
||||
|
||||
assert results['ents_p'] == 100
|
||||
assert results['ents_f'] == 100
|
||||
assert results['ents_r'] == 100
|
||||
assert results['ents_per_type']['CARDINAL']['p'] == 100
|
||||
assert results['ents_per_type']['CARDINAL']['f'] == 100
|
||||
assert results['ents_per_type']['CARDINAL']['r'] == 100
|
||||
assert results["ents_p"] == 100
|
||||
assert results["ents_f"] == 100
|
||||
assert results["ents_r"] == 100
|
||||
assert results["ents_per_type"]["CARDINAL"]["p"] == 100
|
||||
assert results["ents_per_type"]["CARDINAL"]["f"] == 100
|
||||
assert results["ents_per_type"]["CARDINAL"]["r"] == 100
|
||||
|
||||
# Doc has one missing and one extra entity
|
||||
# Entity type MONEY is not present in Doc
|
||||
scorer = Scorer()
|
||||
for input_, annot in test_ner_apple:
|
||||
doc = get_doc(en_vocab, words = input_.split(' '), ents = [[0, 1, 'ORG'], [5, 6, 'GPE'], [6, 7, 'ORG']])
|
||||
gold = GoldParse(doc, entities = annot['entities'])
|
||||
doc = get_doc(
|
||||
en_vocab,
|
||||
words=input_.split(" "),
|
||||
ents=[[0, 1, "ORG"], [5, 6, "GPE"], [6, 7, "ORG"]],
|
||||
)
|
||||
gold = GoldParse(doc, entities=annot["entities"])
|
||||
scorer.score(doc, gold)
|
||||
results = scorer.scores
|
||||
|
||||
assert results['ents_p'] == approx(66.66666)
|
||||
assert results['ents_r'] == approx(66.66666)
|
||||
assert results['ents_f'] == approx(66.66666)
|
||||
assert 'GPE' in results['ents_per_type']
|
||||
assert 'MONEY' in results['ents_per_type']
|
||||
assert 'ORG' in results['ents_per_type']
|
||||
assert results['ents_per_type']['GPE']['p'] == 100
|
||||
assert results['ents_per_type']['GPE']['r'] == 100
|
||||
assert results['ents_per_type']['GPE']['f'] == 100
|
||||
assert results['ents_per_type']['MONEY']['p'] == 0
|
||||
assert results['ents_per_type']['MONEY']['r'] == 0
|
||||
assert results['ents_per_type']['MONEY']['f'] == 0
|
||||
assert results['ents_per_type']['ORG']['p'] == 50
|
||||
assert results['ents_per_type']['ORG']['r'] == 100
|
||||
assert results['ents_per_type']['ORG']['f'] == approx(66.66666)
|
||||
assert results["ents_p"] == approx(66.66666)
|
||||
assert results["ents_r"] == approx(66.66666)
|
||||
assert results["ents_f"] == approx(66.66666)
|
||||
assert "GPE" in results["ents_per_type"]
|
||||
assert "MONEY" in results["ents_per_type"]
|
||||
assert "ORG" in results["ents_per_type"]
|
||||
assert results["ents_per_type"]["GPE"]["p"] == 100
|
||||
assert results["ents_per_type"]["GPE"]["r"] == 100
|
||||
assert results["ents_per_type"]["GPE"]["f"] == 100
|
||||
assert results["ents_per_type"]["MONEY"]["p"] == 0
|
||||
assert results["ents_per_type"]["MONEY"]["r"] == 0
|
||||
assert results["ents_per_type"]["MONEY"]["f"] == 0
|
||||
assert results["ents_per_type"]["ORG"]["p"] == 50
|
||||
assert results["ents_per_type"]["ORG"]["r"] == 100
|
||||
assert results["ents_per_type"]["ORG"]["f"] == approx(66.66666)
|
||||
|
|
Loading…
Reference in New Issue
Block a user