mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
allow missing information in deps and heads annotations
This commit is contained in:
parent
1abeca90a6
commit
dd12c6c8fd
|
@ -195,8 +195,7 @@ cdef class ArcEagerGold:
|
|||
def __init__(self, ArcEager moves, StateClass stcls, Example example):
|
||||
self.mem = Pool()
|
||||
heads, labels = example.get_aligned_parse(projectivize=True)
|
||||
labels = [label if label is not None else "" for label in labels]
|
||||
labels = [example.x.vocab.strings.add(label) for label in labels]
|
||||
labels = [example.x.vocab.strings.add(label) if label is not None else 0 for label in labels]
|
||||
sent_starts = example.get_aligned_sent_starts()
|
||||
assert len(heads) == len(labels) == len(sent_starts), (len(heads), len(labels), len(sent_starts))
|
||||
self.c = create_gold_state(self.mem, stcls.c, heads, labels, sent_starts)
|
||||
|
@ -783,7 +782,7 @@ cdef class ArcEager(TransitionSystem):
|
|||
for i in range(self.n_moves):
|
||||
print(self.get_class_name(i), is_valid[i], costs[i])
|
||||
print("Gold sent starts?", is_sent_start(&gold_state, state.B(0)), is_sent_start(&gold_state, state.B(1)))
|
||||
raise ValueError
|
||||
raise ValueError("Could not find gold transition - see logs above.")
|
||||
|
||||
def get_oracle_sequence_from_state(self, StateClass state, ArcEagerGold gold, _debug=None):
|
||||
cdef int i
|
||||
|
|
|
@ -467,3 +467,4 @@ cdef enum symbol_t:
|
|||
|
||||
IDX
|
||||
_
|
||||
MISSING_LABEL
|
||||
|
|
|
@ -466,6 +466,7 @@ IDS = {
|
|||
"LAW": LAW,
|
||||
"MORPH": MORPH,
|
||||
"_": _,
|
||||
"MISSING_LABEL": MISSING_LABEL,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -98,10 +98,16 @@ def test_doc_from_array_heads_in_bounds(en_vocab):
|
|||
doc_from_array = Doc(en_vocab, words=words)
|
||||
doc_from_array.from_array(["HEAD"], arr)
|
||||
|
||||
# head before start
|
||||
# head before start is used to denote a missing value
|
||||
arr = doc.to_array(["HEAD"])
|
||||
arr[0] = -1
|
||||
doc_from_array = Doc(en_vocab, words=words)
|
||||
doc_from_array.from_array(["HEAD"], arr)
|
||||
|
||||
# other negative values are invalid
|
||||
arr = doc.to_array(["HEAD"])
|
||||
arr[0] = -2
|
||||
doc_from_array = Doc(en_vocab, words=words)
|
||||
with pytest.raises(ValueError):
|
||||
doc_from_array.from_array(["HEAD"], arr)
|
||||
|
||||
|
|
|
@ -45,7 +45,17 @@ CONFLICTING_DATA = [
|
|||
),
|
||||
]
|
||||
|
||||
eps = 0.01
|
||||
PARTIAL_DATA = [
|
||||
(
|
||||
"I like London.",
|
||||
{
|
||||
"heads": [1, 1, 1, None],
|
||||
"deps": ["nsubj", "ROOT", "dobj", None],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
eps = 0.1
|
||||
|
||||
|
||||
def test_parser_root(en_vocab):
|
||||
|
@ -205,6 +215,32 @@ def test_parser_set_sent_starts(en_vocab):
|
|||
assert token.head in sent
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pipe_name", ["parser", "beam_parser"])
|
||||
def test_incomplete_data(pipe_name):
|
||||
# Test that the parser works with incomplete information
|
||||
nlp = English()
|
||||
parser = nlp.add_pipe(pipe_name)
|
||||
train_examples = []
|
||||
for text, annotations in PARTIAL_DATA:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||
for dep in annotations.get("deps", []):
|
||||
if dep is not None:
|
||||
parser.add_label(dep)
|
||||
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||
for i in range(150):
|
||||
losses = {}
|
||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||
assert losses[pipe_name] < 0.0001
|
||||
|
||||
# test the trained model
|
||||
test_text = "I like securities."
|
||||
doc = nlp(test_text)
|
||||
assert doc[0].dep_ == "nsubj"
|
||||
assert doc[2].dep_ == "dobj"
|
||||
assert doc[0].head.i == 1
|
||||
assert doc[2].head.i == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pipe_name", ["parser", "beam_parser"])
|
||||
def test_overfitting_IO(pipe_name):
|
||||
# Simple test to try and quickly overfit the dependency parser (normal or beam)
|
||||
|
|
|
@ -266,7 +266,7 @@ cdef class Doc:
|
|||
self.push_back(lexeme, has_space)
|
||||
|
||||
if heads is not None:
|
||||
heads = [head - i for i, head in enumerate(heads)]
|
||||
heads = [head - i if head is not None else None for i, head in enumerate(heads)]
|
||||
if deps and not heads:
|
||||
heads = [0] * len(deps)
|
||||
if sent_starts is not None:
|
||||
|
@ -328,7 +328,8 @@ cdef class Doc:
|
|||
if annot is not heads and annot is not sent_starts and annot is not ent_iobs:
|
||||
values.extend(annot)
|
||||
for value in values:
|
||||
self.vocab.strings.add(value)
|
||||
if value is not None:
|
||||
self.vocab.strings.add(value)
|
||||
|
||||
# if there are any other annotations, set them
|
||||
if headings:
|
||||
|
@ -1039,7 +1040,8 @@ cdef class Doc:
|
|||
# cast index to signed int
|
||||
abs_head_index = <int32_t>values[col * stride + i]
|
||||
abs_head_index += i
|
||||
if abs_head_index < 0 or abs_head_index >= length:
|
||||
# abs_head_index -1 refers to missing value
|
||||
if abs_head_index < -1 or abs_head_index >= length:
|
||||
raise ValueError(
|
||||
Errors.E190.format(
|
||||
index=i,
|
||||
|
|
|
@ -639,13 +639,16 @@ cdef class Token:
|
|||
return any(ancestor.i == self.i for ancestor in descendant.ancestors)
|
||||
|
||||
property head:
|
||||
"""The syntactic parent, or "governor", of this token.
|
||||
"""The syntactic parent, or "governor", of this token.
|
||||
|
||||
RETURNS (Token): The token predicted by the parser to be the head of
|
||||
the current token.
|
||||
the current token. Returns None if unknown.
|
||||
"""
|
||||
def __get__(self):
|
||||
return self.doc[self.i + self.c.head]
|
||||
head_i = self.i + self.c.head
|
||||
if head_i == -1:
|
||||
return None
|
||||
return self.doc[head_i]
|
||||
|
||||
def __set__(self, Token new_head):
|
||||
# This function sets the head of self to new_head and updates the
|
||||
|
|
|
@ -11,6 +11,7 @@ from .alignment import Alignment
|
|||
from .iob_utils import biluo_to_iob, offsets_to_biluo_tags, doc_to_biluo_tags
|
||||
from .iob_utils import biluo_tags_to_spans
|
||||
from ..errors import Errors, Warnings
|
||||
from ..symbols import MISSING_LABEL
|
||||
from ..pipeline._parser_internals import nonproj
|
||||
from ..util import logger
|
||||
|
||||
|
@ -179,14 +180,18 @@ cdef class Example:
|
|||
gold_to_cand = self.alignment.y2x
|
||||
aligned_heads = [None] * self.x.length
|
||||
aligned_deps = [None] * self.x.length
|
||||
heads = [token.head.i for token in self.y]
|
||||
heads = [token.head.i if token.head is not None else -1 for token in self.y]
|
||||
deps = [token.dep_ for token in self.y]
|
||||
if projectivize:
|
||||
heads, deps = nonproj.projectivize(heads, deps)
|
||||
proj_heads, proj_deps = nonproj.projectivize(heads, deps)
|
||||
# don't touch the missing data
|
||||
heads = [h if heads[i] != -1 else -1 for i, h in enumerate(proj_heads)]
|
||||
MISSING = self.x.vocab.strings[MISSING_LABEL]
|
||||
deps = [d if deps[i] != MISSING else MISSING for i, d in enumerate(proj_deps)]
|
||||
for cand_i in range(self.x.length):
|
||||
if cand_to_gold.lengths[cand_i] == 1:
|
||||
gold_i = cand_to_gold[cand_i].dataXd[0, 0]
|
||||
if gold_to_cand.lengths[heads[gold_i]] == 1:
|
||||
if heads[gold_i] != -1 and gold_to_cand.lengths[heads[gold_i]] == 1:
|
||||
aligned_heads[cand_i] = int(gold_to_cand[heads[gold_i]].dataXd[0, 0])
|
||||
aligned_deps[cand_i] = deps[gold_i]
|
||||
return aligned_heads, aligned_deps
|
||||
|
@ -329,7 +334,10 @@ def _annot2array(vocab, tok_annot, doc_annot):
|
|||
pass
|
||||
elif key == "HEAD":
|
||||
attrs.append(key)
|
||||
values.append([h-i for i, h in enumerate(value)])
|
||||
values.append([h-i if h is not None else -(i+1) for i, h in enumerate(value)])
|
||||
elif key == "DEP":
|
||||
attrs.append(key)
|
||||
values.append([vocab.strings.add(h) if h is not None else MISSING_LABEL for h in value])
|
||||
elif key == "SENT_START":
|
||||
attrs.append(key)
|
||||
values.append(value)
|
||||
|
|
Loading…
Reference in New Issue
Block a user