fix BiluoPushDown parsing entities

This commit is contained in:
svlandeg 2020-06-18 13:00:03 +02:00
parent cd790aaa2a
commit 0c6f1f3891
3 changed files with 4 additions and 6 deletions

View File

@ -117,7 +117,7 @@ cdef class Example:
i = j2i_multi[j] i = j2i_multi[j]
if output[i] is None: if output[i] is None:
output[i] = gold_values[j] output[i] = gold_values[j]
if as_string: if as_string and field not in ["ENT_IOB"]:
output = [vocab.strings[o] if o is not None else o for o in output] output = [vocab.strings[o] if o is not None else o for o in output]
return output return output

View File

@ -72,11 +72,10 @@ cdef class BiluoPushDown(TransitionSystem):
actions[action][entity_type] = 1 actions[action][entity_type] = 1
moves = ('M', 'B', 'I', 'L', 'U') moves = ('M', 'B', 'I', 'L', 'U')
for example in kwargs.get('gold_parses', []): for example in kwargs.get('gold_parses', []):
for i, ner_tag in enumerate(example.token_annotation.entities): for ner_tag in example.get_aligned("ENT_TYPE", as_string=True):
if ner_tag != 'O' and ner_tag != '-': if ner_tag != 'O' and ner_tag != '-':
_, label = ner_tag.split('-', 1)
for action in (BEGIN, IN, LAST, UNIT): for action in (BEGIN, IN, LAST, UNIT):
actions[action][label] += 1 actions[action][ner_tag] += 1
return actions return actions
@property @property

View File

@ -268,7 +268,6 @@ def test_issue1963(en_tokenizer):
assert doc.tensor.shape == (3, 128) assert doc.tensor.shape == (3, 128)
# TODO: fix
@pytest.mark.parametrize("label", ["U-JOB-NAME"]) @pytest.mark.parametrize("label", ["U-JOB-NAME"])
def test_issue1967(label): def test_issue1967(label):
config = {"learn_tokens": False, "min_action_freq": 30, "beam_width": 1, "beam_update_prob": 1.0} config = {"learn_tokens": False, "min_action_freq": 30, "beam_width": 1, "beam_update_prob": 1.0}
@ -284,7 +283,7 @@ def test_issue1967(label):
"entities": [label] "entities": [label]
} }
) )
ner.moves.get_actions(gold_parses=[example]) assert "JOB-NAME" in ner.moves.get_actions(gold_parses=[example])[1]
def test_issue1971(en_vocab): def test_issue1971(en_vocab):