From eaeca5eb6a6e233b1f1f73c47fbfaf3f51720c18 Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Fri, 17 Jun 2022 20:02:37 +0100 Subject: [PATCH] account for NER labels with a hyphen in the name (#10960) * account for NER labels with a hyphen in the name * cleanup * fix docstring * add return type to helper method * shorter method and few more occurrences * user helper method across repo * fix circular import * partial revert to avoid circular import --- spacy/cli/debug_data.py | 8 +++---- .../pipeline/_parser_internals/arc_eager.pyx | 3 ++- spacy/pipeline/_parser_internals/ner.pyx | 3 ++- spacy/pipeline/dep_parser.pyx | 3 ++- spacy/pipeline/ner.pyx | 4 ++-- spacy/tests/parser/test_ner.py | 22 ++++++++++++++++--- spacy/tests/util.py | 3 ++- spacy/training/__init__.py | 1 + spacy/training/augment.py | 8 +++---- spacy/training/example.pyx | 4 ++-- spacy/training/iob_utils.py | 10 ++++++++- 11 files changed, 48 insertions(+), 21 deletions(-) diff --git a/spacy/cli/debug_data.py b/spacy/cli/debug_data.py index 0061515c6..8a6dde955 100644 --- a/spacy/cli/debug_data.py +++ b/spacy/cli/debug_data.py @@ -10,7 +10,7 @@ import math from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides from ._util import import_code, debug_cli -from ..training import Example +from ..training import Example, remove_bilu_prefix from ..training.initialize import get_sourced_components from ..schemas import ConfigSchemaTraining from ..pipeline._parser_internals import nonproj @@ -758,9 +758,9 @@ def _compile_gold( # "Illegal" whitespace entity data["ws_ents"] += 1 if label.startswith(("B-", "U-")): - combined_label = label.split("-")[1] + combined_label = remove_bilu_prefix(label) data["ner"][combined_label] += 1 - if sent_starts[i] == True and label.startswith(("I-", "L-")): + if sent_starts[i] and label.startswith(("I-", "L-")): data["boundary_cross_ents"] += 1 elif label == "-": data["ner"]["-"] += 1 @@ -908,7 +908,7 @@ def _get_examples_without_label( for eg in data: if component == "ner": labels = [ - label.split("-")[1] + remove_bilu_prefix(label) for label in eg.get_aligned_ner() if label not in ("O", "-", None) ] diff --git a/spacy/pipeline/_parser_internals/arc_eager.pyx b/spacy/pipeline/_parser_internals/arc_eager.pyx index d60f1c3e6..257b5ef8a 100644 --- a/spacy/pipeline/_parser_internals/arc_eager.pyx +++ b/spacy/pipeline/_parser_internals/arc_eager.pyx @@ -10,6 +10,7 @@ from ...strings cimport hash_string from ...structs cimport TokenC from ...tokens.doc cimport Doc, set_children_from_heads from ...tokens.token cimport MISSING_DEP +from ...training import split_bilu_label from ...training.example cimport Example from .stateclass cimport StateClass from ._state cimport StateC, ArcC @@ -687,7 +688,7 @@ cdef class ArcEager(TransitionSystem): return self.c[name_or_id] name = name_or_id if '-' in name: - move_str, label_str = name.split('-', 1) + move_str, label_str = split_bilu_label(name) label = self.strings[label_str] else: move_str = name diff --git a/spacy/pipeline/_parser_internals/ner.pyx b/spacy/pipeline/_parser_internals/ner.pyx index 3edeff19a..fab872f00 100644 --- a/spacy/pipeline/_parser_internals/ner.pyx +++ b/spacy/pipeline/_parser_internals/ner.pyx @@ -13,6 +13,7 @@ from ...typedefs cimport weight_t, attr_t from ...lexeme cimport Lexeme from ...attrs cimport IS_SPACE from ...structs cimport TokenC, SpanC +from ...training import split_bilu_label from ...training.example cimport Example from .stateclass cimport StateClass from ._state cimport StateC @@ -182,7 +183,7 @@ cdef class BiluoPushDown(TransitionSystem): if name == '-' or name == '' or name is None: return Transition(clas=0, move=MISSING, label=0, score=0) elif '-' in name: - move_str, label_str = name.split('-', 1) + move_str, label_str = split_bilu_label(name) # Deprecated, hacky way to denote 'not this entity' if label_str.startswith('!'): raise ValueError(Errors.E869.format(label=name)) diff --git a/spacy/pipeline/dep_parser.pyx b/spacy/pipeline/dep_parser.pyx index 50c57ee5b..e5f686158 100644 --- a/spacy/pipeline/dep_parser.pyx +++ b/spacy/pipeline/dep_parser.pyx @@ -12,6 +12,7 @@ from ..language import Language from ._parser_internals import nonproj from ._parser_internals.nonproj import DELIMITER from ..scorer import Scorer +from ..training import remove_bilu_prefix from ..util import registry @@ -314,7 +315,7 @@ cdef class DependencyParser(Parser): # Get the labels from the model by looking at the available moves for move in self.move_names: if "-" in move: - label = move.split("-")[1] + label = remove_bilu_prefix(move) if DELIMITER in label: label = label.split(DELIMITER)[1] labels.add(label) diff --git a/spacy/pipeline/ner.pyx b/spacy/pipeline/ner.pyx index 4835a8c4b..25f48c9f8 100644 --- a/spacy/pipeline/ner.pyx +++ b/spacy/pipeline/ner.pyx @@ -6,10 +6,10 @@ from thinc.api import Model, Config from ._parser_internals.transition_system import TransitionSystem from .transition_parser cimport Parser from ._parser_internals.ner cimport BiluoPushDown - from ..language import Language from ..scorer import get_ner_prf, PRFScore from ..util import registry +from ..training import remove_bilu_prefix default_model_config = """ @@ -242,7 +242,7 @@ cdef class EntityRecognizer(Parser): def labels(self): # Get the labels from the model by looking at the available moves, e.g. # B-PERSON, I-PERSON, L-PERSON, U-PERSON - labels = set(move.split("-")[1] for move in self.move_names + labels = set(remove_bilu_prefix(move) for move in self.move_names if move[0] in ("B", "I", "L", "U")) return tuple(sorted(labels)) diff --git a/spacy/tests/parser/test_ner.py b/spacy/tests/parser/test_ner.py index b3b29d1f9..53bb2d554 100644 --- a/spacy/tests/parser/test_ner.py +++ b/spacy/tests/parser/test_ner.py @@ -10,7 +10,7 @@ from spacy.lang.it import Italian from spacy.language import Language from spacy.lookups import Lookups from spacy.pipeline._parser_internals.ner import BiluoPushDown -from spacy.training import Example, iob_to_biluo +from spacy.training import Example, iob_to_biluo, split_bilu_label from spacy.tokens import Doc, Span from spacy.vocab import Vocab import logging @@ -110,6 +110,9 @@ def test_issue2385(): # maintain support for iob2 format tags3 = ("B-PERSON", "I-PERSON", "B-PERSON") assert iob_to_biluo(tags3) == ["B-PERSON", "L-PERSON", "U-PERSON"] + # ensure it works with hyphens in the name + tags4 = ("B-MULTI-PERSON", "I-MULTI-PERSON", "B-MULTI-PERSON") + assert iob_to_biluo(tags4) == ["B-MULTI-PERSON", "L-MULTI-PERSON", "U-MULTI-PERSON"] @pytest.mark.issue(2800) @@ -154,6 +157,19 @@ def test_issue3209(): assert ner2.move_names == move_names +def test_labels_from_BILUO(): + """Test that labels are inferred correctly when there's a - in label. + """ + nlp = English() + ner = nlp.add_pipe("ner") + ner.add_label("LARGE-ANIMAL") + nlp.initialize() + move_names = ["O", "B-LARGE-ANIMAL", "I-LARGE-ANIMAL", "L-LARGE-ANIMAL", "U-LARGE-ANIMAL"] + labels = {"LARGE-ANIMAL"} + assert ner.move_names == move_names + assert set(ner.labels) == labels + + @pytest.mark.issue(4267) def test_issue4267(): """Test that running an entity_ruler after ner gives consistent results""" @@ -298,7 +314,7 @@ def test_oracle_moves_missing_B(en_vocab): elif tag == "O": moves.add_action(move_types.index("O"), "") else: - action, label = tag.split("-") + action, label = split_bilu_label(tag) moves.add_action(move_types.index("B"), label) moves.add_action(move_types.index("I"), label) moves.add_action(move_types.index("L"), label) @@ -324,7 +340,7 @@ def test_oracle_moves_whitespace(en_vocab): elif tag == "O": moves.add_action(move_types.index("O"), "") else: - action, label = tag.split("-") + action, label = split_bilu_label(tag) moves.add_action(move_types.index(action), label) moves.get_oracle_sequence(example) diff --git a/spacy/tests/util.py b/spacy/tests/util.py index 365ea4349..d5f3c39ff 100644 --- a/spacy/tests/util.py +++ b/spacy/tests/util.py @@ -5,6 +5,7 @@ import srsly from spacy.tokens import Doc from spacy.vocab import Vocab from spacy.util import make_tempdir # noqa: F401 +from spacy.training import split_bilu_label from thinc.api import get_current_ops @@ -40,7 +41,7 @@ def apply_transition_sequence(parser, doc, sequence): desired state.""" for action_name in sequence: if "-" in action_name: - move, label = action_name.split("-") + move, label = split_bilu_label(action_name) parser.add_label(label) with parser.step_through(doc) as stepwise: for transition in sequence: diff --git a/spacy/training/__init__.py b/spacy/training/__init__.py index a4feb01f4..71d1fa775 100644 --- a/spacy/training/__init__.py +++ b/spacy/training/__init__.py @@ -5,6 +5,7 @@ from .augment import dont_augment, orth_variants_augmenter # noqa: F401 from .iob_utils import iob_to_biluo, biluo_to_iob # noqa: F401 from .iob_utils import offsets_to_biluo_tags, biluo_tags_to_offsets # noqa: F401 from .iob_utils import biluo_tags_to_spans, tags_to_entities # noqa: F401 +from .iob_utils import split_bilu_label, remove_bilu_prefix # noqa: F401 from .gold_io import docs_to_json, read_json_file # noqa: F401 from .batchers import minibatch_by_padded_size, minibatch_by_words # noqa: F401 from .loggers import console_logger # noqa: F401 diff --git a/spacy/training/augment.py b/spacy/training/augment.py index 59a39c7ee..55d780ba4 100644 --- a/spacy/training/augment.py +++ b/spacy/training/augment.py @@ -3,10 +3,10 @@ from typing import Optional import random import itertools from functools import partial -from pydantic import BaseModel, StrictStr from ..util import registry from .example import Example +from .iob_utils import split_bilu_label if TYPE_CHECKING: from ..language import Language # noqa: F401 @@ -278,10 +278,8 @@ def make_whitespace_variant( ent_prev = doc_dict["entities"][position - 1] ent_next = doc_dict["entities"][position] if "-" in ent_prev and "-" in ent_next: - ent_iob_prev = ent_prev.split("-")[0] - ent_type_prev = ent_prev.split("-", 1)[1] - ent_iob_next = ent_next.split("-")[0] - ent_type_next = ent_next.split("-", 1)[1] + ent_iob_prev, ent_type_prev = split_bilu_label(ent_prev) + ent_iob_next, ent_type_next = split_bilu_label(ent_next) if ( ent_iob_prev in ("B", "I") and ent_iob_next in ("I", "L") diff --git a/spacy/training/example.pyx b/spacy/training/example.pyx index 3035388a6..045f0b483 100644 --- a/spacy/training/example.pyx +++ b/spacy/training/example.pyx @@ -9,7 +9,7 @@ from ..tokens.span import Span from ..attrs import IDS from .alignment import Alignment from .iob_utils import biluo_to_iob, offsets_to_biluo_tags, doc_to_biluo_tags -from .iob_utils import biluo_tags_to_spans +from .iob_utils import biluo_tags_to_spans, remove_bilu_prefix from ..errors import Errors, Warnings from ..pipeline._parser_internals import nonproj from ..tokens.token cimport MISSING_DEP @@ -519,7 +519,7 @@ def _parse_ner_tags(biluo_or_offsets, vocab, words, spaces): else: ent_iobs.append(iob_tag.split("-")[0]) if iob_tag.startswith("I") or iob_tag.startswith("B"): - ent_types.append(iob_tag.split("-", 1)[1]) + ent_types.append(remove_bilu_prefix(iob_tag)) else: ent_types.append("") return ent_iobs, ent_types diff --git a/spacy/training/iob_utils.py b/spacy/training/iob_utils.py index 64492c2bc..61f83a1c3 100644 --- a/spacy/training/iob_utils.py +++ b/spacy/training/iob_utils.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Tuple, Iterable, Union, Iterator +from typing import List, Dict, Tuple, Iterable, Union, Iterator, cast import warnings from ..errors import Errors, Warnings @@ -218,6 +218,14 @@ def tags_to_entities(tags: Iterable[str]) -> List[Tuple[str, int, int]]: return entities +def split_bilu_label(label: str) -> Tuple[str, str]: + return cast(Tuple[str, str], label.split("-", 1)) + + +def remove_bilu_prefix(label: str) -> str: + return label.split("-", 1)[1] + + # Fallbacks to make backwards-compat easier offsets_from_biluo_tags = biluo_tags_to_offsets spans_from_biluo_tags = biluo_tags_to_spans