Example class for training data (#4543)

* OrigAnnot class instead of gold.orig_annot list of zipped tuples

* from_orig to replace from_annot_tuples

* rename to RawAnnot

* some unit tests for GoldParse creation and internal format

* removing orig_annot and switching to lists instead of tuple

* rewriting tuples to use RawAnnot (+ debug statements, WIP)

* fix pop() changing the data

* small fixes

* pop-append fixes

* return RawAnnot for existing GoldParse to have uniform interface

* clean up imports

* fix merge_sents

* add unit test for 4402 with new structure (not working yet)

* introduce DocAnnot

* typo fixes

* add unit test for merge_sents

* rename from_orig to from_raw

* fixing unit tests

* fix nn parser

* read_annots to produce text, doc_annot pairs

* _make_golds fix

* rename golds_to_gold_annots

* small fixes

* fix encoding

* have golds_to_gold_annots use DocAnnot

* missed a spot

* merge_sents as function in DocAnnot

* allow specifying only part of the token-level annotations

* refactor with Example class + underlying dicts

* pipeline components to work with Example objects (wip)

* input checking

* fix yielding

* fix calls to update

* small fixes

* fix scorer unit test with new format

* fix kwargs order

* fixes for ud and conllu scripts

* fix reading data for conllu script

* add in proper errors (not fixed numbering yet to avoid merge conflicts)

* fixing few more small bugs

* fix EL script
This commit is contained in:
Sofie Van Landeghem 2019-11-11 17:35:27 +01:00 committed by Matthew Honnibal
parent 56ad3a3988
commit e48a09df4e
48 changed files with 1178 additions and 716 deletions

View File

@ -13,23 +13,12 @@ import srsly
import spacy
import spacy.util
from spacy.tokens import Token, Doc
from spacy.gold import GoldParse
from spacy.util import compounding, minibatch_by_words
from spacy.syntax.nonproj import projectivize
from spacy.matcher import Matcher
# from spacy.morphology import Fused_begin, Fused_inside
from spacy import displacy
from collections import defaultdict, Counter
from timeit import default_timer as timer
Fused_begin = None
Fused_inside = None
import itertools
import random
import numpy.random
from . import conll17_ud_eval
from spacy import lang
@ -268,7 +257,7 @@ def load_nlp(experiments_dir, corpus):
return nlp
def initialize_pipeline(nlp, docs, golds, config, device):
def initialize_pipeline(nlp, examples, config, device):
nlp.add_pipe(nlp.create_pipe("parser"))
return nlp

View File

@ -7,24 +7,20 @@ from __future__ import unicode_literals
import plac
from pathlib import Path
import re
import sys
import json
import spacy
import spacy.util
from bin.ud import conll17_ud_eval
from spacy.tokens import Token, Doc
from spacy.gold import GoldParse
from spacy.gold import GoldParse, Example
from spacy.util import compounding, minibatch, minibatch_by_words
from spacy.syntax.nonproj import projectivize
from spacy.matcher import Matcher
from spacy import displacy
from collections import defaultdict, Counter
from timeit import default_timer as timer
from collections import defaultdict
import itertools
import random
import numpy.random
from spacy import lang
from spacy.lang import zh
@ -56,7 +52,7 @@ def read_data(
max_doc_length=None,
limit=None,
):
"""Read the CONLLU format into (Doc, GoldParse) tuples. If raw_text=True,
"""Read the CONLLU format into Example objects. If raw_text=True,
include Doc objects created using nlp.make_doc and then aligned against
the gold-standard sequences. If oracle_segments=True, include Doc objects
created from the gold-standard segments. At least one must be True."""
@ -101,15 +97,16 @@ def read_data(
docs.append(doc)
golds.append(gold)
if limit and len(docs) >= limit:
return docs, golds
return golds_to_gold_data(docs, golds)
if raw_text and sent_annots:
doc, gold = _make_gold(nlp, None, sent_annots)
docs.append(doc)
golds.append(gold)
if limit and len(docs) >= limit:
return docs, golds
return docs, golds
return golds_to_gold_data(docs, golds)
return golds_to_gold_data(docs, golds)
def _parse_morph_string(morph_string):
if morph_string == '_':
@ -123,6 +120,7 @@ def _parse_morph_string(morph_string):
output.append('%s_%s' % (key, value.lower()))
return set(output)
def read_conllu(file_):
docs = []
sent = []
@ -183,16 +181,18 @@ def _make_gold(nlp, text, sent_annots, drop_deps=0.0):
#############################
def golds_to_gold_tuples(docs, golds):
"""Get out the annoying 'tuples' format used by begin_training, given the
def golds_to_gold_data(docs, golds):
"""Get out the training data format used by begin_training, given the
GoldParse objects."""
tuples = []
data = []
for doc, gold in zip(docs, golds):
text = doc.text
ids, words, tags, heads, labels, iob = zip(*gold.orig_annot)
sents = [((ids, words, tags, heads, labels, iob), [])]
tuples.append((text, sents))
return tuples
example = Example(doc=doc)
example.add_doc_annotation(cats=gold.cats)
token_annotation_dict = gold.orig.to_dict()
example.add_token_annotation(**token_annotation_dict)
example.goldparse = gold
data.append(example)
return data
##############
@ -348,7 +348,7 @@ def load_nlp(corpus, config, vectors=None):
return nlp
def initialize_pipeline(nlp, docs, golds, config, device):
def initialize_pipeline(nlp, examples, config, device):
nlp.add_pipe(nlp.create_pipe("tagger", config={"set_morphology": False}))
nlp.add_pipe(nlp.create_pipe("morphologizer"))
nlp.add_pipe(nlp.create_pipe("parser"))
@ -356,14 +356,15 @@ def initialize_pipeline(nlp, docs, golds, config, device):
nlp.parser.add_multitask_objective("tag")
if config.multitask_sent:
nlp.parser.add_multitask_objective("sent_start")
for gold in golds:
for ex in examples:
gold = ex.gold
for tag in gold.tags:
if tag is not None:
nlp.tagger.add_label(tag)
if torch is not None and device != -1:
torch.set_default_tensor_type("torch.cuda.FloatTensor")
optimizer = nlp.begin_training(
lambda: golds_to_gold_tuples(docs, golds),
lambda: examples,
device=device,
subword_features=config.subword_features,
conv_depth=config.conv_depth,
@ -504,20 +505,20 @@ def main(
print("Train and evaluate", corpus, "using lang", paths.lang)
nlp = load_nlp(paths.lang, config, vectors=vectors_dir)
docs, golds = read_data(
examples = read_data(
nlp,
paths.train.conllu.open(),
paths.train.text.open(),
paths.train.conllu.open(encoding="utf8"),
paths.train.text.open(encoding="utf8"),
max_doc_length=config.max_doc_length,
limit=limit,
)
optimizer = initialize_pipeline(nlp, docs, golds, config, gpu_device)
optimizer = initialize_pipeline(nlp, examples, config, gpu_device)
batch_sizes = compounding(config.min_batch_size, config.max_batch_size, 1.001)
beam_prob = compounding(0.2, 0.8, 1.001)
for i in range(config.nr_epoch):
docs, golds = read_data(
examples = read_data(
nlp,
paths.train.conllu.open(encoding="utf8"),
paths.train.text.open(encoding="utf8"),
@ -526,22 +527,19 @@ def main(
oracle_segments=use_oracle_segments,
raw_text=not use_oracle_segments,
)
Xs = list(zip(docs, golds))
random.shuffle(Xs)
random.shuffle(examples)
if config.batch_by_words:
batches = minibatch_by_words(Xs, size=batch_sizes)
batches = minibatch_by_words(examples, size=batch_sizes)
else:
batches = minibatch(Xs, size=batch_sizes)
batches = minibatch(examples, size=batch_sizes)
losses = {}
n_train_words = sum(len(doc) for doc in docs)
n_train_words = sum(len(ex.doc) for ex in examples)
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
for batch in batches:
batch_docs, batch_gold = zip(*batch)
pbar.update(sum(len(doc) for doc in batch_docs))
pbar.update(sum(len(ex.doc) for ex in batch))
nlp.parser.cfg["beam_update_prob"] = next(beam_prob)
nlp.update(
batch_docs,
batch_gold,
batch,
sgd=optimizer,
drop=config.dropout,
losses=losses,

View File

@ -46,7 +46,7 @@ def _define_entities(nlp, kb, entity_def_path, entity_descr_path, min_entity_fre
" cf. https://spacy.io/usage/models#languages."
)
logger.info("Filtering entities with fewer than {} mentions".format(min_entity_freq))
logger.info("Filtering entities with fewer than {} mentions or no description".format(min_entity_freq))
entity_frequencies = io.read_entity_to_count(entity_freq_path)
# filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise
filtered_title_to_id, entity_list, description_list, frequency_list = get_filtered_entities(

View File

@ -131,10 +131,8 @@ def main(
with nlp.disable_pipes(*other_pipes):
for batch in batches:
try:
docs, golds = zip(*batch)
nlp.update(
docs=docs,
golds=golds,
examples=batch,
sgd=optimizer,
drop=dropout,
losses=losses,

View File

@ -11,10 +11,9 @@ import json
import spacy
import spacy.util
from spacy.tokens import Token, Doc
from spacy.gold import GoldParse
from spacy.gold import GoldParse, Example
from spacy.syntax.nonproj import projectivize
from collections import defaultdict, Counter
from timeit import default_timer as timer
from collections import defaultdict
from spacy.matcher import Matcher
import itertools
@ -33,25 +32,25 @@ random.seed(0)
numpy.random.seed(0)
def minibatch_by_words(items, size=5000):
random.shuffle(items)
def minibatch_by_words(examples, size=5000):
random.shuffle(examples)
if isinstance(size, int):
size_ = itertools.repeat(size)
else:
size_ = size
items = iter(items)
examples = iter(examples)
while True:
batch_size = next(size_)
batch = []
while batch_size >= 0:
try:
doc, gold = next(items)
example = next(examples)
except StopIteration:
if batch:
yield batch
return
batch_size -= len(doc)
batch.append((doc, gold))
batch_size -= len(example.doc)
batch.append(example)
if batch:
yield batch
else:
@ -78,7 +77,7 @@ def read_data(
max_doc_length=None,
limit=None,
):
"""Read the CONLLU format into (Doc, GoldParse) tuples. If raw_text=True,
"""Read the CONLLU format into Example objects. If raw_text=True,
include Doc objects created using nlp.make_doc and then aligned against
the gold-standard sequences. If oracle_segments=True, include Doc objects
created from the gold-standard segments. At least one must be True."""
@ -119,15 +118,15 @@ def read_data(
docs.append(doc)
golds.append(gold)
if limit and len(docs) >= limit:
return docs, golds
return golds_to_gold_data(docs, golds)
if raw_text and sent_annots:
doc, gold = _make_gold(nlp, None, sent_annots)
docs.append(doc)
golds.append(gold)
if limit and len(docs) >= limit:
return docs, golds
return docs, golds
return golds_to_gold_data(docs, golds)
return golds_to_gold_data(docs, golds)
def read_conllu(file_):
@ -181,16 +180,18 @@ def _make_gold(nlp, text, sent_annots):
#############################
def golds_to_gold_tuples(docs, golds):
"""Get out the annoying 'tuples' format used by begin_training, given the
def golds_to_gold_data(docs, golds):
"""Get out the training data format used by begin_training, given the
GoldParse objects."""
tuples = []
data = []
for doc, gold in zip(docs, golds):
text = doc.text
ids, words, tags, heads, labels, iob = zip(*gold.orig_annot)
sents = [((ids, words, tags, heads, labels, iob), [])]
tuples.append((text, sents))
return tuples
example = Example(doc=doc)
example.add_doc_annotation(cats=gold.cats)
token_annotation_dict = gold.orig.to_dict()
example.add_token_annotation(**token_annotation_dict)
example.goldparse = gold
data.append(example)
return data
##############
@ -290,9 +291,9 @@ def get_token_conllu(token, i):
return "\n".join(lines)
Token.set_extension("get_conllu_lines", method=get_token_conllu)
Token.set_extension("begins_fused", default=False)
Token.set_extension("inside_fused", default=False)
Token.set_extension("get_conllu_lines", method=get_token_conllu, force=True)
Token.set_extension("begins_fused", default=False, force=True)
Token.set_extension("inside_fused", default=False, force=True)
##################
@ -308,7 +309,7 @@ def load_nlp(corpus, config):
return nlp
def initialize_pipeline(nlp, docs, golds, config):
def initialize_pipeline(nlp, examples, config):
nlp.add_pipe(nlp.create_pipe("parser"))
if config.multitask_tag:
nlp.parser.add_multitask_objective("tag")
@ -316,18 +317,19 @@ def initialize_pipeline(nlp, docs, golds, config):
nlp.parser.add_multitask_objective("sent_start")
nlp.parser.moves.add_action(2, "subtok")
nlp.add_pipe(nlp.create_pipe("tagger"))
for gold in golds:
for tag in gold.tags:
for ex in examples:
for tag in ex.gold.tags:
if tag is not None:
nlp.tagger.add_label(tag)
# Replace labels that didn't make the frequency cutoff
actions = set(nlp.parser.labels)
label_set = set([act.split("-")[1] for act in actions if "-" in act])
for gold in golds:
for ex in examples:
gold = ex.gold
for i, label in enumerate(gold.labels):
if label is not None and label not in label_set:
gold.labels[i] = label.split("||")[0]
return nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds))
return nlp.begin_training(lambda: examples)
########################
@ -401,28 +403,26 @@ def main(ud_dir, parses_dir, config, corpus, limit=0):
print("Train and evaluate", corpus, "using lang", paths.lang)
nlp = load_nlp(paths.lang, config)
docs, golds = read_data(
examples = read_data(
nlp,
paths.train.conllu.open(),
paths.train.text.open(),
paths.train.conllu.open(encoding="utf8"),
paths.train.text.open(encoding="utf8"),
max_doc_length=config.max_doc_length,
limit=limit,
)
optimizer = initialize_pipeline(nlp, docs, golds, config)
optimizer = initialize_pipeline(nlp, examples, config)
for i in range(config.nr_epoch):
docs = [nlp.make_doc(doc.text) for doc in docs]
batches = minibatch_by_words(list(zip(docs, golds)), size=config.batch_size)
docs = [nlp.make_doc(example.doc.text) for example in examples]
batches = minibatch_by_words(examples, size=config.batch_size)
losses = {}
n_train_words = sum(len(doc) for doc in docs)
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
for batch in batches:
batch_docs, batch_gold = zip(*batch)
pbar.update(sum(len(doc) for doc in batch_docs))
pbar.update(sum(len(ex.doc) for ex in batch))
nlp.update(
batch_docs,
batch_gold,
examples=batch,
sgd=optimizer,
drop=config.dropout,
losses=losses,

View File

@ -31,14 +31,13 @@ random.seed(0)
PWD = os.path.dirname(__file__)
TRAIN_DATA = list(read_json_file(
os.path.join(PWD, "ner_example_data", "ner-sent-per-line.json")))
TRAIN_DATA = list(read_json_file(os.path.join(PWD, "training-data.json")))
def get_position_label(i, words, tags, heads, labels, ents):
def get_position_label(i, token_annotation):
"""Return labels indicating the position of the word in the document.
"""
if len(words) < 20:
if len(token_annotation.words) < 20:
return "short-doc"
elif i == 0:
return "first-word"
@ -46,7 +45,7 @@ def get_position_label(i, words, tags, heads, labels, ents):
return "early-word"
elif i < 20:
return "mid-word"
elif i == len(words) - 1:
elif i == len(token_annotation.words) - 1:
return "last-word"
else:
return "late-word"
@ -60,17 +59,17 @@ def main(n_iter=10):
print(nlp.pipeline)
print("Create data", len(TRAIN_DATA))
optimizer = nlp.begin_training(get_gold_tuples=lambda: TRAIN_DATA)
optimizer = nlp.begin_training(get_examples=lambda: TRAIN_DATA)
for itn in range(n_iter):
random.shuffle(TRAIN_DATA)
losses = {}
for text, annot_brackets in TRAIN_DATA:
for annotations, _ in annot_brackets:
doc = Doc(nlp.vocab, words=annotations[1])
gold = GoldParse.from_annot_tuples(doc, annotations)
for example in TRAIN_DATA:
for token_annotation in example.token_annotations:
doc = Doc(nlp.vocab, words=token_annotation.words)
gold = GoldParse.from_annotation(doc, example.doc_annotation, token_annotation)
nlp.update(
[doc], # batch of texts
[gold], # batch of annotations
examples=[(doc, gold)], # 1 example
drop=0.2, # dropout - make it harder to memorise data
sgd=optimizer, # callable to update weights
losses=losses,
@ -78,9 +77,9 @@ def main(n_iter=10):
print(losses.get("nn_labeller", 0.0), losses["ner"])
# test the trained model
for text, _ in TRAIN_DATA:
if text is not None:
doc = nlp(text)
for example in TRAIN_DATA:
if example.text is not None:
doc = nlp(example.text)
print("Entities", [(ent.text, ent.label_) for ent in doc.ents])
print("Tokens", [(t.text, t.ent_type_, t.ent_iob) for t in doc])

View File

@ -116,7 +116,7 @@ def train_tensorizer(nlp, texts, dropout, n_iter):
losses = {}
for i, batch in enumerate(minibatch(tqdm.tqdm(texts))):
docs = [nlp.make_doc(text) for text in batch]
tensorizer.update(docs, None, losses=losses, sgd=optimizer, drop=dropout)
tensorizer.update((docs, None), losses=losses, sgd=optimizer, drop=dropout)
print(losses)
return optimizer
@ -147,8 +147,7 @@ def train_textcat(nlp, n_texts, n_iter=10):
# batch up the examples using spaCy's minibatch
batches = minibatch(tqdm.tqdm(train_data), size=2)
for batch in batches:
texts, annotations = zip(*batch)
nlp.update(texts, annotations, sgd=optimizer, drop=0.2, losses=losses)
nlp.update(batch, sgd=optimizer, drop=0.2, losses=losses)
with textcat.model.use_params(optimizer.averages):
# evaluate on the dev data split off in load_data()
scores = evaluate_textcat(nlp.tokenizer, textcat, dev_texts, dev_cats)

View File

@ -74,8 +74,7 @@ def main(model_name, unlabelled_loc):
# batch up the examples using spaCy's minibatch
raw_batches = minibatch(raw_docs, size=4)
for batch in minibatch(TRAIN_DATA, size=sizes):
docs, golds = zip(*batch)
nlp.update(docs, golds, sgd=optimizer, drop=dropout, losses=losses)
nlp.update(batch, sgd=optimizer, drop=dropout, losses=losses)
raw_batch = list(next(raw_batches))
nlp.rehearse(raw_batch, sgd=optimizer, losses=r_losses)
print("Losses", losses)

View File

@ -108,10 +108,8 @@ def main(kb_path, vocab_path=None, output_dir=None, n_iter=50):
# batch up the examples using spaCy's minibatch
batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
for batch in batches:
texts, annotations = zip(*batch)
nlp.update(
texts, # batch of texts
annotations, # batch of annotations
batch,
drop=0.2, # dropout - make it harder to memorise data
losses=losses,
sgd=optimizer,

View File

@ -133,8 +133,7 @@ def main(model=None, output_dir=None, n_iter=15):
# batch up the examples using spaCy's minibatch
batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
for batch in batches:
texts, annotations = zip(*batch)
nlp.update(texts, annotations, sgd=optimizer, losses=losses)
nlp.update(batch, sgd=optimizer, losses=losses)
print("Losses", losses)
# test the trained model

View File

@ -67,10 +67,8 @@ def main(model=None, output_dir=None, n_iter=100):
# batch up the examples using spaCy's minibatch
batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
for batch in batches:
texts, annotations = zip(*batch)
nlp.update(
texts, # batch of texts
annotations, # batch of annotations
batch,
drop=0.5, # dropout - make it harder to memorise data
losses=losses,
)

View File

@ -104,8 +104,7 @@ def main(model=None, new_model_name="animal", output_dir=None, n_iter=30):
batches = minibatch(TRAIN_DATA, size=sizes)
losses = {}
for batch in batches:
texts, annotations = zip(*batch)
nlp.update(texts, annotations, sgd=optimizer, drop=0.35, losses=losses)
nlp.update(batch, sgd=optimizer, drop=0.35, losses=losses)
print("Losses", losses)
# test the trained model

View File

@ -74,8 +74,7 @@ def main(model=None, output_dir=None, n_iter=15):
# batch up the examples using spaCy's minibatch
batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
for batch in batches:
texts, annotations = zip(*batch)
nlp.update(texts, annotations, sgd=optimizer, losses=losses)
nlp.update(batch, sgd=optimizer, losses=losses)
print("Losses", losses)
# test the trained model

View File

@ -65,8 +65,7 @@ def main(lang="en", output_dir=None, n_iter=25):
# batch up the examples using spaCy's minibatch
batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
for batch in batches:
texts, annotations = zip(*batch)
nlp.update(texts, annotations, sgd=optimizer, losses=losses)
nlp.update(batch, sgd=optimizer, losses=losses)
print("Losses", losses)
# test the trained model

View File

@ -82,8 +82,7 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None
random.shuffle(train_data)
batches = minibatch(train_data, size=batch_sizes)
for batch in batches:
texts, annotations = zip(*batch)
nlp.update(texts, annotations, sgd=optimizer, drop=0.2, losses=losses)
nlp.update(batch, sgd=optimizer, drop=0.2, losses=losses)
with textcat.model.use_params(optimizer.averages):
# evaluate on the dev data split off in load_data()
scores = evaluate(nlp.tokenizer, textcat, dev_texts, dev_cats)

View File

@ -3,6 +3,7 @@ from __future__ import unicode_literals
import re
from spacy.gold import Example
from ...gold import iob_to_biluo
@ -19,15 +20,15 @@ def conllu2json(input_data, n_sents=10, use_morphology=False, lang=None, **_):
# by @katarkor
docs = []
sentences = []
conll_tuples = read_conllx(input_data, use_morphology=use_morphology)
conll_data = read_conllx(input_data, use_morphology=use_morphology)
checked_for_ner = False
has_ner_tags = False
for i, (raw_text, tokens) in enumerate(conll_tuples):
sentence, brackets = tokens[0]
for i, example in enumerate(conll_data):
for token_annotation in example.token_annotations:
if not checked_for_ner:
has_ner_tags = is_ner(sentence[5][0])
has_ner_tags = is_ner(token_annotation.entities[0])
checked_for_ner = True
sentences.append(generate_sentence(sentence, has_ner_tags))
sentences.append(generate_sentence(token_annotation, has_ner_tags))
# Real-sized documents could be extracted using the comments on the
# conluu document
if len(sentences) % n_sents == 0:
@ -52,15 +53,15 @@ def is_ner(tag):
def read_conllx(input_data, use_morphology=False, n=0):
""" Yield example data points, one for each sentence """
i = 0
for sent in input_data.strip().split("\n\n"):
lines = sent.strip().split("\n")
if lines:
while lines[0].startswith("#"):
lines.pop(0)
tokens = []
ids, words, tags, heads, deps, ents = [], [], [], [], [], []
for line in lines:
parts = line.split("\t")
id_, word, lemma, pos, tag, morph, head, dep, _1, iob = parts
if "-" in id_ or "." in id_:
@ -72,14 +73,22 @@ def read_conllx(input_data, use_morphology=False, n=0):
tag = pos if tag == "_" else tag
tag = tag + "__" + morph if use_morphology else tag
iob = iob if iob else "O"
tokens.append((id_, word, tag, head, dep, iob))
ids.append(id_)
words.append(word)
tags.append(tag)
heads.append(head)
deps.append(dep)
ents.append(iob)
except: # noqa: E722
print(line)
raise
tuples = [list(t) for t in zip(*tokens)]
yield (None, [[tuples, []]])
example = Example(doc=None)
example.add_token_annotation(ids=ids, words=words, tags=tags,
heads=heads, deps=deps, entities=ents)
yield example
i += 1
if n >= 1 and i >= n:
if 1 <= n <= i:
break
@ -107,20 +116,19 @@ def simplify_tags(iob):
return new_iob
def generate_sentence(sent, has_ner_tags):
(id_, word, tag, head, dep, iob) = sent
def generate_sentence(token_annotation, has_ner_tags):
sentence = {}
tokens = []
if has_ner_tags:
iob = simplify_tags(iob)
iob = simplify_tags(token_annotation.entities)
biluo = iob_to_biluo(iob)
for i, id in enumerate(id_):
for i, id in enumerate(token_annotation.ids):
token = {}
token["id"] = id
token["orth"] = word[i]
token["tag"] = tag[i]
token["head"] = head[i] - id
token["dep"] = dep[i]
token["orth"] = token_annotation.words[i]
token["tag"] = token_annotation.tags[i]
token["head"] = token_annotation.heads[i] - id
token["dep"] = token_annotation.deps[i]
if has_ner_tags:
token["ner"] = biluo[i]
tokens.append(token)

View File

@ -80,16 +80,16 @@ def debug_data(
with msg.loading("Loading corpus..."):
corpus = GoldCorpus(train_path, dev_path)
try:
train_docs = list(corpus.train_docs(nlp))
train_docs_unpreprocessed = list(
corpus.train_docs_without_preprocessing(nlp)
train_dataset = list(corpus.train_dataset(nlp))
train_dataset_unpreprocessed = list(
corpus.train_dataset_without_preprocessing(nlp)
)
except ValueError as e:
loading_train_error_message = "Training data cannot be loaded: {}".format(
str(e)
)
try:
dev_docs = list(corpus.dev_docs(nlp))
dev_dataset = list(corpus.dev_dataset(nlp))
except ValueError as e:
loading_dev_error_message = "Development data cannot be loaded: {}".format(
str(e)
@ -102,10 +102,10 @@ def debug_data(
sys.exit(1)
msg.good("Corpus is loadable")
# Create all gold data here to avoid iterating over the train_docs constantly
gold_train_data = _compile_gold(train_docs, pipeline)
gold_train_unpreprocessed_data = _compile_gold(train_docs_unpreprocessed, pipeline)
gold_dev_data = _compile_gold(dev_docs, pipeline)
# Create all gold data here to avoid iterating over the train_dataset constantly
gold_train_data = _compile_gold(train_dataset, pipeline)
gold_train_unpreprocessed_data = _compile_gold(train_dataset_unpreprocessed, pipeline)
gold_dev_data = _compile_gold(dev_dataset, pipeline)
train_texts = gold_train_data["texts"]
dev_texts = gold_dev_data["texts"]
@ -118,19 +118,19 @@ def debug_data(
msg.text("Starting with base model '{}'".format(base_model))
else:
msg.text("Starting with blank model '{}'".format(lang))
msg.text("{} training docs".format(len(train_docs)))
msg.text("{} evaluation docs".format(len(dev_docs)))
msg.text("{} training docs".format(len(train_dataset)))
msg.text("{} evaluation docs".format(len(gold_dev_data)))
overlap = len(train_texts.intersection(dev_texts))
if overlap:
msg.warn("{} training examples also in evaluation data".format(overlap))
else:
msg.good("No overlap between training and evaluation data")
if not base_model and len(train_docs) < BLANK_MODEL_THRESHOLD:
if not base_model and len(train_dataset) < BLANK_MODEL_THRESHOLD:
text = "Low number of examples to train from a blank model ({})".format(
len(train_docs)
len(train_dataset)
)
if len(train_docs) < BLANK_MODEL_MIN_THRESHOLD:
if len(train_dataset) < BLANK_MODEL_MIN_THRESHOLD:
msg.fail(text)
else:
msg.warn(text)
@ -238,7 +238,7 @@ def debug_data(
has_low_data_warning = True
with msg.loading("Analyzing label distribution..."):
neg_docs = _get_examples_without_label(train_docs, label)
neg_docs = _get_examples_without_label(train_dataset, label)
if neg_docs == 0:
msg.warn(
"No examples for texts WITHOUT new label '{}'".format(label)
@ -358,7 +358,7 @@ def debug_data(
msg.info(
"Found {} sentence{} with an average length of {:.1f} words.".format(
gold_train_data["n_sents"],
"s" if len(train_docs) > 1 else "",
"s" if len(train_dataset) > 1 else "",
gold_train_data["n_words"] / gold_train_data["n_sents"],
)
)
@ -536,7 +536,7 @@ def _load_file(file_path, msg):
)
def _compile_gold(train_docs, pipeline):
def _compile_gold(examples, pipeline):
data = {
"ner": Counter(),
"cats": Counter(),
@ -553,7 +553,9 @@ def _compile_gold(train_docs, pipeline):
"n_cats_multilabel": 0,
"texts": set(),
}
for doc, gold in train_docs:
for example in examples:
gold = example.gold
doc = example.doc
valid_words = [x for x in gold.words if x is not None]
data["words"].update(valid_words)
data["n_words"] += len(valid_words)
@ -598,8 +600,8 @@ def _format_labels(labels, counts=False):
def _get_examples_without_label(data, label):
count = 0
for doc, gold in data:
labels = [label.split("-")[1] for label in gold.ner if label not in ("O", "-")]
for ex in data:
labels = [label.split("-")[1] for label in ex.gold.ner if label not in ("O", "-")]
if label not in labels:
count += 1
return count

View File

@ -45,11 +45,11 @@ def evaluate(
msg.fail("Visualization output directory not found", displacy_path, exits=1)
corpus = GoldCorpus(data_path, data_path)
nlp = util.load_model(model)
dev_docs = list(corpus.dev_docs(nlp, gold_preproc=gold_preproc))
dev_dataset = list(corpus.dev_dataset(nlp, gold_preproc=gold_preproc))
begin = timer()
scorer = nlp.evaluate(dev_docs, verbose=False)
scorer = nlp.evaluate(dev_dataset, verbose=False)
end = timer()
nwords = sum(len(doc_gold[0]) for doc_gold in dev_docs)
nwords = sum(len(ex.doc) for ex in dev_dataset)
results = {
"Time": "%.2f s" % (end - begin),
"Words": nwords,
@ -66,7 +66,7 @@ def evaluate(
msg.table(results, title="Results")
if displacy_path:
docs, golds = zip(*dev_docs)
docs = [ex.doc for ex in dev_dataset]
render_deps = "parser" in nlp.meta.get("pipeline", [])
render_ents = "ner" in nlp.meta.get("pipeline", [])
render_parses(

View File

@ -14,6 +14,7 @@ from thinc.neural.util import prefer_gpu
from wasabi import Printer
import srsly
from spacy.gold import Example
from ..errors import Errors
from ..tokens import Doc
from ..attrs import ID, HEAD
@ -221,7 +222,7 @@ def pretrain(
skip_counter = 0
for epoch in range(epoch_start, n_iter + epoch_start):
for batch_id, batch in enumerate(
util.minibatch_by_words(((text, None) for text in texts), size=batch_size)
util.minibatch_by_words((Example(doc=text) for text in texts), size=batch_size)
):
docs, count = make_docs(
nlp,

View File

@ -236,7 +236,7 @@ def train(
optimizer = create_default_optimizer(Model.ops)
else:
# Start with a blank model, call begin_training
optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu)
optimizer = nlp.begin_training(lambda: corpus.train_examples, device=use_gpu)
nlp._optimizer = None
@ -261,7 +261,7 @@ def train(
"problem with two labels.".format(textcat_positive_label),
exits=1,
)
train_docs = corpus.train_docs(
train_data = corpus.train_data(
nlp,
noise_level=noise_level,
gold_preproc=gold_preproc,
@ -271,9 +271,9 @@ def train(
train_labels = set()
if textcat_multilabel:
multilabel_found = False
for text, gold in train_docs:
train_labels.update(gold.cats.keys())
if list(gold.cats.values()).count(1.0) != 1:
for ex in train_data:
train_labels.update(ex.gold.cats.keys())
if list(ex.gold.cats.values()).count(1.0) != 1:
multilabel_found = True
if not multilabel_found and not base_model:
msg.warn(
@ -283,9 +283,9 @@ def train(
"mutually-exclusive classes."
)
if not textcat_multilabel:
for text, gold in train_docs:
train_labels.update(gold.cats.keys())
if list(gold.cats.values()).count(1.0) != 1 and not base_model:
for ex in train_data:
train_labels.update(ex.gold.cats.keys())
if list(ex.gold.cats.values()).count(1.0) != 1 and not base_model:
msg.warn(
"Some textcat training instances do not have exactly "
"one positive label. Modifying training options to "
@ -341,7 +341,7 @@ def train(
iter_since_best = 0
best_score = 0.0
for i in range(n_iter):
train_docs = corpus.train_docs(
train_data = corpus.train_data(
nlp,
noise_level=noise_level,
orth_variant_level=orth_variant_level,
@ -357,13 +357,11 @@ def train(
words_seen = 0
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
losses = {}
for batch in util.minibatch_by_words(train_docs, size=batch_sizes):
for batch in util.minibatch_by_words(train_data, size=batch_sizes):
if not batch:
continue
docs, golds = zip(*batch)
nlp.update(
docs,
golds,
batch,
sgd=optimizer,
drop=next(dropout_rates),
losses=losses,
@ -373,6 +371,7 @@ def train(
# which use unlabelled data to reduce overfitting.
raw_batch = list(next(raw_batches))
nlp.rehearse(raw_batch, sgd=optimizer, losses=losses)
docs = [ex.doc for ex in batch]
if not int(os.environ.get("LOG_FRIENDLY", 0)):
pbar.update(sum(len(doc) for doc in docs))
words_seen += sum(len(doc) for doc in docs)
@ -385,16 +384,16 @@ def train(
for name, component in nlp_loaded.pipeline:
if hasattr(component, "cfg"):
component.cfg["beam_width"] = beam_width
dev_docs = list(
corpus.dev_docs(
dev_dataset = list(
corpus.dev_dataset(
nlp_loaded,
gold_preproc=gold_preproc,
ignore_misaligned=True,
)
)
nwords = sum(len(doc_gold[0]) for doc_gold in dev_docs)
nwords = sum(len(ex.doc) for ex in dev_dataset)
start_time = timer()
scorer = nlp_loaded.evaluate(dev_docs, verbose=verbose)
scorer = nlp_loaded.evaluate(dev_dataset, verbose=verbose)
end_time = timer()
if use_gpu < 0:
gpu_wps = None
@ -406,15 +405,15 @@ def train(
for name, component in nlp_loaded.pipeline:
if hasattr(component, "cfg"):
component.cfg["beam_width"] = beam_width
dev_docs = list(
corpus.dev_docs(
dev_dataset = list(
corpus.dev_dataset(
nlp_loaded,
gold_preproc=gold_preproc,
ignore_misaligned=True,
)
)
start_time = timer()
scorer = nlp_loaded.evaluate(dev_docs, verbose=verbose)
scorer = nlp_loaded.evaluate(dev_dataset, verbose=verbose)
end_time = timer()
cpu_wps = nwords / (end_time - start_time)
acc_loc = output_path / ("model%d" % i) / "accuracy.json"

View File

@ -530,6 +530,12 @@ class Errors(object):
"{obj}.{attr}\nAttribute '{attr}' does not exist on {obj}.")
E186 = ("'{tok_a}' and '{tok_b}' are different texts.")
# TODO: fix numbering after merging develop into master
E998 = ("Can only create GoldParse's from Example's without a Doc, "
"if get_gold_parses() is called with a Vocab object.")
E999 = ("Encountered an unexpected format for the dictionary holding "
"gold annotations: {gold_dict}")
@add_codes
class TempErrors(object):

View File

@ -1,6 +1,6 @@
from cymem.cymem cimport Pool
from .structs cimport TokenC
from spacy.tokens import Doc
from .typedefs cimport attr_t
from .syntax.transition_system cimport Transition
@ -19,6 +19,7 @@ cdef class GoldParse:
cdef Pool mem
cdef GoldParseC c
cdef readonly TokenAnnotation orig
cdef int length
cdef public int loss
@ -29,13 +30,36 @@ cdef class GoldParse:
cdef public list labels
cdef public dict orths
cdef public list ner
cdef public list ents
cdef public dict brackets
cdef public object cats
cdef public dict cats
cdef public dict links
cdef readonly list cand_to_gold
cdef readonly list gold_to_cand
cdef readonly list orig_annot
cdef class TokenAnnotation:
cdef public list ids
cdef public list words
cdef public list tags
cdef public list heads
cdef public list deps
cdef public list entities
cdef public list morphology
cdef public list brackets
cdef class DocAnnotation:
cdef public object cats
cdef public object links
cdef class Example:
cdef public object doc
cdef public list token_annotations
cdef public DocAnnotation doc_annotation
cdef public object make_projective
cdef public object ignore_misaligned
cdef public object goldparse

View File

@ -14,11 +14,8 @@ import srsly
from .syntax import nonproj
from .tokens import Doc, Span
from .errors import Errors, AlignmentError
from .compat import path2str
from .compat import path2str, basestring_
from . import util
from .util import minibatch, itershuffle
from libc.stdio cimport FILE, fopen, fclose, fread, fwrite, feof, fseek
USE_NEW_ALIGN = False
@ -54,25 +51,6 @@ def tags_to_entities(tags):
return entities
def merge_sents(sents):
m_deps = [[], [], [], [], [], []]
m_cats = {}
m_brackets = []
i = 0
for (ids, words, tags, heads, labels, ner), (cats, brackets) in sents:
m_deps[0].extend(id_ + i for id_ in ids)
m_deps[1].extend(words)
m_deps[2].extend(tags)
m_deps[3].extend(head + i for head in heads)
m_deps[4].extend(labels)
m_deps[5].extend(ner)
m_brackets.extend((b["first"] + i, b["last"] + i, b["label"])
for b in brackets)
m_cats.update(cats)
i += len(ids)
return [(m_deps, (m_cats, m_brackets))]
_ALIGNMENT_NORM_MAP = [("``", "'"), ("''", "'"), ('"', "'"), ("`", "'")]
@ -211,14 +189,14 @@ class GoldCorpus(object):
def __init__(self, train, dev, gold_preproc=False, limit=None):
"""Create a GoldCorpus.
train_path (unicode or Path): File or directory of training data.
dev_path (unicode or Path): File or directory of development data.
train (unicode or Path): File or directory of training data.
dev (unicode or Path): File or directory of development data.
RETURNS (GoldCorpus): The newly created object.
"""
self.limit = limit
if isinstance(train, str) or isinstance(train, Path):
train = self.read_tuples(self.walk_corpus(train))
dev = self.read_tuples(self.walk_corpus(dev))
train = self.read_examples(self.walk_corpus(train))
dev = self.read_examples(self.walk_corpus(dev))
# Write temp directory with one doc per file, so we can shuffle and stream
self.tmp_dir = Path(tempfile.mkdtemp())
self.write_msgpack(self.tmp_dir / "train", train, limit=self.limit)
@ -228,13 +206,15 @@ class GoldCorpus(object):
shutil.rmtree(path2str(self.tmp_dir))
@staticmethod
def write_msgpack(directory, doc_tuples, limit=0):
def write_msgpack(directory, examples, limit=0):
if not directory.exists():
directory.mkdir()
n = 0
for i, doc_tuple in enumerate(doc_tuples):
srsly.write_msgpack(directory / "{}.msg".format(i), [doc_tuple])
n += len(doc_tuple[1])
for i, example in enumerate(examples):
ex_dict = example.to_dict()
text = example.text
srsly.write_msgpack(directory / "{}.msg".format(i), (text, ex_dict))
n += len(example.token_annotations)
if limit and n >= limit:
break
@ -259,128 +239,144 @@ class GoldCorpus(object):
return locs
@staticmethod
def read_tuples(locs, limit=0):
def read_examples(locs, limit=0):
""" Yield training examples """
i = 0
for loc in locs:
loc = util.ensure_path(loc)
if loc.parts[-1].endswith("json"):
gold_tuples = read_json_file(loc)
examples = read_json_file(loc)
elif loc.parts[-1].endswith("jsonl"):
gold_tuples = srsly.read_jsonl(loc)
first_gold_tuple = next(gold_tuples)
gold_tuples = itertools.chain([first_gold_tuple], gold_tuples)
# TODO: proper format checks with schemas
if isinstance(first_gold_tuple, dict):
gold_tuples = read_json_object(gold_tuples)
if first_gold_tuple.get("paragraphs", None):
examples = read_json_object(gold_tuples)
elif first_gold_tuple.get("doc_annotation", None):
examples = []
for ex_dict in gold_tuples:
doc = ex_dict.get("doc", None)
if doc is None:
doc = ex_dict.get("text", None)
examples.append(Example.from_dict(ex_dict, doc=doc))
elif loc.parts[-1].endswith("msg"):
gold_tuples = srsly.read_msgpack(loc)
text, ex_dict = srsly.read_msgpack(loc)
examples = [Example.from_dict(ex_dict, doc=text)]
else:
supported = ("json", "jsonl", "msg")
raise ValueError(Errors.E124.format(path=path2str(loc), formats=supported))
for item in gold_tuples:
yield item
i += len(item[1])
for example in examples:
yield example
i += len(example.token_annotations)
if limit and i >= limit:
return
@property
def dev_tuples(self):
def dev_examples(self):
locs = (self.tmp_dir / "dev").iterdir()
yield from self.read_tuples(locs, limit=self.limit)
yield from self.read_examples(locs, limit=self.limit)
@property
def train_tuples(self):
def train_examples(self):
locs = (self.tmp_dir / "train").iterdir()
yield from self.read_tuples(locs, limit=self.limit)
yield from self.read_examples(locs, limit=self.limit)
def count_train(self):
# TODO: should this count words or sentences ?
n = 0
i = 0
for raw_text, paragraph_tuples in self.train_tuples:
for sent_tuples, brackets in paragraph_tuples:
n += len(sent_tuples[1])
for example in self.train_examples:
for token_annotation in example.token_annotations:
n += len(token_annotation.words)
if self.limit and i >= self.limit:
break
i += 1
return n
def train_docs(self, nlp, gold_preproc=False, max_length=None,
def train_dataset(self, nlp, gold_preproc=False, max_length=None,
noise_level=0.0, orth_variant_level=0.0,
ignore_misaligned=False):
locs = list((self.tmp_dir / 'train').iterdir())
random.shuffle(locs)
train_tuples = self.read_tuples(locs, limit=self.limit)
gold_docs = self.iter_gold_docs(nlp, train_tuples, gold_preproc,
train_examples = self.read_examples(locs, limit=self.limit)
gold_examples = self.iter_gold_docs(nlp, train_examples, gold_preproc,
max_length=max_length,
noise_level=noise_level,
orth_variant_level=orth_variant_level,
make_projective=True,
ignore_misaligned=ignore_misaligned)
yield from gold_docs
yield from gold_examples
def train_docs_without_preprocessing(self, nlp, gold_preproc=False):
gold_docs = self.iter_gold_docs(nlp, self.train_tuples, gold_preproc=gold_preproc)
yield from gold_docs
def train_dataset_without_preprocessing(self, nlp, gold_preproc=False):
examples = self.iter_gold_docs(nlp, self.train_examples, gold_preproc=gold_preproc)
yield from examples
def dev_docs(self, nlp, gold_preproc=False, ignore_misaligned=False):
gold_docs = self.iter_gold_docs(nlp, self.dev_tuples, gold_preproc=gold_preproc,
def dev_dataset(self, nlp, gold_preproc=False, ignore_misaligned=False):
examples = self.iter_gold_docs(nlp, self.dev_examples, gold_preproc=gold_preproc,
ignore_misaligned=ignore_misaligned)
yield from gold_docs
yield from examples
@classmethod
def iter_gold_docs(cls, nlp, tuples, gold_preproc, max_length=None,
def iter_gold_docs(cls, nlp, examples, gold_preproc, max_length=None,
noise_level=0.0, orth_variant_level=0.0, make_projective=False,
ignore_misaligned=False):
for raw_text, paragraph_tuples in tuples:
""" Setting gold_preproc will result in creating a doc per 'sentence' """
for example in examples:
if gold_preproc:
raw_text = None
example.doc = None
else:
paragraph_tuples = merge_sents(paragraph_tuples)
docs, paragraph_tuples = cls._make_docs(nlp, raw_text,
paragraph_tuples, gold_preproc, noise_level=noise_level,
example = example.merge_sents()
example.make_projective = make_projective
example.ignore_misaligned = ignore_misaligned
examples = cls._make_docs(nlp, example,
gold_preproc, noise_level=noise_level,
orth_variant_level=orth_variant_level)
golds = cls._make_golds(docs, paragraph_tuples, make_projective,
ignore_misaligned=ignore_misaligned)
for doc, gold in zip(docs, golds):
if gold is not None:
if (not max_length) or len(doc) < max_length:
yield doc, gold
examples = cls._make_golds(examples, vocab=nlp.vocab)
for ex in examples:
if ex.gold is not None:
if (not max_length) or len(ex.doc) < max_length:
yield ex
@classmethod
def _make_docs(cls, nlp, raw_text, paragraph_tuples, gold_preproc, noise_level=0.0, orth_variant_level=0.0):
if raw_text is not None:
raw_text, paragraph_tuples = make_orth_variants(nlp, raw_text, paragraph_tuples, orth_variant_level=orth_variant_level)
raw_text = add_noise(raw_text, noise_level)
return [nlp.make_doc(raw_text)], paragraph_tuples
def _make_docs(cls, nlp, example, gold_preproc, noise_level=0.0, orth_variant_level=0.0):
# gold_preproc is not used ?!
if example.text is not None:
var_example = make_orth_variants(nlp, example, orth_variant_level=orth_variant_level)
var_text = add_noise(var_example.text, noise_level)
var_doc = nlp.make_doc(var_text)
var_example.doc = var_doc
return [var_example]
else:
docs = []
raw_text, paragraph_tuples = make_orth_variants(nlp, None, paragraph_tuples, orth_variant_level=orth_variant_level)
return [Doc(nlp.vocab, words=add_noise(sent_tuples[1], noise_level))
for (sent_tuples, brackets) in paragraph_tuples], paragraph_tuples
var_example = make_orth_variants(nlp, example, orth_variant_level=orth_variant_level)
doc_examples = []
for token_annotation in var_example.token_annotations:
t_doc = Doc(nlp.vocab, words=add_noise(token_annotation.words, noise_level))
doc_example = Example(doc_annotation=example.doc_annotation,
token_annotations=[token_annotation],
doc=t_doc)
doc_examples.append(doc_example)
return doc_examples
@classmethod
def _make_golds(cls, docs, paragraph_tuples, make_projective, ignore_misaligned=False):
if len(docs) != len(paragraph_tuples):
n_annots = len(paragraph_tuples)
raise ValueError(Errors.E070.format(n_docs=len(docs), n_annots=n_annots))
golds = []
for doc, (sent_tuples, (cats, brackets)) in zip(docs, paragraph_tuples):
try:
gold = GoldParse.from_annot_tuples(doc, sent_tuples, cats=cats,
make_projective=make_projective)
except AlignmentError:
if ignore_misaligned:
gold = None
else:
raise
golds.append(gold)
return golds
def _make_golds(cls, examples, vocab=None):
gold_examples = []
for example in examples:
gold_parses = example.get_gold_parses(vocab=vocab)
for (doc, gold) in gold_parses:
ex = Example(doc=doc)
ex.goldparse = gold
gold_examples.append(ex)
return gold_examples
def make_orth_variants(nlp, raw, paragraph_tuples, orth_variant_level=0.0):
def make_orth_variants(nlp, example, orth_variant_level=0.0):
if random.random() >= orth_variant_level:
return raw, paragraph_tuples
return example
if not example.token_annotations:
return example
raw = example.text
if random.random() >= 0.5:
lower = True
if raw is not None:
@ -388,9 +384,15 @@ def make_orth_variants(nlp, raw, paragraph_tuples, orth_variant_level=0.0):
ndsv = nlp.Defaults.single_orth_variants
ndpv = nlp.Defaults.paired_orth_variants
# modify words in paragraph_tuples
variant_paragraph_tuples = []
for sent_tuples, brackets in paragraph_tuples:
ids, words, tags, heads, labels, ner = sent_tuples
variant_example = Example(doc=raw)
for token_annotation in example.token_annotations:
words = token_annotation.words
tags = token_annotation.tags
if not words or not tags:
# add the unmodified annotation
token_dict = token_annotation.to_dict()
variant_example.add_token_annotation(**token_dict)
else:
if lower:
words = [w.lower() for w in words]
# single variants
@ -419,7 +421,10 @@ def make_orth_variants(nlp, raw, paragraph_tuples, orth_variant_level=0.0):
pair_idx = pair.index(words[word_idx])
words[word_idx] = punct_choices[punct_idx][pair_idx]
variant_paragraph_tuples.append(((ids, words, tags, heads, labels, ner), brackets))
token_dict = token_annotation.to_dict()
token_dict["words"] = words
token_dict["tags"] = tags
variant_example.add_token_annotation(**token_dict)
# modify raw to match variant_paragraph_tuples
if raw is not None:
variants = []
@ -437,9 +442,8 @@ def make_orth_variants(nlp, raw, paragraph_tuples, orth_variant_level=0.0):
while raw_idx < len(raw) and re.match("\s", raw[raw_idx]):
variant_raw += raw[raw_idx]
raw_idx += 1
for sent_tuples, brackets in variant_paragraph_tuples:
ids, words, tags, heads, labels, ner = sent_tuples
for word in words:
for token_annotation in variant_example.token_annotations:
for word in token_annotation.words:
match_found = False
# add identical word
if word not in variants and raw[raw_idx:].startswith(word):
@ -457,13 +461,14 @@ def make_orth_variants(nlp, raw, paragraph_tuples, orth_variant_level=0.0):
# something went wrong, abort
# (add a warning message?)
if not match_found:
return raw, paragraph_tuples
return example
# add following whitespace
while raw_idx < len(raw) and re.match("\s", raw[raw_idx]):
variant_raw += raw[raw_idx]
raw_idx += 1
return variant_raw, variant_paragraph_tuples
return raw, variant_paragraph_tuples
variant_example.doc = variant_raw
return variant_example
return variant_example
def add_noise(orig, noise_level):
@ -488,30 +493,27 @@ def _corrupt(c, noise_level):
def read_json_object(json_corpus_section):
"""Take a list of JSON-formatted documents (e.g. from an already loaded
training data file) and yield tuples in the GoldParse format.
training data file) and yield annotations in the GoldParse format.
json_corpus_section (list): The data.
YIELDS (tuple): The reformatted data.
YIELDS (Example): The reformatted data - one training example per paragraph
"""
for json_doc in json_corpus_section:
tuple_doc = json_to_tuple(json_doc)
for tuple_paragraph in tuple_doc:
yield tuple_paragraph
examples = json_to_examples(json_doc)
for ex in examples:
yield ex
def json_to_tuple(doc):
"""Convert an item in the JSON-formatted training data to the tuple format
def json_to_examples(doc):
"""Convert an item in the JSON-formatted training data to the format
used by GoldParse.
doc (dict): One entry in the training data.
YIELDS (tuple): The reformatted data.
YIELDS (Example): The reformatted data - one training example per paragraph
"""
paragraphs = []
for paragraph in doc["paragraphs"]:
sents = []
cats = {}
for cat in paragraph.get("cats", {}):
cats[cat["label"]] = cat["value"]
example = Example(doc=paragraph.get("raw", None))
for sent in paragraph["sentences"]:
words = []
ids = []
@ -529,11 +531,14 @@ def json_to_tuple(doc):
if labels[-1].lower() == "root":
labels[-1] = "ROOT"
ner.append(token.get("ner", "-"))
sents.append([
[ids, words, tags, heads, labels, ner],
[cats, sent.get("brackets", [])]])
if sents:
yield [paragraph.get("raw", None), sents]
example.add_token_annotation(ids=ids, words=words, tags=tags,
heads=heads, deps=labels, entities=ner,
brackets=sent.get("brackets", []))
cats = {}
for cat in paragraph.get("cats", {}):
cats[cat["label"]] = cat["value"]
example.add_doc_annotation(cats=cats)
yield example
def read_json_file(loc, docs_filter=None, limit=None):
@ -545,8 +550,8 @@ def read_json_file(loc, docs_filter=None, limit=None):
for doc in _json_iterate(loc):
if docs_filter is not None and not docs_filter(doc):
continue
for json_tuple in json_to_tuple(doc):
yield json_tuple
for json_data in json_to_examples(doc):
yield json_data
def _json_iterate(loc):
@ -639,21 +644,254 @@ def _consume_ent(tags):
return [start] + middle + [end]
cdef class TokenAnnotation:
def __init__(self, ids=None, words=None, tags=None, heads=None, deps=None, entities=None, morphology=None, brackets=None):
self.ids = ids if ids else []
self.words = words if words else []
self.tags = tags if tags else []
self.heads = heads if heads else []
self.deps = deps if deps else []
self.entities = entities if entities else []
self.brackets = brackets if brackets else []
self.morphology = morphology if morphology else []
@classmethod
def from_dict(cls, token_dict):
return cls(ids=token_dict.get("ids", None),
words=token_dict.get("words", None),
tags=token_dict.get("tags", None),
heads=token_dict.get("heads", None),
deps=token_dict.get("deps", None),
entities=token_dict.get("entities", None),
morphology=token_dict.get("morphology", None),
brackets=token_dict.get("brackets", None))
def to_dict(self):
return {"ids": self.ids,
"words": self.words,
"tags": self.tags,
"heads": self.heads,
"deps": self.deps,
"entities": self.entities,
"morphology": self.morphology,
"brackets": self.brackets}
cdef class DocAnnotation:
def __init__(self, cats=None, links=None):
self.cats = cats if cats else {}
self.links = links if links else {}
@classmethod
def from_dict(cls, doc_dict):
return cls(cats=doc_dict.get("cats", None), links=doc_dict.get("links", None))
def to_dict(self):
return {"cats": self.cats, "links": self.links}
cdef class Example:
def __init__(self, doc_annotation=None, token_annotations=None, doc=None,
make_projective=False, ignore_misaligned=False, goldparse=None):
""" Doc can either be text, or an actual Doc """
self.doc = doc
self.doc_annotation = doc_annotation if doc_annotation else DocAnnotation()
self.token_annotations = token_annotations if token_annotations else []
self.make_projective = make_projective
self.ignore_misaligned = ignore_misaligned
self.goldparse = goldparse
@classmethod
def from_gold(cls, goldparse, doc=None):
doc_annotation = DocAnnotation(cats=goldparse.cats, links=goldparse.links)
token_annotation = goldparse.get_token_annotation()
return cls(doc_annotation, [token_annotation], doc)
@classmethod
def from_dict(cls, example_dict, doc=None):
token_dicts = example_dict["token_annotations"]
token_annotations = [TokenAnnotation.from_dict(t) for t in token_dicts]
doc_dict = example_dict["doc_annotation"]
doc_annotation = DocAnnotation.from_dict(doc_dict)
return cls(doc_annotation, token_annotations, doc)
def to_dict(self):
""" Note that this method does NOT export the doc, only the annotations ! """
token_dicts = [t.to_dict() for t in self.token_annotations]
doc_dict = self.doc_annotation.to_dict()
return {"token_annotations": token_dicts, "doc_annotation": doc_dict}
@property
def text(self):
if self.doc is None:
return None
if isinstance(self.doc, Doc):
return self.doc.text
return self.doc
@property
def gold(self):
if self.goldparse is None:
doc, gold = self.get_gold_parses(merge=True)[0]
self.goldparse = gold
return self.goldparse
def add_token_annotation(self, ids=None, words=None, tags=None, heads=None,
deps=None, entities=None, morphology=None, brackets=None):
t = TokenAnnotation(ids=ids, words=words, tags=tags,
heads=heads, deps=deps, entities=entities,
morphology=morphology, brackets=brackets)
self.token_annotations.append(t)
def add_doc_annotation(self, cats=None, links=None):
if cats:
self.doc_annotation.cats.update(cats)
if links:
self.doc_annotation.links.update(links)
def merge_sents(self):
""" Merge the list of token annotations into one object and return this new object """
m_example = Example(doc=self.doc, doc_annotation=self.doc_annotation)
m_ids, m_words, m_tags, m_heads, m_deps, m_ents, m_morph = [], [], [], [], [], [], []
m_brackets = []
i = 0
for t in self.token_annotations:
m_ids.extend(id_ + i for id_ in t.ids)
m_words.extend(t.words)
m_tags.extend(t.tags)
m_heads.extend(head + i if head else None for head in t.heads)
m_deps.extend(t.deps)
m_ents.extend(t.entities)
m_morph.extend(t.morphology)
m_brackets.extend((b["first"] + i, b["last"] + i, b["label"])
for b in t.brackets)
i += len(t.ids)
m_example.add_token_annotation(ids=m_ids, words=m_words, tags=m_tags,
heads=m_heads, deps=m_deps, entities=m_ents,
morphology=m_morph, brackets=m_brackets)
return m_example
def get_gold_parses(self, merge=False, vocab=None):
"""Return a list of (doc, GoldParse) objects.
If merge is set to True, add all Token annotations to one big list."""
d = self.doc_annotation
# merging different sentences
if merge:
merged_example = self.merge_sents()
assert(len(merged_example.token_annotations)) == 1
t = merged_example.token_annotations[0]
m_doc = merged_example.doc
if not m_doc:
if not vocab:
raise ValueError(Errors.E998)
m_doc = Doc(vocab, words=t.words)
try:
gp = GoldParse.from_annotation(m_doc, d, t, make_projective=self.make_projective)
except AlignmentError:
if self.ignore_misaligned:
gp = None
else:
raise
return [(self.doc, gp)]
# we only have one sentence and an appropriate doc
elif len(self.token_annotations) == 1 and self.doc is not None:
t = self.token_annotations[0]
try:
gp = GoldParse.from_annotation(self.doc, d, t, make_projective=self.make_projective)
except AlignmentError:
if self.ignore_misaligned:
gp = None
else:
raise
return [(self.doc, gp)]
# not merging: one GoldParse per 'sentence', defining docs with the words from each sentence
else:
parses = []
for t in self.token_annotations:
if not vocab:
raise ValueError(Errors.E998)
t_doc = Doc(vocab, words=t.words)
try:
gp = GoldParse.from_annotation(t_doc, d, t, make_projective=self.make_projective)
except AlignmentError:
if self.ignore_misaligned:
gp = None
else:
raise
if gp is not None:
parses.append((t_doc, gp))
return parses
@classmethod
def to_example_objects(cls, examples, make_doc=None, keep_raw_text=False):
"""
Return a list of Example objects, from a variety of input formats.
make_doc needs to be provided when the examples contain text strings and keep_raw_text=False
"""
if isinstance(examples, Example):
return [examples]
if isinstance(examples, tuple):
examples = [examples]
converted_examples = []
for ex in examples:
# convert string to Doc to Example
if isinstance(ex, basestring_):
if keep_raw_text:
converted_examples.append(Example(doc=ex))
else:
doc = make_doc(ex)
converted_examples.append(Example(doc=doc))
# convert Doc to Example
elif isinstance(ex, Doc):
converted_examples.append(Example(doc=ex))
# convert tuples to Example
elif isinstance(ex, tuple) and len(ex) == 2:
doc, gold = ex
gold_dict = {}
# convert string to Doc
if isinstance(doc, basestring_) and not keep_raw_text:
doc = make_doc(doc)
# convert dict to GoldParse
if isinstance(gold, dict):
gold_dict = gold
if doc is not None or gold.get("words", None) is not None:
gold = GoldParse(doc, **gold)
else:
gold = None
if gold is not None:
converted_examples.append(Example.from_gold(goldparse=gold, doc=doc))
else:
raise ValueError(Errors.E999.format(gold_dict=gold_dict))
else:
converted_examples.append(ex)
return converted_examples
cdef class GoldParse:
"""Collection for training annotations.
DOCS: https://spacy.io/api/goldparse
"""
@classmethod
def from_annot_tuples(cls, doc, annot_tuples, cats=None, make_projective=False):
_, words, tags, heads, deps, entities = annot_tuples
return cls(doc, words=words, tags=tags, heads=heads, deps=deps,
entities=entities, cats=cats,
def from_annotation(cls, doc, doc_annotation, token_annotation, make_projective=False):
return cls(doc, words=token_annotation.words, tags=token_annotation.tags,
heads=token_annotation.heads, deps=token_annotation.deps, entities=token_annotation.entities,
morphology=token_annotation.morphology, cats=doc_annotation.cats, links=doc_annotation.links,
make_projective=make_projective)
def __init__(self, doc, annot_tuples=None, words=None, tags=None, morphology=None,
def get_token_annotation(self):
ids = None
if self.words:
ids = list(range(len(self.words)))
return TokenAnnotation(ids=ids, words=self.words, tags=self.tags,
heads=self.heads, deps=self.labels, entities=self.ner,
morphology=self.morphology)
def __init__(self, doc, words=None, tags=None, morphology=None,
heads=None, deps=None, entities=None, make_projective=False,
cats=None, links=None, **_):
cats=None, links=None):
"""Create a GoldParse. The fields will not be initialized if len(doc) is zero.
doc (Doc): The document the annotations refer to.
@ -688,19 +926,19 @@ cdef class GoldParse:
self.length = len(doc)
self.cats = {} if cats is None else dict(cats)
self.links = links
self.links = {} if links is None else dict(links)
# avoid allocating memory if the doc does not contain any tokens
if self.length > 0:
if words is None:
if not words:
words = [token.text for token in doc]
if tags is None:
if not tags:
tags = [None for _ in words]
if heads is None:
if not heads:
heads = [None for _ in words]
if deps is None:
if not deps:
deps = [None for _ in words]
if morphology is None:
if not morphology:
morphology = [None for _ in words]
if entities is None:
entities = ["-" for _ in words]
@ -710,7 +948,7 @@ cdef class GoldParse:
# Translate the None values to '-', to make processing easier.
# See Issue #2603
entities = [(ent if ent is not None else "-") for ent in entities]
if not isinstance(entities[0], basestring):
if not isinstance(entities[0], basestring_):
# Assume we have entities specified by character offset.
entities = biluo_tags_from_offsets(doc, entities)
@ -745,8 +983,9 @@ cdef class GoldParse:
self.cand_to_gold = [(j if j >= 0 else None) for j in i2j]
self.gold_to_cand = [(i if i >= 0 else None) for i in j2i]
annot_tuples = (range(len(words)), words, tags, heads, deps, entities)
self.orig_annot = list(zip(*annot_tuples))
self.orig = TokenAnnotation(ids=list(range(len(words))), words=words, tags=tags,
heads=heads, deps=deps, entities=entities, morphology=morphology,
brackets=[])
for i, gold_i in enumerate(self.cand_to_gold):
if doc[i].text.isspace():

View File

@ -3,6 +3,8 @@ from __future__ import absolute_import, unicode_literals
import random
import itertools
from spacy.gold import Example
from spacy.util import minibatch
import weakref
import functools
@ -409,7 +411,7 @@ class Language(object):
def __call__(self, text, disable=[], component_cfg=None):
"""Apply the pipeline to some text. The text can span multiple sentences,
and can contain arbtrary whitespace. Alignment into the original string
and can contain arbitrary whitespace. Alignment into the original string
is preserved.
text (unicode): The text to be processed.
@ -452,30 +454,10 @@ class Language(object):
def make_doc(self, text):
return self.tokenizer(text)
def _format_docs_and_golds(self, docs, golds):
"""Format golds and docs before update models."""
expected_keys = ("words", "tags", "heads", "deps", "entities", "cats", "links")
gold_objs = []
doc_objs = []
for doc, gold in zip(docs, golds):
if isinstance(doc, basestring_):
doc = self.make_doc(doc)
if not isinstance(gold, GoldParse):
unexpected = [k for k in gold if k not in expected_keys]
if unexpected:
err = Errors.E151.format(unexp=unexpected, exp=expected_keys)
raise ValueError(err)
gold = GoldParse(doc, **gold)
doc_objs.append(doc)
gold_objs.append(gold)
return doc_objs, gold_objs
def update(self, docs, golds, drop=0.0, sgd=None, losses=None, component_cfg=None):
def update(self, examples, drop=0.0, sgd=None, losses=None, component_cfg=None):
"""Update the models in the pipeline.
docs (iterable): A batch of `Doc` objects.
golds (iterable): A batch of `GoldParse` objects.
examples (iterable): A batch of `Example` or `Doc` objects.
drop (float): The dropout rate.
sgd (callable): An optimizer.
losses (dict): Dictionary to update with the loss, keyed by component.
@ -484,18 +466,16 @@ class Language(object):
DOCS: https://spacy.io/api/language#update
"""
if len(docs) != len(golds):
raise IndexError(Errors.E009.format(n_docs=len(docs), n_golds=len(golds)))
if len(docs) == 0:
if len(examples) == 0:
return
examples = Example.to_example_objects(examples, make_doc=self.make_doc)
if sgd is None:
if self._optimizer is None:
self._optimizer = create_default_optimizer(Model.ops)
sgd = self._optimizer
# Allow dict of args to GoldParse, instead of GoldParse objects.
docs, golds = self._format_docs_and_golds(docs, golds)
grads = {}
grads = {}
def get_grads(W, dW, key=None):
grads[key] = (W, dW)
@ -512,18 +492,18 @@ class Language(object):
grads = {}
kwargs = component_cfg.get(name, {})
kwargs.setdefault("drop", drop)
proc.update(docs, golds, sgd=get_grads, losses=losses, **kwargs)
proc.update(examples, sgd=get_grads, losses=losses, **kwargs)
for key, (W, dW) in grads.items():
sgd(W, dW, key=key)
def rehearse(self, docs, sgd=None, losses=None, config=None):
def rehearse(self, examples, sgd=None, losses=None, config=None):
"""Make a "rehearsal" update to the models in the pipeline, to prevent
forgetting. Rehearsal updates run an initial copy of the model over some
data, and update the model so its current predictions are more like the
initial ones. This is useful for keeping a pretrained model on-track,
even if you're updating it with a smaller set of examples.
docs (iterable): A batch of `Doc` objects.
examples (iterable): A batch of `Doc` objects.
drop (float): The dropout rate.
sgd (callable): An optimizer.
RETURNS (dict): Results from the update.
@ -531,22 +511,18 @@ class Language(object):
EXAMPLE:
>>> raw_text_batches = minibatch(raw_texts)
>>> for labelled_batch in minibatch(zip(train_docs, train_golds)):
>>> docs, golds = zip(*train_docs)
>>> nlp.update(docs, golds)
>>> nlp.update(labelled_batch)
>>> raw_batch = [nlp.make_doc(text) for text in next(raw_text_batches)]
>>> nlp.rehearse(raw_batch)
"""
# TODO: document
if len(docs) == 0:
if len(examples) == 0:
return
examples = Example.to_example_objects(examples, make_doc=self.make_doc)
if sgd is None:
if self._optimizer is None:
self._optimizer = create_default_optimizer(Model.ops)
sgd = self._optimizer
docs = list(docs)
for i, doc in enumerate(docs):
if isinstance(doc, basestring_):
docs[i] = self.make_doc(doc)
pipes = list(self.pipeline)
random.shuffle(pipes)
if config is None:
@ -563,44 +539,45 @@ class Language(object):
if not hasattr(proc, "rehearse"):
continue
grads = {}
proc.rehearse(docs, sgd=get_grads, losses=losses, **config.get(name, {}))
proc.rehearse(examples, sgd=get_grads, losses=losses, **config.get(name, {}))
for key, (W, dW) in grads.items():
sgd(W, dW, key=key)
return losses
def preprocess_gold(self, docs_golds):
def preprocess_gold(self, examples):
"""Can be called before training to pre-process gold data. By default,
it handles nonprojectivity and adds missing tags to the tag map.
docs_golds (iterable): Tuples of `Doc` and `GoldParse` objects.
YIELDS (tuple): Tuples of preprocessed `Doc` and `GoldParse` objects.
examples (iterable): `Example` objects.
YIELDS (tuple): `Example` objects.
"""
for name, proc in self.pipeline:
if hasattr(proc, "preprocess_gold"):
docs_golds = proc.preprocess_gold(docs_golds)
for doc, gold in docs_golds:
yield doc, gold
examples = proc.preprocess_gold(examples)
for ex in examples:
yield ex
def begin_training(self, get_gold_tuples=None, sgd=None, component_cfg=None, **cfg):
def begin_training(self, get_examples=None, sgd=None, component_cfg=None, **cfg):
"""Allocate models, pre-process training data and acquire a trainer and
optimizer. Used as a contextmanager.
get_gold_tuples (function): Function returning gold data
get_examples (function): Function returning example training data (TODO: document format change since 3.0)
component_cfg (dict): Config parameters for specific components.
**cfg: Config parameters.
RETURNS: An optimizer.
DOCS: https://spacy.io/api/language#begin_training
"""
if get_gold_tuples is None:
get_gold_tuples = lambda: []
# TODO: throw warning when get_gold_tuples is provided instead of get_examples
if get_examples is None:
get_examples = lambda: []
# Populate vocab
else:
for _, annots_brackets in get_gold_tuples():
_ = annots_brackets.pop()
for annots, _ in annots_brackets:
for word in annots[1]:
for example in get_examples():
for token_annotation in example.token_annotations:
for word in token_annotation.words:
_ = self.vocab[word] # noqa: F841
if cfg.get("device", -1) >= 0:
util.use_gpu(cfg["device"])
if self.vocab.vectors.data.shape[1] >= 1:
@ -618,7 +595,7 @@ class Language(object):
kwargs = component_cfg.get(name, {})
kwargs.update(cfg)
proc.begin_training(
get_gold_tuples,
get_examples,
pipeline=self.pipeline,
sgd=self._optimizer,
**kwargs
@ -650,11 +627,11 @@ class Language(object):
return self._optimizer
def evaluate(
self, docs_golds, verbose=False, batch_size=256, scorer=None, component_cfg=None
self, examples, verbose=False, batch_size=256, scorer=None, component_cfg=None
):
"""Evaluate a model's pipeline components.
docs_golds (iterable): Tuples of `Doc` and `GoldParse` objects.
examples (iterable): `Example` objects.
verbose (bool): Print debugging information.
batch_size (int): Batch size to use.
scorer (Scorer): Optional `Scorer` to use. If not passed in, a new one
@ -665,30 +642,24 @@ class Language(object):
DOCS: https://spacy.io/api/language#evaluate
"""
examples = Example.to_example_objects(examples, make_doc=self.make_doc)
if scorer is None:
scorer = Scorer(pipeline=self.pipeline)
if component_cfg is None:
component_cfg = {}
docs, golds = zip(*docs_golds)
docs = [
self.make_doc(doc) if isinstance(doc, basestring_) else doc for doc in docs
]
golds = list(golds)
for name, pipe in self.pipeline:
kwargs = component_cfg.get(name, {})
kwargs.setdefault("batch_size", batch_size)
if not hasattr(pipe, "pipe"):
docs = _pipe(pipe, docs, kwargs)
examples = _pipe(pipe, examples, kwargs)
else:
docs = pipe.pipe(docs, **kwargs)
for doc, gold in zip(docs, golds):
if not isinstance(gold, GoldParse):
gold = GoldParse(doc, **gold)
examples = pipe.pipe(examples, as_example=True, **kwargs)
for ex in examples:
if verbose:
print(doc)
print(ex.doc)
kwargs = component_cfg.get("scorer", {})
kwargs.setdefault("verbose", verbose)
scorer.score(doc, gold, **kwargs)
scorer.score(ex, **kwargs)
return scorer
@contextmanager
@ -733,6 +704,7 @@ class Language(object):
cleanup=False,
component_cfg=None,
n_process=1,
as_example=False
):
"""Process texts as a stream, and yield `Doc` objects in order.
@ -770,6 +742,7 @@ class Language(object):
batch_size=batch_size,
disable=disable,
component_cfg=component_cfg,
as_example=False
)
for doc, context in izip(docs, contexts):
yield (doc, context)
@ -1095,15 +1068,15 @@ class DisabledPipes(list):
self[:] = []
def _pipe(docs, proc, kwargs):
def _pipe(examples, proc, kwargs):
# We added some args for pipe that __call__ doesn't expect.
kwargs = dict(kwargs)
for arg in ["n_threads", "batch_size"]:
if arg in kwargs:
kwargs.pop(arg)
for doc in docs:
doc = proc(doc, **kwargs)
yield doc
for ex in examples:
ex = proc(ex, **kwargs)
yield ex
def _apply_pipes(make_doc, pipes, reciever, sender):

View File

@ -97,18 +97,19 @@ class Morphologizer(Pipe):
if doc[j].morph.pos != 0:
doc.c[j].pos = doc[j].morph.pos
def update(self, docs, golds, drop=0., sgd=None, losses=None):
def update(self, examples, drop=0., sgd=None, losses=None):
if losses is not None and self.name not in losses:
losses[self.name] = 0.
docs = [self._get_doc(ex) for ex in examples]
tag_scores, bp_tag_scores = self.model.begin_update(docs, drop=drop)
loss, d_tag_scores = self.get_loss(docs, golds, tag_scores)
loss, d_tag_scores = self.get_loss(examples, tag_scores)
bp_tag_scores(d_tag_scores, sgd=sgd)
if losses is not None:
losses[self.name] += loss
def get_loss(self, docs, golds, scores):
def get_loss(self, examples, scores):
guesses = []
for doc_scores in scores:
guesses.append(scores_to_guesses(doc_scores, self.model.softmax.out_sizes))
@ -122,7 +123,9 @@ class Morphologizer(Pipe):
# Do this on CPU, as we can't vectorize easily.
target = numpy.zeros(scores.shape, dtype='f')
field_sizes = self.model.softmax.out_sizes
for doc, gold in zip(docs, golds):
for example in examples:
doc = example.doc
gold = example.gold
for t, features in enumerate(gold.morphology):
if features is None:
target[idx] = scores[idx]
@ -146,6 +149,7 @@ class Morphologizer(Pipe):
scores = self.model.ops.asarray(scores, dtype='f')
d_scores = scores - target
loss = (d_scores**2).sum()
docs = [self._get_doc(ex) for ex in examples]
d_scores = self.model.ops.unflatten(d_scores, [len(d) for d in docs])
return float(loss), d_scores

View File

@ -13,6 +13,7 @@ from thinc.misc import LayerNorm
from thinc.neural.util import to_categorical
from thinc.neural.util import get_array_module
from spacy.gold import Example
from ..tokens.doc cimport Doc
from ..syntax.nn_parser cimport Parser
from ..syntax.ner cimport BiluoPushDown
@ -59,11 +60,17 @@ class Pipe(object):
def from_nlp(cls, nlp, **cfg):
return cls(nlp.vocab, **cfg)
def _get_doc(self, example):
""" Use this method if the `example` method can be both a Doc or an Example """
if isinstance(example, Doc):
return example
return example.doc
def __init__(self, vocab, model=True, **cfg):
"""Create a new pipe instance."""
raise NotImplementedError
def __call__(self, doc):
def __call__(self, example):
"""Apply the pipe to one document. The document is
modified in-place, and returned.
@ -71,12 +78,16 @@ class Pipe(object):
and `set_annotations()` methods.
"""
self.require_model()
doc = self._get_doc(example)
predictions = self.predict([doc])
if isinstance(predictions, tuple) and len(predictions) == 2:
scores, tensors = predictions
self.set_annotations([doc], scores, tensors=tensors)
else:
self.set_annotations([doc], predictions)
if isinstance(example, Example):
example.doc = doc
return example
return doc
def require_model(self):
@ -84,20 +95,29 @@ class Pipe(object):
if getattr(self, "model", None) in (None, True, False):
raise ValueError(Errors.E109.format(name=self.name))
def pipe(self, stream, batch_size=128, n_threads=-1):
def pipe(self, stream, batch_size=128, n_threads=-1, as_example=False):
"""Apply the pipe to a stream of documents.
Both __call__ and pipe should delegate to the `predict()`
and `set_annotations()` methods.
"""
for docs in util.minibatch(stream, size=batch_size):
docs = list(docs)
for examples in util.minibatch(stream, size=batch_size):
examples = list(examples)
docs = [self._get_doc(ex) for ex in examples]
predictions = self.predict(docs)
if isinstance(predictions, tuple) and len(tuple) == 2:
scores, tensors = predictions
self.set_annotations(docs, scores, tensors=tensors)
else:
self.set_annotations(docs, predictions)
if as_example:
examples = []
for ex, doc in zip(examples, docs):
ex.doc = doc
examples.append(ex)
yield from examples
else:
yield from docs
def predict(self, docs):
@ -111,7 +131,7 @@ class Pipe(object):
"""Modify a batch of documents, using pre-computed scores."""
raise NotImplementedError
def update(self, docs, golds, drop=0.0, sgd=None, losses=None):
def update(self, examples, drop=0.0, sgd=None, losses=None):
"""Learn from a batch of documents and gold-standard information,
updating the pipe's model.
@ -119,12 +139,12 @@ class Pipe(object):
"""
pass
def rehearse(self, docs, sgd=None, losses=None, **config):
def rehearse(self, examples, sgd=None, losses=None, **config):
pass
def get_loss(self, docs, golds, scores):
def get_loss(self, examples, scores):
"""Find the loss and gradient of loss for the batch of
documents and their predicted scores."""
examples (with embedded docs) and their predicted scores."""
raise NotImplementedError
def add_label(self, label):
@ -140,7 +160,7 @@ class Pipe(object):
return create_default_optimizer(self.model.ops, **self.cfg.get("optimizer", {}))
def begin_training(
self, get_gold_tuples=lambda: [], pipeline=None, sgd=None, **kwargs
self, get_examples=lambda: [], pipeline=None, sgd=None, **kwargs
):
"""Initialize the pipe for training, using data exampes if available.
If no model has been initialized yet, the model is added."""
@ -264,28 +284,40 @@ class Tensorizer(Pipe):
self.cfg = dict(cfg)
self.cfg.setdefault("cnn_maxout_pieces", 3)
def __call__(self, doc):
def __call__(self, example):
"""Add context-sensitive vectors to a `Doc`, e.g. from a CNN or LSTM
model. Vectors are set to the `Doc.tensor` attribute.
docs (Doc or iterable): One or more documents to add vectors to.
RETURNS (dict or None): Intermediate computations.
"""
doc = self._get_doc(example)
tokvecses = self.predict([doc])
self.set_annotations([doc], tokvecses)
if isinstance(example, Example):
example.doc = doc
return example
return doc
def pipe(self, stream, batch_size=128, n_threads=-1):
def pipe(self, stream, batch_size=128, n_threads=-1, as_example=False):
"""Process `Doc` objects as a stream.
stream (iterator): A sequence of `Doc` objects to process.
batch_size (int): Number of `Doc` objects to group.
YIELDS (iterator): A sequence of `Doc` objects, in order of input.
stream (iterator): A sequence of `Doc` or `Example` objects to process.
batch_size (int): Number of `Doc` or `Example` objects to group.
YIELDS (iterator): A sequence of `Doc` or `Example` objects, in order of input.
"""
for docs in util.minibatch(stream, size=batch_size):
docs = list(docs)
for examples in util.minibatch(stream, size=batch_size):
docs = [self._get_doc(ex) for ex in examples]
tensors = self.predict(docs)
self.set_annotations(docs, tensors)
if as_example:
examples = []
for ex, doc in zip(examples, docs):
ex.doc = doc
examples.append(ex)
yield from examples
else:
yield from docs
def predict(self, docs):
@ -310,7 +342,7 @@ class Tensorizer(Pipe):
raise ValueError(Errors.E076.format(rows=tensor.shape[0], words=len(doc)))
doc.tensor = tensor
def update(self, docs, golds, state=None, drop=0.0, sgd=None, losses=None):
def update(self, examples, state=None, drop=0.0, sgd=None, losses=None):
"""Update the model.
docs (iterable): A batch of `Doc` objects.
@ -320,17 +352,16 @@ class Tensorizer(Pipe):
RETURNS (dict): Results from the update.
"""
self.require_model()
if isinstance(docs, Doc):
docs = [docs]
examples = Example.to_example_objects(examples)
inputs = []
bp_inputs = []
for tok2vec in self.input_models:
tensor, bp_tensor = tok2vec.begin_update(docs, drop=drop)
tensor, bp_tensor = tok2vec.begin_update([ex.doc for ex in examples], drop=drop)
inputs.append(tensor)
bp_inputs.append(bp_tensor)
inputs = self.model.ops.xp.hstack(inputs)
scores, bp_scores = self.model.begin_update(inputs, drop=drop)
loss, d_scores = self.get_loss(docs, golds, scores)
loss, d_scores = self.get_loss(examples, scores)
d_inputs = bp_scores(d_scores, sgd=sgd)
d_inputs = self.model.ops.xp.split(d_inputs, len(self.input_models), axis=1)
for d_input, bp_input in zip(d_inputs, bp_inputs):
@ -340,18 +371,19 @@ class Tensorizer(Pipe):
losses[self.name] += loss
return loss
def get_loss(self, docs, golds, prediction):
ids = self.model.ops.flatten([doc.to_array(ID).ravel() for doc in docs])
def get_loss(self, examples, prediction):
examples = Example.to_example_objects(examples)
ids = self.model.ops.flatten([ex.doc.to_array(ID).ravel() for ex in examples])
target = self.vocab.vectors.data[ids]
d_scores = (prediction - target) / prediction.shape[0]
loss = (d_scores ** 2).sum()
return loss, d_scores
def begin_training(self, gold_tuples=lambda: [], pipeline=None, sgd=None, **kwargs):
def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None, **kwargs):
"""Allocate models, pre-process training data and acquire an
optimizer.
gold_tuples (iterable): Gold-standard training data.
get_examples (iterable): Gold-standard training data.
pipeline (list): The pipeline the model is part of.
"""
if pipeline is not None:
@ -391,16 +423,29 @@ class Tagger(Pipe):
else:
return chain(self.model.tok2vec, flatten)
def __call__(self, doc):
def __call__(self, example):
doc = self._get_doc(example)
tags, tokvecs = self.predict([doc])
self.set_annotations([doc], tags, tensors=tokvecs)
if isinstance(example, Example):
example.doc = doc
return example
return doc
def pipe(self, stream, batch_size=128, n_threads=-1):
for docs in util.minibatch(stream, size=batch_size):
docs = list(docs)
def pipe(self, stream, batch_size=128, n_threads=-1, as_example=False):
for examples in util.minibatch(stream, size=batch_size):
examples = list(examples)
docs = [self._get_doc(ex) for ex in examples]
tag_ids, tokvecs = self.predict(docs)
self.set_annotations(docs, tag_ids, tensors=tokvecs)
if as_example:
examples = []
for ex, doc in zip(examples, docs):
ex.doc = doc
examples.append(ex)
yield from examples
else:
yield from docs
def predict(self, docs):
@ -452,47 +497,51 @@ class Tagger(Pipe):
doc.extend_tensor(tensors[i])
doc.is_tagged = True
def update(self, docs, golds, drop=0., sgd=None, losses=None):
def update(self, examples, drop=0., sgd=None, losses=None):
self.require_model()
examples = Example.to_example_objects(examples)
if losses is not None and self.name not in losses:
losses[self.name] = 0.
if not any(len(doc) for doc in docs):
if not any(len(ex.doc) if ex.doc else 0 for ex in examples):
# Handle cases where there are no tokens in any docs.
return
tag_scores, bp_tag_scores = self.model.begin_update(docs, drop=drop)
loss, d_tag_scores = self.get_loss(docs, golds, tag_scores)
tag_scores, bp_tag_scores = self.model.begin_update([ex.doc for ex in examples], drop=drop)
loss, d_tag_scores = self.get_loss(examples, tag_scores)
bp_tag_scores(d_tag_scores, sgd=sgd)
if losses is not None:
losses[self.name] += loss
def rehearse(self, docs, drop=0., sgd=None, losses=None):
def rehearse(self, examples, drop=0., sgd=None, losses=None):
"""Perform a 'rehearsal' update, where we try to match the output of
an initial model.
"""
if self._rehearsal_model is None:
return
examples = Example.to_example_objects(examples)
docs = [ex.doc for ex in examples]
if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs.
return
guesses, backprop = self.model.begin_update(docs, drop=drop)
target = self._rehearsal_model(docs)
target = self._rehearsal_model(examples)
gradient = guesses - target
backprop(gradient, sgd=sgd)
if losses is not None:
losses.setdefault(self.name, 0.0)
losses[self.name] += (gradient**2).sum()
def get_loss(self, docs, golds, scores):
def get_loss(self, examples, scores):
scores = self.model.ops.flatten(scores)
tag_index = {tag: i for i, tag in enumerate(self.labels)}
cdef int idx = 0
correct = numpy.zeros((scores.shape[0],), dtype="i")
guesses = scores.argmax(axis=1)
known_labels = numpy.ones((scores.shape[0], 1), dtype="f")
for gold in golds:
for ex in examples:
gold = ex.gold
for tag in gold.tags:
if tag is None:
correct[idx] = guesses[idx]
@ -506,20 +555,20 @@ class Tagger(Pipe):
d_scores = scores - to_categorical(correct, nb_classes=scores.shape[1])
d_scores *= self.model.ops.asarray(known_labels)
loss = (d_scores**2).sum()
docs = [ex.doc for ex in examples]
d_scores = self.model.ops.unflatten(d_scores, [len(d) for d in docs])
return float(loss), d_scores
def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None,
def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None,
**kwargs):
lemma_tables = ["lemma_rules", "lemma_index", "lemma_exc", "lemma_lookup"]
if not any(table in self.vocab.lookups for table in lemma_tables):
user_warning(Warnings.W022)
orig_tag_map = dict(self.vocab.morphology.tag_map)
new_tag_map = OrderedDict()
for raw_text, annots_brackets in get_gold_tuples():
for annots, brackets in annots_brackets:
ids, words, tags, heads, deps, ents = annots
for tag in tags:
for example in get_examples():
for token_annotation in example.token_annotations:
for tag in token_annotation.tags:
if tag in orig_tag_map:
new_tag_map[tag] = orig_tag_map[tag]
else:
@ -698,14 +747,14 @@ class MultitaskObjective(Tagger):
def set_annotations(self, docs, dep_ids, tensors=None):
pass
def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, tok2vec=None,
def begin_training(self, get_examples=lambda: [], pipeline=None, tok2vec=None,
sgd=None, **kwargs):
gold_tuples = nonproj.preprocess_training_data(get_gold_tuples())
for raw_text, annots_brackets in gold_tuples:
for annots, brackets in annots_brackets:
ids, words, tags, heads, deps, ents = annots
for i in range(len(ids)):
label = self.make_label(i, words, tags, heads, deps, ents)
gold_examples = nonproj.preprocess_training_data(get_examples())
# for raw_text, doc_annot in gold_tuples:
for example in gold_examples:
for token_annotation in example.token_annotations:
for i in range(len(token_annotation.ids)):
label = self.make_label(i, token_annotation)
if label is not None and label not in self.labels:
self.labels[label] = len(self.labels)
if self.model is True:
@ -735,18 +784,17 @@ class MultitaskObjective(Tagger):
scores = self.model.softmax(tokvecs)
return tokvecs, scores
def get_loss(self, docs, golds, scores):
if len(docs) != len(golds):
raise ValueError(Errors.E077.format(value="loss", n_docs=len(docs),
n_golds=len(golds)))
def get_loss(self, examples, scores):
cdef int idx = 0
correct = numpy.zeros((scores.shape[0],), dtype="i")
guesses = scores.argmax(axis=1)
golds = [ex.gold for ex in examples]
docs = [ex.doc for ex in examples]
for i, gold in enumerate(golds):
for j in range(len(docs[i])):
# Handes alignment for tokenization differences
label = self.make_label(j, gold.words, gold.tags,
gold.heads, gold.labels, gold.ents)
# Handels alignment for tokenization differences
token_annotation = gold.get_token_annotation()
label = self.make_label(j, token_annotation)
if label is None or label not in self.labels:
correct[idx] = guesses[idx]
else:
@ -758,39 +806,39 @@ class MultitaskObjective(Tagger):
return float(loss), d_scores
@staticmethod
def make_dep(i, words, tags, heads, deps, ents):
if deps[i] is None or heads[i] is None:
def make_dep(i, token_annotation):
if token_annotation.deps[i] is None or token_annotation.heads[i] is None:
return None
return deps[i]
return token_annotation.deps[i]
@staticmethod
def make_tag(i, words, tags, heads, deps, ents):
return tags[i]
def make_tag(i, token_annotation):
return token_annotation.tags[i]
@staticmethod
def make_ent(i, words, tags, heads, deps, ents):
if ents is None:
def make_ent(i, token_annotation):
if token_annotation.entities is None:
return None
return ents[i]
return token_annotation.entities[i]
@staticmethod
def make_dep_tag_offset(i, words, tags, heads, deps, ents):
if deps[i] is None or heads[i] is None:
def make_dep_tag_offset(i, token_annotation):
if token_annotation.deps[i] is None or token_annotation.heads[i] is None:
return None
offset = heads[i] - i
offset = token_annotation.heads[i] - i
offset = min(offset, 2)
offset = max(offset, -2)
return "%s-%s:%d" % (deps[i], tags[i], offset)
return "%s-%s:%d" % (token_annotation.deps[i], token_annotation.tags[i], offset)
@staticmethod
def make_ent_tag(i, words, tags, heads, deps, ents):
if ents is None or ents[i] is None:
def make_ent_tag(i, token_annotation):
if token_annotation.entities is None or token_annotation.entities[i] is None:
return None
else:
return "%s-%s" % (tags[i], ents[i])
return "%s-%s" % (token_annotation.tags[i], token_annotation.entities[i])
@staticmethod
def make_sent_start(target, words, tags, heads, deps, ents, cache=True, _cache={}):
def make_sent_start(target, token_annotation, cache=True, _cache={}):
"""A multi-task objective for representing sentence boundaries,
using BILU scheme. (O is impossible)
@ -799,6 +847,8 @@ class MultitaskObjective(Tagger):
of gold data. You can pass cache=False if you know the cache will
do the wrong thing.
"""
words = token_annotation.words
heads = token_annotation.heads
assert len(words) == len(heads)
assert target < len(words), (target, len(words))
if cache:
@ -857,7 +907,7 @@ class ClozeMultitask(Pipe):
def set_annotations(self, docs, dep_ids, tensors=None):
pass
def begin_training(self, get_gold_tuples=lambda: [], pipeline=None,
def begin_training(self, get_examples=lambda: [], pipeline=None,
tok2vec=None, sgd=None, **kwargs):
link_vectors_to_models(self.vocab)
if self.model is True:
@ -874,25 +924,26 @@ class ClozeMultitask(Pipe):
vectors = self.model.output_layer(tokvecs)
return tokvecs, vectors
def get_loss(self, docs, vectors, prediction):
def get_loss(self, examples, vectors, prediction):
# The simplest way to implement this would be to vstack the
# token.vector values, but that's a bit inefficient, especially on GPU.
# Instead we fetch the index into the vectors table for each of our tokens,
# and look them up all at once. This prevents data copying.
ids = self.model.ops.flatten([doc.to_array(ID).ravel() for doc in docs])
ids = self.model.ops.flatten([ex.doc.to_array(ID).ravel() for ex in examples])
target = vectors[ids]
loss, gradient = get_cossim_loss(prediction, target, ignore_zeros=True)
return float(loss), gradient
def update(self, docs, golds, drop=0., sgd=None, losses=None):
def update(self, examples, drop=0., sgd=None, losses=None):
pass
def rehearse(self, docs, drop=0., sgd=None, losses=None):
def rehearse(self, examples, drop=0., sgd=None, losses=None):
self.require_model()
examples = Example.to_example_objects(examples)
if losses is not None and self.name not in losses:
losses[self.name] = 0.
predictions, bp_predictions = self.model.begin_update(docs, drop=drop)
loss, d_predictions = self.get_loss(docs, self.vocab.vectors.data, predictions)
predictions, bp_predictions = self.model.begin_update([ex.doc for ex in examples], drop=drop)
loss, d_predictions = self.get_loss(examples, self.vocab.vectors.data, predictions)
bp_predictions(d_predictions, sgd=sgd)
if losses is not None:
@ -947,11 +998,20 @@ class TextCategorizer(Pipe):
def labels(self, value):
self.cfg["labels"] = tuple(value)
def pipe(self, stream, batch_size=128, n_threads=-1):
for docs in util.minibatch(stream, size=batch_size):
docs = list(docs)
def pipe(self, stream, batch_size=128, n_threads=-1, as_example=False):
for examples in util.minibatch(stream, size=batch_size):
examples = list(examples)
docs = [self._get_doc(ex) for ex in examples]
scores, tensors = self.predict(docs)
self.set_annotations(docs, scores, tensors=tensors)
if as_example:
examples = []
for ex, doc in zip(examples, docs):
ex.doc = doc
examples.append(ex)
yield from examples
else:
yield from docs
def predict(self, docs):
@ -973,33 +1033,37 @@ class TextCategorizer(Pipe):
for j, label in enumerate(self.labels):
doc.cats[label] = float(scores[i, j])
def update(self, docs, golds, state=None, drop=0., sgd=None, losses=None):
def update(self, examples, state=None, drop=0., sgd=None, losses=None):
self.require_model()
if not any(len(doc) for doc in docs):
examples = Example.to_example_objects(examples)
if not any(len(ex.doc) if ex.doc else 0 for ex in examples):
# Handle cases where there are no tokens in any docs.
return
scores, bp_scores = self.model.begin_update(docs, drop=drop)
loss, d_scores = self.get_loss(docs, golds, scores)
scores, bp_scores = self.model.begin_update([ex.doc for ex in examples], drop=drop)
loss, d_scores = self.get_loss(examples, scores)
bp_scores(d_scores, sgd=sgd)
if losses is not None:
losses.setdefault(self.name, 0.0)
losses[self.name] += loss
def rehearse(self, docs, drop=0., sgd=None, losses=None):
def rehearse(self, examples, drop=0., sgd=None, losses=None):
if self._rehearsal_model is None:
return
examples = Example.to_example_objects(examples)
docs=[ex.doc for ex in examples]
if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs.
return
scores, bp_scores = self.model.begin_update(docs, drop=drop)
target = self._rehearsal_model(docs)
target = self._rehearsal_model(examples)
gradient = scores - target
bp_scores(gradient, sgd=sgd)
if losses is not None:
losses.setdefault(self.name, 0.0)
losses[self.name] += (gradient**2).sum()
def get_loss(self, docs, golds, scores):
def get_loss(self, examples, scores):
golds = [ex.gold for ex in examples]
truths = numpy.zeros((len(golds), len(self.labels)), dtype="f")
not_missing = numpy.ones((len(golds), len(self.labels)), dtype="f")
for i, gold in enumerate(golds):
@ -1032,10 +1096,9 @@ class TextCategorizer(Pipe):
self.labels = tuple(list(self.labels) + [label])
return 1
def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None, **kwargs):
for raw_text, annot_brackets in get_gold_tuples():
for _, (cats, _2) in annot_brackets:
for cat in cats:
def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None, **kwargs):
for example in get_examples():
for cat in example.doc_annotation.cats:
self.add_label(cat)
if self.model is True:
self.cfg["pretrained_vectors"] = kwargs.get("pretrained_vectors")
@ -1074,10 +1137,10 @@ cdef class DependencyParser(Parser):
labeller = MultitaskObjective(self.vocab, target=target)
self._multitasks.append(labeller)
def init_multitask_objectives(self, get_gold_tuples, pipeline, sgd=None, **cfg):
def init_multitask_objectives(self, get_examples, pipeline, sgd=None, **cfg):
for labeller in self._multitasks:
tok2vec = self.model.tok2vec
labeller.begin_training(get_gold_tuples, pipeline=pipeline,
labeller.begin_training(get_examples, pipeline=pipeline,
tok2vec=tok2vec, sgd=sgd)
def __reduce__(self):
@ -1116,10 +1179,10 @@ cdef class EntityRecognizer(Parser):
labeller = MultitaskObjective(self.vocab, target=target)
self._multitasks.append(labeller)
def init_multitask_objectives(self, get_gold_tuples, pipeline, sgd=None, **cfg):
def init_multitask_objectives(self, get_examples, pipeline, sgd=None, **cfg):
for labeller in self._multitasks:
tok2vec = self.model.tok2vec
labeller.begin_training(get_gold_tuples, pipeline=pipeline,
labeller.begin_training(get_examples, pipeline=pipeline,
tok2vec=tok2vec)
def __reduce__(self):
@ -1175,7 +1238,7 @@ class EntityLinker(Pipe):
if getattr(self, "kb", None) in (None, True, False):
raise ValueError(Errors.E139.format(name=self.name))
def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None, **kwargs):
def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None, **kwargs):
self.require_kb()
self.cfg["entity_width"] = self.kb.entity_vector_length
@ -1187,25 +1250,18 @@ class EntityLinker(Pipe):
return sgd
def update(self, docs, golds, state=None, drop=0.0, sgd=None, losses=None):
def update(self, examples, state=None, drop=0.0, sgd=None, losses=None):
self.require_model()
self.require_kb()
if losses is not None:
losses.setdefault(self.name, 0.0)
if not docs or not golds:
if not examples:
return 0
if len(docs) != len(golds):
raise ValueError(Errors.E077.format(value="EL training", n_docs=len(docs),
n_golds=len(golds)))
if isinstance(docs, Doc):
docs = [docs]
golds = [golds]
examples = Example.to_example_objects(examples)
sentence_docs = []
docs = [ex.doc for ex in examples]
golds = [ex.gold for ex in examples]
for doc, gold in zip(docs, golds):
ents_by_offset = dict()
@ -1219,19 +1275,19 @@ class EntityLinker(Pipe):
ent = ents_by_offset[(start, end)]
for kb_id, value in kb_dict.items():
# Currently only training on the positive instances
# Currently only training on the positive instances - we assume there is at least 1 per doc/gold
if value:
sentence_docs.append(ent.sent.as_doc())
sentence_encodings, bp_context = self.model.begin_update(sentence_docs, drop=drop)
loss, d_scores = self.get_similarity_loss(scores=sentence_encodings, golds=golds, docs=None)
loss, d_scores = self.get_similarity_loss(scores=sentence_encodings, golds=golds)
bp_context(d_scores, sgd=sgd)
if losses is not None:
losses[self.name] += loss
return loss
def get_similarity_loss(self, docs, golds, scores):
def get_similarity_loss(self, golds, scores):
entity_encodings = []
for gold in golds:
for entity, kb_dict in gold.links.items():
@ -1244,16 +1300,16 @@ class EntityLinker(Pipe):
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
if scores.shape != entity_encodings.shape:
raise RuntimeError(Errors.E147.format(method="get_loss", msg="gold entities do not match up"))
raise RuntimeError(Errors.E147.format(method="get_similarity_loss", msg="gold entities do not match up"))
loss, gradients = get_cossim_loss(yh=scores, y=entity_encodings)
loss = loss / len(entity_encodings)
return loss, gradients
def get_loss(self, docs, golds, scores):
def get_loss(self, examples, scores):
cats = []
for gold in golds:
for entity, kb_dict in gold.links.items():
for ex in examples:
for entity, kb_dict in ex.gold.links.items():
for kb_id, value in kb_dict.items():
cats.append([value])
@ -1266,16 +1322,29 @@ class EntityLinker(Pipe):
loss = loss / len(cats)
return loss, d_scores
def __call__(self, doc):
def __call__(self, example):
doc = self._get_doc(example)
kb_ids, tensors = self.predict([doc])
self.set_annotations([doc], kb_ids, tensors=tensors)
if isinstance(example, Example):
example.doc = doc
return example
return doc
def pipe(self, stream, batch_size=128, n_threads=-1):
for docs in util.minibatch(stream, size=batch_size):
docs = list(docs)
def pipe(self, stream, batch_size=128, n_threads=-1, as_example=False):
for examples in util.minibatch(stream, size=batch_size):
examples = list(examples)
docs = [self._get_doc(ex) for ex in examples]
kb_ids, tensors = self.predict(docs)
self.set_annotations(docs, kb_ids, tensors=tensors)
if as_example:
examples = []
for ex, doc in zip(examples, docs):
ex.doc = doc
examples.append(ex)
yield from examples
else:
yield from docs
def predict(self, docs):
@ -1408,7 +1477,7 @@ class EntityLinker(Pipe):
util.from_disk(path, deserialize, exclude)
return self
def rehearse(self, docs, sgd=None, losses=None, **config):
def rehearse(self, examples, sgd=None, losses=None, **config):
raise NotImplementedError
def add_label(self, label):
@ -1416,7 +1485,7 @@ class EntityLinker(Pipe):
@component("sentencizer", assigns=["token.is_sent_start", "doc.sents"])
class Sentencizer(object):
class Sentencizer(Pipe):
"""Segment the Doc into sentences using a rule-based strategy.
DOCS: https://spacy.io/api/sentencizer
@ -1451,14 +1520,15 @@ class Sentencizer(object):
def from_nlp(cls, nlp, **cfg):
return cls(**cfg)
def __call__(self, doc):
def __call__(self, example):
"""Apply the sentencizer to a Doc and set Token.is_sent_start.
doc (Doc): The document to process.
RETURNS (Doc): The processed Doc.
example (Doc or Example): The document to process.
RETURNS (Doc or Example): The processed Doc or Example.
DOCS: https://spacy.io/api/sentencizer#call
"""
doc = self._get_doc(example)
start = 0
seen_period = False
for i, token in enumerate(doc):
@ -1472,6 +1542,9 @@ class Sentencizer(object):
seen_period = True
if start < len(doc):
doc[start].is_sent_start = True
if isinstance(example, Example):
example.doc = doc
return example
return doc
def to_bytes(self, **kwargs):

View File

@ -3,7 +3,7 @@ from __future__ import division, print_function, unicode_literals
import numpy as np
from .gold import tags_to_entities, GoldParse
from .gold import tags_to_entities, GoldParse, DocAnnotation
from .errors import Errors
@ -217,11 +217,10 @@ class Scorer(object):
"textcats_per_cat": self.textcats_per_cat,
}
def score(self, doc, gold, verbose=False, punct_labels=("p", "punct")):
def score(self, example, verbose=False, punct_labels=("p", "punct")):
"""Update the evaluation scores from a single Doc / GoldParse pair.
doc (Doc): The predicted annotations.
gold (GoldParse): The correct annotations.
example (Example): The predicted annotations + correct annotations.
verbose (bool): Print debugging information.
punct_labels (tuple): Dependency labels for punctuation. Used to
evaluate dependency attachments to punctuation if `eval_punct` is
@ -229,15 +228,22 @@ class Scorer(object):
DOCS: https://spacy.io/api/scorer#score
"""
if isinstance(example, tuple) and len(example) == 2:
doc, gold = example
else:
gold = example.gold
doc = example.doc
if len(doc) != len(gold):
gold = GoldParse.from_annot_tuples(
doc, tuple(zip(*gold.orig_annot)) + (gold.cats,)
)
doc_annotation = DocAnnotation(cats=gold.cats)
token_annotation = gold.orig
gold = GoldParse.from_annotation(doc, doc_annotation, [token_annotation])
orig = gold.orig
gold_deps = set()
gold_deps_per_dep = {}
gold_tags = set()
gold_ents = set(tags_to_entities([annot[-1] for annot in gold.orig_annot]))
for id_, word, tag, head, dep, ner in gold.orig_annot:
gold_ents = set(tags_to_entities(orig.entities))
for id_, tag, head, dep in zip(orig.ids, orig.tags, orig.heads, orig.deps):
gold_tags.add((id_, tag))
if dep not in (None, "") and dep.lower() not in punct_labels:
gold_deps.add((id_, head, dep.lower()))
@ -272,7 +278,7 @@ class Scorer(object):
if token.dep_.lower() not in cand_deps_per_dep:
cand_deps_per_dep[token.dep_.lower()] = set()
cand_deps_per_dep[token.dep_.lower()].add((gold_i, gold_head, token.dep_.lower()))
if "-" not in [token[-1] for token in gold.orig_annot]:
if "-" not in orig.entities:
# Find all NER labels in gold and doc
ent_labels = set([x[0] for x in gold_ents] + [k.label_ for k in doc.ents])
# Set up all labels for per type scoring and prepare gold per type
@ -336,7 +342,7 @@ class Scorer(object):
Errors.E162.format(model_labels=model_labels, eval_labels=eval_labels)
)
if verbose:
gold_words = [item[1] for item in gold.orig_annot]
gold_words = orig.words
for w_id, h_id, dep in cand_deps - gold_deps:
print("F", gold_words[w_id], dep, gold_words[h_id])
for w_id, h_id, dep in gold_deps - cand_deps:

View File

@ -341,10 +341,10 @@ cdef class ArcEager(TransitionSystem):
for label in kwargs.get('right_labels', []):
actions[RIGHT][label] = 1
actions[REDUCE][label] = 1
for raw_text, sents in kwargs.get('gold_parses', []):
for (ids, words, tags, heads, labels, iob), ctnts in sents:
heads, labels = nonproj.projectivize(heads, labels)
for child, head, label in zip(ids, heads, labels):
for example in kwargs.get('gold_parses', []):
for token_annotation in example.token_annotations:
heads, labels = nonproj.projectivize(token_annotation.heads, token_annotation.deps)
for child, head, label in zip(token_annotation.ids, heads, labels):
if label.upper() == 'ROOT' :
label = 'ROOT'
if head == child:
@ -397,7 +397,9 @@ cdef class ArcEager(TransitionSystem):
self.strings[state.safe_get(i).dep]))
else:
predicted.add((i, state.H(i), 'ROOT'))
id_, word, tag, head, dep, ner = gold.orig_annot[gold.cand_to_gold[i]]
id_ = gold.orig.ids[gold.cand_to_gold[i]]
head = gold.orig.heads[gold.cand_to_gold[i]]
dep = gold.orig.deps[gold.cand_to_gold[i]]
truth.add((id_, head, dep))
return truth == predicted

View File

@ -72,9 +72,9 @@ cdef class BiluoPushDown(TransitionSystem):
for action in (BEGIN, IN, LAST, UNIT):
actions[action][entity_type] = 1
moves = ('M', 'B', 'I', 'L', 'U')
for raw_text, sents in kwargs.get('gold_parses', []):
for (ids, words, tags, heads, labels, biluo), _ in sents:
for i, ner_tag in enumerate(biluo):
for example in kwargs.get('gold_parses', []):
for token_annotation in example.token_annotations:
for i, ner_tag in enumerate(token_annotation.entities):
if ner_tag != 'O' and ner_tag != '-':
_, label = ner_tag.split('-', 1)
for action in (BEGIN, IN, LAST, UNIT):

View File

@ -27,6 +27,7 @@ from thinc.neural.util import get_array_module
from thinc.linalg cimport Vec, VecVec
import srsly
from spacy.gold import Example
from ._parser_model cimport alloc_activations, free_activations
from ._parser_model cimport predict_states, arg_max_if_valid
from ._parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss
@ -193,7 +194,7 @@ cdef class Parser:
# Defined in subclasses, to avoid circular import
raise NotImplementedError
def init_multitask_objectives(self, get_gold_tuples, pipeline, **cfg):
def init_multitask_objectives(self, get_examples, pipeline, **cfg):
'''Setup models for secondary objectives, to benefit from multi-task
learning. This method is intended to be overridden by subclasses.
@ -203,9 +204,9 @@ cdef class Parser:
'''
pass
def preprocess_gold(self, docs_golds):
for doc, gold in docs_golds:
yield doc, gold
def preprocess_gold(self, examples):
for ex in examples:
yield ex
def use_params(self, params):
# Can't decorate cdef class :(. Workaround.
@ -411,35 +412,31 @@ cdef class Parser:
beam.check_done(_beam_utils.check_final_state, NULL)
return [b for b in beams if not b.is_done]
def update(self, docs, golds, drop=0., sgd=None, losses=None):
def update(self, examples, drop=0., sgd=None, losses=None):
self.require_model()
if isinstance(docs, Doc) and isinstance(golds, GoldParse):
docs = [docs]
golds = [golds]
if len(docs) != len(golds):
raise ValueError(Errors.E077.format(value='update', n_docs=len(docs),
n_golds=len(golds)))
examples = Example.to_example_objects(examples)
if losses is None:
losses = {}
losses.setdefault(self.name, 0.)
for multitask in self._multitasks:
multitask.update(docs, golds, drop=drop, sgd=sgd)
multitask.update(examples, drop=drop, sgd=sgd)
# The probability we use beam update, instead of falling back to
# a greedy update
beam_update_prob = self.cfg.get('beam_update_prob', 0.5)
if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() < beam_update_prob:
return self.update_beam(docs, golds, self.cfg.get('beam_width', 1),
return self.update_beam(examples, self.cfg.get('beam_width', 1),
drop=drop, sgd=sgd, losses=losses,
beam_density=self.cfg.get('beam_density', 0.001))
# Chop sequences into lengths of this many transitions, to make the
# batch uniform length.
cut_gold = numpy.random.choice(range(20, 100))
states, golds, max_steps = self._init_gold_batch(docs, golds, max_length=cut_gold)
states, golds, max_steps = self._init_gold_batch(examples, max_length=cut_gold)
states_golds = [(s, g) for (s, g) in zip(states, golds)
if not s.is_final() and g is not None]
# Prepare the stepwise model, and get the callback for finishing the batch
model, finish_update = self.model.begin_update(docs, drop=drop)
model, finish_update = self.model.begin_update([ex.doc for ex in examples], drop=drop)
for _ in range(max_steps):
if not states_golds:
break
@ -454,19 +451,19 @@ cdef class Parser:
finish_update(golds, sgd=sgd)
return losses
def rehearse(self, docs, sgd=None, losses=None, **cfg):
def rehearse(self, examples, sgd=None, losses=None, **cfg):
"""Perform a "rehearsal" update, to prevent catastrophic forgetting."""
if isinstance(docs, Doc):
docs = [docs]
examples = Example.to_example_objects(examples)
if losses is None:
losses = {}
for multitask in self._multitasks:
if hasattr(multitask, 'rehearse'):
multitask.rehearse(docs, losses=losses, sgd=sgd)
multitask.rehearse(examples, losses=losses, sgd=sgd)
if self._rehearsal_model is None:
return None
losses.setdefault(self.name, 0.)
docs = [ex.doc for ex in examples]
states = self.moves.init_batch(docs)
# This is pretty dirty, but the NER can resize itself in init_batch,
# if labels are missing. We therefore have to check whether we need to
@ -494,15 +491,20 @@ cdef class Parser:
losses[self.name] += loss / n_scores
return losses
def update_beam(self, docs, golds, width, drop=0., sgd=None, losses=None,
def update_beam(self, examples, width, drop=0., sgd=None, losses=None,
beam_density=0.0):
examples = Example.to_example_objects(examples)
docs = [ex.doc for ex in examples]
golds = [ex.gold for ex in examples]
new_golds = []
lengths = [len(d) for d in docs]
states = self.moves.init_batch(docs)
for gold in golds:
self.moves.preprocess_gold(gold)
new_golds.append(gold)
model, finish_update = self.model.begin_update(docs, drop=drop)
states_d_scores, backprops, beams = _beam_utils.update_beam(
self.moves, self.nr_feature, 10000, states, golds, model.state2vec,
self.moves, self.nr_feature, 10000, states, new_golds, model.state2vec,
model.vec2scores, width, drop=drop, losses=losses,
beam_density=beam_density)
for i, d_scores in enumerate(states_d_scores):
@ -522,7 +524,7 @@ cdef class Parser:
for beam in beams:
_beam_utils.cleanup_beam(beam)
def _init_gold_batch(self, whole_docs, whole_golds, min_length=5, max_length=500):
def _init_gold_batch(self, whole_examples, min_length=5, max_length=500):
"""Make a square batch, of length equal to the shortest doc. A long
doc will get multiple states. Let's say we have a doc of length 2*N,
where N is the shortest doc. We'll make two states, one representing
@ -530,6 +532,8 @@ cdef class Parser:
cdef:
StateClass state
Transition action
whole_docs = [ex.doc for ex in whole_examples]
whole_golds = [ex.gold for ex in whole_examples]
whole_states = self.moves.init_batch(whole_docs)
max_length = max(min_length, min(max_length, min([len(doc) for doc in whole_docs])))
max_moves = 0
@ -592,14 +596,14 @@ cdef class Parser:
return create_default_optimizer(self.model.ops,
**self.cfg.get('optimizer', {}))
def begin_training(self, get_gold_tuples, pipeline=None, sgd=None, **cfg):
def begin_training(self, get_examples, pipeline=None, sgd=None, **cfg):
if 'model' in cfg:
self.model = cfg['model']
if not hasattr(get_gold_tuples, '__call__'):
gold_tuples = get_gold_tuples
get_gold_tuples = lambda: gold_tuples
if not hasattr(get_examples, '__call__'):
gold_tuples = get_examples
get_examples = lambda: gold_tuples
cfg.setdefault('min_action_freq', 30)
actions = self.moves.get_actions(gold_parses=get_gold_tuples(),
actions = self.moves.get_actions(gold_parses=get_examples(),
min_freq=cfg.get('min_action_freq', 30),
learn_tokens=self.cfg.get("learn_tokens", False))
for action, labels in self.moves.labels.items():
@ -615,15 +619,14 @@ cdef class Parser:
sgd = self.create_optimizer()
doc_sample = []
gold_sample = []
for raw_text, annots_brackets in islice(get_gold_tuples(), 1000):
for annots, brackets in annots_brackets:
ids, words, tags, heads, deps, ents = annots
doc_sample.append(Doc(self.vocab, words=words))
gold_sample.append(GoldParse(doc_sample[-1], words=words, tags=tags,
heads=heads, deps=deps, entities=ents))
for example in islice(get_examples(), 1000):
parses = example.get_gold_parses(merge=False, vocab=self.vocab)
for doc, gold in parses:
doc_sample.append(doc)
gold_sample.append(gold)
self.model.begin_training(doc_sample, gold_sample)
if pipeline is not None:
self.init_multitask_objectives(get_gold_tuples, pipeline, sgd=sgd, **cfg)
self.init_multitask_objectives(get_examples, pipeline, sgd=sgd, **cfg)
link_vectors_to_models(self.vocab)
else:
if sgd is None:

View File

@ -9,6 +9,7 @@ from __future__ import unicode_literals
from copy import copy
from spacy.gold import Example
from ..tokens.doc cimport Doc, set_children_from_heads
from ..errors import Errors
@ -77,39 +78,42 @@ def decompose(label):
def is_decorated(label):
return DELIMITER in label
def count_decorated_labels(gold_tuples):
def count_decorated_labels(gold_data):
freqs = {}
for raw_text, sents in gold_tuples:
for (ids, words, tags, heads, labels, iob), ctnts in sents:
proj_heads, deco_labels = projectivize(heads, labels)
for example in gold_data:
for token_annotation in example.token_annotations:
proj_heads, deco_deps = projectivize(token_annotation.heads, token_annotation.deps)
# set the label to ROOT for each root dependent
deco_labels = ['ROOT' if head == i else deco_labels[i]
deco_deps = ['ROOT' if head == i else deco_deps[i]
for i, head in enumerate(proj_heads)]
# count label frequencies
for label in deco_labels:
for label in deco_deps:
if is_decorated(label):
freqs[label] = freqs.get(label, 0) + 1
return freqs
def preprocess_training_data(gold_tuples, label_freq_cutoff=30):
def preprocess_training_data(gold_data, label_freq_cutoff=30):
preprocessed = []
freqs = {}
for raw_text, sents in gold_tuples:
prepro_sents = []
for (ids, words, tags, heads, labels, iob), ctnts in sents:
proj_heads, deco_labels = projectivize(heads, labels)
for example in gold_data:
new_example = Example(doc=example.doc)
for token_annotation in example.token_annotations:
proj_heads, deco_deps = projectivize(token_annotation.heads, token_annotation.deps)
# set the label to ROOT for each root dependent
deco_labels = ['ROOT' if head == i else deco_labels[i]
deco_deps = ['ROOT' if head == i else deco_deps[i]
for i, head in enumerate(proj_heads)]
# count label frequencies
if label_freq_cutoff > 0:
for label in deco_labels:
for label in deco_deps:
if is_decorated(label):
freqs[label] = freqs.get(label, 0) + 1
prepro_sents.append(
((ids, words, tags, proj_heads, deco_labels, iob), ctnts))
preprocessed.append((raw_text, prepro_sents))
# TODO: the code would be less ugly when changing heads and deps in-place, but is this OK upstream ?
proj_token_dict = token_annotation.to_dict()
proj_token_dict["heads"] = proj_heads
proj_token_dict["deps"] = deco_deps
new_example.add_token_annotation(**proj_token_dict)
preprocessed.append(new_example)
if label_freq_cutoff > 0:
return _filter_labels(preprocessed, label_freq_cutoff, freqs)
return preprocessed
@ -203,20 +207,21 @@ def _find_new_head(token, headlabel):
return token.head
def _filter_labels(gold_tuples, cutoff, freqs):
def _filter_labels(examples, cutoff, freqs):
# throw away infrequent decorated labels
# can't learn them reliably anyway and keeps label set smaller
filtered = []
for raw_text, sents in gold_tuples:
filtered_sents = []
for (ids, words, tags, heads, labels, iob), ctnts in sents:
for example in examples:
new_example = Example(doc=example.doc)
for token_annotation in example.token_annotations:
filtered_labels = []
for label in labels:
for label in token_annotation.deps:
if is_decorated(label) and freqs.get(label, 0) < cutoff:
filtered_labels.append(decompose(label)[0])
else:
filtered_labels.append(label)
filtered_sents.append(
((ids, words, tags, heads, filtered_labels, iob), ctnts))
filtered.append((raw_text, filtered_sents))
filtered_token_dict = token_annotation.to_dict()
filtered_token_dict["deps"] = filtered_labels
new_example.add_token_annotation(**filtered_token_dict)
filtered.append(new_example)
return filtered

View File

@ -37,7 +37,7 @@ def _train_parser(parser):
losses = {}
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
gold = GoldParse(doc, heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"])
parser.update([doc], [gold], sgd=sgd, losses=losses)
parser.update((doc, gold), sgd=sgd, losses=losses)
return parser
@ -51,7 +51,7 @@ def test_add_label(parser):
gold = GoldParse(
doc, heads=[1, 1, 3, 3], deps=["right", "ROOT", "left", "ROOT"]
)
parser.update([doc], [gold], sgd=sgd, losses=losses)
parser.update((doc, gold), sgd=sgd, losses=losses)
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
doc = parser(doc)
assert doc[0].dep_ == "right"

View File

@ -130,18 +130,25 @@ annot_tuples = [
def test_get_oracle_actions():
ids, words, tags, heads, deps, ents = [], [], [], [], [], []
for id_, word, tag, head, dep, ent in annot_tuples:
ids.append(id_)
words.append(word)
tags.append(tag)
heads.append(head)
deps.append(dep)
ents.append(ent)
doc = Doc(Vocab(), words=[t[1] for t in annot_tuples])
parser = DependencyParser(doc.vocab)
parser.moves.add_action(0, "")
parser.moves.add_action(1, "")
parser.moves.add_action(1, "")
parser.moves.add_action(4, "ROOT")
for i, (id_, word, tag, head, dep, ent) in enumerate(annot_tuples):
for i, (head, dep) in enumerate(zip(heads, deps)):
if head > i:
parser.moves.add_action(2, dep)
elif head < i:
parser.moves.add_action(3, dep)
ids, words, tags, heads, deps, ents = zip(*annot_tuples)
heads, deps = projectivize(heads, deps)
gold = GoldParse(doc, words=words, tags=tags, heads=heads, deps=deps)
parser.moves.preprocess_gold(gold)

View File

@ -67,7 +67,7 @@ def test_update_doc(parser, model, doc, gold):
def optimize(weights, gradient, key=None):
weights -= 0.001 * gradient
parser.update([doc], [gold], sgd=optimize)
parser.update((doc, gold), sgd=optimize)
@pytest.mark.xfail
@ -83,4 +83,4 @@ def test_update_doc_beam(parser, model, doc, gold):
def optimize(weights, gradient, key=None):
weights -= 0.001 * gradient
parser.update_beam([doc], [gold], sgd=optimize)
parser.update_beam((doc, gold), sgd=optimize)

View File

@ -30,7 +30,7 @@ def parser(vocab):
losses = {}
doc = Doc(vocab, words=["a", "b", "c", "d"])
gold = GoldParse(doc, heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"])
parser.update([doc], [gold], sgd=sgd, losses=losses)
parser.update((doc, gold), sgd=sgd, losses=losses)
return parser

View File

@ -24,7 +24,7 @@ def test_simple_train():
("bbbbbbbbb", 0.0),
("aaaaaa", 1),
]:
nlp.update([text], [{"cats": {"answer": answer}}])
nlp.update((text, {"cats": {"answer": answer}}))
doc = nlp("aaa")
assert "answer" in doc.cats
assert doc.cats["answer"] >= 0.5

View File

@ -451,7 +451,7 @@ def test_issue999(train_data):
for itn in range(100):
random.shuffle(TRAIN_DATA)
for raw_text, entity_offsets in TRAIN_DATA:
nlp.update([raw_text], [{"entities": entity_offsets}])
nlp.update((raw_text, {"entities": entity_offsets}))
with make_tempdir() as model_dir:
nlp.to_disk(model_dir)

View File

@ -5,6 +5,8 @@ import pytest
import gc
import numpy
import copy
from spacy.gold import Example
from spacy.lang.en import English
from spacy.lang.en.stop_words import STOP_WORDS
from spacy.lang.lex_attrs import is_stop
@ -270,9 +272,9 @@ def test_issue1963(en_tokenizer):
@pytest.mark.parametrize("label", ["U-JOB-NAME"])
def test_issue1967(label):
ner = EntityRecognizer(Vocab())
entry = ([0], ["word"], ["tag"], [0], ["dep"], [label])
gold_parses = [(None, [(entry, None)])]
ner.moves.get_actions(gold_parses=gold_parses)
example = Example(doc=None)
example.add_token_annotation(ids=[0], words=["word"], tags=["tag"], heads=[0], deps=["dep"], entities=[label])
ner.moves.get_actions(gold_parses=[example])
def test_issue1971(en_vocab):

View File

@ -157,7 +157,7 @@ def test_issue2800():
losses = {}
random.shuffle(train_data)
for statement, entities in train_data:
nlp.update([statement], [entities], sgd=optimizer, losses=losses, drop=0.5)
nlp.update((statement, entities), sgd=optimizer, losses=losses, drop=0.5)
def test_issue2822(it_tokenizer):

View File

@ -41,10 +41,8 @@ def test_issue3611():
batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001))
for batch in batches:
texts, annotations = zip(*batch)
nlp.update(
docs=texts,
golds=annotations,
examples=batch,
sgd=optimizer,
drop=0.1,
losses=losses,

View File

@ -41,10 +41,8 @@ def test_issue4030():
batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001))
for batch in batches:
texts, annotations = zip(*batch)
nlp.update(
docs=texts,
golds=annotations,
examples=batch,
sgd=optimizer,
drop=0.1,
losses=losses,

View File

@ -19,5 +19,4 @@ def test_issue4348():
losses = {}
batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
for batch in batches:
texts, annotations = zip(*batch)
nlp.update(texts, annotations, sgd=optimizer, losses=losses)
nlp.update(batch, sgd=optimizer, losses=losses)

View File

@ -11,15 +11,14 @@ from spacy.tests.util import make_tempdir
def test_issue4402():
nlp = English()
with make_tempdir() as tmpdir:
print("temp", tmpdir)
json_path = tmpdir / "test4402.json"
srsly.write_json(json_path, json_data)
corpus = GoldCorpus(str(json_path), str(json_path))
train_docs = list(corpus.train_docs(nlp, gold_preproc=True, max_length=0))
train_data = list(corpus.train_dataset(nlp, gold_preproc=True, max_length=0))
# assert that the data got split into 4 sentences
assert len(train_docs) == 4
assert len(train_data) == 4
json_data = [

View File

@ -1,11 +1,12 @@
# coding: utf-8
from __future__ import unicode_literals
from spacy.gold import biluo_tags_from_offsets, offsets_from_biluo_tags
from spacy.gold import biluo_tags_from_offsets, offsets_from_biluo_tags, Example, DocAnnotation
from spacy.gold import spans_from_biluo_tags, GoldParse, iob_to_biluo
from spacy.gold import GoldCorpus, docs_to_json, align
from spacy.lang.en import English
from spacy.tokens import Doc
from spacy.util import compounding, minibatch
from .util import make_tempdir
import pytest
import srsly
@ -119,12 +120,13 @@ def test_roundtrip_docs_to_json():
with make_tempdir() as tmpdir:
json_file = tmpdir / "roundtrip.json"
srsly.write_json(json_file, [docs_to_json(doc)])
goldcorpus = GoldCorpus(str(json_file), str(json_file))
goldcorpus = GoldCorpus(train=str(json_file), dev=str(json_file))
reloaded_doc, goldparse = next(goldcorpus.train_docs(nlp))
reloaded_example = next(goldcorpus.train_dataset(nlp))
goldparse = reloaded_example.gold
assert len(doc) == goldcorpus.count_train()
assert text == reloaded_doc.text
assert text == reloaded_example.text
assert tags == goldparse.tags
assert deps == goldparse.labels
assert heads == goldparse.heads
@ -140,10 +142,11 @@ def test_roundtrip_docs_to_json():
srsly.write_jsonl(jsonl_file, [docs_to_json(doc)])
goldcorpus = GoldCorpus(str(jsonl_file), str(jsonl_file))
reloaded_doc, goldparse = next(goldcorpus.train_docs(nlp))
reloaded_example = next(goldcorpus.train_dataset(nlp))
goldparse = reloaded_example.gold
assert len(doc) == goldcorpus.count_train()
assert text == reloaded_doc.text
assert text == reloaded_example.text
assert tags == goldparse.tags
assert deps == goldparse.labels
assert heads == goldparse.heads
@ -160,13 +163,14 @@ def test_roundtrip_docs_to_json():
srsly.write_jsonl(jsonl_file, [docs_to_json(doc)])
goldcorpus = GoldCorpus(str(jsonl_file), str(jsonl_file))
# load and rewrite as JSONL tuples
srsly.write_jsonl(jsonl_file, goldcorpus.train_tuples)
srsly.write_jsonl(jsonl_file, goldcorpus.train_examples)
goldcorpus = GoldCorpus(str(jsonl_file), str(jsonl_file))
reloaded_doc, goldparse = next(goldcorpus.train_docs(nlp))
reloaded_example = next(goldcorpus.train_dataset(nlp))
goldparse = reloaded_example.gold
assert len(doc) == goldcorpus.count_train()
assert text == reloaded_doc.text
assert text == reloaded_example.text
assert tags == goldparse.tags
assert deps == goldparse.labels
assert heads == goldparse.heads
@ -217,3 +221,144 @@ def test_goldparse_startswith_space(en_tokenizer):
assert g.words == [" ", "a"]
assert g.ner == [None, "U-DATE"]
assert g.labels == [None, "ROOT"]
def test_gold_constructor():
"""Test that the GoldParse constructor works fine"""
nlp = English()
doc = nlp("This is a sentence")
gold = GoldParse(doc, cats={"cat1": 1.0, "cat2": 0.0})
assert gold.cats["cat1"]
assert not gold.cats["cat2"]
assert gold.words == ["This", "is", "a", "sentence"]
def test_gold_orig_annot():
nlp = English()
doc = nlp("This is a sentence")
gold = GoldParse(doc, cats={"cat1": 1.0, "cat2": 0.0})
assert gold.orig.words == ["This", "is", "a", "sentence"]
assert gold.cats["cat1"]
doc_annotation = DocAnnotation(cats={"cat1": 0.0, "cat2": 1.0})
gold2 = GoldParse.from_annotation(doc, doc_annotation, gold.orig)
assert gold2.orig.words == ["This", "is", "a", "sentence"]
assert not gold2.cats["cat1"]
def test_tuple_format_implicit():
"""Test tuple format with implicit GoldParse creation"""
train_data = [
("Uber blew through $1 million a week", {"entities": [(0, 4, "ORG")]}),
(
"Spotify steps up Asia expansion",
{"entities": [(0, 8, "ORG"), (17, 21, "LOC")]},
),
("Google rebrands its business apps", {"entities": [(0, 6, "ORG")]}),
]
_train(train_data)
def test_tuple_format_implicit_invalid():
"""Test that an error is thrown for an implicit invalid GoldParse field"""
train_data = [
("Uber blew through $1 million a week", {"frumble": [(0, 4, "ORG")]}),
(
"Spotify steps up Asia expansion",
{"entities": [(0, 8, "ORG"), (17, 21, "LOC")]},
),
("Google rebrands its business apps", {"entities": [(0, 6, "ORG")]}),
]
with pytest.raises(TypeError):
_train(train_data)
def _train(train_data):
nlp = English()
ner = nlp.create_pipe("ner")
ner.add_label("ORG")
ner.add_label("LOC")
nlp.add_pipe(ner)
optimizer = nlp.begin_training()
for i in range(5):
losses = {}
batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001))
for batch in batches:
nlp.update(batch, sgd=optimizer, losses=losses)
tokens_1 = {
"ids": [1, 2, 3],
"words": ["Hi", "there", "everyone"],
"tags": ["INTJ", "ADV", "PRON"],
}
tokens_2 = {
"ids": [1, 2, 3, 4],
"words": ["It", "is", "just", "me"],
"tags": ["PRON", "AUX", "ADV", "PRON"],
}
text0 = "Hi there everyone It is just me"
def test_merge_sents():
nlp = English()
example = Example()
example.add_token_annotation(**tokens_1)
example.add_token_annotation(**tokens_2)
assert len(example.get_gold_parses(merge=False, vocab=nlp.vocab)) == 2
assert len(example.get_gold_parses(merge=True, vocab=nlp.vocab)) == 1 # this shouldn't change the original object
merged_example = example.merge_sents()
token_annotation_1 = example.token_annotations[0]
assert token_annotation_1.ids == [1, 2, 3]
assert token_annotation_1.words == ["Hi", "there", "everyone"]
assert token_annotation_1.tags == ["INTJ", "ADV", "PRON"]
token_annotation_m = merged_example.token_annotations[0]
assert token_annotation_m.ids == [1, 2, 3, 4, 5, 6, 7]
assert token_annotation_m.words == ["Hi", "there", "everyone", "It", "is", "just", "me"]
assert token_annotation_m.tags == ["INTJ", "ADV", "PRON", "PRON", "AUX", "ADV", "PRON"]
def test_tuples_to_example():
ex = Example()
ex.add_token_annotation(**tokens_1)
ex.add_token_annotation(**tokens_2)
ex.add_doc_annotation(cats={"TRAVEL": 1.0, "BAKING": 0.0})
ex_dict = ex.to_dict()
token_dicts = [
{
"ids": [1, 2, 3],
"words": ["Hi", "there", "everyone"],
"tags": ["INTJ", "ADV", "PRON"],
"heads": [],
"deps": [],
"entities": [],
"morphology": [],
"brackets": [],
},
{
"ids": [1, 2, 3, 4],
"words": ["It", "is", "just", "me"],
"tags": ["PRON", "AUX", "ADV", "PRON"],
"heads": [],
"deps": [],
"entities": [],
"morphology": [],
"brackets": [],
},
]
doc_dict = {"cats": {"TRAVEL": 1.0, "BAKING": 0.0}, "links": {}}
assert ex_dict == {"token_annotations": token_dicts, "doc_annotation": doc_dict}

View File

@ -31,20 +31,20 @@ def test_language_update(nlp):
doc = Doc(nlp.vocab, words=text.split(" "))
gold = GoldParse(doc, **annots)
# Update with doc and gold objects
nlp.update([doc], [gold])
nlp.update((doc, gold))
# Update with text and dict
nlp.update([text], [annots])
nlp.update((text, annots))
# Update with doc object and dict
nlp.update([doc], [annots])
nlp.update((doc, annots))
# Update with text and gold object
nlp.update([text], [gold])
nlp.update((text, gold))
# Update with empty doc and gold object
nlp.update((None, gold))
# Update badly
with pytest.raises(IndexError):
nlp.update([doc], [])
with pytest.raises(IndexError):
nlp.update([], [gold])
with pytest.raises(ValueError):
nlp.update([text], [wrongkeyannots])
nlp.update((doc, None))
with pytest.raises(TypeError):
nlp.update((text, wrongkeyannots))
def test_language_evaluate(nlp):

View File

@ -4,7 +4,7 @@ from __future__ import unicode_literals
from numpy.testing import assert_almost_equal, assert_array_almost_equal
import pytest
from pytest import approx
from spacy.gold import GoldParse
from spacy.gold import Example, GoldParse
from spacy.scorer import Scorer, ROCAUCScore
from spacy.scorer import _roc_auc_score, _roc_curve
from .util import get_doc
@ -40,7 +40,7 @@ def test_las_per_type(en_vocab):
deps=annot["deps"],
)
gold = GoldParse(doc, heads=annot["heads"], deps=annot["deps"])
scorer.score(doc, gold)
scorer.score((doc, gold))
results = scorer.scores
assert results["uas"] == 100
@ -63,7 +63,7 @@ def test_las_per_type(en_vocab):
)
gold = GoldParse(doc, heads=annot["heads"], deps=annot["deps"])
doc[0].dep_ = "compound"
scorer.score(doc, gold)
scorer.score((doc, gold))
results = scorer.scores
assert results["uas"] == 100
@ -85,8 +85,9 @@ def test_ner_per_type(en_vocab):
words=input_.split(" "),
ents=[[0, 1, "CARDINAL"], [2, 3, "CARDINAL"]],
)
gold = GoldParse(doc, entities=annot["entities"])
scorer.score(doc, gold)
ex = Example(doc=doc)
ex.add_token_annotation(entities=annot["entities"])
scorer.score(ex)
results = scorer.scores
assert results["ents_p"] == 100
@ -105,8 +106,9 @@ def test_ner_per_type(en_vocab):
words=input_.split(" "),
ents=[[0, 1, "ORG"], [5, 6, "GPE"], [6, 7, "ORG"]],
)
gold = GoldParse(doc, entities=annot["entities"])
scorer.score(doc, gold)
ex = Example(doc=doc)
ex.add_token_annotation(entities=annot["entities"])
scorer.score(ex)
results = scorer.scores
assert results["ents_p"] == approx(66.66666)

View File

@ -158,7 +158,7 @@ cdef class Tokenizer:
doc.c[doc.length - 1].spacy = string[-1] == " " and not in_ws
return doc
def pipe(self, texts, batch_size=1000, n_threads=-1):
def pipe(self, texts, batch_size=1000, n_threads=-1, as_example=False):
"""Tokenize a stream of texts.
texts: A sequence of unicode texts.

View File

@ -616,31 +616,25 @@ def decaying(start, stop, decay):
curr -= decay
def minibatch_by_words(items, size, tuples=True, count_words=len):
def minibatch_by_words(examples, size, tuples=True, count_words=len):
"""Create minibatches of a given number of words."""
if isinstance(size, int):
size_ = itertools.repeat(size)
else:
size_ = size
items = iter(items)
examples = iter(examples)
while True:
batch_size = next(size_)
batch = []
while batch_size >= 0:
try:
if tuples:
doc, gold = next(items)
else:
doc = next(items)
example = next(examples)
except StopIteration:
if batch:
yield batch
return
batch_size -= count_words(doc)
if tuples:
batch.append((doc, gold))
else:
batch.append(doc)
batch_size -= count_words(example.doc)
batch.append(example)
if batch:
yield batch