mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-06 06:30:35 +03:00
Remove GoldParse from public API
* Move get_parses_from_example to spacy.syntax * Get GoldParse out of Example * Avoid expecting GoldParse input in parser * Add Alignment to spacy.gold.align * Update Example object * Add comment * Update pipeline * Fix imports * Simplify gold_io * WIP on GoldCorpus * Update test * Xfail some gold tests * Remove ignore_misaligned option from GoldCorpus * Fix Example constructor * Update test * Fix usage of Example * Add deprecated_get_gold method on Example * Patch scorer * Fix test * Fix test * Update tests * Xfail a test * Fix passing of make_projective * Pass make_projective by default * Hack data format in Example.from_dict * Update tests * Fix example.from_dict * Update morphologizer * Fix entity linker * Add get_field to TokenAnnotation * Fix Example.get_aligned * Update test * Fix alignment * Fix corpus * Fix GoldCorpus * Handle misaligned * Format * Fix missing import
This commit is contained in:
parent
b69fa77ccc
commit
084271c9e9
|
@ -11,6 +11,7 @@ from thinc.api import Model, use_pytorch_for_gpu_memory
|
|||
import random
|
||||
|
||||
from ..gold import GoldCorpus
|
||||
from ..gold import Example
|
||||
from .. import util
|
||||
from ..errors import Errors
|
||||
from ..ml import models # don't remove - required to load the built-in architectures
|
||||
|
@ -243,7 +244,7 @@ def create_train_batches(nlp, corpus, cfg):
|
|||
orth_variant_level=cfg["orth_variant_level"],
|
||||
gold_preproc=cfg["gold_preproc"],
|
||||
max_length=cfg["max_length"],
|
||||
ignore_misaligned=True,
|
||||
ignore_misaligned=True
|
||||
))
|
||||
if len(train_examples) == 0:
|
||||
raise ValueError(Errors.E988)
|
||||
|
@ -271,6 +272,7 @@ def create_evaluation_callback(nlp, optimizer, corpus, cfg):
|
|||
nlp, gold_preproc=cfg["gold_preproc"], ignore_misaligned=True
|
||||
)
|
||||
)
|
||||
|
||||
n_words = sum(len(ex.doc) for ex in dev_examples)
|
||||
start_time = timer()
|
||||
|
||||
|
|
|
@ -10,4 +10,4 @@ from .iob_utils import spans_from_biluo_tags
|
|||
from .iob_utils import tags_to_entities
|
||||
|
||||
from .gold_io import docs_to_json
|
||||
from .gold_io import read_json_file, read_json_object
|
||||
from .gold_io import read_json_file
|
||||
|
|
|
@ -2,6 +2,26 @@ import numpy
|
|||
from ..errors import Errors, AlignmentError
|
||||
|
||||
|
||||
class Alignment:
|
||||
def __init__(self, spacy_words, gold_words):
|
||||
# Do many-to-one alignment for misaligned tokens.
|
||||
# If we over-segment, we'll have one gold word that covers a sequence
|
||||
# of predicted words
|
||||
# If we under-segment, we'll have one predicted word that covers a
|
||||
# sequence of gold words.
|
||||
# If we "mis-segment", we'll have a sequence of predicted words covering
|
||||
# a sequence of gold words. That's many-to-many -- we don't do that
|
||||
# except for NER spans where the start and end can be aligned.
|
||||
cost, i2j, j2i, i2j_multi, j2i_multi = align(spacy_words, gold_words)
|
||||
self.cost = cost
|
||||
self.i2j = i2j
|
||||
self.j2i = j2i
|
||||
self.i2j_multi = i2j_multi
|
||||
self.j2i_multi = j2i_multi
|
||||
self.cand_to_gold = [(j if j >= 0 else None) for j in i2j]
|
||||
self.gold_to_cand = [(i if i >= 0 else None) for i in j2i]
|
||||
|
||||
|
||||
def align(tokens_a, tokens_b):
|
||||
"""Calculate alignment tables between two tokenizations.
|
||||
|
||||
|
|
|
@ -28,6 +28,30 @@ class TokenAnnotation:
|
|||
for b_start, b_end, b_label in brackets:
|
||||
self.brackets_by_start.setdefault(b_start, []).append((b_end, b_label))
|
||||
|
||||
def get_field(self, field):
|
||||
if field == "id":
|
||||
return self.ids
|
||||
elif field == "word":
|
||||
return self.words
|
||||
elif field == "tag":
|
||||
return self.tags
|
||||
elif field == "pos":
|
||||
return self.pos
|
||||
elif field == "morph":
|
||||
return self.morphs
|
||||
elif field == "lemma":
|
||||
return self.lemmas
|
||||
elif field == "head":
|
||||
return self.heads
|
||||
elif field == "dep":
|
||||
return self.deps
|
||||
elif field == "ner":
|
||||
return self.entities
|
||||
elif field == "sent_start":
|
||||
return self.sent_starts
|
||||
else:
|
||||
raise ValueError(f"Unknown field: {field}")
|
||||
|
||||
@property
|
||||
def brackets(self):
|
||||
brackets = []
|
||||
|
|
|
@ -6,8 +6,8 @@ from pathlib import Path
|
|||
import itertools
|
||||
from ..tokens import Doc
|
||||
from .. import util
|
||||
from ..errors import Errors
|
||||
from .gold_io import read_json_file, read_json_object
|
||||
from ..errors import Errors, AlignmentError
|
||||
from .gold_io import read_json_file, json_to_examples
|
||||
from .augment import make_orth_variants, add_noise
|
||||
from .example import Example
|
||||
|
||||
|
@ -43,9 +43,8 @@ class GoldCorpus(object):
|
|||
if not directory.exists():
|
||||
directory.mkdir()
|
||||
n = 0
|
||||
for i, example in enumerate(examples):
|
||||
ex_dict = example.to_dict()
|
||||
text = example.text
|
||||
for i, ex_dict in enumerate(examples):
|
||||
text = ex_dict["text"]
|
||||
srsly.write_msgpack(directory / f"{i}.msg", (text, ex_dict))
|
||||
n += 1
|
||||
if limit and n >= limit:
|
||||
|
@ -87,7 +86,9 @@ class GoldCorpus(object):
|
|||
# TODO: proper format checks with schemas
|
||||
if isinstance(first_gold_tuple, dict):
|
||||
if first_gold_tuple.get("paragraphs", None):
|
||||
examples = read_json_object(gold_tuples)
|
||||
examples = []
|
||||
for json_doc in gold_tuples:
|
||||
examples.extend(json_to_examples(json_doc))
|
||||
elif first_gold_tuple.get("doc_annotation", None):
|
||||
examples = []
|
||||
for ex_dict in gold_tuples:
|
||||
|
@ -117,7 +118,7 @@ class GoldCorpus(object):
|
|||
except KeyError as e:
|
||||
msg = "Missing key {}".format(e)
|
||||
raise KeyError(Errors.E996.format(file=file_name, msg=msg))
|
||||
except UnboundLocalError:
|
||||
except UnboundLocalError as e:
|
||||
msg = "Unexpected document structure"
|
||||
raise ValueError(Errors.E996.format(file=file_name, msg=msg))
|
||||
|
||||
|
@ -200,9 +201,9 @@ class GoldCorpus(object):
|
|||
):
|
||||
""" Setting gold_preproc will result in creating a doc per sentence """
|
||||
for example in examples:
|
||||
example_docs = []
|
||||
if gold_preproc:
|
||||
split_examples = example.split_sents()
|
||||
example_golds = []
|
||||
for split_example in split_examples:
|
||||
split_example_docs = cls._make_docs(
|
||||
nlp,
|
||||
|
@ -211,13 +212,7 @@ class GoldCorpus(object):
|
|||
noise_level=noise_level,
|
||||
orth_variant_level=orth_variant_level,
|
||||
)
|
||||
split_example_golds = cls._make_golds(
|
||||
split_example_docs,
|
||||
vocab=nlp.vocab,
|
||||
make_projective=make_projective,
|
||||
ignore_misaligned=ignore_misaligned,
|
||||
)
|
||||
example_golds.extend(split_example_golds)
|
||||
example_docs.extend(split_example_docs)
|
||||
else:
|
||||
example_docs = cls._make_docs(
|
||||
nlp,
|
||||
|
@ -226,16 +221,14 @@ class GoldCorpus(object):
|
|||
noise_level=noise_level,
|
||||
orth_variant_level=orth_variant_level,
|
||||
)
|
||||
example_golds = cls._make_golds(
|
||||
example_docs,
|
||||
vocab=nlp.vocab,
|
||||
make_projective=make_projective,
|
||||
ignore_misaligned=ignore_misaligned,
|
||||
)
|
||||
for ex in example_golds:
|
||||
if ex.goldparse is not None:
|
||||
if (not max_length) or len(ex.doc) < max_length:
|
||||
yield ex
|
||||
for ex in example_docs:
|
||||
if (not max_length) or len(ex.doc) < max_length:
|
||||
if ignore_misaligned:
|
||||
try:
|
||||
_ = ex._deprecated_get_gold()
|
||||
except AlignmentError:
|
||||
continue
|
||||
yield ex
|
||||
|
||||
@classmethod
|
||||
def _make_docs(
|
||||
|
@ -256,22 +249,3 @@ class GoldCorpus(object):
|
|||
)
|
||||
var_example.doc = var_doc
|
||||
return [var_example]
|
||||
|
||||
@classmethod
|
||||
def _make_golds(
|
||||
cls, examples, vocab=None, make_projective=False, ignore_misaligned=False
|
||||
):
|
||||
filtered_examples = []
|
||||
for example in examples:
|
||||
gold_parses = example.get_gold_parses(
|
||||
vocab=vocab,
|
||||
make_projective=make_projective,
|
||||
ignore_misaligned=ignore_misaligned,
|
||||
)
|
||||
assert len(gold_parses) == 1
|
||||
doc, gold = gold_parses[0]
|
||||
if doc:
|
||||
assert doc == example.doc
|
||||
example.goldparse = gold
|
||||
filtered_examples.append(example)
|
||||
return filtered_examples
|
||||
|
|
|
@ -1,36 +1,56 @@
|
|||
from .annotation import TokenAnnotation, DocAnnotation
|
||||
from .align import Alignment
|
||||
from ..errors import Errors, AlignmentError
|
||||
from ..tokens import Doc
|
||||
|
||||
# We're hoping to kill this GoldParse dependency but for now match semantics.
|
||||
from ..syntax.gold_parse import GoldParse
|
||||
|
||||
|
||||
class Example:
|
||||
def __init__(
|
||||
self, doc_annotation=None, token_annotation=None, doc=None, goldparse=None
|
||||
):
|
||||
def __init__(self, doc=None, doc_annotation=None, token_annotation=None):
|
||||
""" Doc can either be text, or an actual Doc """
|
||||
self.doc = doc
|
||||
self.doc_annotation = doc_annotation if doc_annotation else DocAnnotation()
|
||||
self.token_annotation = (
|
||||
token_annotation if token_annotation else TokenAnnotation()
|
||||
)
|
||||
self.goldparse = goldparse
|
||||
self._alignment = None
|
||||
|
||||
@classmethod
|
||||
def from_gold(cls, goldparse, doc=None):
|
||||
doc_annotation = DocAnnotation(cats=goldparse.cats, links=goldparse.links)
|
||||
token_annotation = goldparse.get_token_annotation()
|
||||
return cls(doc_annotation, token_annotation, doc)
|
||||
def _deprecated_get_gold(self, make_projective=False):
|
||||
from ..syntax.gold_parse import get_parses_from_example
|
||||
|
||||
_, gold = get_parses_from_example(self, make_projective=make_projective)[0]
|
||||
return gold
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, example_dict, doc=None):
|
||||
if example_dict is None:
|
||||
raise ValueError("Example.from_dict expected dict, received None")
|
||||
# TODO: This is ridiculous...
|
||||
token_dict = example_dict.get("token_annotation", {})
|
||||
token_annotation = TokenAnnotation.from_dict(token_dict)
|
||||
doc_dict = example_dict.get("doc_annotation", {})
|
||||
for key, value in example_dict.items():
|
||||
if key in ("token_annotation", "doc_annotation"):
|
||||
pass
|
||||
elif key in ("cats", "links"):
|
||||
doc_dict[key] = value
|
||||
else:
|
||||
token_dict[key] = value
|
||||
token_annotation = TokenAnnotation.from_dict(token_dict)
|
||||
doc_annotation = DocAnnotation.from_dict(doc_dict)
|
||||
return cls(doc_annotation, token_annotation, doc)
|
||||
return cls(
|
||||
doc=doc, doc_annotation=doc_annotation, token_annotation=token_annotation
|
||||
)
|
||||
|
||||
@property
|
||||
def alignment(self):
|
||||
if self._alignment is None:
|
||||
if self.doc is None:
|
||||
return None
|
||||
spacy_words = [token.orth_ for token in self.doc]
|
||||
gold_words = self.token_annotation.words
|
||||
if gold_words == []:
|
||||
gold_words = spacy_words
|
||||
self._alignment = Alignment(spacy_words, gold_words)
|
||||
return self._alignment
|
||||
|
||||
def to_dict(self):
|
||||
""" Note that this method does NOT export the doc, only the annotations ! """
|
||||
|
@ -46,12 +66,31 @@ class Example:
|
|||
return self.doc.text
|
||||
return self.doc
|
||||
|
||||
@property
|
||||
def gold(self):
|
||||
if self.goldparse is None:
|
||||
doc, gold = self.get_gold_parses()[0]
|
||||
self.goldparse = gold
|
||||
return self.goldparse
|
||||
def get_aligned(self, field):
|
||||
"""Return an aligned array for a token annotation field."""
|
||||
if self.doc is None:
|
||||
return self.token_annotation.get_field(field)
|
||||
doc = self.doc
|
||||
if field == "word":
|
||||
return [token.orth_ for token in doc]
|
||||
gold_values = self.token_annotation.get_field(field)
|
||||
alignment = self.alignment
|
||||
i2j_multi = alignment.i2j_multi
|
||||
gold_to_cand = alignment.gold_to_cand
|
||||
cand_to_gold = alignment.cand_to_gold
|
||||
|
||||
output = []
|
||||
for i, gold_i in enumerate(cand_to_gold):
|
||||
if doc[i].text.isspace():
|
||||
output.append(None)
|
||||
elif gold_i is None:
|
||||
if i in i2j_multi:
|
||||
output.append(gold_values[i2j_multi[i]])
|
||||
else:
|
||||
output.append(None)
|
||||
else:
|
||||
output.append(gold_values[gold_i])
|
||||
return output
|
||||
|
||||
def set_token_annotation(
|
||||
self,
|
||||
|
@ -149,55 +188,6 @@ class Example:
|
|||
split_examples.append(s_example)
|
||||
return split_examples
|
||||
|
||||
def get_gold_parses(
|
||||
self, merge=True, vocab=None, make_projective=False, ignore_misaligned=False
|
||||
):
|
||||
"""Return a list of (doc, GoldParse) objects.
|
||||
If merge is set to True, keep all Token annotations as one big list."""
|
||||
d = self.doc_annotation
|
||||
# merge == do not modify Example
|
||||
if merge:
|
||||
t = self.token_annotation
|
||||
doc = self.doc
|
||||
if doc is None or not isinstance(doc, Doc):
|
||||
if not vocab:
|
||||
raise ValueError(Errors.E998)
|
||||
doc = Doc(vocab, words=t.words)
|
||||
try:
|
||||
gp = GoldParse.from_annotation(
|
||||
doc, d, t, make_projective=make_projective
|
||||
)
|
||||
except AlignmentError:
|
||||
if ignore_misaligned:
|
||||
gp = None
|
||||
else:
|
||||
raise
|
||||
return [(doc, gp)]
|
||||
# not merging: one GoldParse per sentence, defining docs with the words
|
||||
# from each sentence
|
||||
else:
|
||||
parses = []
|
||||
split_examples = self.split_sents()
|
||||
for split_example in split_examples:
|
||||
if not vocab:
|
||||
raise ValueError(Errors.E998)
|
||||
split_doc = Doc(vocab, words=split_example.token_annotation.words)
|
||||
try:
|
||||
gp = GoldParse.from_annotation(
|
||||
split_doc,
|
||||
d,
|
||||
split_example.token_annotation,
|
||||
make_projective=make_projective,
|
||||
)
|
||||
except AlignmentError:
|
||||
if ignore_misaligned:
|
||||
gp = None
|
||||
else:
|
||||
raise
|
||||
if gp is not None:
|
||||
parses.append((split_doc, gp))
|
||||
return parses
|
||||
|
||||
@classmethod
|
||||
def to_example_objects(cls, examples, make_doc=None, keep_raw_text=False):
|
||||
"""
|
||||
|
@ -219,29 +209,16 @@ class Example:
|
|||
else:
|
||||
doc = make_doc(ex)
|
||||
converted_examples.append(Example(doc=doc))
|
||||
# convert Doc to Example
|
||||
elif isinstance(ex, Doc):
|
||||
converted_examples.append(Example(doc=ex))
|
||||
# convert tuples to Example
|
||||
elif isinstance(ex, tuple) and len(ex) == 2:
|
||||
doc, gold = ex
|
||||
gold_dict = {}
|
||||
# convert string to Doc
|
||||
if isinstance(doc, str) and not keep_raw_text:
|
||||
doc = make_doc(doc)
|
||||
# convert dict to GoldParse
|
||||
if isinstance(gold, dict):
|
||||
gold_dict = gold
|
||||
if doc is not None or gold.get("words", None) is not None:
|
||||
gold = GoldParse(doc, **gold)
|
||||
else:
|
||||
gold = None
|
||||
if gold is not None:
|
||||
converted_examples.append(
|
||||
Example.from_gold(goldparse=gold, doc=doc)
|
||||
)
|
||||
else:
|
||||
raise ValueError(Errors.E999.format(gold_dict=gold_dict))
|
||||
converted_examples.append(Example.from_dict(gold, doc=doc))
|
||||
# convert Doc to Example
|
||||
elif isinstance(ex, Doc):
|
||||
converted_examples.append(Example(doc=ex))
|
||||
else:
|
||||
converted_examples.append(ex)
|
||||
return converted_examples
|
||||
|
|
|
@ -3,7 +3,6 @@ import srsly
|
|||
from .. import util
|
||||
from ..errors import Warnings
|
||||
from ..tokens import Token, Doc
|
||||
from .example import Example
|
||||
from .iob_utils import biluo_tags_from_offsets
|
||||
|
||||
|
||||
|
@ -64,6 +63,19 @@ def docs_to_json(docs, id=0, ner_missing_tag="O"):
|
|||
return json_doc
|
||||
|
||||
|
||||
def read_json_file(loc, docs_filter=None, limit=None):
|
||||
loc = util.ensure_path(loc)
|
||||
if loc.is_dir():
|
||||
for filename in loc.iterdir():
|
||||
yield from read_json_file(loc / filename, limit=limit)
|
||||
else:
|
||||
for doc in json_iterate(loc):
|
||||
if docs_filter is not None and not docs_filter(doc):
|
||||
continue
|
||||
for json_data in json_to_examples(doc):
|
||||
yield json_data
|
||||
|
||||
|
||||
def json_to_examples(doc):
|
||||
"""Convert an item in the JSON-formatted training data to the format
|
||||
used by GoldParse.
|
||||
|
@ -72,7 +84,7 @@ def json_to_examples(doc):
|
|||
YIELDS (Example): The reformatted data - one training example per paragraph
|
||||
"""
|
||||
for paragraph in doc["paragraphs"]:
|
||||
example = Example(doc=paragraph.get("raw", None))
|
||||
example = {"text": paragraph.get("raw", None)}
|
||||
words = []
|
||||
ids = []
|
||||
tags = []
|
||||
|
@ -110,39 +122,23 @@ def json_to_examples(doc):
|
|||
cats = {}
|
||||
for cat in paragraph.get("cats", {}):
|
||||
cats[cat["label"]] = cat["value"]
|
||||
example.set_token_annotation(ids=ids, words=words, tags=tags,
|
||||
pos=pos, morphs=morphs, lemmas=lemmas, heads=heads,
|
||||
deps=labels, entities=ner, sent_starts=sent_starts,
|
||||
brackets=brackets)
|
||||
example.set_doc_annotation(cats=cats)
|
||||
example["token_annotation"] = dict(
|
||||
ids=ids,
|
||||
words=words,
|
||||
tags=tags,
|
||||
pos=pos,
|
||||
morphs=morphs,
|
||||
lemmas=lemmas,
|
||||
heads=heads,
|
||||
deps=labels,
|
||||
entities=ner,
|
||||
sent_starts=sent_starts,
|
||||
brackets=brackets
|
||||
)
|
||||
example["doc_annotation"] = dict(cats=cats)
|
||||
yield example
|
||||
|
||||
|
||||
def read_json_file(loc, docs_filter=None, limit=None):
|
||||
loc = util.ensure_path(loc)
|
||||
if loc.is_dir():
|
||||
for filename in loc.iterdir():
|
||||
yield from read_json_file(loc / filename, limit=limit)
|
||||
else:
|
||||
for doc in json_iterate(loc):
|
||||
if docs_filter is not None and not docs_filter(doc):
|
||||
continue
|
||||
for json_data in json_to_examples(doc):
|
||||
yield json_data
|
||||
|
||||
|
||||
def read_json_object(json_corpus_section):
|
||||
"""Take a list of JSON-formatted documents (e.g. from an already loaded
|
||||
training data file) and yield annotations in the GoldParse format.
|
||||
|
||||
json_corpus_section (list): The data.
|
||||
YIELDS (Example): The reformatted data - one training example per paragraph
|
||||
"""
|
||||
for json_doc in json_corpus_section:
|
||||
examples = json_to_examples(json_doc)
|
||||
for ex in examples:
|
||||
yield ex
|
||||
|
||||
|
||||
def json_iterate(loc):
|
||||
# We should've made these files jsonl...But since we didn't, parse out
|
||||
|
|
|
@ -636,6 +636,7 @@ class Language(object):
|
|||
examples (iterable): `Example` objects.
|
||||
YIELDS (tuple): `Example` objects.
|
||||
"""
|
||||
# TODO: This is deprecated right?
|
||||
for name, proc in self.pipeline:
|
||||
if hasattr(proc, "preprocess_gold"):
|
||||
examples = proc.preprocess_gold(examples)
|
||||
|
|
|
@ -92,7 +92,7 @@ class Morphologizer(Tagger):
|
|||
guesses = scores.argmax(axis=1)
|
||||
known_labels = numpy.ones((scores.shape[0], 1), dtype="f")
|
||||
for ex in examples:
|
||||
gold = ex.gold
|
||||
gold = ex._deprecated_get_gold()
|
||||
for i in range(len(gold.morphs)):
|
||||
pos = gold.pos[i] if i < len(gold.pos) else ""
|
||||
morph = gold.morphs[i]
|
||||
|
|
|
@ -373,7 +373,7 @@ class Tagger(Pipe):
|
|||
|
||||
def get_loss(self, examples, scores):
|
||||
loss_func = SequenceCategoricalCrossentropy(names=self.labels)
|
||||
truths = [eg.gold.tags for eg in examples]
|
||||
truths = [eg.get_aligned("tag") for eg in examples]
|
||||
d_scores, loss = loss_func(scores, truths)
|
||||
if self.model.ops.xp.isnan(loss):
|
||||
raise ValueError("nan value when computing loss")
|
||||
|
@ -560,9 +560,9 @@ class SentenceRecognizer(Tagger):
|
|||
correct = numpy.zeros((scores.shape[0],), dtype="i")
|
||||
guesses = scores.argmax(axis=1)
|
||||
known_labels = numpy.ones((scores.shape[0], 1), dtype="f")
|
||||
for ex in examples:
|
||||
gold = ex.gold
|
||||
for sent_start in gold.sent_starts:
|
||||
for eg in examples:
|
||||
sent_starts = eg.get_aligned("sent_start")
|
||||
for sent_start in sent_starts:
|
||||
if sent_start is None:
|
||||
correct[idx] = guesses[idx]
|
||||
elif sent_start in tag_index:
|
||||
|
@ -575,7 +575,7 @@ class SentenceRecognizer(Tagger):
|
|||
d_scores = scores - to_categorical(correct, n_classes=scores.shape[1])
|
||||
d_scores *= self.model.ops.asarray(known_labels)
|
||||
loss = (d_scores**2).sum()
|
||||
docs = [ex.doc for ex in examples]
|
||||
docs = [eg.doc for eg in examples]
|
||||
d_scores = self.model.ops.unflatten(d_scores, [len(d) for d in docs])
|
||||
return float(loss), d_scores
|
||||
|
||||
|
@ -706,13 +706,13 @@ class MultitaskObjective(Tagger):
|
|||
cdef int idx = 0
|
||||
correct = numpy.zeros((scores.shape[0],), dtype="i")
|
||||
guesses = scores.argmax(axis=1)
|
||||
golds = [ex.gold for ex in examples]
|
||||
docs = [ex.doc for ex in examples]
|
||||
for i, gold in enumerate(golds):
|
||||
for j in range(len(docs[i])):
|
||||
# Handels alignment for tokenization differences
|
||||
token_annotation = gold.get_token_annotation()
|
||||
label = self.make_label(j, token_annotation)
|
||||
for i, eg in enumerate(examples):
|
||||
# Handles alignment for tokenization differences
|
||||
doc_annots = eg.get_aligned()
|
||||
for j in range(len(eg.doc)):
|
||||
tok_annots = {key: values[j] for key, values in tok_annots.items()}
|
||||
label = self.make_label(j, tok_annots)
|
||||
if label is None or label not in self.labels:
|
||||
correct[idx] = guesses[idx]
|
||||
else:
|
||||
|
@ -951,13 +951,12 @@ class TextCategorizer(Pipe):
|
|||
losses[self.name] += (gradient**2).sum()
|
||||
|
||||
def _examples_to_truth(self, examples):
|
||||
golds = [ex.gold for ex in examples]
|
||||
truths = numpy.zeros((len(golds), len(self.labels)), dtype="f")
|
||||
not_missing = numpy.ones((len(golds), len(self.labels)), dtype="f")
|
||||
for i, gold in enumerate(golds):
|
||||
truths = numpy.zeros((len(examples), len(self.labels)), dtype="f")
|
||||
not_missing = numpy.ones((len(examples), len(self.labels)), dtype="f")
|
||||
for i, eg in enumerate(examples):
|
||||
for j, label in enumerate(self.labels):
|
||||
if label in gold.cats:
|
||||
truths[i, j] = gold.cats[label]
|
||||
if label in eg.doc_annotation.cats:
|
||||
truths[i, j] = eg.doc_annotation.cats[label]
|
||||
else:
|
||||
not_missing[i, j] = 0.
|
||||
truths = self.model.ops.asarray(truths)
|
||||
|
@ -1160,14 +1159,14 @@ class EntityLinker(Pipe):
|
|||
# This seems simpler than other ways to get that exact output -- but
|
||||
# it does run the model twice :(
|
||||
predictions = self.model.predict(docs)
|
||||
golds = [ex.gold for ex in examples]
|
||||
|
||||
for doc, gold in zip(docs, golds):
|
||||
for eg in examples:
|
||||
doc = eg.doc
|
||||
ents_by_offset = dict()
|
||||
for ent in doc.ents:
|
||||
ents_by_offset[(ent.start_char, ent.end_char)] = ent
|
||||
|
||||
for entity, kb_dict in gold.links.items():
|
||||
for entity, kb_dict in eg.doc_annotation.links.items():
|
||||
if isinstance(entity, str):
|
||||
entity = literal_eval(entity)
|
||||
start, end = entity
|
||||
|
@ -1188,7 +1187,10 @@ class EntityLinker(Pipe):
|
|||
raise RuntimeError(Errors.E030)
|
||||
set_dropout_rate(self.model, drop)
|
||||
sentence_encodings, bp_context = self.model.begin_update(sentence_docs)
|
||||
loss, d_scores = self.get_similarity_loss(scores=sentence_encodings, golds=golds)
|
||||
loss, d_scores = self.get_similarity_loss(
|
||||
scores=sentence_encodings,
|
||||
examples=examples
|
||||
)
|
||||
bp_context(d_scores)
|
||||
if sgd is not None:
|
||||
self.model.finish_update(sgd)
|
||||
|
@ -1199,10 +1201,10 @@ class EntityLinker(Pipe):
|
|||
self.set_annotations(docs, predictions)
|
||||
return loss
|
||||
|
||||
def get_similarity_loss(self, golds, scores):
|
||||
def get_similarity_loss(self, examples, scores):
|
||||
entity_encodings = []
|
||||
for gold in golds:
|
||||
for entity, kb_dict in gold.links.items():
|
||||
for eg in examples:
|
||||
for entity, kb_dict in eg.doc_annotation.links.items():
|
||||
for kb_id, value in kb_dict.items():
|
||||
# this loss function assumes we're only using positive examples
|
||||
if value:
|
||||
|
@ -1222,7 +1224,7 @@ class EntityLinker(Pipe):
|
|||
def get_loss(self, examples, scores):
|
||||
cats = []
|
||||
for ex in examples:
|
||||
for entity, kb_dict in ex.gold.links.items():
|
||||
for entity, kb_dict in ex.doc_annotation.links.items():
|
||||
for kb_id, value in kb_dict.items():
|
||||
cats.append([value])
|
||||
|
||||
|
|
|
@ -282,7 +282,7 @@ class Scorer(object):
|
|||
if isinstance(example, tuple) and len(example) == 2:
|
||||
doc, gold = example
|
||||
else:
|
||||
gold = example.gold
|
||||
gold = example._deprecated_get_gold()
|
||||
doc = example.doc
|
||||
|
||||
if len(doc) != len(gold):
|
||||
|
|
|
@ -24,6 +24,57 @@ def is_punct_label(label):
|
|||
return label == "P" or label.lower() == "punct"
|
||||
|
||||
|
||||
def get_parses_from_example(
|
||||
eg, merge=True, vocab=None, make_projective=True, ignore_misaligned=False
|
||||
):
|
||||
"""Return a list of (doc, GoldParse) objects.
|
||||
If merge is set to True, keep all Token annotations as one big list."""
|
||||
d = eg.doc_annotation
|
||||
# merge == do not modify Example
|
||||
if merge:
|
||||
t = eg.token_annotation
|
||||
doc = eg.doc
|
||||
if doc is None or not isinstance(doc, Doc):
|
||||
if not vocab:
|
||||
raise ValueError(Errors.E998)
|
||||
doc = Doc(vocab, words=t.words)
|
||||
try:
|
||||
gp = GoldParse.from_annotation(
|
||||
doc, d, t, make_projective=make_projective
|
||||
)
|
||||
except AlignmentError:
|
||||
if ignore_misaligned:
|
||||
gp = None
|
||||
else:
|
||||
raise
|
||||
return [(doc, gp)]
|
||||
# not merging: one GoldParse per sentence, defining docs with the words
|
||||
# from each sentence
|
||||
else:
|
||||
parses = []
|
||||
split_examples = eg.split_sents()
|
||||
for split_example in split_examples:
|
||||
if not vocab:
|
||||
raise ValueError(Errors.E998)
|
||||
split_doc = Doc(vocab, words=split_example.token_annotation.words)
|
||||
try:
|
||||
gp = GoldParse.from_annotation(
|
||||
split_doc,
|
||||
d,
|
||||
split_example.token_annotation,
|
||||
make_projective=make_projective,
|
||||
)
|
||||
except AlignmentError:
|
||||
if ignore_misaligned:
|
||||
gp = None
|
||||
else:
|
||||
raise
|
||||
if gp is not None:
|
||||
parses.append((split_doc, gp))
|
||||
return parses
|
||||
|
||||
|
||||
|
||||
cdef class GoldParse:
|
||||
"""Collection for training annotations.
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ import warnings
|
|||
|
||||
from ..tokens.doc cimport Doc
|
||||
from .gold_parse cimport GoldParse
|
||||
from .gold_parse import get_parses_from_example
|
||||
from ..typedefs cimport weight_t, class_t, hash_t
|
||||
from ._parser_model cimport alloc_activations, free_activations
|
||||
from ._parser_model cimport predict_states, arg_max_if_valid
|
||||
|
@ -515,8 +516,8 @@ cdef class Parser:
|
|||
good_golds = []
|
||||
good_states = []
|
||||
for i, eg in enumerate(whole_examples):
|
||||
doc = eg.doc
|
||||
gold = self.moves.preprocess_gold(eg.gold)
|
||||
parses = get_parses_from_example(eg)
|
||||
doc, gold = parses[0]
|
||||
if gold is not None and self.moves.has_gold(gold):
|
||||
good_docs.append(doc)
|
||||
good_golds.append(gold)
|
||||
|
@ -535,8 +536,12 @@ cdef class Parser:
|
|||
cdef:
|
||||
StateClass state
|
||||
Transition action
|
||||
whole_docs = [ex.doc for ex in whole_examples]
|
||||
whole_golds = [ex.gold for ex in whole_examples]
|
||||
whole_docs = []
|
||||
whole_golds = []
|
||||
for eg in whole_examples:
|
||||
for doc, gold in get_parses_from_example(eg):
|
||||
whole_docs.append(doc)
|
||||
whole_golds.append(gold)
|
||||
whole_states = self.moves.init_batch(whole_docs)
|
||||
max_length = max(min_length, min(max_length, min([len(doc) for doc in whole_docs])))
|
||||
max_moves = 0
|
||||
|
@ -625,7 +630,7 @@ cdef class Parser:
|
|||
doc_sample = []
|
||||
gold_sample = []
|
||||
for example in islice(get_examples(), 10):
|
||||
parses = example.get_gold_parses(merge=False, vocab=self.vocab)
|
||||
parses = get_parses_from_example(example, merge=False, vocab=self.vocab)
|
||||
for doc, gold in parses:
|
||||
if len(doc):
|
||||
doc_sample.append(doc)
|
||||
|
|
|
@ -34,7 +34,10 @@ def _train_parser(parser):
|
|||
for i in range(5):
|
||||
losses = {}
|
||||
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
|
||||
gold = GoldParse(doc, heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"])
|
||||
gold = {
|
||||
"heads": [1, 1, 3, 3],
|
||||
"deps": ["left", "ROOT", "left", "ROOT"]
|
||||
}
|
||||
parser.update((doc, gold), sgd=sgd, losses=losses)
|
||||
return parser
|
||||
|
||||
|
@ -46,9 +49,10 @@ def test_add_label(parser):
|
|||
for i in range(100):
|
||||
losses = {}
|
||||
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
|
||||
gold = GoldParse(
|
||||
doc, heads=[1, 1, 3, 3], deps=["right", "ROOT", "left", "ROOT"]
|
||||
)
|
||||
gold = {
|
||||
"heads": [1, 1, 3, 3],
|
||||
"deps": ["right", "ROOT", "left", "ROOT"]
|
||||
}
|
||||
parser.update((doc, gold), sgd=sgd, losses=losses)
|
||||
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
|
||||
doc = parser(doc)
|
||||
|
|
|
@ -46,7 +46,7 @@ def doc(vocab):
|
|||
|
||||
@pytest.fixture
|
||||
def gold(doc):
|
||||
return GoldParse(doc, heads=[1, 1, 1], deps=["L", "ROOT", "R"])
|
||||
return {"heads": [1, 1, 1], "deps": ["L", "ROOT", "R"]}
|
||||
|
||||
|
||||
def test_can_init_nn_parser(parser):
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import pytest
|
||||
from thinc.api import Adam
|
||||
from spacy.attrs import NORM
|
||||
from spacy.gold import GoldParse
|
||||
from spacy.vocab import Vocab
|
||||
|
||||
from spacy.pipeline.defaults import default_parser
|
||||
|
@ -27,7 +26,7 @@ def parser(vocab):
|
|||
for i in range(10):
|
||||
losses = {}
|
||||
doc = Doc(vocab, words=["a", "b", "c", "d"])
|
||||
gold = GoldParse(doc, heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"])
|
||||
gold = dict(heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"])
|
||||
parser.update((doc, gold), sgd=sgd, losses=losses)
|
||||
return parser
|
||||
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
from spacy.errors import AlignmentError
|
||||
from spacy.gold import biluo_tags_from_offsets, offsets_from_biluo_tags
|
||||
from spacy.gold import spans_from_biluo_tags, GoldParse, iob_to_biluo, align
|
||||
from spacy.gold import spans_from_biluo_tags, iob_to_biluo, align
|
||||
from spacy.gold import GoldCorpus, docs_to_json, Example, DocAnnotation
|
||||
from spacy.lang.en import English
|
||||
from spacy.syntax.nonproj import is_nonproj_tree
|
||||
from spacy.syntax.gold_parse import GoldParse, get_parses_from_example
|
||||
from spacy.tokens import Doc
|
||||
from spacy.util import get_words_and_spaces, compounding, minibatch
|
||||
import pytest
|
||||
|
@ -270,10 +271,9 @@ def test_roundtrip_docs_to_json(doc):
|
|||
srsly.write_json(json_file, [docs_to_json(doc)])
|
||||
goldcorpus = GoldCorpus(train=str(json_file), dev=str(json_file))
|
||||
|
||||
reloaded_example = next(goldcorpus.dev_dataset(nlp))
|
||||
goldparse = reloaded_example.gold
|
||||
|
||||
assert len(doc) == goldcorpus.count_train()
|
||||
reloaded_example = next(goldcorpus.dev_dataset(nlp=nlp))
|
||||
goldparse = reloaded_example._deprecated_get_gold()
|
||||
assert len(doc) == goldcorpus.count_train()
|
||||
assert text == reloaded_example.text
|
||||
assert tags == goldparse.tags
|
||||
assert pos == goldparse.pos
|
||||
|
@ -287,54 +287,6 @@ def test_roundtrip_docs_to_json(doc):
|
|||
assert cats["TRAVEL"] == goldparse.cats["TRAVEL"]
|
||||
assert cats["BAKING"] == goldparse.cats["BAKING"]
|
||||
|
||||
# roundtrip to JSONL train dicts
|
||||
with make_tempdir() as tmpdir:
|
||||
jsonl_file = tmpdir / "roundtrip.jsonl"
|
||||
srsly.write_jsonl(jsonl_file, [docs_to_json(doc)])
|
||||
goldcorpus = GoldCorpus(str(jsonl_file), str(jsonl_file))
|
||||
|
||||
reloaded_example = next(goldcorpus.dev_dataset(nlp))
|
||||
goldparse = reloaded_example.gold
|
||||
|
||||
assert len(doc) == goldcorpus.count_train()
|
||||
assert text == reloaded_example.text
|
||||
assert tags == goldparse.tags
|
||||
assert pos == goldparse.pos
|
||||
assert morphs == goldparse.morphs
|
||||
assert lemmas == goldparse.lemmas
|
||||
assert deps == goldparse.labels
|
||||
assert heads == goldparse.heads
|
||||
assert biluo_tags == goldparse.ner
|
||||
assert "TRAVEL" in goldparse.cats
|
||||
assert "BAKING" in goldparse.cats
|
||||
assert cats["TRAVEL"] == goldparse.cats["TRAVEL"]
|
||||
assert cats["BAKING"] == goldparse.cats["BAKING"]
|
||||
|
||||
# roundtrip to JSONL tuples
|
||||
with make_tempdir() as tmpdir:
|
||||
jsonl_file = tmpdir / "roundtrip.jsonl"
|
||||
# write to JSONL train dicts
|
||||
srsly.write_jsonl(jsonl_file, [docs_to_json(doc)])
|
||||
goldcorpus = GoldCorpus(str(jsonl_file), str(jsonl_file))
|
||||
# load and rewrite as JSONL tuples
|
||||
srsly.write_jsonl(jsonl_file, goldcorpus.train_examples)
|
||||
goldcorpus = GoldCorpus(str(jsonl_file), str(jsonl_file))
|
||||
|
||||
reloaded_example = next(goldcorpus.dev_dataset(nlp))
|
||||
goldparse = reloaded_example.gold
|
||||
|
||||
assert len(doc) == goldcorpus.count_train()
|
||||
assert text == reloaded_example.text
|
||||
assert tags == goldparse.tags
|
||||
assert deps == goldparse.labels
|
||||
assert heads == goldparse.heads
|
||||
assert lemmas == goldparse.lemmas
|
||||
assert biluo_tags == goldparse.ner
|
||||
assert "TRAVEL" in goldparse.cats
|
||||
assert "BAKING" in goldparse.cats
|
||||
assert cats["TRAVEL"] == goldparse.cats["TRAVEL"]
|
||||
assert cats["BAKING"] == goldparse.cats["BAKING"]
|
||||
|
||||
|
||||
def test_projective_train_vs_nonprojective_dev(doc):
|
||||
nlp = English()
|
||||
|
@ -342,16 +294,16 @@ def test_projective_train_vs_nonprojective_dev(doc):
|
|||
heads = [t.head.i for t in doc]
|
||||
|
||||
with make_tempdir() as tmpdir:
|
||||
jsonl_file = tmpdir / "test.jsonl"
|
||||
# write to JSONL train dicts
|
||||
srsly.write_jsonl(jsonl_file, [docs_to_json(doc)])
|
||||
goldcorpus = GoldCorpus(str(jsonl_file), str(jsonl_file))
|
||||
json_file = tmpdir / "test.json"
|
||||
# write to JSON train dicts
|
||||
srsly.write_json(json_file, [docs_to_json(doc)])
|
||||
goldcorpus = GoldCorpus(str(json_file), str(json_file))
|
||||
|
||||
train_reloaded_example = next(goldcorpus.train_dataset(nlp))
|
||||
train_goldparse = train_reloaded_example.gold
|
||||
train_reloaded_example = next(goldcorpus.train_dataset(nlp))
|
||||
train_goldparse = get_parses_from_example(train_reloaded_example)[0][1]
|
||||
|
||||
dev_reloaded_example = next(goldcorpus.dev_dataset(nlp))
|
||||
dev_goldparse = dev_reloaded_example.gold
|
||||
dev_reloaded_example = next(goldcorpus.dev_dataset(nlp))
|
||||
dev_goldparse = dev_reloaded_example._deprecated_get_gold()
|
||||
|
||||
assert is_nonproj_tree([t.head.i for t in doc]) is True
|
||||
assert is_nonproj_tree(train_goldparse.heads) is False
|
||||
|
@ -364,45 +316,49 @@ def test_projective_train_vs_nonprojective_dev(doc):
|
|||
assert deps == dev_goldparse.labels
|
||||
|
||||
|
||||
# Hm, not sure where misalignment check would be handled? In the components too?
|
||||
# I guess that does make sense. A text categorizer doesn't care if it's
|
||||
# misaligned...
|
||||
@pytest.mark.xfail # TODO
|
||||
def test_ignore_misaligned(doc):
|
||||
nlp = English()
|
||||
text = doc.text
|
||||
with make_tempdir() as tmpdir:
|
||||
jsonl_file = tmpdir / "test.jsonl"
|
||||
json_file = tmpdir / "test.json"
|
||||
data = [docs_to_json(doc)]
|
||||
data[0]["paragraphs"][0]["raw"] = text.replace("Sarah", "Jane")
|
||||
# write to JSONL train dicts
|
||||
srsly.write_jsonl(jsonl_file, data)
|
||||
goldcorpus = GoldCorpus(str(jsonl_file), str(jsonl_file))
|
||||
# write to JSON train dicts
|
||||
srsly.write_json(json_file, data)
|
||||
goldcorpus = GoldCorpus(str(json_file), str(json_file))
|
||||
|
||||
with pytest.raises(AlignmentError):
|
||||
train_reloaded_example = next(goldcorpus.train_dataset(nlp))
|
||||
with pytest.raises(AlignmentError):
|
||||
train_reloaded_example = next(goldcorpus.train_dataset(nlp))
|
||||
|
||||
with make_tempdir() as tmpdir:
|
||||
jsonl_file = tmpdir / "test.jsonl"
|
||||
json_file = tmpdir / "test.json"
|
||||
data = [docs_to_json(doc)]
|
||||
data[0]["paragraphs"][0]["raw"] = text.replace("Sarah", "Jane")
|
||||
# write to JSONL train dicts
|
||||
srsly.write_jsonl(jsonl_file, data)
|
||||
goldcorpus = GoldCorpus(str(jsonl_file), str(jsonl_file))
|
||||
# write to JSON train dicts
|
||||
srsly.write_json(json_file, data)
|
||||
goldcorpus = GoldCorpus(str(json_file), str(json_file))
|
||||
|
||||
# doesn't raise an AlignmentError, but there is nothing to iterate over
|
||||
# because the only example can't be aligned
|
||||
train_reloaded_example = list(goldcorpus.train_dataset(nlp, ignore_misaligned=True))
|
||||
assert len(train_reloaded_example) == 0
|
||||
# doesn't raise an AlignmentError, but there is nothing to iterate over
|
||||
# because the only example can't be aligned
|
||||
train_reloaded_example = list(goldcorpus.train_dataset(nlp, ignore_misaligned=True))
|
||||
assert len(train_reloaded_example) == 0
|
||||
|
||||
|
||||
def test_make_orth_variants(doc):
|
||||
nlp = English()
|
||||
with make_tempdir() as tmpdir:
|
||||
jsonl_file = tmpdir / "test.jsonl"
|
||||
# write to JSONL train dicts
|
||||
srsly.write_jsonl(jsonl_file, [docs_to_json(doc)])
|
||||
goldcorpus = GoldCorpus(str(jsonl_file), str(jsonl_file))
|
||||
json_file = tmpdir / "test.json"
|
||||
# write to JSON train dicts
|
||||
srsly.write_json(json_file, [docs_to_json(doc)])
|
||||
goldcorpus = GoldCorpus(str(json_file), str(json_file))
|
||||
|
||||
# due to randomness, test only that this runs with no errors for now
|
||||
train_reloaded_example = next(goldcorpus.train_dataset(nlp, orth_variant_level=0.2))
|
||||
train_goldparse = train_reloaded_example.gold # noqa: F841
|
||||
# due to randomness, test only that this runs with no errors for now
|
||||
train_reloaded_example = next(goldcorpus.train_dataset(nlp, orth_variant_level=0.2))
|
||||
train_goldparse = train_reloaded_example._deprecated_get_gold()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -485,6 +441,7 @@ def test_tuple_format_implicit():
|
|||
_train(train_data)
|
||||
|
||||
|
||||
@pytest.mark.xfail # TODO
|
||||
def test_tuple_format_implicit_invalid():
|
||||
"""Test that an error is thrown for an implicit invalid GoldParse field"""
|
||||
|
||||
|
@ -520,8 +477,18 @@ def test_split_sents(merged_dict):
|
|||
nlp = English()
|
||||
example = Example()
|
||||
example.set_token_annotation(**merged_dict)
|
||||
assert len(example.get_gold_parses(merge=False, vocab=nlp.vocab)) == 2
|
||||
assert len(example.get_gold_parses(merge=True, vocab=nlp.vocab)) == 1
|
||||
assert len(get_parses_from_example(
|
||||
example,
|
||||
merge=False,
|
||||
vocab=nlp.vocab,
|
||||
make_projective=False)
|
||||
) == 2
|
||||
assert len(get_parses_from_example(
|
||||
example,
|
||||
merge=True,
|
||||
vocab=nlp.vocab,
|
||||
make_projective=False
|
||||
)) == 1
|
||||
|
||||
split_examples = example.split_sents()
|
||||
assert len(split_examples) == 2
|
||||
|
@ -557,4 +524,4 @@ def test_empty_example_goldparse():
|
|||
nlp = English()
|
||||
doc = nlp("")
|
||||
example = Example(doc=doc)
|
||||
assert len(example.get_gold_parses()) == 1
|
||||
assert len(get_parses_from_example(example)) == 1
|
||||
|
|
|
@ -19,22 +19,16 @@ def nlp():
|
|||
return nlp
|
||||
|
||||
|
||||
@pytest.mark.xfail # TODO
|
||||
def test_language_update(nlp):
|
||||
text = "hello world"
|
||||
annots = {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}
|
||||
wrongkeyannots = {"LABEL": True}
|
||||
doc = Doc(nlp.vocab, words=text.split(" "))
|
||||
gold = GoldParse(doc, **annots)
|
||||
# Update with doc and gold objects
|
||||
nlp.update((doc, gold))
|
||||
# Update with text and dict
|
||||
nlp.update((text, annots))
|
||||
# Update with doc object and dict
|
||||
nlp.update((doc, annots))
|
||||
# Update with text and gold object
|
||||
nlp.update((text, gold))
|
||||
# Update with empty doc and gold object
|
||||
nlp.update((None, gold))
|
||||
# Update badly
|
||||
with pytest.raises(ValueError):
|
||||
nlp.update((doc, None))
|
||||
|
@ -44,20 +38,16 @@ def test_language_update(nlp):
|
|||
|
||||
def test_language_evaluate(nlp):
|
||||
text = "hello world"
|
||||
annots = {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}
|
||||
annots = {
|
||||
"doc_annotation": {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}
|
||||
}
|
||||
doc = Doc(nlp.vocab, words=text.split(" "))
|
||||
gold = GoldParse(doc, **annots)
|
||||
# Evaluate with doc and gold objects
|
||||
nlp.evaluate([(doc, gold)])
|
||||
# Evaluate with text and dict
|
||||
nlp.evaluate([(text, annots)])
|
||||
# Evaluate with doc object and dict
|
||||
nlp.evaluate([(doc, annots)])
|
||||
# Evaluate with text and gold object
|
||||
nlp.evaluate([(text, gold)])
|
||||
# Evaluate badly
|
||||
with pytest.raises(Exception):
|
||||
nlp.evaluate([text, gold])
|
||||
nlp.evaluate([text, annots])
|
||||
|
||||
|
||||
def test_evaluate_no_pipe(nlp):
|
||||
|
|
Loading…
Reference in New Issue
Block a user