mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 09:26:27 +03:00
account for NER labels with a hyphen in the name (#10960)
* account for NER labels with a hyphen in the name * cleanup * fix docstring * add return type to helper method * shorter method and few more occurrences * user helper method across repo * fix circular import * partial revert to avoid circular import
This commit is contained in:
parent
6313787fb6
commit
eaeca5eb6a
|
@ -10,7 +10,7 @@ import math
|
|||
|
||||
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
|
||||
from ._util import import_code, debug_cli
|
||||
from ..training import Example
|
||||
from ..training import Example, remove_bilu_prefix
|
||||
from ..training.initialize import get_sourced_components
|
||||
from ..schemas import ConfigSchemaTraining
|
||||
from ..pipeline._parser_internals import nonproj
|
||||
|
@ -758,9 +758,9 @@ def _compile_gold(
|
|||
# "Illegal" whitespace entity
|
||||
data["ws_ents"] += 1
|
||||
if label.startswith(("B-", "U-")):
|
||||
combined_label = label.split("-")[1]
|
||||
combined_label = remove_bilu_prefix(label)
|
||||
data["ner"][combined_label] += 1
|
||||
if sent_starts[i] == True and label.startswith(("I-", "L-")):
|
||||
if sent_starts[i] and label.startswith(("I-", "L-")):
|
||||
data["boundary_cross_ents"] += 1
|
||||
elif label == "-":
|
||||
data["ner"]["-"] += 1
|
||||
|
@ -908,7 +908,7 @@ def _get_examples_without_label(
|
|||
for eg in data:
|
||||
if component == "ner":
|
||||
labels = [
|
||||
label.split("-")[1]
|
||||
remove_bilu_prefix(label)
|
||||
for label in eg.get_aligned_ner()
|
||||
if label not in ("O", "-", None)
|
||||
]
|
||||
|
|
|
@ -10,6 +10,7 @@ from ...strings cimport hash_string
|
|||
from ...structs cimport TokenC
|
||||
from ...tokens.doc cimport Doc, set_children_from_heads
|
||||
from ...tokens.token cimport MISSING_DEP
|
||||
from ...training import split_bilu_label
|
||||
from ...training.example cimport Example
|
||||
from .stateclass cimport StateClass
|
||||
from ._state cimport StateC, ArcC
|
||||
|
@ -687,7 +688,7 @@ cdef class ArcEager(TransitionSystem):
|
|||
return self.c[name_or_id]
|
||||
name = name_or_id
|
||||
if '-' in name:
|
||||
move_str, label_str = name.split('-', 1)
|
||||
move_str, label_str = split_bilu_label(name)
|
||||
label = self.strings[label_str]
|
||||
else:
|
||||
move_str = name
|
||||
|
|
|
@ -13,6 +13,7 @@ from ...typedefs cimport weight_t, attr_t
|
|||
from ...lexeme cimport Lexeme
|
||||
from ...attrs cimport IS_SPACE
|
||||
from ...structs cimport TokenC, SpanC
|
||||
from ...training import split_bilu_label
|
||||
from ...training.example cimport Example
|
||||
from .stateclass cimport StateClass
|
||||
from ._state cimport StateC
|
||||
|
@ -182,7 +183,7 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
if name == '-' or name == '' or name is None:
|
||||
return Transition(clas=0, move=MISSING, label=0, score=0)
|
||||
elif '-' in name:
|
||||
move_str, label_str = name.split('-', 1)
|
||||
move_str, label_str = split_bilu_label(name)
|
||||
# Deprecated, hacky way to denote 'not this entity'
|
||||
if label_str.startswith('!'):
|
||||
raise ValueError(Errors.E869.format(label=name))
|
||||
|
|
|
@ -12,6 +12,7 @@ from ..language import Language
|
|||
from ._parser_internals import nonproj
|
||||
from ._parser_internals.nonproj import DELIMITER
|
||||
from ..scorer import Scorer
|
||||
from ..training import remove_bilu_prefix
|
||||
from ..util import registry
|
||||
|
||||
|
||||
|
@ -314,7 +315,7 @@ cdef class DependencyParser(Parser):
|
|||
# Get the labels from the model by looking at the available moves
|
||||
for move in self.move_names:
|
||||
if "-" in move:
|
||||
label = move.split("-")[1]
|
||||
label = remove_bilu_prefix(move)
|
||||
if DELIMITER in label:
|
||||
label = label.split(DELIMITER)[1]
|
||||
labels.add(label)
|
||||
|
|
|
@ -6,10 +6,10 @@ from thinc.api import Model, Config
|
|||
from ._parser_internals.transition_system import TransitionSystem
|
||||
from .transition_parser cimport Parser
|
||||
from ._parser_internals.ner cimport BiluoPushDown
|
||||
|
||||
from ..language import Language
|
||||
from ..scorer import get_ner_prf, PRFScore
|
||||
from ..util import registry
|
||||
from ..training import remove_bilu_prefix
|
||||
|
||||
|
||||
default_model_config = """
|
||||
|
@ -242,7 +242,7 @@ cdef class EntityRecognizer(Parser):
|
|||
def labels(self):
|
||||
# Get the labels from the model by looking at the available moves, e.g.
|
||||
# B-PERSON, I-PERSON, L-PERSON, U-PERSON
|
||||
labels = set(move.split("-")[1] for move in self.move_names
|
||||
labels = set(remove_bilu_prefix(move) for move in self.move_names
|
||||
if move[0] in ("B", "I", "L", "U"))
|
||||
return tuple(sorted(labels))
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ from spacy.lang.it import Italian
|
|||
from spacy.language import Language
|
||||
from spacy.lookups import Lookups
|
||||
from spacy.pipeline._parser_internals.ner import BiluoPushDown
|
||||
from spacy.training import Example, iob_to_biluo
|
||||
from spacy.training import Example, iob_to_biluo, split_bilu_label
|
||||
from spacy.tokens import Doc, Span
|
||||
from spacy.vocab import Vocab
|
||||
import logging
|
||||
|
@ -110,6 +110,9 @@ def test_issue2385():
|
|||
# maintain support for iob2 format
|
||||
tags3 = ("B-PERSON", "I-PERSON", "B-PERSON")
|
||||
assert iob_to_biluo(tags3) == ["B-PERSON", "L-PERSON", "U-PERSON"]
|
||||
# ensure it works with hyphens in the name
|
||||
tags4 = ("B-MULTI-PERSON", "I-MULTI-PERSON", "B-MULTI-PERSON")
|
||||
assert iob_to_biluo(tags4) == ["B-MULTI-PERSON", "L-MULTI-PERSON", "U-MULTI-PERSON"]
|
||||
|
||||
|
||||
@pytest.mark.issue(2800)
|
||||
|
@ -154,6 +157,19 @@ def test_issue3209():
|
|||
assert ner2.move_names == move_names
|
||||
|
||||
|
||||
def test_labels_from_BILUO():
|
||||
"""Test that labels are inferred correctly when there's a - in label.
|
||||
"""
|
||||
nlp = English()
|
||||
ner = nlp.add_pipe("ner")
|
||||
ner.add_label("LARGE-ANIMAL")
|
||||
nlp.initialize()
|
||||
move_names = ["O", "B-LARGE-ANIMAL", "I-LARGE-ANIMAL", "L-LARGE-ANIMAL", "U-LARGE-ANIMAL"]
|
||||
labels = {"LARGE-ANIMAL"}
|
||||
assert ner.move_names == move_names
|
||||
assert set(ner.labels) == labels
|
||||
|
||||
|
||||
@pytest.mark.issue(4267)
|
||||
def test_issue4267():
|
||||
"""Test that running an entity_ruler after ner gives consistent results"""
|
||||
|
@ -298,7 +314,7 @@ def test_oracle_moves_missing_B(en_vocab):
|
|||
elif tag == "O":
|
||||
moves.add_action(move_types.index("O"), "")
|
||||
else:
|
||||
action, label = tag.split("-")
|
||||
action, label = split_bilu_label(tag)
|
||||
moves.add_action(move_types.index("B"), label)
|
||||
moves.add_action(move_types.index("I"), label)
|
||||
moves.add_action(move_types.index("L"), label)
|
||||
|
@ -324,7 +340,7 @@ def test_oracle_moves_whitespace(en_vocab):
|
|||
elif tag == "O":
|
||||
moves.add_action(move_types.index("O"), "")
|
||||
else:
|
||||
action, label = tag.split("-")
|
||||
action, label = split_bilu_label(tag)
|
||||
moves.add_action(move_types.index(action), label)
|
||||
moves.get_oracle_sequence(example)
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import srsly
|
|||
from spacy.tokens import Doc
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.util import make_tempdir # noqa: F401
|
||||
from spacy.training import split_bilu_label
|
||||
from thinc.api import get_current_ops
|
||||
|
||||
|
||||
|
@ -40,7 +41,7 @@ def apply_transition_sequence(parser, doc, sequence):
|
|||
desired state."""
|
||||
for action_name in sequence:
|
||||
if "-" in action_name:
|
||||
move, label = action_name.split("-")
|
||||
move, label = split_bilu_label(action_name)
|
||||
parser.add_label(label)
|
||||
with parser.step_through(doc) as stepwise:
|
||||
for transition in sequence:
|
||||
|
|
|
@ -5,6 +5,7 @@ from .augment import dont_augment, orth_variants_augmenter # noqa: F401
|
|||
from .iob_utils import iob_to_biluo, biluo_to_iob # noqa: F401
|
||||
from .iob_utils import offsets_to_biluo_tags, biluo_tags_to_offsets # noqa: F401
|
||||
from .iob_utils import biluo_tags_to_spans, tags_to_entities # noqa: F401
|
||||
from .iob_utils import split_bilu_label, remove_bilu_prefix # noqa: F401
|
||||
from .gold_io import docs_to_json, read_json_file # noqa: F401
|
||||
from .batchers import minibatch_by_padded_size, minibatch_by_words # noqa: F401
|
||||
from .loggers import console_logger # noqa: F401
|
||||
|
|
|
@ -3,10 +3,10 @@ from typing import Optional
|
|||
import random
|
||||
import itertools
|
||||
from functools import partial
|
||||
from pydantic import BaseModel, StrictStr
|
||||
|
||||
from ..util import registry
|
||||
from .example import Example
|
||||
from .iob_utils import split_bilu_label
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..language import Language # noqa: F401
|
||||
|
@ -278,10 +278,8 @@ def make_whitespace_variant(
|
|||
ent_prev = doc_dict["entities"][position - 1]
|
||||
ent_next = doc_dict["entities"][position]
|
||||
if "-" in ent_prev and "-" in ent_next:
|
||||
ent_iob_prev = ent_prev.split("-")[0]
|
||||
ent_type_prev = ent_prev.split("-", 1)[1]
|
||||
ent_iob_next = ent_next.split("-")[0]
|
||||
ent_type_next = ent_next.split("-", 1)[1]
|
||||
ent_iob_prev, ent_type_prev = split_bilu_label(ent_prev)
|
||||
ent_iob_next, ent_type_next = split_bilu_label(ent_next)
|
||||
if (
|
||||
ent_iob_prev in ("B", "I")
|
||||
and ent_iob_next in ("I", "L")
|
||||
|
|
|
@ -9,7 +9,7 @@ from ..tokens.span import Span
|
|||
from ..attrs import IDS
|
||||
from .alignment import Alignment
|
||||
from .iob_utils import biluo_to_iob, offsets_to_biluo_tags, doc_to_biluo_tags
|
||||
from .iob_utils import biluo_tags_to_spans
|
||||
from .iob_utils import biluo_tags_to_spans, remove_bilu_prefix
|
||||
from ..errors import Errors, Warnings
|
||||
from ..pipeline._parser_internals import nonproj
|
||||
from ..tokens.token cimport MISSING_DEP
|
||||
|
@ -519,7 +519,7 @@ def _parse_ner_tags(biluo_or_offsets, vocab, words, spaces):
|
|||
else:
|
||||
ent_iobs.append(iob_tag.split("-")[0])
|
||||
if iob_tag.startswith("I") or iob_tag.startswith("B"):
|
||||
ent_types.append(iob_tag.split("-", 1)[1])
|
||||
ent_types.append(remove_bilu_prefix(iob_tag))
|
||||
else:
|
||||
ent_types.append("")
|
||||
return ent_iobs, ent_types
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Dict, Tuple, Iterable, Union, Iterator
|
||||
from typing import List, Dict, Tuple, Iterable, Union, Iterator, cast
|
||||
import warnings
|
||||
|
||||
from ..errors import Errors, Warnings
|
||||
|
@ -218,6 +218,14 @@ def tags_to_entities(tags: Iterable[str]) -> List[Tuple[str, int, int]]:
|
|||
return entities
|
||||
|
||||
|
||||
def split_bilu_label(label: str) -> Tuple[str, str]:
|
||||
return cast(Tuple[str, str], label.split("-", 1))
|
||||
|
||||
|
||||
def remove_bilu_prefix(label: str) -> str:
|
||||
return label.split("-", 1)[1]
|
||||
|
||||
|
||||
# Fallbacks to make backwards-compat easier
|
||||
offsets_from_biluo_tags = biluo_tags_to_offsets
|
||||
spans_from_biluo_tags = biluo_tags_to_spans
|
||||
|
|
Loading…
Reference in New Issue
Block a user