account for NER labels with a hyphen in the name (#10960)

* account for NER labels with a hyphen in the name

* cleanup

* fix docstring

* add return type to helper method

* shorter method and few more occurrences

* user helper method across repo

* fix circular import

* partial revert to avoid circular import
This commit is contained in:
Sofie Van Landeghem 2022-06-17 20:02:37 +01:00 committed by GitHub
parent 6313787fb6
commit eaeca5eb6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 48 additions and 21 deletions

View File

@ -10,7 +10,7 @@ import math
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
from ._util import import_code, debug_cli from ._util import import_code, debug_cli
from ..training import Example from ..training import Example, remove_bilu_prefix
from ..training.initialize import get_sourced_components from ..training.initialize import get_sourced_components
from ..schemas import ConfigSchemaTraining from ..schemas import ConfigSchemaTraining
from ..pipeline._parser_internals import nonproj from ..pipeline._parser_internals import nonproj
@ -758,9 +758,9 @@ def _compile_gold(
# "Illegal" whitespace entity # "Illegal" whitespace entity
data["ws_ents"] += 1 data["ws_ents"] += 1
if label.startswith(("B-", "U-")): if label.startswith(("B-", "U-")):
combined_label = label.split("-")[1] combined_label = remove_bilu_prefix(label)
data["ner"][combined_label] += 1 data["ner"][combined_label] += 1
if sent_starts[i] == True and label.startswith(("I-", "L-")): if sent_starts[i] and label.startswith(("I-", "L-")):
data["boundary_cross_ents"] += 1 data["boundary_cross_ents"] += 1
elif label == "-": elif label == "-":
data["ner"]["-"] += 1 data["ner"]["-"] += 1
@ -908,7 +908,7 @@ def _get_examples_without_label(
for eg in data: for eg in data:
if component == "ner": if component == "ner":
labels = [ labels = [
label.split("-")[1] remove_bilu_prefix(label)
for label in eg.get_aligned_ner() for label in eg.get_aligned_ner()
if label not in ("O", "-", None) if label not in ("O", "-", None)
] ]

View File

@ -10,6 +10,7 @@ from ...strings cimport hash_string
from ...structs cimport TokenC from ...structs cimport TokenC
from ...tokens.doc cimport Doc, set_children_from_heads from ...tokens.doc cimport Doc, set_children_from_heads
from ...tokens.token cimport MISSING_DEP from ...tokens.token cimport MISSING_DEP
from ...training import split_bilu_label
from ...training.example cimport Example from ...training.example cimport Example
from .stateclass cimport StateClass from .stateclass cimport StateClass
from ._state cimport StateC, ArcC from ._state cimport StateC, ArcC
@ -687,7 +688,7 @@ cdef class ArcEager(TransitionSystem):
return self.c[name_or_id] return self.c[name_or_id]
name = name_or_id name = name_or_id
if '-' in name: if '-' in name:
move_str, label_str = name.split('-', 1) move_str, label_str = split_bilu_label(name)
label = self.strings[label_str] label = self.strings[label_str]
else: else:
move_str = name move_str = name

View File

@ -13,6 +13,7 @@ from ...typedefs cimport weight_t, attr_t
from ...lexeme cimport Lexeme from ...lexeme cimport Lexeme
from ...attrs cimport IS_SPACE from ...attrs cimport IS_SPACE
from ...structs cimport TokenC, SpanC from ...structs cimport TokenC, SpanC
from ...training import split_bilu_label
from ...training.example cimport Example from ...training.example cimport Example
from .stateclass cimport StateClass from .stateclass cimport StateClass
from ._state cimport StateC from ._state cimport StateC
@ -182,7 +183,7 @@ cdef class BiluoPushDown(TransitionSystem):
if name == '-' or name == '' or name is None: if name == '-' or name == '' or name is None:
return Transition(clas=0, move=MISSING, label=0, score=0) return Transition(clas=0, move=MISSING, label=0, score=0)
elif '-' in name: elif '-' in name:
move_str, label_str = name.split('-', 1) move_str, label_str = split_bilu_label(name)
# Deprecated, hacky way to denote 'not this entity' # Deprecated, hacky way to denote 'not this entity'
if label_str.startswith('!'): if label_str.startswith('!'):
raise ValueError(Errors.E869.format(label=name)) raise ValueError(Errors.E869.format(label=name))

View File

@ -12,6 +12,7 @@ from ..language import Language
from ._parser_internals import nonproj from ._parser_internals import nonproj
from ._parser_internals.nonproj import DELIMITER from ._parser_internals.nonproj import DELIMITER
from ..scorer import Scorer from ..scorer import Scorer
from ..training import remove_bilu_prefix
from ..util import registry from ..util import registry
@ -314,7 +315,7 @@ cdef class DependencyParser(Parser):
# Get the labels from the model by looking at the available moves # Get the labels from the model by looking at the available moves
for move in self.move_names: for move in self.move_names:
if "-" in move: if "-" in move:
label = move.split("-")[1] label = remove_bilu_prefix(move)
if DELIMITER in label: if DELIMITER in label:
label = label.split(DELIMITER)[1] label = label.split(DELIMITER)[1]
labels.add(label) labels.add(label)

View File

@ -6,10 +6,10 @@ from thinc.api import Model, Config
from ._parser_internals.transition_system import TransitionSystem from ._parser_internals.transition_system import TransitionSystem
from .transition_parser cimport Parser from .transition_parser cimport Parser
from ._parser_internals.ner cimport BiluoPushDown from ._parser_internals.ner cimport BiluoPushDown
from ..language import Language from ..language import Language
from ..scorer import get_ner_prf, PRFScore from ..scorer import get_ner_prf, PRFScore
from ..util import registry from ..util import registry
from ..training import remove_bilu_prefix
default_model_config = """ default_model_config = """
@ -242,7 +242,7 @@ cdef class EntityRecognizer(Parser):
def labels(self): def labels(self):
# Get the labels from the model by looking at the available moves, e.g. # Get the labels from the model by looking at the available moves, e.g.
# B-PERSON, I-PERSON, L-PERSON, U-PERSON # B-PERSON, I-PERSON, L-PERSON, U-PERSON
labels = set(move.split("-")[1] for move in self.move_names labels = set(remove_bilu_prefix(move) for move in self.move_names
if move[0] in ("B", "I", "L", "U")) if move[0] in ("B", "I", "L", "U"))
return tuple(sorted(labels)) return tuple(sorted(labels))

View File

@ -10,7 +10,7 @@ from spacy.lang.it import Italian
from spacy.language import Language from spacy.language import Language
from spacy.lookups import Lookups from spacy.lookups import Lookups
from spacy.pipeline._parser_internals.ner import BiluoPushDown from spacy.pipeline._parser_internals.ner import BiluoPushDown
from spacy.training import Example, iob_to_biluo from spacy.training import Example, iob_to_biluo, split_bilu_label
from spacy.tokens import Doc, Span from spacy.tokens import Doc, Span
from spacy.vocab import Vocab from spacy.vocab import Vocab
import logging import logging
@ -110,6 +110,9 @@ def test_issue2385():
# maintain support for iob2 format # maintain support for iob2 format
tags3 = ("B-PERSON", "I-PERSON", "B-PERSON") tags3 = ("B-PERSON", "I-PERSON", "B-PERSON")
assert iob_to_biluo(tags3) == ["B-PERSON", "L-PERSON", "U-PERSON"] assert iob_to_biluo(tags3) == ["B-PERSON", "L-PERSON", "U-PERSON"]
# ensure it works with hyphens in the name
tags4 = ("B-MULTI-PERSON", "I-MULTI-PERSON", "B-MULTI-PERSON")
assert iob_to_biluo(tags4) == ["B-MULTI-PERSON", "L-MULTI-PERSON", "U-MULTI-PERSON"]
@pytest.mark.issue(2800) @pytest.mark.issue(2800)
@ -154,6 +157,19 @@ def test_issue3209():
assert ner2.move_names == move_names assert ner2.move_names == move_names
def test_labels_from_BILUO():
"""Test that labels are inferred correctly when there's a - in label.
"""
nlp = English()
ner = nlp.add_pipe("ner")
ner.add_label("LARGE-ANIMAL")
nlp.initialize()
move_names = ["O", "B-LARGE-ANIMAL", "I-LARGE-ANIMAL", "L-LARGE-ANIMAL", "U-LARGE-ANIMAL"]
labels = {"LARGE-ANIMAL"}
assert ner.move_names == move_names
assert set(ner.labels) == labels
@pytest.mark.issue(4267) @pytest.mark.issue(4267)
def test_issue4267(): def test_issue4267():
"""Test that running an entity_ruler after ner gives consistent results""" """Test that running an entity_ruler after ner gives consistent results"""
@ -298,7 +314,7 @@ def test_oracle_moves_missing_B(en_vocab):
elif tag == "O": elif tag == "O":
moves.add_action(move_types.index("O"), "") moves.add_action(move_types.index("O"), "")
else: else:
action, label = tag.split("-") action, label = split_bilu_label(tag)
moves.add_action(move_types.index("B"), label) moves.add_action(move_types.index("B"), label)
moves.add_action(move_types.index("I"), label) moves.add_action(move_types.index("I"), label)
moves.add_action(move_types.index("L"), label) moves.add_action(move_types.index("L"), label)
@ -324,7 +340,7 @@ def test_oracle_moves_whitespace(en_vocab):
elif tag == "O": elif tag == "O":
moves.add_action(move_types.index("O"), "") moves.add_action(move_types.index("O"), "")
else: else:
action, label = tag.split("-") action, label = split_bilu_label(tag)
moves.add_action(move_types.index(action), label) moves.add_action(move_types.index(action), label)
moves.get_oracle_sequence(example) moves.get_oracle_sequence(example)

View File

@ -5,6 +5,7 @@ import srsly
from spacy.tokens import Doc from spacy.tokens import Doc
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.util import make_tempdir # noqa: F401 from spacy.util import make_tempdir # noqa: F401
from spacy.training import split_bilu_label
from thinc.api import get_current_ops from thinc.api import get_current_ops
@ -40,7 +41,7 @@ def apply_transition_sequence(parser, doc, sequence):
desired state.""" desired state."""
for action_name in sequence: for action_name in sequence:
if "-" in action_name: if "-" in action_name:
move, label = action_name.split("-") move, label = split_bilu_label(action_name)
parser.add_label(label) parser.add_label(label)
with parser.step_through(doc) as stepwise: with parser.step_through(doc) as stepwise:
for transition in sequence: for transition in sequence:

View File

@ -5,6 +5,7 @@ from .augment import dont_augment, orth_variants_augmenter # noqa: F401
from .iob_utils import iob_to_biluo, biluo_to_iob # noqa: F401 from .iob_utils import iob_to_biluo, biluo_to_iob # noqa: F401
from .iob_utils import offsets_to_biluo_tags, biluo_tags_to_offsets # noqa: F401 from .iob_utils import offsets_to_biluo_tags, biluo_tags_to_offsets # noqa: F401
from .iob_utils import biluo_tags_to_spans, tags_to_entities # noqa: F401 from .iob_utils import biluo_tags_to_spans, tags_to_entities # noqa: F401
from .iob_utils import split_bilu_label, remove_bilu_prefix # noqa: F401
from .gold_io import docs_to_json, read_json_file # noqa: F401 from .gold_io import docs_to_json, read_json_file # noqa: F401
from .batchers import minibatch_by_padded_size, minibatch_by_words # noqa: F401 from .batchers import minibatch_by_padded_size, minibatch_by_words # noqa: F401
from .loggers import console_logger # noqa: F401 from .loggers import console_logger # noqa: F401

View File

@ -3,10 +3,10 @@ from typing import Optional
import random import random
import itertools import itertools
from functools import partial from functools import partial
from pydantic import BaseModel, StrictStr
from ..util import registry from ..util import registry
from .example import Example from .example import Example
from .iob_utils import split_bilu_label
if TYPE_CHECKING: if TYPE_CHECKING:
from ..language import Language # noqa: F401 from ..language import Language # noqa: F401
@ -278,10 +278,8 @@ def make_whitespace_variant(
ent_prev = doc_dict["entities"][position - 1] ent_prev = doc_dict["entities"][position - 1]
ent_next = doc_dict["entities"][position] ent_next = doc_dict["entities"][position]
if "-" in ent_prev and "-" in ent_next: if "-" in ent_prev and "-" in ent_next:
ent_iob_prev = ent_prev.split("-")[0] ent_iob_prev, ent_type_prev = split_bilu_label(ent_prev)
ent_type_prev = ent_prev.split("-", 1)[1] ent_iob_next, ent_type_next = split_bilu_label(ent_next)
ent_iob_next = ent_next.split("-")[0]
ent_type_next = ent_next.split("-", 1)[1]
if ( if (
ent_iob_prev in ("B", "I") ent_iob_prev in ("B", "I")
and ent_iob_next in ("I", "L") and ent_iob_next in ("I", "L")

View File

@ -9,7 +9,7 @@ from ..tokens.span import Span
from ..attrs import IDS from ..attrs import IDS
from .alignment import Alignment from .alignment import Alignment
from .iob_utils import biluo_to_iob, offsets_to_biluo_tags, doc_to_biluo_tags from .iob_utils import biluo_to_iob, offsets_to_biluo_tags, doc_to_biluo_tags
from .iob_utils import biluo_tags_to_spans from .iob_utils import biluo_tags_to_spans, remove_bilu_prefix
from ..errors import Errors, Warnings from ..errors import Errors, Warnings
from ..pipeline._parser_internals import nonproj from ..pipeline._parser_internals import nonproj
from ..tokens.token cimport MISSING_DEP from ..tokens.token cimport MISSING_DEP
@ -519,7 +519,7 @@ def _parse_ner_tags(biluo_or_offsets, vocab, words, spaces):
else: else:
ent_iobs.append(iob_tag.split("-")[0]) ent_iobs.append(iob_tag.split("-")[0])
if iob_tag.startswith("I") or iob_tag.startswith("B"): if iob_tag.startswith("I") or iob_tag.startswith("B"):
ent_types.append(iob_tag.split("-", 1)[1]) ent_types.append(remove_bilu_prefix(iob_tag))
else: else:
ent_types.append("") ent_types.append("")
return ent_iobs, ent_types return ent_iobs, ent_types

View File

@ -1,4 +1,4 @@
from typing import List, Dict, Tuple, Iterable, Union, Iterator from typing import List, Dict, Tuple, Iterable, Union, Iterator, cast
import warnings import warnings
from ..errors import Errors, Warnings from ..errors import Errors, Warnings
@ -218,6 +218,14 @@ def tags_to_entities(tags: Iterable[str]) -> List[Tuple[str, int, int]]:
return entities return entities
def split_bilu_label(label: str) -> Tuple[str, str]:
return cast(Tuple[str, str], label.split("-", 1))
def remove_bilu_prefix(label: str) -> str:
return label.split("-", 1)[1]
# Fallbacks to make backwards-compat easier # Fallbacks to make backwards-compat easier
offsets_from_biluo_tags = biluo_tags_to_offsets offsets_from_biluo_tags = biluo_tags_to_offsets
spans_from_biluo_tags = biluo_tags_to_spans spans_from_biluo_tags = biluo_tags_to_spans