Remove GoldParse from public API

* Move get_parses_from_example to spacy.syntax

* Get GoldParse out of Example

* Avoid expecting GoldParse input in parser

* Add Alignment to spacy.gold.align

* Update Example object

* Add comment

* Update pipeline

* Fix imports

* Simplify gold_io

* WIP on GoldCorpus

* Update test

* Xfail some gold tests

* Remove ignore_misaligned option from GoldCorpus

* Fix Example constructor

* Update test

* Fix usage of Example

* Add deprecated_get_gold method on Example

* Patch scorer

* Fix test

* Fix test

* Update tests

* Xfail a test

* Fix passing of make_projective

* Pass make_projective by default

* Hack data format in Example.from_dict

* Update tests

* Fix example.from_dict

* Update morphologizer

* Fix entity linker

* Add get_field to TokenAnnotation

* Fix Example.get_aligned

* Update test

* Fix alignment

* Fix corpus

* Fix GoldCorpus

* Handle misaligned

* Format

* Fix missing import
This commit is contained in:
Matthew Honnibal 2020-06-08 22:09:57 +02:00 committed by GitHub
parent b69fa77ccc
commit 084271c9e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 315 additions and 303 deletions

View File

@ -11,6 +11,7 @@ from thinc.api import Model, use_pytorch_for_gpu_memory
import random
from ..gold import GoldCorpus
from ..gold import Example
from .. import util
from ..errors import Errors
from ..ml import models # don't remove - required to load the built-in architectures
@ -243,7 +244,7 @@ def create_train_batches(nlp, corpus, cfg):
orth_variant_level=cfg["orth_variant_level"],
gold_preproc=cfg["gold_preproc"],
max_length=cfg["max_length"],
ignore_misaligned=True,
ignore_misaligned=True
))
if len(train_examples) == 0:
raise ValueError(Errors.E988)
@ -271,6 +272,7 @@ def create_evaluation_callback(nlp, optimizer, corpus, cfg):
nlp, gold_preproc=cfg["gold_preproc"], ignore_misaligned=True
)
)
n_words = sum(len(ex.doc) for ex in dev_examples)
start_time = timer()

View File

@ -10,4 +10,4 @@ from .iob_utils import spans_from_biluo_tags
from .iob_utils import tags_to_entities
from .gold_io import docs_to_json
from .gold_io import read_json_file, read_json_object
from .gold_io import read_json_file

View File

