mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-05 22:20:34 +03:00
Merge branch 'whatif/arrow' of https://github.com/explosion/spaCy into whatif/arrow
This commit is contained in:
commit
1720c58287
|
@ -7,6 +7,7 @@ from ..tokens.doc cimport Doc
|
|||
from ..attrs import IDS
|
||||
from .align cimport Alignment
|
||||
from .iob_utils import biluo_to_iob, biluo_tags_from_offsets, biluo_tags_from_doc
|
||||
from .iob_utils import spans_from_biluo_tags
|
||||
from .align import Alignment
|
||||
from ..errors import Errors, AlignmentError
|
||||
from ..syntax import nonproj
|
||||
|
@ -142,6 +143,63 @@ cdef class Example:
|
|||
aligned_deps[cand_i] = deps[gold_i]
|
||||
return aligned_heads, aligned_deps
|
||||
|
||||
def get_aligned_ner(self):
|
||||
cand_to_gold = self.alignment.cand_to_gold
|
||||
gold_to_cand = self.alignment.gold_to_cand
|
||||
i2j_multi = self.alignment.i2j_multi
|
||||
j2i_multi = self.alignment.j2i_multi
|
||||
y_tags = biluo_tags_from_offsets(
|
||||
self.y,
|
||||
[(e.start_char, e.end_char, e.label_) for e in self.y.ents]
|
||||
)
|
||||
x_tags = [None] * self.x.length
|
||||
for i in range(self.x.length):
|
||||
if self.x[i].is_space:
|
||||
pass
|
||||
elif cand_to_gold[i] is not None:
|
||||
x_tags[i] = y_tags[cand_to_gold[i]]
|
||||
elif i in i2j_multi:
|
||||
# Assign O/- for many-to-one O/- NER tags
|
||||
if y_tags[i2j_multi[i]] in ("O", "-"):
|
||||
x_tags[i] = y_tags[i2j_multi[i]]
|
||||
# Assign O/- for one-to-many O/- NER tags
|
||||
for gold_i, cand_i in enumerate(gold_to_cand):
|
||||
if y_tags[gold_i] in ("O", "-"):
|
||||
if cand_i is None and gold_i in j2i_multi:
|
||||
x_tags[j2i_multi[gold_i]] = y_tags[gold_i]
|
||||
# TODO: I'm copying this over from v2.x but this seems kind of nuts?
|
||||
# If there is entity annotation and some tokens remain unaligned,
|
||||
# align all entities at the character level to account for all
|
||||
# possible token misalignments within the entity spans
|
||||
if list(self.y.ents) and None in x_tags:
|
||||
# Get offsets based on gold words and BILUO entities
|
||||
aligned_offsets = []
|
||||
aligned_spans = []
|
||||
# Filter offsets to identify those that align with doc tokens
|
||||
for span in spans_from_biluo_tags(self.x, x_tags):
|
||||
if span and not span.text.isspace():
|
||||
aligned_offsets.append(
|
||||
(span.start_char, span.end_char, span.label_)
|
||||
)
|
||||
aligned_spans.append(span)
|
||||
# Convert back to BILUO for doc tokens and assign NER for all
|
||||
# aligned spans
|
||||
aligned_tags = biluo_tags_from_offsets(self.x, aligned_offsets, missing=None)
|
||||
for span in aligned_spans:
|
||||
for i in range(span.start, span.end):
|
||||
x_tags[i] = aligned_tags[i]
|
||||
# Prevent whitespace that isn't within entities from being tagged as
|
||||
# an entity.
|
||||
for i, token in enumerate(self.x):
|
||||
if token.is_space:
|
||||
prev_ner = x_tags[i] if i >= 1 else None
|
||||
next_ner = x_tags[i+1] if (i+1) < self.x.length else None
|
||||
if prev_ner == "O" or next_ner == "O":
|
||||
x_tags[i] = "O"
|
||||
#print("Y tags", y_tags)
|
||||
#print("X tags", x_tags)
|
||||
return x_tags
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"doc_annotation": {
|
||||
|
|
|
@ -49,6 +49,10 @@ cdef class BiluoGold:
|
|||
self.mem = Pool()
|
||||
self.c = create_gold_state(self.mem, moves, stcls, example)
|
||||
|
||||
def update(self, StateClass stcls):
|
||||
update_gold_state(&self.c, stcls)
|
||||
|
||||
|
||||
|
||||
cdef GoldNERStateC create_gold_state(
|
||||
Pool mem,
|
||||
|
@ -58,30 +62,15 @@ cdef GoldNERStateC create_gold_state(
|
|||
) except *:
|
||||
cdef GoldNERStateC gs
|
||||
gs.ner = <Transition*>mem.alloc(example.x.length, sizeof(Transition))
|
||||
ner_tags = get_aligned_ner(example)
|
||||
ner_tags = example.get_aligned_ner()
|
||||
for i, ner_tag in enumerate(ner_tags):
|
||||
gs.ner[i] = moves.lookup_transition(ner_tag)
|
||||
return gs
|
||||
|
||||
|
||||
def get_aligned_ner(Example example):
|
||||
cand_to_gold = example.alignment.cand_to_gold
|
||||
i2j_multi = example.alignment.i2j_multi
|
||||
y_tags = biluo_tags_from_offsets(
|
||||
example.y,
|
||||
[(e.start_char, e.end_char, e.label_) for e in example.y.ents]
|
||||
)
|
||||
x_tags = [None] * example.x.length
|
||||
for i in range(example.x.length):
|
||||
if example.x[i].is_space:
|
||||
cdef void update_gold_state(GoldNERStateC* gs, StateClass stcls) except *:
|
||||
# We don't need to update each time, unlike the parser.
|
||||
pass
|
||||
elif cand_to_gold[i] is not None:
|
||||
x_tags[i] = y_tags[cand_to_gold[i]]
|
||||
elif i in i2j_multi:
|
||||
# Assign O/- for many-to-one O/- NER tags
|
||||
if y_tags[i2j_multi[i]] in ("O", "-"):
|
||||
x_tags[i] = y_tags[i2j_multi[i]]
|
||||
return y_tags
|
||||
|
||||
|
||||
cdef do_func_t[N_MOVES] do_funcs
|
||||
|
@ -120,11 +109,12 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
for action in (BEGIN, IN, LAST, UNIT):
|
||||
actions[action][entity_type] = 1
|
||||
moves = ('M', 'B', 'I', 'L', 'U')
|
||||
for example in kwargs.get('gold_parses', []):
|
||||
for ner_tag in example.get_aligned("ENT_TYPE", as_string=True):
|
||||
if ner_tag != 'O' and ner_tag != '-':
|
||||
for example in kwargs.get('examples', []):
|
||||
for token in example.y:
|
||||
ent_type = token.ent_type_
|
||||
if ent_type:
|
||||
for action in (BEGIN, IN, LAST, UNIT):
|
||||
actions[action][ner_tag] += 1
|
||||
actions[action][ent_type] += 1
|
||||
return actions
|
||||
|
||||
@property
|
||||
|
@ -247,6 +237,37 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
self.add_action(UNIT, st._sent[i].ent_type)
|
||||
self.add_action(LAST, st._sent[i].ent_type)
|
||||
|
||||
def get_cost(self, StateClass stcls, gold, int i):
|
||||
if not isinstance(gold, BiluoGold):
|
||||
raise TypeError("Expected BiluoGold")
|
||||
cdef BiluoGold gold_ = gold
|
||||
gold_state = gold_.c
|
||||
n_gold = 0
|
||||
if self.c[i].is_valid(stcls.c, self.c[i].label):
|
||||
cost = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
|
||||
else:
|
||||
cost = 9000
|
||||
return cost
|
||||
|
||||
cdef int set_costs(self, int* is_valid, weight_t* costs,
|
||||
StateClass stcls, gold) except -1:
|
||||
if not isinstance(gold, BiluoGold):
|
||||
raise TypeError("Expected BiluoGold")
|
||||
cdef BiluoGold gold_ = gold
|
||||
gold_.update(stcls)
|
||||
gold_state = gold_.c
|
||||
n_gold = 0
|
||||
for i in range(self.n_moves):
|
||||
if self.c[i].is_valid(stcls.c, self.c[i].label):
|
||||
is_valid[i] = True
|
||||
costs[i] = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
|
||||
n_gold += costs[i] <= 0
|
||||
else:
|
||||
is_valid[i] = False
|
||||
costs[i] = 9000
|
||||
if n_gold < 1:
|
||||
raise ValueError
|
||||
|
||||
|
||||
cdef class Missing:
|
||||
@staticmethod
|
||||
|
|
|
@ -54,7 +54,6 @@ def tsys(vocab, entity_types):
|
|||
|
||||
def test_get_oracle_moves(tsys, doc, entity_annots):
|
||||
example = Example.from_dict(doc, {"entities": entity_annots})
|
||||
tsys.preprocess_gold(example)
|
||||
act_classes = tsys.get_oracle_sequence(example)
|
||||
names = [tsys.get_class_name(act) for act in act_classes]
|
||||
assert names == ["U-PERSON", "O", "O", "B-GPE", "L-GPE", "O"]
|
||||
|
@ -70,7 +69,6 @@ def test_get_oracle_moves_negative_entities(tsys, doc, entity_annots):
|
|||
ex_dict["doc_annotation"]["entities"][i] = "-"
|
||||
example = Example.from_dict(doc, ex_dict)
|
||||
|
||||
tsys.preprocess_gold(example)
|
||||
act_classes = tsys.get_oracle_sequence(example)
|
||||
names = [tsys.get_class_name(act) for act in act_classes]
|
||||
assert names
|
||||
|
@ -80,7 +78,6 @@ def test_get_oracle_moves_negative_entities2(tsys, vocab):
|
|||
doc = Doc(vocab, words=["A", "B", "C", "D"])
|
||||
entity_annots = ["B-!PERSON", "L-!PERSON", "B-!PERSON", "L-!PERSON"]
|
||||
example = Example.from_dict(doc, {"entities": entity_annots})
|
||||
tsys.preprocess_gold(example)
|
||||
act_classes = tsys.get_oracle_sequence(example)
|
||||
names = [tsys.get_class_name(act) for act in act_classes]
|
||||
assert names
|
||||
|
@ -90,7 +87,6 @@ def test_get_oracle_moves_negative_O(tsys, vocab):
|
|||
doc = Doc(vocab, words=["A", "B", "C", "D"])
|
||||
entity_annots = ["O", "!O", "O", "!O"]
|
||||
example = Example.from_dict(doc, {"entities": []})
|
||||
tsys.preprocess_gold(example)
|
||||
act_classes = tsys.get_oracle_sequence(example)
|
||||
names = [tsys.get_class_name(act) for act in act_classes]
|
||||
assert names
|
||||
|
@ -116,7 +112,6 @@ def test_oracle_moves_missing_B(en_vocab):
|
|||
moves.add_action(move_types.index("I"), label)
|
||||
moves.add_action(move_types.index("L"), label)
|
||||
moves.add_action(move_types.index("U"), label)
|
||||
moves.preprocess_gold(example)
|
||||
moves.get_oracle_sequence(example)
|
||||
|
||||
|
||||
|
@ -137,7 +132,6 @@ def test_oracle_moves_whitespace(en_vocab):
|
|||
else:
|
||||
action, label = tag.split("-")
|
||||
moves.add_action(move_types.index(action), label)
|
||||
moves.preprocess_gold(example)
|
||||
moves.get_oracle_sequence(example)
|
||||
|
||||
|
||||
|
|
|
@ -364,12 +364,14 @@ def test_roundtrip_docs_to_docbin(doc):
|
|||
|
||||
# roundtrip to DocBin
|
||||
with make_tempdir() as tmpdir:
|
||||
json_file = tmpdir / "roundtrip.json"
|
||||
srsly.write_json(json_file, [docs_to_json(doc)])
|
||||
goldcorpus = Corpus(str(json_file), str(json_file))
|
||||
output_file = tmpdir / "roundtrip.spacy"
|
||||
data = DocBin(docs=[doc]).to_bytes()
|
||||
with output_file.open("wb") as file_:
|
||||
file_.write(data)
|
||||
goldcorpus = Corpus(train_loc=str(output_file), dev_loc=str(output_file))
|
||||
|
||||
reloaded_example = next(goldcorpus.dev_dataset(nlp=nlp))
|
||||
assert len(doc) == goldcorpus.count_train(nlp)
|
||||
assert text == reloaded_example.reference.text
|
||||
|
@ -389,40 +391,10 @@ def test_roundtrip_docs_to_docbin(doc):
|
|||
assert cats["BAKING"] == reloaded_example.reference.cats["BAKING"]
|
||||
|
||||
|
||||
@pytest.mark.xfail # TODO do we need to do the projectivity differently?
|
||||
def test_projective_train_vs_nonprojective_dev(doc):
|
||||
nlp = English()
|
||||
deps = [t.dep_ for t in doc]
|
||||
heads = [t.head.i for t in doc]
|
||||
|
||||
with make_tempdir() as tmpdir:
|
||||
output_file = tmpdir / "roundtrip.spacy"
|
||||
data = DocBin(docs=[doc]).to_bytes()
|
||||
with output_file.open("wb") as file_:
|
||||
file_.write(data)
|
||||
goldcorpus = Corpus(train_loc=str(output_file), dev_loc=str(output_file))
|
||||
|
||||
train_reloaded_example = next(goldcorpus.train_dataset(nlp))
|
||||
train_goldparse = get_parses_from_example(train_reloaded_example)[0][1]
|
||||
|
||||
dev_reloaded_example = next(goldcorpus.dev_dataset(nlp))
|
||||
dev_goldparse = get_parses_from_example(dev_reloaded_example)[0][1]
|
||||
|
||||
assert is_nonproj_tree([t.head.i for t in doc]) is True
|
||||
assert is_nonproj_tree(train_goldparse.heads) is False
|
||||
assert heads[:-1] == train_goldparse.heads[:-1]
|
||||
assert heads[-1] != train_goldparse.heads[-1]
|
||||
assert deps[:-1] == train_goldparse.labels[:-1]
|
||||
assert deps[-1] != train_goldparse.labels[-1]
|
||||
|
||||
assert heads == dev_goldparse.heads
|
||||
assert deps == dev_goldparse.labels
|
||||
|
||||
|
||||
# Hm, not sure where misalignment check would be handled? In the components too?
|
||||
# I guess that does make sense. A text categorizer doesn't care if it's
|
||||
# misaligned...
|
||||
@pytest.mark.xfail # TODO
|
||||
@pytest.mark.xfail(reason="Outdated")
|
||||
def test_ignore_misaligned(doc):
|
||||
nlp = English()
|
||||
text = doc.text
|
||||
|
@ -453,6 +425,8 @@ def test_ignore_misaligned(doc):
|
|||
assert len(train_reloaded_example) == 0
|
||||
|
||||
|
||||
# We probably want the orth variant logic back, but this test won't be quite
|
||||
# right -- we need to go from DocBin.
|
||||
def test_make_orth_variants(doc):
|
||||
nlp = English()
|
||||
with make_tempdir() as tmpdir:
|
||||
|
@ -598,17 +572,3 @@ def test_split_sents(merged_dict):
|
|||
assert token_annotation_2["words"] == ["It", "is", "just", "me"]
|
||||
assert token_annotation_2["tags"] == ["PRON", "AUX", "ADV", "PRON"]
|
||||
assert token_annotation_2["sent_starts"] == [1, 0, 0, 0]
|
||||
|
||||
|
||||
def test_tuples_to_example(vocab, merged_dict):
|
||||
cats = {"TRAVEL": 1.0, "BAKING": 0.0}
|
||||
merged_dict = dict(merged_dict)
|
||||
merged_dict["cats"] = cats
|
||||
ex = Example.from_dict(Doc(vocab, words=merged_dict["words"]), merged_dict)
|
||||
words = [token.text for token in ex.reference]
|
||||
assert words == merged_dict["words"]
|
||||
tags = [token.tag_ for token in ex.reference]
|
||||
assert tags == merged_dict["tags"]
|
||||
sent_starts = [bool(token.is_sent_start) for token in ex.reference]
|
||||
assert sent_starts == [bool(v) for v in merged_dict["sent_starts"]]
|
||||
assert ex.reference.cats == cats
|
||||
|
|
Loading…
Reference in New Issue
Block a user