mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Fix Example details for train CLI / pipeline components (#4624)
* Switch to train_dataset() function in train CLI * Fixes for pipe() methods in pipeline components * Don't clobber `examples` variable with `as_example` in pipe() methods * Remove unnecessary traversals of `examples` * Update Parser.pipe() for Examples * Add `as_examples` kwarg to `pipe()` with implementation to return `Example`s * Accept `Doc` or `Example` in `pipe()` with `_get_doc()` (copied from `Pipe`) * Fixes to Example implementation in spacy.gold * Move `make_projective` from an attribute of Example to an argument of `Example.get_gold_parses()` * Head of 0 are not treated as unset * Unset heads are set to self rather than `None` (which causes problems while projectivizing) * Check for `Doc` (not just not `None`) when creating GoldParses for pre-merged example * Don't clobber `examples` variable in `iter_gold_docs()` * Add/modify gold tests for handling projectivity * In JSON roundtrip compare results from `dev_dataset` rather than `train_dataset` to avoid projectivization (and other potential modifications) * Add test for projective train vs. nonprojective dev versions of the same `Doc` * Handle ignore_misaligned as arg rather than attr Move `ignore_misaligned` from an attribute of `Example` to an argument to `Example.get_gold_parses()`, which makes it parallel to `make_projective`. Add test with old and new align that checks whether `ignore_misaligned` errors are raised as expected (only for new align). * Remove unused attrs from gold.pxd Remove `ignore_misaligned` and `make_projective` from `gold.pxd` * Refer to Example.goldparse in iter_gold_docs() Use `Example.goldparse` in `iter_gold_docs()` instead of `Example.gold` because a `None` `GoldParse` is generated with ignore_misaligned and generating it on-the-fly can raise an unwanted AlignmentError * Update test for ignore_misaligned
This commit is contained in:
parent
faaa832518
commit
44829950ba
|
@ -340,7 +340,7 @@ def train(
|
|||
iter_since_best = 0
|
||||
best_score = 0.0
|
||||
for i in range(n_iter):
|
||||
train_data = corpus.train_data(
|
||||
train_data = corpus.train_dataset(
|
||||
nlp,
|
||||
noise_level=noise_level,
|
||||
orth_variant_level=orth_variant_level,
|
||||
|
|
|
@ -58,8 +58,6 @@ cdef class Example:
|
|||
cdef public object doc
|
||||
cdef public list token_annotations
|
||||
cdef public DocAnnotation doc_annotation
|
||||
cdef public object make_projective
|
||||
cdef public object ignore_misaligned
|
||||
cdef public object goldparse
|
||||
|
||||
|
||||
|
|
|
@ -311,47 +311,50 @@ class GoldCorpus(object):
|
|||
ignore_misaligned=ignore_misaligned)
|
||||
yield from gold_examples
|
||||
|
||||
def train_dataset_without_preprocessing(self, nlp, gold_preproc=False):
|
||||
examples = self.iter_gold_docs(nlp, self.train_examples, gold_preproc=gold_preproc)
|
||||
def train_dataset_without_preprocessing(self, nlp, gold_preproc=False,
|
||||
ignore_misaligned=False):
|
||||
examples = self.iter_gold_docs(nlp, self.train_examples,
|
||||
gold_preproc=gold_preproc,
|
||||
ignore_misaligned=ignore_misaligned)
|
||||
yield from examples
|
||||
|
||||
def dev_dataset(self, nlp, gold_preproc=False, ignore_misaligned=False):
|
||||
examples = self.iter_gold_docs(nlp, self.dev_examples, gold_preproc=gold_preproc,
|
||||
ignore_misaligned=ignore_misaligned)
|
||||
examples = self.iter_gold_docs(nlp, self.dev_examples,
|
||||
gold_preproc=gold_preproc,
|
||||
ignore_misaligned=ignore_misaligned)
|
||||
yield from examples
|
||||
|
||||
@classmethod
|
||||
def iter_gold_docs(cls, nlp, examples, gold_preproc, max_length=None,
|
||||
noise_level=0.0, orth_variant_level=0.0, make_projective=False,
|
||||
ignore_misaligned=False):
|
||||
noise_level=0.0, orth_variant_level=0.0,
|
||||
make_projective=False, ignore_misaligned=False):
|
||||
""" Setting gold_preproc will result in creating a doc per 'sentence' """
|
||||
for example in examples:
|
||||
if gold_preproc:
|
||||
example.doc = None
|
||||
else:
|
||||
example = example.merge_sents()
|
||||
example.make_projective = make_projective
|
||||
example.ignore_misaligned = ignore_misaligned
|
||||
examples = cls._make_docs(nlp, example,
|
||||
example_docs = cls._make_docs(nlp, example,
|
||||
gold_preproc, noise_level=noise_level,
|
||||
orth_variant_level=orth_variant_level)
|
||||
examples = cls._make_golds(examples, vocab=nlp.vocab)
|
||||
for ex in examples:
|
||||
if ex.gold is not None:
|
||||
example_golds = cls._make_golds(example_docs, vocab=nlp.vocab,
|
||||
make_projective=make_projective,
|
||||
ignore_misaligned=ignore_misaligned)
|
||||
for ex in example_golds:
|
||||
if ex.goldparse is not None:
|
||||
if (not max_length) or len(ex.doc) < max_length:
|
||||
yield ex
|
||||
|
||||
@classmethod
|
||||
def _make_docs(cls, nlp, example, gold_preproc, noise_level=0.0, orth_variant_level=0.0):
|
||||
var_example = make_orth_variants(nlp, example, orth_variant_level=orth_variant_level)
|
||||
# gold_preproc is not used ?!
|
||||
if example.text is not None:
|
||||
var_example = make_orth_variants(nlp, example, orth_variant_level=orth_variant_level)
|
||||
var_text = add_noise(var_example.text, noise_level)
|
||||
var_doc = nlp.make_doc(var_text)
|
||||
var_example.doc = var_doc
|
||||
return [var_example]
|
||||
else:
|
||||
var_example = make_orth_variants(nlp, example, orth_variant_level=orth_variant_level)
|
||||
doc_examples = []
|
||||
for token_annotation in var_example.token_annotations:
|
||||
t_doc = Doc(nlp.vocab, words=add_noise(token_annotation.words, noise_level))
|
||||
|
@ -362,10 +365,13 @@ class GoldCorpus(object):
|
|||
return doc_examples
|
||||
|
||||
@classmethod
|
||||
def _make_golds(cls, examples, vocab=None):
|
||||
def _make_golds(cls, examples, vocab=None, make_projective=False,
|
||||
ignore_misaligned=False):
|
||||
gold_examples = []
|
||||
for example in examples:
|
||||
gold_parses = example.get_gold_parses(vocab=vocab)
|
||||
gold_parses = example.get_gold_parses(vocab=vocab,
|
||||
make_projective=make_projective,
|
||||
ignore_misaligned=ignore_misaligned)
|
||||
for (doc, gold) in gold_parses:
|
||||
ex = Example(doc=doc)
|
||||
ex.goldparse = gold
|
||||
|
@ -693,13 +699,11 @@ cdef class DocAnnotation:
|
|||
|
||||
cdef class Example:
|
||||
def __init__(self, doc_annotation=None, token_annotations=None, doc=None,
|
||||
make_projective=False, ignore_misaligned=False, goldparse=None):
|
||||
goldparse=None):
|
||||
""" Doc can either be text, or an actual Doc """
|
||||
self.doc = doc
|
||||
self.doc_annotation = doc_annotation if doc_annotation else DocAnnotation()
|
||||
self.token_annotations = token_annotations if token_annotations else []
|
||||
self.make_projective = make_projective
|
||||
self.ignore_misaligned = ignore_misaligned
|
||||
self.goldparse = goldparse
|
||||
|
||||
@classmethod
|
||||
|
@ -760,7 +764,7 @@ cdef class Example:
|
|||
m_ids.extend(id_ + i for id_ in t.ids)
|
||||
m_words.extend(t.words)
|
||||
m_tags.extend(t.tags)
|
||||
m_heads.extend(head + i if head else None for head in t.heads)
|
||||
m_heads.extend(head + i if head is not None and head >= 0 else head_i + i for head_i, head in enumerate(t.heads))
|
||||
m_deps.extend(t.deps)
|
||||
m_ents.extend(t.entities)
|
||||
m_morph.extend(t.morphology)
|
||||
|
@ -773,7 +777,8 @@ cdef class Example:
|
|||
return m_example
|
||||
|
||||
|
||||
def get_gold_parses(self, merge=False, vocab=None):
|
||||
def get_gold_parses(self, merge=False, vocab=None, make_projective=False,
|
||||
ignore_misaligned=False):
|
||||
"""Return a list of (doc, GoldParse) objects.
|
||||
If merge is set to True, add all Token annotations to one big list."""
|
||||
d = self.doc_annotation
|
||||
|
@ -788,20 +793,20 @@ cdef class Example:
|
|||
raise ValueError(Errors.E998)
|
||||
m_doc = Doc(vocab, words=t.words)
|
||||
try:
|
||||
gp = GoldParse.from_annotation(m_doc, d, t, make_projective=self.make_projective)
|
||||
gp = GoldParse.from_annotation(m_doc, d, t, make_projective=make_projective)
|
||||
except AlignmentError:
|
||||
if self.ignore_misaligned:
|
||||
if ignore_misaligned:
|
||||
gp = None
|
||||
else:
|
||||
raise
|
||||
return [(self.doc, gp)]
|
||||
# we only have one sentence and an appropriate doc
|
||||
elif len(self.token_annotations) == 1 and self.doc is not None:
|
||||
elif len(self.token_annotations) == 1 and isinstance(self.doc, Doc):
|
||||
t = self.token_annotations[0]
|
||||
try:
|
||||
gp = GoldParse.from_annotation(self.doc, d, t, make_projective=self.make_projective)
|
||||
gp = GoldParse.from_annotation(self.doc, d, t, make_projective=make_projective)
|
||||
except AlignmentError:
|
||||
if self.ignore_misaligned:
|
||||
if ignore_misaligned:
|
||||
gp = None
|
||||
else:
|
||||
raise
|
||||
|
@ -814,9 +819,9 @@ cdef class Example:
|
|||
raise ValueError(Errors.E998)
|
||||
t_doc = Doc(vocab, words=t.words)
|
||||
try:
|
||||
gp = GoldParse.from_annotation(t_doc, d, t, make_projective=self.make_projective)
|
||||
gp = GoldParse.from_annotation(t_doc, d, t, make_projective=make_projective)
|
||||
except AlignmentError:
|
||||
if self.ignore_misaligned:
|
||||
if ignore_misaligned:
|
||||
gp = None
|
||||
else:
|
||||
raise
|
||||
|
|
|
@ -61,7 +61,7 @@ class Pipe(object):
|
|||
return cls(nlp.vocab, **cfg)
|
||||
|
||||
def _get_doc(self, example):
|
||||
""" Use this method if the `example` method can be both a Doc or an Example """
|
||||
""" Use this method if the `example` can be both a Doc or an Example """
|
||||
if isinstance(example, Doc):
|
||||
return example
|
||||
return example.doc
|
||||
|
@ -102,7 +102,6 @@ class Pipe(object):
|
|||
and `set_annotations()` methods.
|
||||
"""
|
||||
for examples in util.minibatch(stream, size=batch_size):
|
||||
examples = list(examples)
|
||||
docs = [self._get_doc(ex) for ex in examples]
|
||||
predictions = self.predict(docs)
|
||||
if isinstance(predictions, tuple) and len(tuple) == 2:
|
||||
|
@ -112,11 +111,11 @@ class Pipe(object):
|
|||
self.set_annotations(docs, predictions)
|
||||
|
||||
if as_example:
|
||||
examples = []
|
||||
annotated_examples = []
|
||||
for ex, doc in zip(examples, docs):
|
||||
ex.doc = doc
|
||||
examples.append(ex)
|
||||
yield from examples
|
||||
annotated_examples.append(ex)
|
||||
yield from annotated_examples
|
||||
else:
|
||||
yield from docs
|
||||
|
||||
|
@ -312,11 +311,11 @@ class Tensorizer(Pipe):
|
|||
self.set_annotations(docs, tensors)
|
||||
|
||||
if as_example:
|
||||
examples = []
|
||||
annotated_examples = []
|
||||
for ex, doc in zip(examples, docs):
|
||||
ex.doc = doc
|
||||
examples.append(ex)
|
||||
yield from examples
|
||||
annotated_examples.append(ex)
|
||||
yield from annotated_examples
|
||||
else:
|
||||
yield from docs
|
||||
|
||||
|
@ -434,17 +433,16 @@ class Tagger(Pipe):
|
|||
|
||||
def pipe(self, stream, batch_size=128, n_threads=-1, as_example=False):
|
||||
for examples in util.minibatch(stream, size=batch_size):
|
||||
examples = list(examples)
|
||||
docs = [self._get_doc(ex) for ex in examples]
|
||||
tag_ids, tokvecs = self.predict(docs)
|
||||
self.set_annotations(docs, tag_ids, tensors=tokvecs)
|
||||
|
||||
if as_example:
|
||||
examples = []
|
||||
annotated_examples = []
|
||||
for ex, doc in zip(examples, docs):
|
||||
ex.doc = doc
|
||||
examples.append(ex)
|
||||
yield from examples
|
||||
annotated_examples.append(ex)
|
||||
yield from annotated_examples
|
||||
else:
|
||||
yield from docs
|
||||
|
||||
|
@ -1000,17 +998,16 @@ class TextCategorizer(Pipe):
|
|||
|
||||
def pipe(self, stream, batch_size=128, n_threads=-1, as_example=False):
|
||||
for examples in util.minibatch(stream, size=batch_size):
|
||||
examples = list(examples)
|
||||
docs = [self._get_doc(ex) for ex in examples]
|
||||
scores, tensors = self.predict(docs)
|
||||
self.set_annotations(docs, scores, tensors=tensors)
|
||||
|
||||
if as_example:
|
||||
examples = []
|
||||
annotated_examples = []
|
||||
for ex, doc in zip(examples, docs):
|
||||
ex.doc = doc
|
||||
examples.append(ex)
|
||||
yield from examples
|
||||
annotated_examples.append(ex)
|
||||
yield from annotated_examples
|
||||
else:
|
||||
yield from docs
|
||||
|
||||
|
@ -1333,17 +1330,16 @@ class EntityLinker(Pipe):
|
|||
|
||||
def pipe(self, stream, batch_size=128, n_threads=-1, as_example=False):
|
||||
for examples in util.minibatch(stream, size=batch_size):
|
||||
examples = list(examples)
|
||||
docs = [self._get_doc(ex) for ex in examples]
|
||||
kb_ids, tensors = self.predict(docs)
|
||||
self.set_annotations(docs, kb_ids, tensors=tensors)
|
||||
|
||||
if as_example:
|
||||
examples = []
|
||||
annotated_examples = []
|
||||
for ex, doc in zip(examples, docs):
|
||||
ex.doc = doc
|
||||
examples.append(ex)
|
||||
yield from examples
|
||||
annotated_examples.append(ex)
|
||||
yield from annotated_examples
|
||||
else:
|
||||
yield from docs
|
||||
|
||||
|
|
|
@ -227,7 +227,8 @@ cdef class Parser:
|
|||
self.set_annotations([doc], states, tensors=None)
|
||||
return doc
|
||||
|
||||
def pipe(self, docs, int batch_size=256, int n_threads=-1, beam_width=None):
|
||||
def pipe(self, docs, int batch_size=256, int n_threads=-1, beam_width=None,
|
||||
as_example=False):
|
||||
"""Process a stream of documents.
|
||||
|
||||
stream: The sequence of documents to process.
|
||||
|
@ -240,14 +241,21 @@ cdef class Parser:
|
|||
cdef Doc doc
|
||||
for batch in util.minibatch(docs, size=batch_size):
|
||||
batch_in_order = list(batch)
|
||||
by_length = sorted(batch_in_order, key=lambda doc: len(doc))
|
||||
docs = [self._get_doc(ex) for ex in batch_in_order]
|
||||
by_length = sorted(docs, key=lambda doc: len(doc))
|
||||
for subbatch in util.minibatch(by_length, size=max(batch_size//4, 2)):
|
||||
subbatch = list(subbatch)
|
||||
parse_states = self.predict(subbatch, beam_width=beam_width,
|
||||
beam_density=beam_density)
|
||||
self.set_annotations(subbatch, parse_states, tensors=None)
|
||||
for doc in batch_in_order:
|
||||
yield doc
|
||||
if as_example:
|
||||
annotated_examples = []
|
||||
for ex, doc in zip(batch_in_order, docs):
|
||||
ex.doc = doc
|
||||
annotated_examples.append(ex)
|
||||
yield from annotated_examples
|
||||
else:
|
||||
yield from batch_in_order
|
||||
|
||||
def require_model(self):
|
||||
"""Raise an error if the component's model is not initialized."""
|
||||
|
@ -635,6 +643,12 @@ cdef class Parser:
|
|||
self.cfg.update(cfg)
|
||||
return sgd
|
||||
|
||||
def _get_doc(self, example):
|
||||
""" Use this method if the `example` can be both a Doc or an Example """
|
||||
if isinstance(example, Doc):
|
||||
return example
|
||||
return example.doc
|
||||
|
||||
def to_disk(self, path, exclude=tuple(), **kwargs):
|
||||
serializers = {
|
||||
'model': lambda p: (self.model.to_disk(p) if self.model is not True else True),
|
||||
|
|
|
@ -1,16 +1,40 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import spacy
|
||||
from spacy.errors import AlignmentError
|
||||
from spacy.gold import biluo_tags_from_offsets, offsets_from_biluo_tags, Example, DocAnnotation
|
||||
from spacy.gold import spans_from_biluo_tags, GoldParse, iob_to_biluo
|
||||
from spacy.gold import GoldCorpus, docs_to_json, align
|
||||
from spacy.lang.en import English
|
||||
from spacy.syntax.nonproj import is_nonproj_tree
|
||||
from spacy.tokens import Doc
|
||||
from spacy.util import compounding, minibatch
|
||||
from .util import make_tempdir
|
||||
import pytest
|
||||
import srsly
|
||||
|
||||
@pytest.fixture
|
||||
def doc():
|
||||
text = "Sarah's sister flew to Silicon Valley via London."
|
||||
tags = ['NNP', 'POS', 'NN', 'VBD', 'IN', 'NNP', 'NNP', 'IN', 'NNP', '.']
|
||||
# head of '.' is intentionally nonprojective for testing
|
||||
heads = [2, 0, 3, 3, 3, 6, 4, 3, 7, 5]
|
||||
deps = ['poss', 'case', 'nsubj', 'ROOT', 'prep', 'compound', 'pobj', 'prep', 'pobj', 'punct']
|
||||
biluo_tags = ["U-PERSON", "O", "O", "O", "O", "B-LOC", "L-LOC", "O", "U-GPE", "O"]
|
||||
cats = {"TRAVEL": 1.0, "BAKING": 0.0}
|
||||
nlp = English()
|
||||
doc = nlp(text)
|
||||
for i in range(len(tags)):
|
||||
doc[i].tag_ = tags[i]
|
||||
doc[i].dep_ = deps[i]
|
||||
doc[i].head = doc[heads[i]]
|
||||
doc.ents = spans_from_biluo_tags(doc, biluo_tags)
|
||||
doc.cats = cats
|
||||
doc.is_tagged = True
|
||||
doc.is_parsed = True
|
||||
return doc
|
||||
|
||||
|
||||
def test_gold_biluo_U(en_vocab):
|
||||
words = ["I", "flew", "to", "London", "."]
|
||||
|
@ -98,23 +122,14 @@ def test_iob_to_biluo():
|
|||
iob_to_biluo(bad_iob)
|
||||
|
||||
|
||||
def test_roundtrip_docs_to_json():
|
||||
text = "I flew to Silicon Valley via London."
|
||||
tags = ["PRP", "VBD", "IN", "NNP", "NNP", "IN", "NNP", "."]
|
||||
heads = [1, 1, 1, 4, 2, 1, 5, 1]
|
||||
deps = ["nsubj", "ROOT", "prep", "compound", "pobj", "prep", "pobj", "punct"]
|
||||
biluo_tags = ["O", "O", "O", "B-LOC", "L-LOC", "O", "U-GPE", "O"]
|
||||
cats = {"TRAVEL": 1.0, "BAKING": 0.0}
|
||||
def test_roundtrip_docs_to_json(doc):
|
||||
nlp = English()
|
||||
doc = nlp(text)
|
||||
for i in range(len(tags)):
|
||||
doc[i].tag_ = tags[i]
|
||||
doc[i].dep_ = deps[i]
|
||||
doc[i].head = doc[heads[i]]
|
||||
doc.ents = spans_from_biluo_tags(doc, biluo_tags)
|
||||
doc.cats = cats
|
||||
doc.is_tagged = True
|
||||
doc.is_parsed = True
|
||||
text = doc.text
|
||||
tags = [t.tag_ for t in doc]
|
||||
deps = [t.dep_ for t in doc]
|
||||
heads = [t.head.i for t in doc]
|
||||
biluo_tags = iob_to_biluo([t.ent_iob_ + "-" + t.ent_type_ if t.ent_type_ else "O" for t in doc])
|
||||
cats = doc.cats
|
||||
|
||||
# roundtrip to JSON
|
||||
with make_tempdir() as tmpdir:
|
||||
|
@ -122,7 +137,7 @@ def test_roundtrip_docs_to_json():
|
|||
srsly.write_json(json_file, [docs_to_json(doc)])
|
||||
goldcorpus = GoldCorpus(train=str(json_file), dev=str(json_file))
|
||||
|
||||
reloaded_example = next(goldcorpus.train_dataset(nlp))
|
||||
reloaded_example = next(goldcorpus.dev_dataset(nlp))
|
||||
goldparse = reloaded_example.gold
|
||||
|
||||
assert len(doc) == goldcorpus.count_train()
|
||||
|
@ -142,7 +157,7 @@ def test_roundtrip_docs_to_json():
|
|||
srsly.write_jsonl(jsonl_file, [docs_to_json(doc)])
|
||||
goldcorpus = GoldCorpus(str(jsonl_file), str(jsonl_file))
|
||||
|
||||
reloaded_example = next(goldcorpus.train_dataset(nlp))
|
||||
reloaded_example = next(goldcorpus.dev_dataset(nlp))
|
||||
goldparse = reloaded_example.gold
|
||||
|
||||
assert len(doc) == goldcorpus.count_train()
|
||||
|
@ -166,7 +181,7 @@ def test_roundtrip_docs_to_json():
|
|||
srsly.write_jsonl(jsonl_file, goldcorpus.train_examples)
|
||||
goldcorpus = GoldCorpus(str(jsonl_file), str(jsonl_file))
|
||||
|
||||
reloaded_example = next(goldcorpus.train_dataset(nlp))
|
||||
reloaded_example = next(goldcorpus.dev_dataset(nlp))
|
||||
goldparse = reloaded_example.gold
|
||||
|
||||
assert len(doc) == goldcorpus.count_train()
|
||||
|
@ -181,6 +196,83 @@ def test_roundtrip_docs_to_json():
|
|||
assert cats["BAKING"] == goldparse.cats["BAKING"]
|
||||
|
||||
|
||||
def test_projective_train_vs_nonprojective_dev(doc):
|
||||
nlp = English()
|
||||
text = doc.text
|
||||
deps = [t.dep_ for t in doc]
|
||||
heads = [t.head.i for t in doc]
|
||||
|
||||
with make_tempdir() as tmpdir:
|
||||
jsonl_file = tmpdir / "test.jsonl"
|
||||
# write to JSONL train dicts
|
||||
srsly.write_jsonl(jsonl_file, [docs_to_json(doc)])
|
||||
goldcorpus = GoldCorpus(str(jsonl_file), str(jsonl_file))
|
||||
|
||||
train_reloaded_example = next(goldcorpus.train_dataset(nlp))
|
||||
train_goldparse = train_reloaded_example.gold
|
||||
|
||||
dev_reloaded_example = next(goldcorpus.dev_dataset(nlp))
|
||||
dev_goldparse = dev_reloaded_example.gold
|
||||
|
||||
assert is_nonproj_tree([t.head.i for t in doc]) is True
|
||||
assert is_nonproj_tree(train_goldparse.heads) is False
|
||||
assert heads[:-1] == train_goldparse.heads[:-1]
|
||||
assert heads[-1] != train_goldparse.heads[-1]
|
||||
assert deps[:-1] == train_goldparse.labels[:-1]
|
||||
assert deps[-1] != train_goldparse.labels[-1]
|
||||
|
||||
assert heads == dev_goldparse.heads
|
||||
assert deps == dev_goldparse.labels
|
||||
|
||||
|
||||
def test_ignore_misaligned(doc):
|
||||
nlp = English()
|
||||
text = doc.text
|
||||
deps = [t.dep_ for t in doc]
|
||||
heads = [t.head.i for t in doc]
|
||||
|
||||
use_new_align = spacy.gold.USE_NEW_ALIGN
|
||||
|
||||
spacy.gold.USE_NEW_ALIGN = False
|
||||
with make_tempdir() as tmpdir:
|
||||
jsonl_file = tmpdir / "test.jsonl"
|
||||
data = [docs_to_json(doc)]
|
||||
data[0]["paragraphs"][0]["raw"] = text.replace("Sarah", "Jane")
|
||||
# write to JSONL train dicts
|
||||
srsly.write_jsonl(jsonl_file, data)
|
||||
goldcorpus = GoldCorpus(str(jsonl_file), str(jsonl_file))
|
||||
|
||||
train_reloaded_example = next(goldcorpus.train_dataset(nlp))
|
||||
|
||||
spacy.gold.USE_NEW_ALIGN = True
|
||||
with make_tempdir() as tmpdir:
|
||||
jsonl_file = tmpdir / "test.jsonl"
|
||||
data = [docs_to_json(doc)]
|
||||
data[0]["paragraphs"][0]["raw"] = text.replace("Sarah", "Jane")
|
||||
# write to JSONL train dicts
|
||||
srsly.write_jsonl(jsonl_file, data)
|
||||
goldcorpus = GoldCorpus(str(jsonl_file), str(jsonl_file))
|
||||
|
||||
with pytest.raises(AlignmentError):
|
||||
train_reloaded_example = next(goldcorpus.train_dataset(nlp))
|
||||
|
||||
with make_tempdir() as tmpdir:
|
||||
jsonl_file = tmpdir / "test.jsonl"
|
||||
data = [docs_to_json(doc)]
|
||||
data[0]["paragraphs"][0]["raw"] = text.replace("Sarah", "Jane")
|
||||
# write to JSONL train dicts
|
||||
srsly.write_jsonl(jsonl_file, data)
|
||||
goldcorpus = GoldCorpus(str(jsonl_file), str(jsonl_file))
|
||||
|
||||
# doesn't raise an AlignmentError, but there is nothing to iterate over
|
||||
# because the only example can't be aligned
|
||||
train_reloaded_example = list(goldcorpus.train_dataset(nlp,
|
||||
ignore_misaligned=True))
|
||||
assert len(train_reloaded_example) == 0
|
||||
|
||||
spacy.gold.USE_NEW_ALIGN = use_new_align
|
||||
|
||||
|
||||
# xfail while we have backwards-compatible alignment
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.parametrize(
|
||||
|
|
Loading…
Reference in New Issue
Block a user