@ -2,6 +2,26 @@ import numpy
from ..errors import Errors, AlignmentError
class Alignment:
def __init__(self, spacy_words, gold_words):
# Do many-to-one alignment for misaligned tokens.
# If we over-segment, we'll have one gold word that covers a sequence
# of predicted words
# If we under-segment, we'll have one predicted word that covers a
# sequence of gold words.
# If we "mis-segment", we'll have a sequence of predicted words covering
# a sequence of gold words. That's many-to-many -- we don't do that
# except for NER spans where the start and end can be aligned.
cost, i2j, j2i, i2j_multi, j2i_multi = align(spacy_words, gold_words)
self.cost = cost
self.i2j = i2j
self.j2i = j2i
self.i2j_multi = i2j_multi
self.j2i_multi = j2i_multi
self.cand_to_gold = [(j if j >= 0 else None) for j in i2j]
self.gold_to_cand = [(i if i >= 0 else None) for i in j2i]
def align(tokens_a, tokens_b):
"""Calculate alignment tables between two tokenizations.

View File

@ -28,6 +28,30 @@ class TokenAnnotation:
for b_start, b_end, b_label in brackets:
self.brackets_by_start.setdefault(b_start, []).append((b_end, b_label))
def get_field(self, field):
if field == "id":
return self.ids
elif field == "word":
return self.words
elif field == "tag":
return self.tags
elif field == "pos":
return self.pos
elif field == "morph":
return self.morphs
elif field == "lemma":
return self.lemmas
elif field == "head":
return self.heads
elif field == "dep":
return self.deps
elif field == "ner":
return self.entities
elif field == "sent_start":
return self.sent_starts
else:
raise ValueError(f"Unknown field: {field}")
@property
def brackets(self):
brackets = []

View File

@ -6,8 +6,8 @@ from pathlib import Path
import itertools
from ..tokens import Doc
from .. import util
from ..errors import Errors
from .gold_io import read_json_file, read_json_object
from ..errors import Errors, AlignmentError
from .gold_io import read_json_file, json_to_examples
from .augment import make_orth_variants, add_noise
from .example import Example
@ -43,9 +43,8 @@ class GoldCorpus(object):
if not directory.exists():
directory.mkdir()
n = 0
for i, example in enumerate(examples):
ex_dict = example.to_dict()
text = example.text
for i, ex_dict in enumerate(examples):
text = ex_dict["text"]
srsly.write_msgpack(directory / f"{i}.msg", (text, ex_dict))
n += 1
if limit and n >= limit:
@ -87,7 +86,9 @@ class GoldCorpus(object):
# TODO: proper format checks with schemas
if isinstance(first_gold_tuple, dict):
if first_gold_tuple.get("paragraphs", None):
examples = read_json_object(gold_tuples)
examples = []
for json_doc in gold_tuples:
examples.extend(json_to_examples(json_doc))
elif first_gold_tuple.get("doc_annotation", None):
examples = []
for ex_dict in gold_tuples:
@ -117,7 +118,7 @@ class GoldCorpus(object):
except KeyError as e:
msg = "Missing key {}".format(e)
raise KeyError(Errors.E996.format(file=file_name, msg=msg))
except UnboundLocalError:
except UnboundLocalError as e:
msg = "Unexpected document structure"
raise ValueError(Errors.E996.format(file=file_name, msg=msg))
@ -200,9 +201,9 @@ class GoldCorpus(object):
):
""" Setting gold_preproc will result in creating a doc per sentence """
for example in examples:
example_docs = []
if gold_preproc:
split_examples = example.split_sents()
example_golds = []
for split_example in split_examples:
split_example_docs = cls._make_docs(
nlp,
@ -211,13 +212,7 @@ class GoldCorpus(object):
noise_level=noise_level,
orth_variant_level=orth_variant_level,
)
split_example_golds = cls._make_golds(
split_example_docs,
vocab=nlp.vocab,
make_projective=make_projective,
ignore_misaligned=ignore_misaligned,
)
example_golds.extend(split_example_golds)
example_docs.extend(split_example_docs)
else:
example_docs = cls._make_docs(
nlp,
@ -226,16 +221,14 @@ class GoldCorpus(object):
noise_level=noise_level,
orth_variant_level=orth_variant_level,
)
example_golds = cls._make_golds(
example_docs,
vocab=nlp.vocab,
make_projective=make_projective,
ignore_misaligned=ignore_misaligned,
)
for ex in example_golds:
if ex.goldparse is not None:
if (not max_length) or len(ex.doc) < max_length:
yield ex
for ex in example_docs:
if (not max_length) or len(ex.doc) < max_length:
if ignore_misaligned:
try:
_ = ex._deprecated_get_gold()
except AlignmentError:
continue
yield ex
@classmethod
def _make_docs(
@ -256,22 +249,3 @@ class GoldCorpus(object):
)
var_example.doc = var_doc
return [var_example]
@classmethod
def _make_golds(
cls, examples, vocab=None, make_projective=False, ignore_misaligned=False
):
filtered_examples = []
for example in examples:
gold_parses = example.get_gold_parses(
vocab=vocab,
make_projective=make_projective,
ignore_misaligned=ignore_misaligned,
)
assert len(gold_parses) == 1
doc, gold = gold_parses[0]
if doc:
assert doc == example.doc
example.goldparse = gold
filtered_examples.append(example)
return filtered_examples

View File

@ -1,36 +1,56 @@
from .annotation import TokenAnnotation, DocAnnotation
from .align import Alignment
from ..errors import Errors, AlignmentError
from ..tokens import Doc
# We're hoping to kill this GoldParse dependency but for now match semantics.
from ..syntax.gold_parse import GoldParse
class Example:
def __init__(
self, doc_annotation=None, token_annotation=None, doc=None, goldparse=None
):
def __init__(self, doc=None, doc_annotation=None, token_annotation=None):
""" Doc can either be text, or an actual Doc """
self.doc = doc
self.doc_annotation = doc_annotation if doc_annotation else DocAnnotation()
self.token_annotation = (
token_annotation if token_annotation else TokenAnnotation()
)
self.goldparse = goldparse
self._alignment = None
@classmethod
def from_gold(cls, goldparse, doc=None):
doc_annotation = DocAnnotation(cats=goldparse.cats, links=goldparse.links)
token_annotation = goldparse.get_token_annotation()
return cls(doc_annotation, token_annotation, doc)
def _deprecated_get_gold(self, make_projective=False):
from ..syntax.gold_parse import get_parses_from_example
_, gold = get_parses_from_example(self, make_projective=make_projective)[0]
return gold
@classmethod
def from_dict(cls, example_dict, doc=None):
if example_dict is None:
raise ValueError("Example.from_dict expected dict, received None")
# TODO: This is ridiculous...
token_dict = example_dict.get("token_annotation", {})
token_annotation = TokenAnnotation.from_dict(token_dict)
doc_dict = example_dict.get("doc_annotation", {})
for key, value in example_dict.items():
if key in ("token_annotation", "doc_annotation"):
pass
elif key in ("cats", "links"):
doc_dict[key] = value
else:
token_dict[key] = value
token_annotation = TokenAnnotation.from_dict(token_dict)
doc_annotation = DocAnnotation.from_dict(doc_dict)
return cls(doc_annotation, token_annotation, doc)
return cls(
doc=doc, doc_annotation=doc_annotation, token_annotation=token_annotation
)
@property
def alignment(self):
if self._alignment is None:
if self.doc is None:
return None
spacy_words = [token.orth_ for token in self.doc]
gold_words = self.token_annotation.words
if gold_words == []:
gold_words = spacy_words
self._alignment = Alignment(spacy_words, gold_words)
return self._alignment
def to_dict(self):
""" Note that this method does NOT export the doc, only the annotations ! """
@ -46,12 +66,31 @@ class Example:
return self.doc.text
return self.doc
@property
def gold(self):
if self.goldparse is None:
doc, gold = self.get_gold_parses()[0]
self.goldparse = gold
return self.goldparse
def get_aligned(self, field):
"""Return an aligned array for a token annotation field."""
if self.doc is None:
return self.token_annotation.get_field(field)
doc = self.doc
if field == "word":
return [token.orth_ for token in doc]
gold_values = self.token_annotation.get_field(field)
alignment = self.alignment
i2j_multi = alignment.i2j_multi
gold_to_cand = alignment.gold_to_cand
cand_to_gold = alignment.cand_to_gold
output = []
for i, gold_i in enumerate(cand_to_gold):
if doc[i].text.isspace():
output.append(None)
elif gold_i is None:
if i in i2j_multi:
output.append(gold_values[i2j_multi[i]])
else:
output.append(None)
else:
output.append(gold_values[gold_i])
return output
def set_token_annotation(
self,
@ -149,55 +188,6 @@ class Example:
split_examples.append(s_example)
return split_examples
def get_gold_parses(
self, merge=True, vocab=None, make_projective=False, ignore_misaligned=False
):
"""Return a list of (doc, GoldParse) objects.
If merge is set to True, keep all Token annotations as one big list."""
d = self.doc_annotation
# merge == do not modify Example
if merge:
t = self.token_annotation
doc = self.doc
if doc is None or not isinstance(doc, Doc):
if not vocab:
raise ValueError(Errors.E998)
doc = Doc(vocab, words=t.words)
try:
gp = GoldParse.from_annotation(
doc, d, t, make_projective=make_projective
)
except AlignmentError:
if ignore_misaligned:
gp = None
else:
raise
return [(doc, gp)]
# not merging: one GoldParse per sentence, defining docs with the words
# from each sentence
else:
parses = []
split_examples = self.split_sents()
for split_example in split_examples:
if not vocab:
raise ValueError(Errors.E998)
split_doc = Doc(vocab, words=split_example.token_annotation.words)
try:
gp = GoldParse.from_annotation(
split_doc,
d,
split_example.token_annotation,
make_projective=make_projective,
)
except AlignmentError:
if ignore_misaligned:
gp = None
else:
raise
if gp is not None:
parses.append((split_doc, gp))
return parses
@classmethod
def to_example_objects(cls, examples, make_doc=None, keep_raw_text=False):
"""
@ -219,29 +209,16 @@ class Example:
else:
doc = make_doc(ex)
converted_examples.append(Example(doc=doc))
# convert Doc to Example
elif isinstance(ex, Doc):
converted_examples.append(Example(doc=ex))
# convert tuples to Example
elif isinstance(ex, tuple) and len(ex) == 2:
doc, gold = ex
gold_dict = {}
# convert string to Doc
if isinstance(doc, str) and not keep_raw_text:
doc = make_doc(doc)
# convert dict to GoldParse
if isinstance(gold, dict):
gold_dict = gold
if doc is not None or gold.get("words", None) is not None:
gold = GoldParse(doc, **gold)
else:
gold = None
if gold is not None:
converted_examples.append(
Example.from_gold(goldparse=gold, doc=doc)
)
else:
raise ValueError(Errors.E999.format(gold_dict=gold_dict))
converted_examples.append(Example.from_dict(gold, doc=doc))
# convert Doc to Example
elif isinstance(ex, Doc):
converted_examples.append(Example(doc=ex))
else:
converted_examples.append(ex)
return converted_examples

