various fixes in scripts - needs to be further tested

This commit is contained in:
svlandeg 2020-06-17 12:05:58 +02:00
parent 3c4f9e4cc4
commit f7ad8e8c83
9 changed files with 63 additions and 126 deletions

View File

@ -14,7 +14,7 @@ import spacy
import spacy.util import spacy.util
from bin.ud import conll17_ud_eval from bin.ud import conll17_ud_eval
from spacy.tokens import Token, Doc from spacy.tokens import Token, Doc
from spacy.gold import GoldParse, Example from spacy.gold import Example
from spacy.util import compounding, minibatch, minibatch_by_words from spacy.util import compounding, minibatch, minibatch_by_words
from spacy.syntax.nonproj import projectivize from spacy.syntax.nonproj import projectivize
from spacy.matcher import Matcher from spacy.matcher import Matcher
@ -83,11 +83,11 @@ def read_data(
sent["heads"].append(head) sent["heads"].append(head)
sent["deps"].append("ROOT" if dep == "root" else dep) sent["deps"].append("ROOT" if dep == "root" else dep)
sent["spaces"].append(space_after == "_") sent["spaces"].append(space_after == "_")
sent["entities"] = ["-"] * len(sent["words"]) sent["entities"] = ["-"] * len(sent["words"]) # TODO: doc-level format
sent["heads"], sent["deps"] = projectivize(sent["heads"], sent["deps"]) sent["heads"], sent["deps"] = projectivize(sent["heads"], sent["deps"])
if oracle_segments: if oracle_segments:
docs.append(Doc(nlp.vocab, words=sent["words"], spaces=sent["spaces"])) docs.append(Doc(nlp.vocab, words=sent["words"], spaces=sent["spaces"]))
golds.append(GoldParse(docs[-1], **sent)) golds.append(sent)
assert golds[-1].morphology is not None assert golds[-1].morphology is not None
sent_annots.append(sent) sent_annots.append(sent)
@ -151,28 +151,27 @@ def read_conllu(file_):
def _make_gold(nlp, text, sent_annots, drop_deps=0.0): def _make_gold(nlp, text, sent_annots, drop_deps=0.0):
# Flatten the conll annotations, and adjust the head indices # Flatten the conll annotations, and adjust the head indices
flat = defaultdict(list) gold = defaultdict(list)
sent_starts = [] sent_starts = []
for sent in sent_annots: for sent in sent_annots:
flat["heads"].extend(len(flat["words"])+head for head in sent["heads"]) gold["heads"].extend(len(gold["words"])+head for head in sent["heads"])
for field in ["words", "tags", "deps", "morphology", "entities", "spaces"]: for field in ["words", "tags", "deps", "morphology", "entities", "spaces"]:
flat[field].extend(sent[field]) gold[field].extend(sent[field])
sent_starts.append(True) sent_starts.append(True)
sent_starts.extend([False] * (len(sent["words"]) - 1)) sent_starts.extend([False] * (len(sent["words"]) - 1))
# Construct text if necessary # Construct text if necessary
assert len(flat["words"]) == len(flat["spaces"]) assert len(gold["words"]) == len(gold["spaces"])
if text is None: if text is None:
text = "".join( text = "".join(
word + " " * space for word, space in zip(flat["words"], flat["spaces"]) word + " " * space for word, space in zip(gold["words"], gold["spaces"])
) )
doc = nlp.make_doc(text) doc = nlp.make_doc(text)
flat.pop("spaces") gold.pop("spaces")
gold = GoldParse(doc, **flat) gold["sent_starts"] = sent_starts
gold.sent_starts = sent_starts
for i in range(len(gold.heads)): for i in range(len(gold.heads)):
if random.random() < drop_deps: if random.random() < drop_deps:
gold.heads[i] = None gold["heads"][i] = None
gold.labels[i] = None gold["labels"][i] = None
return doc, gold return doc, gold
@ -183,15 +182,10 @@ def _make_gold(nlp, text, sent_annots, drop_deps=0.0):
def golds_to_gold_data(docs, golds): def golds_to_gold_data(docs, golds):
"""Get out the training data format used by begin_training, given the """Get out the training data format used by begin_training"""
GoldParse objects."""
data = [] data = []
for doc, gold in zip(docs, golds): for doc, gold in zip(docs, golds):
example = Example(doc=doc) example = Example.from_dict(doc, gold)
example.add_doc_annotation(cats=gold.cats)
token_annotation_dict = gold.orig.to_dict()
example.add_token_annotation(**token_annotation_dict)
example.goldparse = gold
data.append(example) data.append(example)
return data return data
@ -359,8 +353,8 @@ def initialize_pipeline(nlp, examples, config, device):
nlp.parser.add_multitask_objective("tag") nlp.parser.add_multitask_objective("tag")
if config.multitask_sent: if config.multitask_sent:
nlp.parser.add_multitask_objective("sent_start") nlp.parser.add_multitask_objective("sent_start")
for ex in examples: for eg in examples:
gold = ex.gold gold = eg.gold
for tag in gold.tags: for tag in gold.tags:
if tag is not None: if tag is not None:
nlp.tagger.add_label(tag) nlp.tagger.add_label(tag)
@ -541,7 +535,7 @@ def main(
else: else:
batches = minibatch(examples, size=batch_sizes) batches = minibatch(examples, size=batch_sizes)
losses = {} losses = {}
n_train_words = sum(len(ex.doc) for ex in examples) n_train_words = sum(len(eg.doc) for eg in examples)
with tqdm.tqdm(total=n_train_words, leave=False) as pbar: with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
for batch in batches: for batch in batches:
pbar.update(sum(len(ex.doc) for ex in batch)) pbar.update(sum(len(ex.doc) for ex in batch))

View File

@ -12,7 +12,7 @@ import tqdm
import spacy import spacy
import spacy.util import spacy.util
from spacy.tokens import Token, Doc from spacy.tokens import Token, Doc
from spacy.gold import GoldParse, Example from spacy.gold import Example
from spacy.syntax.nonproj import projectivize from spacy.syntax.nonproj import projectivize
from collections import defaultdict from collections import defaultdict
from spacy.matcher import Matcher from spacy.matcher import Matcher
@ -33,31 +33,6 @@ random.seed(0)
numpy.random.seed(0) numpy.random.seed(0)
def minibatch_by_words(examples, size=5000):
random.shuffle(examples)
if isinstance(size, int):
size_ = itertools.repeat(size)
else:
size_ = size
examples = iter(examples)
while True:
batch_size = next(size_)
batch = []
while batch_size >= 0:
try:
example = next(examples)
except StopIteration:
if batch:
yield batch
return
batch_size -= len(example.doc)
batch.append(example)
if batch:
yield batch
else:
break
################ ################
# Data reading # # Data reading #
################ ################
@ -110,7 +85,7 @@ def read_data(
sent["heads"], sent["deps"] = projectivize(sent["heads"], sent["deps"]) sent["heads"], sent["deps"] = projectivize(sent["heads"], sent["deps"])
if oracle_segments: if oracle_segments:
docs.append(Doc(nlp.vocab, words=sent["words"], spaces=sent["spaces"])) docs.append(Doc(nlp.vocab, words=sent["words"], spaces=sent["spaces"]))
golds.append(GoldParse(docs[-1], **sent)) golds.append(sent)
sent_annots.append(sent) sent_annots.append(sent)
if raw_text and max_doc_length and len(sent_annots) >= max_doc_length: if raw_text and max_doc_length and len(sent_annots) >= max_doc_length:
@ -159,20 +134,19 @@ def read_conllu(file_):
def _make_gold(nlp, text, sent_annots): def _make_gold(nlp, text, sent_annots):
# Flatten the conll annotations, and adjust the head indices # Flatten the conll annotations, and adjust the head indices
flat = defaultdict(list) gold = defaultdict(list)
for sent in sent_annots: for sent in sent_annots:
flat["heads"].extend(len(flat["words"]) + head for head in sent["heads"]) gold["heads"].extend(len(gold["words"]) + head for head in sent["heads"])
for field in ["words", "tags", "deps", "entities", "spaces"]: for field in ["words", "tags", "deps", "entities", "spaces"]:
flat[field].extend(sent[field]) gold[field].extend(sent[field])
# Construct text if necessary # Construct text if necessary
assert len(flat["words"]) == len(flat["spaces"]) assert len(gold["words"]) == len(gold["spaces"])
if text is None: if text is None:
text = "".join( text = "".join(
word + " " * space for word, space in zip(flat["words"], flat["spaces"]) word + " " * space for word, space in zip(gold["words"], gold["spaces"])
) )
doc = nlp.make_doc(text) doc = nlp.make_doc(text)
flat.pop("spaces") gold.pop("spaces")
gold = GoldParse(doc, **flat)
return doc, gold return doc, gold
@ -182,15 +156,10 @@ def _make_gold(nlp, text, sent_annots):
def golds_to_gold_data(docs, golds): def golds_to_gold_data(docs, golds):
"""Get out the training data format used by begin_training, given the """Get out the training data format used by begin_training."""
GoldParse objects."""
data = [] data = []
for doc, gold in zip(docs, golds): for doc, gold in zip(docs, golds):
example = Example(doc=doc) example = Example.from_dict(doc, gold)
example.add_doc_annotation(cats=gold.cats)
token_annotation_dict = gold.orig.to_dict()
example.add_token_annotation(**token_annotation_dict)
example.goldparse = gold
data.append(example) data.append(example)
return data return data
@ -313,15 +282,15 @@ def initialize_pipeline(nlp, examples, config):
nlp.parser.add_multitask_objective("sent_start") nlp.parser.add_multitask_objective("sent_start")
nlp.parser.moves.add_action(2, "subtok") nlp.parser.moves.add_action(2, "subtok")
nlp.add_pipe(nlp.create_pipe("tagger")) nlp.add_pipe(nlp.create_pipe("tagger"))
for ex in examples: for eg in examples:
for tag in ex.gold.tags: for tag in eg.gold.tags:
if tag is not None: if tag is not None:
nlp.tagger.add_label(tag) nlp.tagger.add_label(tag)
# Replace labels that didn't make the frequency cutoff # Replace labels that didn't make the frequency cutoff
actions = set(nlp.parser.labels) actions = set(nlp.parser.labels)
label_set = set([act.split("-")[1] for act in actions if "-" in act]) label_set = set([act.split("-")[1] for act in actions if "-" in act])
for ex in examples: for eg in examples:
gold = ex.gold gold = eg.gold
for i, label in enumerate(gold.labels): for i, label in enumerate(gold.labels):
if label is not None and label not in label_set: if label is not None and label not in label_set:
gold.labels[i] = label.split("||")[0] gold.labels[i] = label.split("||")[0]
@ -415,13 +384,12 @@ def main(ud_dir, parses_dir, config, corpus, limit=0):
optimizer = initialize_pipeline(nlp, examples, config) optimizer = initialize_pipeline(nlp, examples, config)
for i in range(config.nr_epoch): for i in range(config.nr_epoch):
docs = [nlp.make_doc(example.doc.text) for example in examples] batches = spacy.minibatch_by_words(examples, size=config.batch_size)
batches = minibatch_by_words(examples, size=config.batch_size)
losses = {} losses = {}
n_train_words = sum(len(doc) for doc in docs) n_train_words = sum(len(eg.reference.doc) for eg in examples)
with tqdm.tqdm(total=n_train_words, leave=False) as pbar: with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
for batch in batches: for batch in batches:
pbar.update(sum(len(ex.doc) for ex in batch)) pbar.update(sum(len(eg.reference.doc) for eg in batch))
nlp.update( nlp.update(
examples=batch, sgd=optimizer, drop=config.dropout, losses=losses, examples=batch, sgd=optimizer, drop=config.dropout, losses=losses,
) )

View File

@ -24,8 +24,10 @@ import random
import plac import plac
import spacy import spacy
import os.path import os.path
from spacy.gold.example import Example
from spacy.tokens import Doc from spacy.tokens import Doc
from spacy.gold import read_json_file, GoldParse from spacy.gold import read_json_file
random.seed(0) random.seed(0)
@ -59,17 +61,15 @@ def main(n_iter=10):
print(nlp.pipeline) print(nlp.pipeline)
print("Create data", len(TRAIN_DATA)) print("Create data", len(TRAIN_DATA))
optimizer = nlp.begin_training(get_examples=lambda: TRAIN_DATA) optimizer = nlp.begin_training()
for itn in range(n_iter): for itn in range(n_iter):
random.shuffle(TRAIN_DATA) random.shuffle(TRAIN_DATA)
losses = {} losses = {}
for example in TRAIN_DATA: for example_dict in TRAIN_DATA:
for token_annotation in example.token_annotations: doc = Doc(nlp.vocab, words=example_dict["words"])
doc = Doc(nlp.vocab, words=token_annotation.words) example = Example.from_dict(doc, example_dict)
gold = GoldParse.from_annotation(doc, example.doc_annotation, token_annotation)
nlp.update( nlp.update(
examples=[(doc, gold)], # 1 example examples=[example], # 1 example
drop=0.2, # dropout - make it harder to memorise data drop=0.2, # dropout - make it harder to memorise data
sgd=optimizer, # callable to update weights sgd=optimizer, # callable to update weights
losses=losses, losses=losses,
@ -77,9 +77,9 @@ def main(n_iter=10):
print(losses.get("nn_labeller", 0.0), losses["ner"]) print(losses.get("nn_labeller", 0.0), losses["ner"])
# test the trained model # test the trained model
for example in TRAIN_DATA: for example_dict in TRAIN_DATA:
if example.text is not None: if "text" in example_dict:
doc = nlp(example.text) doc = nlp(example_dict["text"])
print("Entities", [(ent.text, ent.label_) for ent in doc.ents]) print("Entities", [(ent.text, ent.label_) for ent in doc.ents])
print("Tokens", [(t.text, t.ent_type_, t.ent_iob) for t in doc]) print("Tokens", [(t.text, t.ent_type_, t.ent_iob) for t in doc])

View File

@ -1,8 +1,7 @@
import re import re
from ...gold import Example from ...gold import Example
from ...gold import iob_to_biluo, spans_from_biluo_tags, biluo_tags_from_offsets from ...gold import iob_to_biluo, spans_from_biluo_tags
from ...gold import TokenAnnotation
from ...language import Language from ...language import Language
from ...tokens import Doc, Token from ...tokens import Doc, Token
from .conll_ner2json import n_sents_info from .conll_ner2json import n_sents_info
@ -42,10 +41,10 @@ def conllu2json(
) )
has_ner_tags = has_ner(input_data, MISC_NER_PATTERN) has_ner_tags = has_ner(input_data, MISC_NER_PATTERN)
for i, example in enumerate(conll_data): for i, example in enumerate(conll_data):
raw += example.text raw += example.predicted.text
sentences.append( sentences.append(
generate_sentence( generate_sentence(
example.token_annotation, example,
has_ner_tags, has_ner_tags,
MISC_NER_PATTERN, MISC_NER_PATTERN,
ner_map=ner_map, ner_map=ner_map,
@ -268,36 +267,14 @@ def example_from_conllu_sentence(
doc = merge_conllu_subtokens(lines, doc) doc = merge_conllu_subtokens(lines, doc)
# create Example from custom Doc annotation # create Example from custom Doc annotation
ids, words, tags, heads, deps = [], [], [], [], [] words, spaces = [], []
pos, lemmas, morphs, spaces = [], [], [], []
for i, t in enumerate(doc): for i, t in enumerate(doc):
ids.append(i)
words.append(t._.merged_orth) words.append(t._.merged_orth)
if append_morphology and t._.merged_morph:
tags.append(t.tag_ + "__" + t._.merged_morph)
else:
tags.append(t.tag_)
pos.append(t.pos_)
morphs.append(t._.merged_morph)
lemmas.append(t._.merged_lemma)
heads.append(t.head.i)
deps.append(t.dep_)
spaces.append(t._.merged_spaceafter) spaces.append(t._.merged_spaceafter)
ent_offsets = [(e.start_char, e.end_char, e.label_) for e in doc.ents] if append_morphology and t._.merged_morph:
ents = biluo_tags_from_offsets(doc, ent_offsets) t.tag_ = t.tag_ + "__" + t._.merged_morph
example = Example(doc=Doc(vocab, words=words, spaces=spaces))
example.token_annotation = TokenAnnotation( return Example(predicted=Doc(vocab, words=words, spaces=spaces), reference=doc)
ids=ids,
words=words,
tags=tags,
pos=pos,
morphs=morphs,
lemmas=lemmas,
heads=heads,
deps=deps,
entities=ents,
)
return example
def merge_conllu_subtokens(lines, doc): def merge_conllu_subtokens(lines, doc):

View File

@ -69,6 +69,7 @@ def docs_to_json(docs, id=0, ner_missing_tag="O"):
def read_json_file(loc, docs_filter=None, limit=None): def read_json_file(loc, docs_filter=None, limit=None):
"""Read Example dictionaries from a json file or directory."""
loc = util.ensure_path(loc) loc = util.ensure_path(loc)
if loc.is_dir(): if loc.is_dir():
for filename in loc.iterdir(): for filename in loc.iterdir():
@ -105,7 +106,7 @@ def json_to_annotations(doc):
sent_start_i = len(words) sent_start_i = len(words)
for i, token in enumerate(sent["tokens"]): for i, token in enumerate(sent["tokens"]):
words.append(token["orth"]) words.append(token["orth"])
spaces.append(token["space"]) spaces.append(token.get("space", True))
ids.append(token.get('id', sent_start_i + i)) ids.append(token.get('id', sent_start_i + i))
tags.append(token.get('tag', "-")) tags.append(token.get('tag', "-"))
pos.append(token.get("pos", "")) pos.append(token.get("pos", ""))

View File

@ -804,7 +804,6 @@ class Language(object):
cleanup=False, cleanup=False,
component_cfg=None, component_cfg=None,
n_process=1, n_process=1,
as_example=False,
): ):
"""Process texts as a stream, and yield `Doc` objects in order. """Process texts as a stream, and yield `Doc` objects in order.
@ -837,8 +836,7 @@ class Language(object):
batch_size=batch_size, batch_size=batch_size,
disable=disable, disable=disable,
n_process=n_process, n_process=n_process,
component_cfg=component_cfg, component_cfg=component_cfg
as_example=as_example,
) )
for doc, context in zip(docs, contexts): for doc, context in zip(docs, contexts):
yield (doc, context) yield (doc, context)

View File

@ -26,7 +26,7 @@ def test_sentencizer_pipe():
sent_starts = [t.is_sent_start for t in doc] sent_starts = [t.is_sent_start for t in doc]
assert sent_starts == [True, False, True, False, False, False, False] assert sent_starts == [True, False, True, False, False, False, False]
assert len(list(doc.sents)) == 2 assert len(list(doc.sents)) == 2
for ex in nlp.pipe(texts, as_example=True): for ex in nlp.pipe(texts):
doc = ex.doc doc = ex.doc
assert doc.is_sentenced assert doc.is_sentenced
sent_starts = [t.is_sent_start for t in doc] sent_starts = [t.is_sent_start for t in doc]

View File

@ -1,5 +1,4 @@
import pytest import pytest
from spacy.gold import Example
from .util import get_random_doc from .util import get_random_doc

View File

@ -205,7 +205,7 @@ cdef class Tokenizer:
doc.c[doc.length - 1].spacy = string[-1] == " " and not in_ws doc.c[doc.length - 1].spacy = string[-1] == " " and not in_ws
return doc return doc
def pipe(self, texts, batch_size=1000, n_threads=-1, as_example=False): def pipe(self, texts, batch_size=1000, n_threads=-1):
"""Tokenize a stream of texts. """Tokenize a stream of texts.
texts: A sequence of unicode texts. texts: A sequence of unicode texts.