allow missing information in deps and heads annotations

This commit is contained in:
svlandeg 2021-01-07 19:10:32 +01:00
parent 1abeca90a6
commit dd12c6c8fd
8 changed files with 71 additions and 15 deletions

View File

@ -195,8 +195,7 @@ cdef class ArcEagerGold:
def __init__(self, ArcEager moves, StateClass stcls, Example example):
self.mem = Pool()
heads, labels = example.get_aligned_parse(projectivize=True)
labels = [label if label is not None else "" for label in labels]
labels = [example.x.vocab.strings.add(label) for label in labels]
labels = [example.x.vocab.strings.add(label) if label is not None else 0 for label in labels]
sent_starts = example.get_aligned_sent_starts()
assert len(heads) == len(labels) == len(sent_starts), (len(heads), len(labels), len(sent_starts))
self.c = create_gold_state(self.mem, stcls.c, heads, labels, sent_starts)
@ -783,7 +782,7 @@ cdef class ArcEager(TransitionSystem):
for i in range(self.n_moves):
print(self.get_class_name(i), is_valid[i], costs[i])
print("Gold sent starts?", is_sent_start(&gold_state, state.B(0)), is_sent_start(&gold_state, state.B(1)))
raise ValueError
raise ValueError("Could not find gold transition - see logs above.")
def get_oracle_sequence_from_state(self, StateClass state, ArcEagerGold gold, _debug=None):
cdef int i

View File

@ -467,3 +467,4 @@ cdef enum symbol_t:
IDX
_
MISSING_LABEL

View File

@ -466,6 +466,7 @@ IDS = {
"LAW": LAW,
"MORPH": MORPH,
"_": _,
"MISSING_LABEL": MISSING_LABEL,
}

View File

@ -98,10 +98,16 @@ def test_doc_from_array_heads_in_bounds(en_vocab):
doc_from_array = Doc(en_vocab, words=words)
doc_from_array.from_array(["HEAD"], arr)
# head before start
# head before start is used to denote a missing value
arr = doc.to_array(["HEAD"])
arr[0] = -1
doc_from_array = Doc(en_vocab, words=words)
doc_from_array.from_array(["HEAD"], arr)
# other negative values are invalid
arr = doc.to_array(["HEAD"])
arr[0] = -2
doc_from_array = Doc(en_vocab, words=words)
with pytest.raises(ValueError):
doc_from_array.from_array(["HEAD"], arr)

View File

@ -45,7 +45,17 @@ CONFLICTING_DATA = [
),
]
eps = 0.01
PARTIAL_DATA = [
(
"I like London.",
{
"heads": [1, 1, 1, None],
"deps": ["nsubj", "ROOT", "dobj", None],
},
),
]
eps = 0.1
def test_parser_root(en_vocab):
@ -205,6 +215,32 @@ def test_parser_set_sent_starts(en_vocab):
assert token.head in sent
@pytest.mark.parametrize("pipe_name", ["parser", "beam_parser"])
def test_incomplete_data(pipe_name):
# Test that the parser works with incomplete information
nlp = English()
parser = nlp.add_pipe(pipe_name)
train_examples = []
for text, annotations in PARTIAL_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
for dep in annotations.get("deps", []):
if dep is not None:
parser.add_label(dep)
optimizer = nlp.initialize(get_examples=lambda: train_examples)
for i in range(150):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
assert losses[pipe_name] < 0.0001
# test the trained model
test_text = "I like securities."
doc = nlp(test_text)
assert doc[0].dep_ == "nsubj"
assert doc[2].dep_ == "dobj"
assert doc[0].head.i == 1
assert doc[2].head.i == 1
@pytest.mark.parametrize("pipe_name", ["parser", "beam_parser"])
def test_overfitting_IO(pipe_name):
# Simple test to try and quickly overfit the dependency parser (normal or beam)

View File

@ -266,7 +266,7 @@ cdef class Doc:
self.push_back(lexeme, has_space)
if heads is not None:
heads = [head - i for i, head in enumerate(heads)]
heads = [head - i if head is not None else None for i, head in enumerate(heads)]
if deps and not heads:
heads = [0] * len(deps)
if sent_starts is not None:
@ -328,7 +328,8 @@ cdef class Doc:
if annot is not heads and annot is not sent_starts and annot is not ent_iobs:
values.extend(annot)
for value in values:
self.vocab.strings.add(value)
if value is not None:
self.vocab.strings.add(value)
# if there are any other annotations, set them
if headings:
@ -1039,7 +1040,8 @@ cdef class Doc:
# cast index to signed int
abs_head_index = <int32_t>values[col * stride + i]
abs_head_index += i
if abs_head_index < 0 or abs_head_index >= length:
# abs_head_index -1 refers to missing value
if abs_head_index < -1 or abs_head_index >= length:
raise ValueError(
Errors.E190.format(
index=i,

View File

@ -639,13 +639,16 @@ cdef class Token:
return any(ancestor.i == self.i for ancestor in descendant.ancestors)
property head:
"""The syntactic parent, or "governor", of this token.
"""The syntactic parent, or "governor", of this token.
RETURNS (Token): The token predicted by the parser to be the head of
the current token.
the current token. Returns None if unknown.
"""
def __get__(self):
return self.doc[self.i + self.c.head]
head_i = self.i + self.c.head
if head_i == -1:
return None
return self.doc[head_i]
def __set__(self, Token new_head):
# This function sets the head of self to new_head and updates the

View File

@ -11,6 +11,7 @@ from .alignment import Alignment
from .iob_utils import biluo_to_iob, offsets_to_biluo_tags, doc_to_biluo_tags
from .iob_utils import biluo_tags_to_spans
from ..errors import Errors, Warnings
from ..symbols import MISSING_LABEL
from ..pipeline._parser_internals import nonproj
from ..util import logger
@ -179,14 +180,18 @@ cdef class Example:
gold_to_cand = self.alignment.y2x
aligned_heads = [None] * self.x.length
aligned_deps = [None] * self.x.length
heads = [token.head.i for token in self.y]
heads = [token.head.i if token.head is not None else -1 for token in self.y]
deps = [token.dep_ for token in self.y]
if projectivize:
heads, deps = nonproj.projectivize(heads, deps)
proj_heads, proj_deps = nonproj.projectivize(heads, deps)
# don't touch the missing data
heads = [h if heads[i] != -1 else -1 for i, h in enumerate(proj_heads)]
MISSING = self.x.vocab.strings[MISSING_LABEL]
deps = [d if deps[i] != MISSING else MISSING for i, d in enumerate(proj_deps)]
for cand_i in range(self.x.length):
if cand_to_gold.lengths[cand_i] == 1:
gold_i = cand_to_gold[cand_i].dataXd[0, 0]
if gold_to_cand.lengths[heads[gold_i]] == 1:
if heads[gold_i] != -1 and gold_to_cand.lengths[heads[gold_i]] == 1:
aligned_heads[cand_i] = int(gold_to_cand[heads[gold_i]].dataXd[0, 0])
aligned_deps[cand_i] = deps[gold_i]
return aligned_heads, aligned_deps
@ -329,7 +334,10 @@ def _annot2array(vocab, tok_annot, doc_annot):
pass
elif key == "HEAD":
attrs.append(key)
values.append([h-i for i, h in enumerate(value)])
values.append([h-i if h is not None else -(i+1) for i, h in enumerate(value)])
elif key == "DEP":
attrs.append(key)
values.append([vocab.strings.add(h) if h is not None else MISSING_LABEL for h in value])
elif key == "SENT_START":
attrs.append(key)
values.append(value)