View File

@ -3,7 +3,6 @@ import srsly
from .. import util
from ..errors import Warnings
from ..tokens import Token, Doc
from .example import Example
from .iob_utils import biluo_tags_from_offsets
@ -64,6 +63,19 @@ def docs_to_json(docs, id=0, ner_missing_tag="O"):
return json_doc
def read_json_file(loc, docs_filter=None, limit=None):
loc = util.ensure_path(loc)
if loc.is_dir():
for filename in loc.iterdir():
yield from read_json_file(loc / filename, limit=limit)
else:
for doc in json_iterate(loc):
if docs_filter is not None and not docs_filter(doc):
continue
for json_data in json_to_examples(doc):
yield json_data
def json_to_examples(doc):
"""Convert an item in the JSON-formatted training data to the format
used by GoldParse.
@ -72,7 +84,7 @@ def json_to_examples(doc):
YIELDS (Example): The reformatted data - one training example per paragraph
"""
for paragraph in doc["paragraphs"]:
example = Example(doc=paragraph.get("raw", None))
example = {"text": paragraph.get("raw", None)}
words = []
ids = []
tags = []
@ -110,39 +122,23 @@ def json_to_examples(doc):
cats = {}
for cat in paragraph.get("cats", {}):
cats[cat["label"]] = cat["value"]
example.set_token_annotation(ids=ids, words=words, tags=tags,
pos=pos, morphs=morphs, lemmas=lemmas, heads=heads,
deps=labels, entities=ner, sent_starts=sent_starts,
brackets=brackets)
example.set_doc_annotation(cats=cats)
example["token_annotation"] = dict(
ids=ids,
words=words,
tags=tags,
pos=pos,
morphs=morphs,
lemmas=lemmas,
heads=heads,
deps=labels,
entities=ner,
sent_starts=sent_starts,
brackets=brackets
)
example["doc_annotation"] = dict(cats=cats)
yield example
def read_json_file(loc, docs_filter=None, limit=None):
loc = util.ensure_path(loc)
if loc.is_dir():
for filename in loc.iterdir():
yield from read_json_file(loc / filename, limit=limit)
else:
for doc in json_iterate(loc):
if docs_filter is not None and not docs_filter(doc):
continue
for json_data in json_to_examples(doc):
yield json_data
def read_json_object(json_corpus_section):
"""Take a list of JSON-formatted documents (e.g. from an already loaded
training data file) and yield annotations in the GoldParse format.
json_corpus_section (list): The data.
YIELDS (Example): The reformatted data - one training example per paragraph
"""
for json_doc in json_corpus_section:
examples = json_to_examples(json_doc)
for ex in examples:
yield ex
def json_iterate(loc):
# We should've made these files jsonl...But since we didn't, parse out

