Merge branch 'whatif/arrow' of https://github.com/explosion/spaCy into whatif/arrow

This commit is contained in:
Matthew Honnibal 2020-06-19 02:30:27 +02:00
commit a389866df6
12 changed files with 104 additions and 63 deletions

View File

@ -4,8 +4,10 @@ import random
import warnings
import srsly
import spacy
from spacy.gold import Example
from spacy.util import minibatch, compounding
# TODO: further fix & test this script for v.3 ? (read_gold_data is never called)
LABEL = "ANIMAL"
TRAIN_DATA = [
@ -35,15 +37,13 @@ def read_raw_data(nlp, jsonl_loc):
def read_gold_data(nlp, gold_loc):
docs = []
golds = []
examples = []
for json_obj in srsly.read_jsonl(gold_loc):
doc = nlp.make_doc(json_obj["text"])
ents = [(ent["start"], ent["end"], ent["label"]) for ent in json_obj["spans"]]
gold = GoldParse(doc, entities=ents)
docs.append(doc)
golds.append(gold)
return list(zip(docs, golds))
example = Example.from_dict(doc, {"entities": ents})
examples.append(example)
return examples
def main(model_name, unlabelled_loc):

View File

@ -62,11 +62,10 @@ def main(config_path, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=Non
train_examples = []
for text, cats in zip(train_texts, train_cats):
doc = nlp.make_doc(text)
gold = GoldParse(doc, cats=cats)
example = Example.from_dict(doc, {"cats": cats})
for cat in cats:
textcat.add_label(cat)
ex = Example.from_gold(gold, doc=doc)
train_examples.append(ex)
train_examples.append(example)
with nlp.select_pipes(enable="textcat"): # only train textcat
optimizer = nlp.begin_training()

View File

@ -231,8 +231,8 @@ def train(
# check whether the setting 'exclusive_classes' corresponds to the provided training data
if textcat_multilabel:
multilabel_found = False
for ex in corpus.train_examples:
cats = ex.doc_annotation.cats
for eg in corpus.train_annotations:
cats = eg.reference.cats
textcat_labels.update(cats.keys())
if list(cats.values()).count(1.0) != 1:
multilabel_found = True
@ -244,8 +244,8 @@ def train(
"mutually exclusive classes more accurately."
)
else:
for ex in corpus.train_examples:
cats = ex.doc_annotation.cats
for eg in corpus.train_annotations:
cats = eg.reference.cats
textcat_labels.update(cats.keys())
if list(cats.values()).count(1.0) != 1:
msg.fail(
@ -346,10 +346,8 @@ def train(
progress = tqdm.tqdm(total=training["eval_frequency"], leave=False)
# Clean up the objects to faciliate garbage collection.
for eg in batch:
eg.doc = None
eg.goldparse = None
eg.doc_annotation = None
eg.token_annotation = None
eg.reference = None
eg.predicted = None
except Exception as e:
msg.warn(
f"Aborting and saving the final best model. "
@ -469,7 +467,7 @@ def train_while_improving(
Every iteration, the function yields out a tuple with:
* batch: A zipped sequence of Tuple[Doc, GoldParse] pairs.
* batch: A list of Example objects.
* info: A dict with various information about the last update (see below).
* is_best_checkpoint: A value in None, False, True, indicating whether this
was the best evaluation so far. You should use this to save the model

View File

@ -72,7 +72,7 @@ class GoldCorpus(object):
@staticmethod
def read_annotations(locs, limit=0):
""" Yield training examples """
""" Yield training examples as example dicts """
i = 0
for loc in locs:
loc = util.ensure_path(loc)

View File

@ -117,7 +117,7 @@ cdef class Example:
i = j2i_multi[j]
if output[i] is None:
output[i] = gold_values[j]
if as_string and field not in ["ENT_IOB"]:
if as_string and field not in ["ENT_IOB", "SENT_START"]:
output = [vocab.strings[o] if o is not None else o for o in output]
return output
@ -146,22 +146,19 @@ cdef class Example:
sent_starts and return a list of the new Examples"""
if not self.reference.is_sentenced:
return [self]
# TODO: Do this for misaligned somehow?
predicted_words = [t.text for t in self.predicted]
reference_words = [t.text for t in self.reference]
if predicted_words != reference_words:
raise NotImplementedError("TODO: Implement this")
# Implement the easy case.
sent_starts = self.get_aligned("SENT_START")
sent_starts.append(1) # appending virtual start of a next sentence to facilitate search
output = []
cls = self.__class__
pred_start = 0
for sent in self.reference.sents:
# I guess for misaligned we just need to use the gold_to_cand?
output.append(
cls(
self.predicted[sent.start : sent.end + 1].as_doc(),
sent.as_doc()
)
)
new_ref = sent.as_doc()
pred_end = sent_starts.index(1, pred_start+1) # find where the next sentence starts
new_pred = self.predicted[pred_start : pred_end].as_doc()
output.append(Example(new_pred, new_ref))
pred_start = pred_end
return output
property text:

View File

@ -108,12 +108,18 @@ def json_to_annotations(doc):
words.append(token["orth"])
spaces.append(token.get("space", True))
ids.append(token.get('id', sent_start_i + i))
tags.append(token.get('tag', "-"))
pos.append(token.get("pos", ""))
morphs.append(token.get("morph", ""))
lemmas.append(token.get("lemma", ""))
heads.append(token.get("head", 0) + sent_start_i + i)
labels.append(token.get("dep", ""))
if "tag" in token:
tags.append(token["tag"])
if "pos" in token:
pos.append(token["pos"])
if "morph" in token:
morphs.append(token["morph"])
if "lemma" in token:
lemmas.append(token["lemma"])
if "head" in token:
heads.append(token["head"] + sent_start_i + i)
if "dep" in token:
labels.append(token["dep"])
# Ensure ROOT label is case-insensitive
if labels[-1].lower() == "root":
labels[-1] = "ROOT"
@ -130,15 +136,24 @@ def json_to_annotations(doc):
ids=ids,
words=words,
spaces=spaces,
tags=tags,
pos=pos,
morphs=morphs,
lemmas=lemmas,
heads=heads,
deps=labels,
sent_starts=sent_starts,
brackets=brackets
)
# avoid including dummy values that looks like gold info was present
if tags:
example["token_annotation"]["tags"] = tags
if pos:
example["token_annotation"]["pos"] = pos
if morphs:
example["token_annotation"]["morphs"] = morphs
if lemmas:
example["token_annotation"]["lemmas"] = lemmas
if heads:
example["token_annotation"]["heads"] = heads
if labels:
example["token_annotation"]["deps"] = labels
if pos:
example["token_annotation"]["pos"] = pos
cats = {}
for cat in paragraph.get("cats", {}):

View File

@ -143,8 +143,7 @@ def _has_ner(eg):
def _get_labels(examples):
labels = set()
for eg in examples:
for ner_tag in eg.token_annotation.entities:
for ner_tag in eg.get_aligned("ENT_TYPE", as_string=True):
if ner_tag != 'O' and ner_tag != '-':
_, label = ner_tag.split('-', 1)
labels.add(label)
labels.add(ner_tag)
return list(sorted(labels))

View File

@ -59,10 +59,10 @@ class Tok2Vec(Pipe):
YIELDS (iterator): A sequence of `Doc` objects, in order of input.
"""
for docs in minibatch(stream, batch_size):
batch = list(batch)
docs = list(docs)
tokvecses = self.predict(docs)
self.set_annotations(docs, tokvecses)
yield from batch
yield from docs
def predict(self, docs):
"""Return a single tensor for a batch of documents.

View File

@ -11,6 +11,7 @@ from spacy.util import fix_random_seed
from ..util import make_tempdir
from spacy.pipeline.defaults import default_tok2vec
from ...gold import Example
TRAIN_DATA = [
("I'm so happy.", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}),
@ -50,21 +51,20 @@ def test_textcat_learns_multilabel():
cats = {letter: float(w2 == letter) for letter in letters}
docs.append((Doc(nlp.vocab, words=["d"] * 3 + [w1, w2] + ["d"] * 3), cats))
random.shuffle(docs)
model = TextCategorizer(nlp.vocab, width=8)
textcat = TextCategorizer(nlp.vocab, width=8)
for letter in letters:
model.add_label(letter)
optimizer = model.begin_training()
textcat.add_label(letter)
optimizer = textcat.begin_training()
for i in range(30):
losses = {}
Ys = [GoldParse(doc, cats=cats) for doc, cats in docs]
Xs = [doc for doc, cats in docs]
model.update(Xs, Ys, sgd=optimizer, losses=losses)
examples = [Example.from_dict(doc, {"cats": cats}) for doc, cat in docs]
textcat.update(examples, sgd=optimizer, losses=losses)
random.shuffle(docs)
for w1 in letters:
for w2 in letters:
doc = Doc(nlp.vocab, words=["d"] * 3 + [w1, w2] + ["d"] * 3)
truth = {letter: w2 == letter for letter in letters}
model(doc)
textcat(doc)
for cat, score in doc.cats.items():
if not truth[cat]:
assert score < 0.5

View File

@ -1,5 +1,7 @@
from collections import defaultdict
import pytest
from spacy.pipeline.defaults import default_ner
from spacy.pipeline import EntityRecognizer
@ -7,6 +9,8 @@ from spacy.lang.en import English
from spacy.tokens import Span
# skipped after removing Beam stuff during the Example/GoldParse refactor
@pytest.mark.skip
def test_issue4313():
""" This should not crash or exit with some strange error code """
beam_width = 16

View File

@ -1,9 +1,11 @@
import pytest
from spacy.gold import Example
@pytest.mark.parametrize(
"text,words", [("A'B C", ["A", "'", "B", "C"]), ("A-B", ["A-B"])]
)
def test_gold_misaligned(en_tokenizer, text, words):
doc = en_tokenizer(text)
GoldParse(doc, words=words)
Example.from_dict(doc, {"words": words})

View File

@ -90,6 +90,7 @@ def merged_dict():
return {
"ids": [1, 2, 3, 4, 5, 6, 7],
"words": ["Hi", "there", "everyone", "It", "is", "just", "me"],
"spaces": [True, True, True, True, True, True, False],
"tags": ["INTJ", "ADV", "PRON", "PRON", "AUX", "ADV", "PRON"],
"sent_starts": [1, 0, 0, 1, 0, 0, 0],
}
@ -150,6 +151,30 @@ def test_gold_biluo_misalign(en_vocab):
assert tags == ["O", "O", "O", "-", "-", "-"]
def test_split_sentences(en_vocab):
words = ["I", "flew", "to", "San Francisco Valley", "had", "loads of fun"]
doc = Doc(en_vocab, words=words)
gold_words = ["I", "flew", "to", "San", "Francisco", "Valley", "had", "loads", "of", "fun"]
sent_starts = [True, False, False, False, False, False, True, False, False, False]
example = Example.from_dict(doc, {"words": gold_words, "sent_starts": sent_starts})
assert example.text == "I flew to San Francisco Valley had loads of fun "
split_examples = example.split_sents()
assert len(split_examples) == 2
assert split_examples[0].text == "I flew to San Francisco Valley "
assert split_examples[1].text == "had loads of fun "
words = ["I", "flew", "to", "San", "Francisco", "Valley", "had", "loads", "of fun"]
doc = Doc(en_vocab, words=words)
gold_words = ["I", "flew", "to", "San Francisco", "Valley", "had", "loads of", "fun"]
sent_starts = [True, False, False, False, False, True, False, False]
example = Example.from_dict(doc, {"words": gold_words, "sent_starts": sent_starts})
assert example.text == "I flew to San Francisco Valley had loads of fun "
split_examples = example.split_sents()
assert len(split_examples) == 2
assert split_examples[0].text == "I flew to San Francisco Valley "
assert split_examples[1].text == "had loads of fun "
def test_gold_biluo_different_tokenization(en_vocab, en_tokenizer):
# one-to-many
words = ["I", "flew to", "San Francisco Valley", "."]
@ -466,7 +491,7 @@ def _train(train_data):
def test_split_sents(merged_dict):
nlp = English()
example = Example.from_dict(
Doc(nlp.vocab, words=merged_dict["words"]),
Doc(nlp.vocab, words=merged_dict["words"], spaces=merged_dict["spaces"]),
merged_dict
)
assert len(get_parses_from_example(
@ -484,6 +509,8 @@ def test_split_sents(merged_dict):
split_examples = example.split_sents()
assert len(split_examples) == 2
assert split_examples[0].text == "Hi there everyone "
assert split_examples[1].text == "It is just me"
token_annotation_1 = split_examples[0].to_dict()["token_annotation"]
assert token_annotation_1["words"] == ["Hi", "there", "everyone"]