small fixes

This commit is contained in:
svlandeg 2020-06-22 10:05:12 +02:00
parent 6a75992af6
commit 0d64c435b0
7 changed files with 18 additions and 20 deletions

View File

@ -30,7 +30,7 @@ ENTITIES = {"Q2146908": ("American golfer", 342), "Q7381115": ("publisher", 17)}
model=("Model name, should have pretrained word embeddings", "positional", None, str), model=("Model name, should have pretrained word embeddings", "positional", None, str),
output_dir=("Optional output directory", "option", "o", Path), output_dir=("Optional output directory", "option", "o", Path),
) )
def main(model=None, output_dir=None): def main(model, output_dir=None):
"""Load the model and create the KB with pre-defined entity encodings. """Load the model and create the KB with pre-defined entity encodings.
If an output_dir is provided, the KB will be stored there in a file 'kb'. If an output_dir is provided, the KB will be stored there in a file 'kb'.
The updated vocab will also be written to a directory in the output_dir.""" The updated vocab will also be written to a directory in the output_dir."""

View File

@ -14,11 +14,11 @@ class Corpus:
""" """
def __init__(self, train_loc, dev_loc, limit=0): def __init__(self, train_loc, dev_loc, limit=0):
"""Create a GoldCorpus. """Create a Corpus.
train (str / Path): File or directory of training data. train (str / Path): File or directory of training data.
dev (str / Path): File or directory of development data. dev (str / Path): File or directory of development data.
RETURNS (GoldCorpus): The newly created object. RETURNS (Corpus): The newly created object.
""" """
self.train_loc = train_loc self.train_loc = train_loc
self.dev_loc = dev_loc self.dev_loc = dev_loc

View File

@ -1,3 +1,5 @@
import warnings
import numpy import numpy
from ..tokens import Token from ..tokens import Token
@ -204,24 +206,23 @@ def _annot2array(vocab, tok_annot, doc_annot):
values = [] values = []
for key, value in doc_annot.items(): for key, value in doc_annot.items():
if key == "entities": if value:
if value: if key == "entities":
words = tok_annot["ORTH"] words = tok_annot["ORTH"]
spaces = tok_annot["SPACY"] spaces = tok_annot["SPACY"]
ent_iobs, ent_types = _parse_ner_tags(value, vocab, words, spaces) ent_iobs, ent_types = _parse_ner_tags(value, vocab, words, spaces)
tok_annot["ENT_IOB"] = ent_iobs tok_annot["ENT_IOB"] = ent_iobs
tok_annot["ENT_TYPE"] = ent_types tok_annot["ENT_TYPE"] = ent_types
elif key == "links": elif key == "links":
if value:
entities = doc_annot.get("entities", {}) entities = doc_annot.get("entities", {})
if value and not entities: if value and not entities:
raise ValueError(Errors.E981) raise ValueError(Errors.E981)
ent_kb_ids = _parse_links(vocab, tok_annot["ORTH"], value, entities) ent_kb_ids = _parse_links(vocab, tok_annot["ORTH"], value, entities)
tok_annot["ENT_KB_ID"] = ent_kb_ids tok_annot["ENT_KB_ID"] = ent_kb_ids
elif key == "cats": elif key == "cats":
pass pass
else: else:
raise ValueError(f"Unknown doc attribute: {key}") raise ValueError(f"Unknown doc attribute: {key}")
for key, value in tok_annot.items(): for key, value in tok_annot.items():
if key not in IDS: if key not in IDS:
@ -298,6 +299,7 @@ def _fix_legacy_dict_data(example_dict):
if "HEAD" in token_dict and "SENT_START" in token_dict: if "HEAD" in token_dict and "SENT_START" in token_dict:
# If heads are set, we don't also redundantly specify SENT_START. # If heads are set, we don't also redundantly specify SENT_START.
token_dict.pop("SENT_START") token_dict.pop("SENT_START")
warnings.warn("Ignoring annotations for sentence starts, as dependency heads are set")
return { return {
"token_annotation": token_dict, "token_annotation": token_dict,
"doc_annotation": doc_dict "doc_annotation": doc_dict

View File

@ -48,9 +48,7 @@ def build_masked_language_model(vocab, wrapped_model, mask_prob=0.15):
def mlm_forward(model, docs, is_train): def mlm_forward(model, docs, is_train):
mask, docs = _apply_mask(docs, random_words, mask_prob=mask_prob) mask, docs = _apply_mask(docs, random_words, mask_prob=mask_prob)
mask = model.ops.asarray(mask).reshape((mask.shape[0], 1)) mask = model.ops.asarray(mask).reshape((mask.shape[0], 1))
output, backprop = model.get_ref("wrapped-model").begin_update( output, backprop = model.get_ref("wrapped-model").begin_update(docs)
docs
) # drop=drop
def mlm_backward(d_output): def mlm_backward(d_output):
d_output *= 1 - mask d_output *= 1 - mask

View File

@ -147,7 +147,7 @@ def hash_char_embed_bilstm_v1(
@registry.architectures.register("spacy.LayerNormalizedMaxout.v1") @registry.architectures.register("spacy.LayerNormalizedMaxout.v1")
def LayerNormalizedMaxout(width, maxout_pieces): def LayerNormalizedMaxout(width, maxout_pieces):
return Maxout(nO=width, nP=maxout_pieces, dropout=0.0, normalize=True,) return Maxout(nO=width, nP=maxout_pieces, dropout=0.0, normalize=True)
@registry.architectures.register("spacy.MultiHashEmbed.v1") @registry.architectures.register("spacy.MultiHashEmbed.v1")

View File

@ -7,10 +7,10 @@ from spacy.pipeline.defaults import default_ner
from spacy.pipeline import EntityRecognizer, EntityRuler from spacy.pipeline import EntityRecognizer, EntityRuler
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.syntax.ner import BiluoPushDown from spacy.syntax.ner import BiluoPushDown
from spacy.gold import Example
from spacy.tokens import Doc from spacy.tokens import Doc
from ..util import make_tempdir from ..util import make_tempdir
from ...gold import Example
TRAIN_DATA = [ TRAIN_DATA = [
("Who is Shaka Khan?", {"entities": [(7, 17, "PERSON")]}), ("Who is Shaka Khan?", {"entities": [(7, 17, "PERSON")]}),

View File

@ -596,8 +596,6 @@ def test_split_sents(merged_dict):
assert token_annotation_2["sent_starts"] == [1, 0, 0, 0] assert token_annotation_2["sent_starts"] == [1, 0, 0, 0]
# This fails on some None value? Need to look into that.
@pytest.mark.xfail # TODO
def test_tuples_to_example(vocab, merged_dict): def test_tuples_to_example(vocab, merged_dict):
cats = {"TRAVEL": 1.0, "BAKING": 0.0} cats = {"TRAVEL": 1.0, "BAKING": 0.0}
merged_dict = dict(merged_dict) merged_dict = dict(merged_dict)
@ -607,6 +605,6 @@ def test_tuples_to_example(vocab, merged_dict):
assert words == merged_dict["words"] assert words == merged_dict["words"]
tags = [token.tag_ for token in ex.reference] tags = [token.tag_ for token in ex.reference]
assert tags == merged_dict["tags"] assert tags == merged_dict["tags"]
sent_starts = [token.is_sent_start for token in ex.reference] sent_starts = [bool(token.is_sent_start) for token in ex.reference]
assert sent_starts == [bool(v) for v in merged_dict["sent_starts"]] assert sent_starts == [bool(v) for v in merged_dict["sent_starts"]]
ex.reference.cats == cats assert ex.reference.cats == cats