View File

@ -636,6 +636,7 @@ class Language(object):
examples (iterable): `Example` objects.
YIELDS (tuple): `Example` objects.
"""
# TODO: This is deprecated right?
for name, proc in self.pipeline:
if hasattr(proc, "preprocess_gold"):
examples = proc.preprocess_gold(examples)

View File

@ -92,7 +92,7 @@ class Morphologizer(Tagger):
guesses = scores.argmax(axis=1)
known_labels = numpy.ones((scores.shape[0], 1), dtype="f")
for ex in examples:
gold = ex.gold
gold = ex._deprecated_get_gold()
for i in range(len(gold.morphs)):
pos = gold.pos[i] if i < len(gold.pos) else ""
morph = gold.morphs[i]

View File

@ -373,7 +373,7 @@ class Tagger(Pipe):
def get_loss(self, examples, scores):
loss_func = SequenceCategoricalCrossentropy(names=self.labels)
truths = [eg.gold.tags for eg in examples]
truths = [eg.get_aligned("tag") for eg in examples]
d_scores, loss = loss_func(scores, truths)
if self.model.ops.xp.isnan(loss):
raise ValueError("nan value when computing loss")
@ -560,9 +560,9 @@ class SentenceRecognizer(Tagger):
correct = numpy.zeros((scores.shape[0],), dtype="i")
guesses = scores.argmax(axis=1)
known_labels = numpy.ones((scores.shape[0], 1), dtype="f")
for ex in examples:
gold = ex.gold
for sent_start in gold.sent_starts:
for eg in examples:
sent_starts = eg.get_aligned("sent_start")
for sent_start in sent_starts:
if sent_start is None:
correct[idx] = guesses[idx]
elif sent_start in tag_index:
@ -575,7 +575,7 @@ class SentenceRecognizer(Tagger):
d_scores = scores - to_categorical(correct, n_classes=scores.shape[1])
d_scores *= self.model.ops.asarray(known_labels)
loss = (d_scores**2).sum()
docs = [ex.doc for ex in examples]
docs = [eg.doc for eg in examples]
d_scores = self.model.ops.unflatten(d_scores, [len(d) for d in docs])
return float(loss), d_scores
@ -706,13 +706,13 @@ class MultitaskObjective(Tagger):
cdef int idx = 0
correct = numpy.zeros((scores.shape[0],), dtype="i")
guesses = scores.argmax(axis=1)
golds = [ex.gold for ex in examples]
docs = [ex.doc for ex in examples]
for i, gold in enumerate(golds):
for j in range(len(docs[i])):
# Handels alignment for tokenization differences
token_annotation = gold.get_token_annotation()
label = self.make_label(j, token_annotation)
for i, eg in enumerate(examples):
# Handles alignment for tokenization differences
doc_annots = eg.get_aligned()
for j in range(len(eg.doc)):
tok_annots = {key: values[j] for key, values in tok_annots.items()}
label = self.make_label(j, tok_annots)
if label is None or label not in self.labels:
correct[idx] = guesses[idx]
else:
@ -951,13 +951,12 @@ class TextCategorizer(Pipe):
losses[self.name] += (gradient**2).sum()
def _examples_to_truth(self, examples):
golds = [ex.gold for ex in examples]
truths = numpy.zeros((len(golds), len(self.labels)), dtype="f")
not_missing = numpy.ones((len(golds), len(self.labels)), dtype="f")
for i, gold in enumerate(golds):
truths = numpy.zeros((len(examples), len(self.labels)), dtype="f")
not_missing = numpy.ones((len(examples), len(self.labels)), dtype="f")
for i, eg in enumerate(examples):
for j, label in enumerate(self.labels):
if label in gold.cats:
truths[i, j] = gold.cats[label]
if label in eg.doc_annotation.cats:
truths[i, j] = eg.doc_annotation.cats[label]
else:
not_missing[i, j] = 0.
truths = self.model.ops.asarray(truths)
@ -1160,14 +1159,14 @@ class EntityLinker(Pipe):
# This seems simpler than other ways to get that exact output -- but
# it does run the model twice :(
predictions = self.model.predict(docs)
golds = [ex.gold for ex in examples]
for doc, gold in zip(docs, golds):
for eg in examples:
doc = eg.doc
ents_by_offset = dict()
for ent in doc.ents:
ents_by_offset[(ent.start_char, ent.end_char)] = ent
for entity, kb_dict in gold.links.items():
for entity, kb_dict in eg.doc_annotation.links.items():
if isinstance(entity, str):
entity = literal_eval(entity)
start, end = entity
@ -1188,7 +1187,10 @@ class EntityLinker(Pipe):
raise RuntimeError(Errors.E030)
set_dropout_rate(self.model, drop)
sentence_encodings, bp_context = self.model.begin_update(sentence_docs)
loss, d_scores = self.get_similarity_loss(scores=sentence_encodings, golds=golds)
loss, d_scores = self.get_similarity_loss(
scores=sentence_encodings,
examples=examples
)
bp_context(d_scores)
if sgd is not None:
self.model.finish_update(sgd)
@ -1199,10 +1201,10 @@ class EntityLinker(Pipe):
self.set_annotations(docs, predictions)
return loss
def get_similarity_loss(self, golds, scores):
def get_similarity_loss(self, examples, scores):
entity_encodings = []
for gold in golds:
for entity, kb_dict in gold.links.items():
for eg in examples:
for entity, kb_dict in eg.doc_annotation.links.items():
for kb_id, value in kb_dict.items():
# this loss function assumes we're only using positive examples
if value:
@ -1222,7 +1224,7 @@ class EntityLinker(Pipe):
def get_loss(self, examples, scores):
cats = []
for ex in examples:
for entity, kb_dict in ex.gold.links.items():
for entity, kb_dict in ex.doc_annotation.links.items():
for kb_id, value in kb_dict.items():
cats.append([value])

View File

@ -282,7 +282,7 @@ class Scorer(object):
if isinstance(example, tuple) and len(example) == 2:
doc, gold = example
else:
gold = example.gold
gold = example._deprecated_get_gold()
doc = example.doc
if len(doc) != len(gold):

View File

@ -24,6 +24,57 @@ def is_punct_label(label):
return label == "P" or label.lower() == "punct"
def get_parses_from_example(
eg, merge=True, vocab=None, make_projective=True, ignore_misaligned=False
):
"""Return a list of (doc, GoldParse) objects.
If merge is set to True, keep all Token annotations as one big list."""
d = eg.doc_annotation
# merge == do not modify Example
if merge:
t = eg.token_annotation
doc = eg.doc
if doc is None or not isinstance(doc, Doc):
if not vocab:
raise ValueError(Errors.E998)
doc = Doc(vocab, words=t.words)
try:
gp = GoldParse.from_annotation(
doc, d, t, make_projective=make_projective
)
except AlignmentError:
if ignore_misaligned:
gp = None
else:
raise
return [(doc, gp)]
# not merging: one GoldParse per sentence, defining docs with the words
# from each sentence
else:
parses = []
split_examples = eg.split_sents()
for split_example in split_examples:
if not vocab:
raise ValueError(Errors.E998)
split_doc = Doc(vocab, words=split_example.token_annotation.words)
try:
gp = GoldParse.from_annotation(
split_doc,
d,
split_example.token_annotation,
make_projective=make_projective,
)
except AlignmentError:
if ignore_misaligned:
gp = None
else:
raise
if gp is not None:
parses.append((split_doc, gp))
return parses
cdef class GoldParse:
"""Collection for training annotations.

