Match pop with append for training format (#4516)

* trying to fix script - not succesful yet

* match pop() with extend() to avoid changing the data

* few more pop-extend fixes

* reinsert deleted print statement

* fix print statement

* add last tested version

* append instead of extend

* add in few comments

* quick fix for 4402 + unit test

* fixing number of docs (not counting cats)

* more fixes

* fix len

* print tmp file instead of using data from examples dir

* print tmp file instead of using data from examples dir (2)
This commit is contained in:
Sofie Van Landeghem 2019-10-27 16:01:32 +01:00 committed by Matthew Honnibal
parent fcd25db033
commit 8e7414dace
9 changed files with 164 additions and 34 deletions

View File

@ -18,7 +18,7 @@ during training. We discard the auxiliary model before run-time.
The specific example here is not necessarily a good idea --- but it shows
how an arbitrary objective function for some word can be used.
Developed and tested for spaCy 2.0.6
Developed for spaCy 2.0.6 and last tested for 2.2.2
"""
import random
import plac
@ -26,6 +26,8 @@ import spacy
import os.path
from spacy.gold import read_json_file, GoldParse
from spacy.tokens import Doc
random.seed(0)
PWD = os.path.dirname(__file__)
@ -56,22 +58,29 @@ def main(n_iter=10):
ner.add_multitask_objective(get_position_label)
nlp.add_pipe(ner)
print("Create data", len(TRAIN_DATA))
_, sents = TRAIN_DATA[0]
print("Create data, # of sentences =", len(sents) - 1) # not counting the cats attribute
optimizer = nlp.begin_training(get_gold_tuples=lambda: TRAIN_DATA)
for itn in range(n_iter):
random.shuffle(TRAIN_DATA)
losses = {}
for text, annot_brackets in TRAIN_DATA:
annotations, _ = annot_brackets
doc = nlp.make_doc(text)
gold = GoldParse.from_annot_tuples(doc, annotations[0])
nlp.update(
[doc], # batch of texts
[gold], # batch of annotations
drop=0.2, # dropout - make it harder to memorise data
sgd=optimizer, # callable to update weights
losses=losses,
)
for raw_text, annots_brackets in TRAIN_DATA:
cats = annots_brackets.pop()
for annotations, _ in annots_brackets:
annotations.append(cats) # temporarily add it here for from_annot_tuples to work
doc = Doc(nlp.vocab, words=annotations[1])
gold = GoldParse.from_annot_tuples(doc, annotations)
annotations.pop() # restore data
nlp.update(
[doc], # batch of texts
[gold], # batch of annotations
drop=0.2, # dropout - make it harder to memorise data
sgd=optimizer, # callable to update weights
losses=losses,
)
annots_brackets.append(cats) # restore data
print(losses.get("nn_labeller", 0.0), losses["ner"])
# test the trained model

View File

@ -55,22 +55,22 @@ def tags_to_entities(tags):
def merge_sents(sents):
m_deps = [[], [], [], [], [], []]
m_sents = [[], [], [], [], [], []]
m_brackets = []
m_cats = sents.pop()
i = 0
for (ids, words, tags, heads, labels, ner), brackets in sents:
m_deps[0].extend(id_ + i for id_ in ids)
m_deps[1].extend(words)
m_deps[2].extend(tags)
m_deps[3].extend(head + i for head in heads)
m_deps[4].extend(labels)
m_deps[5].extend(ner)
m_sents[0].extend(id_ + i for id_ in ids)
m_sents[1].extend(words)
m_sents[2].extend(tags)
m_sents[3].extend(head + i for head in heads)
m_sents[4].extend(labels)
m_sents[5].extend(ner)
m_brackets.extend((b["first"] + i, b["last"] + i, b["label"])
for b in brackets)
i += len(ids)
m_deps.append(m_cats)
return [(m_deps, m_brackets)]
sents.append(m_cats) # restore original data
return [[(m_sents, m_brackets)], m_cats]
_NORM_MAP = {"``": '"', "''": '"'}
@ -248,6 +248,7 @@ class GoldCorpus(object):
if self.limit and i >= self.limit:
break
i += 1
paragraph_tuples.append(cats) # restore original data
return n
def train_docs(self, nlp, gold_preproc=False, max_length=None,
@ -288,26 +289,36 @@ class GoldCorpus(object):
@classmethod
def _make_docs(cls, nlp, raw_text, paragraph_tuples, gold_preproc, noise_level=0.0, orth_variant_level=0.0):
cats = paragraph_tuples.pop()
if raw_text is not None:
raw_text, paragraph_tuples = make_orth_variants(nlp, raw_text, paragraph_tuples, orth_variant_level=orth_variant_level)
raw_text = add_noise(raw_text, noise_level)
return [nlp.make_doc(raw_text)], paragraph_tuples
result = [nlp.make_doc(raw_text)], paragraph_tuples
else:
docs = []
raw_text, paragraph_tuples = make_orth_variants(nlp, None, paragraph_tuples, orth_variant_level=orth_variant_level)
return [Doc(nlp.vocab, words=add_noise(sent_tuples[1], noise_level))
result = [Doc(nlp.vocab, words=add_noise(sent_tuples[1], noise_level))
for (sent_tuples, brackets) in paragraph_tuples], paragraph_tuples
paragraph_tuples.append(cats)
return result
@classmethod
def _make_golds(cls, docs, paragraph_tuples, make_projective):
cats = paragraph_tuples.pop()
if len(docs) != len(paragraph_tuples):
n_annots = len(paragraph_tuples)
raise ValueError(Errors.E070.format(n_docs=len(docs), n_annots=n_annots))
return [GoldParse.from_annot_tuples(doc, sent_tuples,
make_projective=make_projective)
for doc, (sent_tuples, brackets)
in zip(docs, paragraph_tuples)]
result = []
for doc, brack_annot in zip(docs, paragraph_tuples):
if len(brack_annot) == 1:
brack_annot = brack_annot[0]
sent_tuples, brackets = brack_annot
sent_tuples.append(cats)
result.append(GoldParse.from_annot_tuples(doc, sent_tuples, make_projective=make_projective))
sent_tuples.pop()
paragraph_tuples.append(cats)
return result
def make_orth_variants(nlp, raw, paragraph_tuples, orth_variant_level=0.0):

View File

@ -598,10 +598,11 @@ class Language(object):
# Populate vocab
else:
for _, annots_brackets in get_gold_tuples():
_ = annots_brackets.pop()
cats = annots_brackets.pop()
for annots, _ in annots_brackets:
for word in annots[1]:
_ = self.vocab[word] # noqa: F841
annots_brackets.append(cats) # restore original data
if cfg.get("device", -1) >= 0:
util.use_gpu(cfg["device"])
if self.vocab.vectors.data.shape[1] >= 1:

View File

@ -517,7 +517,7 @@ class Tagger(Pipe):
orig_tag_map = dict(self.vocab.morphology.tag_map)
new_tag_map = OrderedDict()
for raw_text, annots_brackets in get_gold_tuples():
_ = annots_brackets.pop()
cats = annots_brackets.pop()
for annots, brackets in annots_brackets:
ids, words, tags, heads, deps, ents = annots
for tag in tags:
@ -525,6 +525,7 @@ class Tagger(Pipe):
new_tag_map[tag] = orig_tag_map[tag]
else:
new_tag_map[tag] = {POS: X}
annots_brackets.append(cats) # restore original data
cdef Vocab vocab = self.vocab
if new_tag_map:
vocab.morphology = Morphology(vocab.strings, new_tag_map,
@ -703,12 +704,14 @@ class MultitaskObjective(Tagger):
sgd=None, **kwargs):
gold_tuples = nonproj.preprocess_training_data(get_gold_tuples())
for raw_text, annots_brackets in gold_tuples:
cats = annots_brackets.pop()
for annots, brackets in annots_brackets:
ids, words, tags, heads, deps, ents = annots
for i in range(len(ids)):
label = self.make_label(i, words, tags, heads, deps, ents)
if label is not None and label not in self.labels:
self.labels[label] = len(self.labels)
annots_brackets.append(cats)
if self.model is True:
token_vector_width = util.env_opt("token_vector_width")
self.model = self.Model(len(self.labels), tok2vec=tok2vec)
@ -1035,7 +1038,7 @@ class TextCategorizer(Pipe):
def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None, **kwargs):
for raw_text, annots_brackets in get_gold_tuples():
cats = annots_brackets.pop()
cats = annots_brackets[-1]
for cat in cats:
self.add_label(cat)
if self.model is True:

View File

@ -342,7 +342,7 @@ cdef class ArcEager(TransitionSystem):
actions[RIGHT][label] = 1
actions[REDUCE][label] = 1
for raw_text, sents in kwargs.get('gold_parses', []):
_ = sents.pop()
cats = sents.pop()
for (ids, words, tags, heads, labels, iob), ctnts in sents:
heads, labels = nonproj.projectivize(heads, labels)
for child, head, label in zip(ids, heads, labels):
@ -356,6 +356,7 @@ cdef class ArcEager(TransitionSystem):
elif head > child:
actions[LEFT][label] += 1
actions[SHIFT][''] += 1
sents.append(cats) # restore original data
if min_freq is not None:
for action, label_freqs in actions.items():
for label, freq in list(label_freqs.items()):

View File

@ -73,13 +73,14 @@ cdef class BiluoPushDown(TransitionSystem):
actions[action][entity_type] = 1
moves = ('M', 'B', 'I', 'L', 'U')
for raw_text, sents in kwargs.get('gold_parses', []):
_ = sents.pop()
cats = sents.pop()
for (ids, words, tags, heads, labels, biluo), _ in sents:
for i, ner_tag in enumerate(biluo):
if ner_tag != 'O' and ner_tag != '-':
_, label = ner_tag.split('-', 1)
for action in (BEGIN, IN, LAST, UNIT):
actions[action][label] += 1
sents.append(cats) # restore original data
return actions
@property

View File

@ -606,12 +606,13 @@ cdef class Parser:
doc_sample = []
gold_sample = []
for raw_text, annots_brackets in islice(get_gold_tuples(), 1000):
_ = annots_brackets.pop()
cats = annots_brackets.pop()
for annots, brackets in annots_brackets:
ids, words, tags, heads, deps, ents = annots
doc_sample.append(Doc(self.vocab, words=words))
gold_sample.append(GoldParse(doc_sample[-1], words=words, tags=tags,
heads=heads, deps=deps, entities=ents))
annots_brackets.append(cats) # restore original data
self.model.begin_training(doc_sample, gold_sample)
if pipeline is not None:
self.init_multitask_objectives(get_gold_tuples, pipeline, sgd=sgd, **cfg)

View File

@ -97,6 +97,7 @@ def preprocess_training_data(gold_tuples, label_freq_cutoff=30):
freqs = {}
for raw_text, sents in gold_tuples:
prepro_sents = []
cats = sents.pop()
for (ids, words, tags, heads, labels, iob), ctnts in sents:
proj_heads, deco_labels = projectivize(heads, labels)
# set the label to ROOT for each root dependent
@ -109,6 +110,8 @@ def preprocess_training_data(gold_tuples, label_freq_cutoff=30):
freqs[label] = freqs.get(label, 0) + 1
prepro_sents.append(
((ids, words, tags, proj_heads, deco_labels, iob), ctnts))
sents.append(cats)
prepro_sents.append(cats)
preprocessed.append((raw_text, prepro_sents))
if label_freq_cutoff > 0:
return _filter_labels(preprocessed, label_freq_cutoff, freqs)
@ -209,6 +212,7 @@ def _filter_labels(gold_tuples, cutoff, freqs):
filtered = []
for raw_text, sents in gold_tuples:
filtered_sents = []
cats = sents.pop()
for (ids, words, tags, heads, labels, iob), ctnts in sents:
filtered_labels = []
for label in labels:
@ -218,5 +222,7 @@ def _filter_labels(gold_tuples, cutoff, freqs):
filtered_labels.append(label)
filtered_sents.append(
((ids, words, tags, heads, filtered_labels, iob), ctnts))
sents.append(cats)
filtered_sents.append(cats)
filtered.append((raw_text, filtered_sents))
return filtered

View File

@ -0,0 +1,97 @@
# coding: utf8
from __future__ import unicode_literals
import srsly
from spacy.gold import GoldCorpus, json_to_tuple
from spacy.lang.en import English
from spacy.tests.util import make_tempdir
def test_issue4402():
nlp = English()
with make_tempdir() as tmpdir:
print("temp", tmpdir)
json_path = tmpdir / "test4402.json"
srsly.write_json(json_path, json_data)
corpus = GoldCorpus(str(json_path), str(json_path))
train_docs = list(corpus.train_docs(nlp, gold_preproc=True, max_length=0))
# assert that the data got split into 4 sentences
assert len(train_docs) == 4
json_data = [
{
"id": 0,
"paragraphs": [
{
"raw": "How should I cook bacon in an oven?\nI've heard of people cooking bacon in an oven.",
"sentences": [
{
"tokens": [
{"id": 0, "orth": "How", "ner": "O"},
{"id": 1, "orth": "should", "ner": "O"},
{"id": 2, "orth": "I", "ner": "O"},
{"id": 3, "orth": "cook", "ner": "O"},
{"id": 4, "orth": "bacon", "ner": "O"},
{"id": 5, "orth": "in", "ner": "O"},
{"id": 6, "orth": "an", "ner": "O"},
{"id": 7, "orth": "oven", "ner": "O"},
{"id": 8, "orth": "?", "ner": "O"},
],
"brackets": [],
},
{
"tokens": [
{"id": 9, "orth": "\n", "ner": "O"},
{"id": 10, "orth": "I", "ner": "O"},
{"id": 11, "orth": "'ve", "ner": "O"},
{"id": 12, "orth": "heard", "ner": "O"},
{"id": 13, "orth": "of", "ner": "O"},
{"id": 14, "orth": "people", "ner": "O"},
{"id": 15, "orth": "cooking", "ner": "O"},
{"id": 16, "orth": "bacon", "ner": "O"},
{"id": 17, "orth": "in", "ner": "O"},
{"id": 18, "orth": "an", "ner": "O"},
{"id": 19, "orth": "oven", "ner": "O"},
{"id": 20, "orth": ".", "ner": "O"},
],
"brackets": [],
},
],
"cats": [
{"label": "baking", "value": 1.0},
{"label": "not_baking", "value": 0.0},
],
},
{
"raw": "What is the difference between white and brown eggs?\n",
"sentences": [
{
"tokens": [
{"id": 0, "orth": "What", "ner": "O"},
{"id": 1, "orth": "is", "ner": "O"},
{"id": 2, "orth": "the", "ner": "O"},
{"id": 3, "orth": "difference", "ner": "O"},
{"id": 4, "orth": "between", "ner": "O"},
{"id": 5, "orth": "white", "ner": "O"},
{"id": 6, "orth": "and", "ner": "O"},
{"id": 7, "orth": "brown", "ner": "O"},
{"id": 8, "orth": "eggs", "ner": "O"},
{"id": 9, "orth": "?", "ner": "O"},
],
"brackets": [],
},
{"tokens": [{"id": 10, "orth": "\n", "ner": "O"}], "brackets": []},
],
"cats": [
{"label": "baking", "value": 0.0},
{"label": "not_baking", "value": 1.0},
],
},
],
}
]