View File

@ -21,6 +21,7 @@ import warnings
from ..tokens.doc cimport Doc
from .gold_parse cimport GoldParse
from .gold_parse import get_parses_from_example
from ..typedefs cimport weight_t, class_t, hash_t
from ._parser_model cimport alloc_activations, free_activations
from ._parser_model cimport predict_states, arg_max_if_valid
@ -515,8 +516,8 @@ cdef class Parser:
good_golds = []
good_states = []
for i, eg in enumerate(whole_examples):
doc = eg.doc
gold = self.moves.preprocess_gold(eg.gold)
parses = get_parses_from_example(eg)
doc, gold = parses[0]
if gold is not None and self.moves.has_gold(gold):
good_docs.append(doc)
good_golds.append(gold)
@ -535,8 +536,12 @@ cdef class Parser:
cdef:
StateClass state
Transition action
whole_docs = [ex.doc for ex in whole_examples]
whole_golds = [ex.gold for ex in whole_examples]
whole_docs = []
whole_golds = []
for eg in whole_examples:
for doc, gold in get_parses_from_example(eg):
whole_docs.append(doc)
whole_golds.append(gold)
whole_states = self.moves.init_batch(whole_docs)
max_length = max(min_length, min(max_length, min([len(doc) for doc in whole_docs])))
max_moves = 0
@ -625,7 +630,7 @@ cdef class Parser:
doc_sample = []
gold_sample = []
for example in islice(get_examples(), 10):
parses = example.get_gold_parses(merge=False, vocab=self.vocab)
parses = get_parses_from_example(example, merge=False, vocab=self.vocab)
for doc, gold in parses:
if len(doc):
doc_sample.append(doc)

View File

@ -34,7 +34,10 @@ def _train_parser(parser):
for i in range(5):
losses = {}
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
gold = GoldParse(doc, heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"])
gold = {
"heads": [1, 1, 3, 3],
"deps": ["left", "ROOT", "left", "ROOT"]
}
parser.update((doc, gold), sgd=sgd, losses=losses)
return parser
@ -46,9 +49,10 @@ def test_add_label(parser):
for i in range(100):
losses = {}
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
gold = GoldParse(
doc, heads=[1, 1, 3, 3], deps=["right", "ROOT", "left", "ROOT"]
)
gold = {
"heads": [1, 1, 3, 3],
"deps": ["right", "ROOT", "left", "ROOT"]
}
parser.update((doc, gold), sgd=sgd, losses=losses)
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
doc = parser(doc)

View File

@ -46,7 +46,7 @@ def doc(vocab):
@pytest.fixture
def gold(doc):
return GoldParse(doc, heads=[1, 1, 1], deps=["L", "ROOT", "R"])
return {"heads": [1, 1, 1], "deps": ["L", "ROOT", "R"]}
def test_can_init_nn_parser(parser):

View File

@ -1,7 +1,6 @@
import pytest
from thinc.api import Adam
from spacy.attrs import NORM
from spacy.gold import GoldParse
from spacy.vocab import Vocab
from spacy.pipeline.defaults import default_parser
@ -27,7 +26,7 @@ def parser(vocab):
for i in range(10):
losses = {}
doc = Doc(vocab, words=["a", "b", "c", "d"])
gold = GoldParse(doc, heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"])
gold = dict(heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"])
parser.update((doc, gold), sgd=sgd, losses=losses)
return parser

View File

@ -1,9 +1,10 @@
from spacy.errors import AlignmentError
from spacy.gold import biluo_tags_from_offsets, offsets_from_biluo_tags
from spacy.gold import spans_from_biluo_tags, GoldParse, iob_to_biluo, align
from spacy.gold import spans_from_biluo_tags, iob_to_biluo, align
from spacy.gold import GoldCorpus, docs_to_json, Example, DocAnnotation
from spacy.lang.en import English
from spacy.syntax.nonproj import is_nonproj_tree
from spacy.syntax.gold_parse import GoldParse, get_parses_from_example
from spacy.tokens import Doc
from spacy.util import get_words_and_spaces, compounding, minibatch
import pytest
@ -270,10 +271,9 @@ def test_roundtrip_docs_to_json(doc):
srsly.write_json(json_file, [docs_to_json(doc)])
goldcorpus = GoldCorpus(train=str(json_file), dev=str(json_file))
reloaded_example = next(goldcorpus.dev_dataset(nlp))
goldparse = reloaded_example.gold
assert len(doc) == goldcorpus.count_train()
reloaded_example = next(goldcorpus.dev_dataset(nlp=nlp))
goldparse = reloaded_example._deprecated_get_gold()
assert len(doc) == goldcorpus.count_train()
assert text == reloaded_example.text
assert tags == goldparse.tags
assert pos == goldparse.pos
@ -287,54 +287,6 @@ def test_roundtrip_docs_to_json(doc):
assert cats["TRAVEL"] == goldparse.cats["TRAVEL"]
assert cats["BAKING"] == goldparse.cats["BAKING"]
# roundtrip to JSONL train dicts
with make_tempdir() as tmpdir:
jsonl_file = tmpdir / "roundtrip.jsonl"
srsly.write_jsonl(jsonl_file, [docs_to_json(doc)])
goldcorpus = GoldCorpus(str(jsonl_file), str(jsonl_file))
reloaded_example = next(goldcorpus.dev_dataset(nlp))
goldparse = reloaded_example.gold
assert len(doc) == goldcorpus.count_train()
assert text == reloaded_example.text
assert tags == goldparse.tags
assert pos == goldparse.pos
assert morphs == goldparse.morphs
assert lemmas == goldparse.lemmas
assert deps == goldparse.labels
assert heads == goldparse.heads
assert biluo_tags == goldparse.ner
assert "TRAVEL" in goldparse.cats
assert "BAKING" in goldparse.cats
assert cats["TRAVEL"] == goldparse.cats["TRAVEL"]
assert cats["BAKING"] == goldparse.cats["BAKING"]
# roundtrip to JSONL tuples
with make_tempdir() as tmpdir:
jsonl_file = tmpdir / "roundtrip.jsonl"
# write to JSONL train dicts
srsly.write_jsonl(jsonl_file, [docs_to_json(doc)])
goldcorpus = GoldCorpus(str(jsonl_file), str(jsonl_file))
# load and rewrite as JSONL tuples
srsly.write_jsonl(jsonl_file, goldcorpus.train_examples)
goldcorpus = GoldCorpus(str(jsonl_file), str(jsonl_file))
reloaded_example = next(goldcorpus.dev_dataset(nlp))
goldparse = reloaded_example.gold
assert len(doc) == goldcorpus.count_train()
assert text == reloaded_example.text
assert tags == goldparse.tags
assert deps == goldparse.labels
assert heads == goldparse.heads
assert lemmas == goldparse.lemmas
assert biluo_tags == goldparse.ner
assert "TRAVEL" in goldparse.cats
assert "BAKING" in goldparse.cats
assert cats["TRAVEL"] == goldparse.cats["TRAVEL"]
assert cats["BAKING"] == goldparse.cats["BAKING"]
def test_projective_train_vs_nonprojective_dev(doc):
nlp = English()
@ -342,16 +294,16 @@ def test_projective_train_vs_nonprojective_dev(doc):
heads = [t.head.i for t in doc]
with make_tempdir() as tmpdir:
jsonl_file = tmpdir / "test.jsonl"
# write to JSONL train dicts
srsly.write_jsonl(jsonl_file, [docs_to_json(doc)])
goldcorpus = GoldCorpus(str(jsonl_file), str(jsonl_file))
json_file = tmpdir / "test.json"
# write to JSON train dicts
srsly.write_json(json_file, [docs_to_json(doc)])
goldcorpus = GoldCorpus(str(json_file), str(json_file))
train_reloaded_example = next(goldcorpus.train_dataset(nlp))
train_goldparse = train_reloaded_example.gold
train_reloaded_example = next(goldcorpus.train_dataset(nlp))
train_goldparse = get_parses_from_example(train_reloaded_example)[0][1]
dev_reloaded_example = next(goldcorpus.dev_dataset(nlp))
dev_goldparse = dev_reloaded_example.gold
dev_reloaded_example = next(goldcorpus.dev_dataset(nlp))
dev_goldparse = dev_reloaded_example._deprecated_get_gold()
assert is_nonproj_tree([t.head.i for t in doc]) is True
assert is_nonproj_tree(train_goldparse.heads) is False
@ -364,45 +316,49 @@ def test_projective_train_vs_nonprojective_dev(doc):
assert deps == dev_goldparse.labels
# Hm, not sure where misalignment check would be handled? In the components too?
# I guess that does make sense. A text categorizer doesn't care if it's
# misaligned...
@pytest.mark.xfail # TODO
def test_ignore_misaligned(doc):
nlp = English()
text = doc.text
with make_tempdir() as tmpdir:
jsonl_file = tmpdir / "test.jsonl"
json_file = tmpdir / "test.json"
data = [docs_to_json(doc)]
data[0]["paragraphs"][0]["raw"] = text.replace("Sarah", "Jane")
# write to JSONL train dicts
srsly.write_jsonl(jsonl_file, data)
goldcorpus = GoldCorpus(str(jsonl_file), str(jsonl_file))
# write to JSON train dicts
srsly.write_json(json_file, data)
goldcorpus = GoldCorpus(str(json_file), str(json_file))
with pytest.raises(AlignmentError):
train_reloaded_example = next(goldcorpus.train_dataset(nlp))
with pytest.raises(AlignmentError):
train_reloaded_example = next(goldcorpus.train_dataset(nlp))
with make_tempdir() as tmpdir:
jsonl_file = tmpdir / "test.jsonl"
json_file = tmpdir / "test.json"
data = [docs_to_json(doc)]
data[0]["paragraphs"][0]["raw"] = text.replace("Sarah", "Jane")
# write to JSONL train dicts
srsly.write_jsonl(jsonl_file, data)
goldcorpus = GoldCorpus(str(jsonl_file), str(jsonl_file))
# write to JSON train dicts
srsly.write_json(json_file, data)
goldcorpus = GoldCorpus(str(json_file), str(json_file))
# doesn't raise an AlignmentError, but there is nothing to iterate over
# because the only example can't be aligned
train_reloaded_example = list(goldcorpus.train_dataset(nlp, ignore_misaligned=True))
assert len(train_reloaded_example) == 0
# doesn't raise an AlignmentError, but there is nothing to iterate over
# because the only example can't be aligned
train_reloaded_example = list(goldcorpus.train_dataset(nlp, ignore_misaligned=True))
assert len(train_reloaded_example) == 0
def test_make_orth_variants(doc):
nlp = English()
with make_tempdir() as tmpdir:
jsonl_file = tmpdir / "test.jsonl"
# write to JSONL train dicts
srsly.write_jsonl(jsonl_file, [docs_to_json(doc)])
goldcorpus = GoldCorpus(str(jsonl_file), str(jsonl_file))
json_file = tmpdir / "test.json"
# write to JSON train dicts
srsly.write_json(json_file, [docs_to_json(doc)])
goldcorpus = GoldCorpus(str(json_file), str(json_file))
# due to randomness, test only that this runs with no errors for now
train_reloaded_example = next(goldcorpus.train_dataset(nlp, orth_variant_level=0.2))
train_goldparse = train_reloaded_example.gold # noqa: F841
# due to randomness, test only that this runs with no errors for now
train_reloaded_example = next(goldcorpus.train_dataset(nlp, orth_variant_level=0.2))
train_goldparse = train_reloaded_example._deprecated_get_gold()
@pytest.mark.parametrize(
@ -485,6 +441,7 @@ def test_tuple_format_implicit():
_train(train_data)
@pytest.mark.xfail # TODO
def test_tuple_format_implicit_invalid():
"""Test that an error is thrown for an implicit invalid GoldParse field"""
@ -520,8 +477,18 @@ def test_split_sents(merged_dict):
nlp = English()
example = Example()
example.set_token_annotation(**merged_dict)
assert len(example.get_gold_parses(merge=False, vocab=nlp.vocab)) == 2
assert len(example.get_gold_parses(merge=True, vocab=nlp.vocab)) == 1
assert len(get_parses_from_example(
example,
merge=False,
vocab=nlp.vocab,
make_projective=False)
) == 2
assert len(get_parses_from_example(
example,
merge=True,
vocab=nlp.vocab,
make_projective=False
)) == 1
split_examples = example.split_sents()
assert len(split_examples) == 2
@ -557,4 +524,4 @@ def test_empty_example_goldparse():
nlp = English()
doc = nlp("")
example = Example(doc=doc)
assert len(example.get_gold_parses()) == 1
assert len(get_parses_from_example(example)) == 1

View File

@ -19,22 +19,16 @@ def nlp():
return nlp
@pytest.mark.xfail # TODO
def test_language_update(nlp):
text = "hello world"
annots = {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}
wrongkeyannots = {"LABEL": True}
doc = Doc(nlp.vocab, words=text.split(" "))
gold = GoldParse(doc, **annots)
# Update with doc and gold objects
nlp.update((doc, gold))
# Update with text and dict
nlp.update((text, annots))
# Update with doc object and dict
nlp.update((doc, annots))
# Update with text and gold object
nlp.update((text, gold))
# Update with empty doc and gold object
nlp.update((None, gold))
# Update badly
with pytest.raises(ValueError):
nlp.update((doc, None))
@ -44,20 +38,16 @@ def test_language_update(nlp):
def test_language_evaluate(nlp):
text = "hello world"
annots = {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}
annots = {
"doc_annotation": {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}
}
doc = Doc(nlp.vocab, words=text.split(" "))
gold = GoldParse(doc, **annots)
# Evaluate with doc and gold objects
nlp.evaluate([(doc, gold)])
# Evaluate with text and dict
nlp.evaluate([(text, annots)])
# Evaluate with doc object and dict
nlp.evaluate([(doc, annots)])
# Evaluate with text and gold object
nlp.evaluate([(text, gold)])
# Evaluate badly
with pytest.raises(Exception):
nlp.evaluate([text, gold])
nlp.evaluate([text, annots])
def test_evaluate_no_pipe(nlp):