mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	account for NER labels with a hyphen in the name (#10960)
* account for NER labels with a hyphen in the name * cleanup * fix docstring * add return type to helper method * shorter method and few more occurrences * user helper method across repo * fix circular import * partial revert to avoid circular import
This commit is contained in:
		
							parent
							
								
									6313787fb6
								
							
						
					
					
						commit
						eaeca5eb6a
					
				|  | @ -10,7 +10,7 @@ import math | ||||||
| 
 | 
 | ||||||
| from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides | from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides | ||||||
| from ._util import import_code, debug_cli | from ._util import import_code, debug_cli | ||||||
| from ..training import Example | from ..training import Example, remove_bilu_prefix | ||||||
| from ..training.initialize import get_sourced_components | from ..training.initialize import get_sourced_components | ||||||
| from ..schemas import ConfigSchemaTraining | from ..schemas import ConfigSchemaTraining | ||||||
| from ..pipeline._parser_internals import nonproj | from ..pipeline._parser_internals import nonproj | ||||||
|  | @ -758,9 +758,9 @@ def _compile_gold( | ||||||
|                     # "Illegal" whitespace entity |                     # "Illegal" whitespace entity | ||||||
|                     data["ws_ents"] += 1 |                     data["ws_ents"] += 1 | ||||||
|                 if label.startswith(("B-", "U-")): |                 if label.startswith(("B-", "U-")): | ||||||
|                     combined_label = label.split("-")[1] |                     combined_label = remove_bilu_prefix(label) | ||||||
|                     data["ner"][combined_label] += 1 |                     data["ner"][combined_label] += 1 | ||||||
|                 if sent_starts[i] == True and label.startswith(("I-", "L-")): |                 if sent_starts[i] and label.startswith(("I-", "L-")): | ||||||
|                     data["boundary_cross_ents"] += 1 |                     data["boundary_cross_ents"] += 1 | ||||||
|                 elif label == "-": |                 elif label == "-": | ||||||
|                     data["ner"]["-"] += 1 |                     data["ner"]["-"] += 1 | ||||||
|  | @ -908,7 +908,7 @@ def _get_examples_without_label( | ||||||
|     for eg in data: |     for eg in data: | ||||||
|         if component == "ner": |         if component == "ner": | ||||||
|             labels = [ |             labels = [ | ||||||
|                 label.split("-")[1] |                 remove_bilu_prefix(label) | ||||||
|                 for label in eg.get_aligned_ner() |                 for label in eg.get_aligned_ner() | ||||||
|                 if label not in ("O", "-", None) |                 if label not in ("O", "-", None) | ||||||
|             ] |             ] | ||||||
|  |  | ||||||
|  | @ -10,6 +10,7 @@ from ...strings cimport hash_string | ||||||
| from ...structs cimport TokenC | from ...structs cimport TokenC | ||||||
| from ...tokens.doc cimport Doc, set_children_from_heads | from ...tokens.doc cimport Doc, set_children_from_heads | ||||||
| from ...tokens.token cimport MISSING_DEP | from ...tokens.token cimport MISSING_DEP | ||||||
|  | from ...training import split_bilu_label | ||||||
| from ...training.example cimport Example | from ...training.example cimport Example | ||||||
| from .stateclass cimport StateClass | from .stateclass cimport StateClass | ||||||
| from ._state cimport StateC, ArcC | from ._state cimport StateC, ArcC | ||||||
|  | @ -687,7 +688,7 @@ cdef class ArcEager(TransitionSystem): | ||||||
|             return self.c[name_or_id] |             return self.c[name_or_id] | ||||||
|         name = name_or_id |         name = name_or_id | ||||||
|         if '-' in name: |         if '-' in name: | ||||||
|             move_str, label_str = name.split('-', 1) |             move_str, label_str = split_bilu_label(name) | ||||||
|             label = self.strings[label_str] |             label = self.strings[label_str] | ||||||
|         else: |         else: | ||||||
|             move_str = name |             move_str = name | ||||||
|  |  | ||||||
|  | @ -13,6 +13,7 @@ from ...typedefs cimport weight_t, attr_t | ||||||
| from ...lexeme cimport Lexeme | from ...lexeme cimport Lexeme | ||||||
| from ...attrs cimport IS_SPACE | from ...attrs cimport IS_SPACE | ||||||
| from ...structs cimport TokenC, SpanC | from ...structs cimport TokenC, SpanC | ||||||
|  | from ...training import split_bilu_label | ||||||
| from ...training.example cimport Example | from ...training.example cimport Example | ||||||
| from .stateclass cimport StateClass | from .stateclass cimport StateClass | ||||||
| from ._state cimport StateC | from ._state cimport StateC | ||||||
|  | @ -182,7 +183,7 @@ cdef class BiluoPushDown(TransitionSystem): | ||||||
|         if name == '-' or name == '' or name is None: |         if name == '-' or name == '' or name is None: | ||||||
|             return Transition(clas=0, move=MISSING, label=0, score=0) |             return Transition(clas=0, move=MISSING, label=0, score=0) | ||||||
|         elif '-' in name: |         elif '-' in name: | ||||||
|             move_str, label_str = name.split('-', 1) |             move_str, label_str = split_bilu_label(name) | ||||||
|             # Deprecated, hacky way to denote 'not this entity' |             # Deprecated, hacky way to denote 'not this entity' | ||||||
|             if label_str.startswith('!'): |             if label_str.startswith('!'): | ||||||
|                 raise ValueError(Errors.E869.format(label=name)) |                 raise ValueError(Errors.E869.format(label=name)) | ||||||
|  |  | ||||||
|  | @ -12,6 +12,7 @@ from ..language import Language | ||||||
| from ._parser_internals import nonproj | from ._parser_internals import nonproj | ||||||
| from ._parser_internals.nonproj import DELIMITER | from ._parser_internals.nonproj import DELIMITER | ||||||
| from ..scorer import Scorer | from ..scorer import Scorer | ||||||
|  | from ..training import remove_bilu_prefix | ||||||
| from ..util import registry | from ..util import registry | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -314,7 +315,7 @@ cdef class DependencyParser(Parser): | ||||||
|         # Get the labels from the model by looking at the available moves |         # Get the labels from the model by looking at the available moves | ||||||
|         for move in self.move_names: |         for move in self.move_names: | ||||||
|             if "-" in move: |             if "-" in move: | ||||||
|                 label = move.split("-")[1] |                 label = remove_bilu_prefix(move) | ||||||
|                 if DELIMITER in label: |                 if DELIMITER in label: | ||||||
|                     label = label.split(DELIMITER)[1] |                     label = label.split(DELIMITER)[1] | ||||||
|                 labels.add(label) |                 labels.add(label) | ||||||
|  |  | ||||||
|  | @ -6,10 +6,10 @@ from thinc.api import Model, Config | ||||||
| from ._parser_internals.transition_system import TransitionSystem | from ._parser_internals.transition_system import TransitionSystem | ||||||
| from .transition_parser cimport Parser | from .transition_parser cimport Parser | ||||||
| from ._parser_internals.ner cimport BiluoPushDown | from ._parser_internals.ner cimport BiluoPushDown | ||||||
| 
 |  | ||||||
| from ..language import Language | from ..language import Language | ||||||
| from ..scorer import get_ner_prf, PRFScore | from ..scorer import get_ner_prf, PRFScore | ||||||
| from ..util import registry | from ..util import registry | ||||||
|  | from ..training import remove_bilu_prefix | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| default_model_config = """ | default_model_config = """ | ||||||
|  | @ -242,7 +242,7 @@ cdef class EntityRecognizer(Parser): | ||||||
|     def labels(self): |     def labels(self): | ||||||
|         # Get the labels from the model by looking at the available moves, e.g. |         # Get the labels from the model by looking at the available moves, e.g. | ||||||
|         # B-PERSON, I-PERSON, L-PERSON, U-PERSON |         # B-PERSON, I-PERSON, L-PERSON, U-PERSON | ||||||
|         labels = set(move.split("-")[1] for move in self.move_names |         labels = set(remove_bilu_prefix(move) for move in self.move_names | ||||||
|                      if move[0] in ("B", "I", "L", "U")) |                      if move[0] in ("B", "I", "L", "U")) | ||||||
|         return tuple(sorted(labels)) |         return tuple(sorted(labels)) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -10,7 +10,7 @@ from spacy.lang.it import Italian | ||||||
| from spacy.language import Language | from spacy.language import Language | ||||||
| from spacy.lookups import Lookups | from spacy.lookups import Lookups | ||||||
| from spacy.pipeline._parser_internals.ner import BiluoPushDown | from spacy.pipeline._parser_internals.ner import BiluoPushDown | ||||||
| from spacy.training import Example, iob_to_biluo | from spacy.training import Example, iob_to_biluo, split_bilu_label | ||||||
| from spacy.tokens import Doc, Span | from spacy.tokens import Doc, Span | ||||||
| from spacy.vocab import Vocab | from spacy.vocab import Vocab | ||||||
| import logging | import logging | ||||||
|  | @ -110,6 +110,9 @@ def test_issue2385(): | ||||||
|     # maintain support for iob2 format |     # maintain support for iob2 format | ||||||
|     tags3 = ("B-PERSON", "I-PERSON", "B-PERSON") |     tags3 = ("B-PERSON", "I-PERSON", "B-PERSON") | ||||||
|     assert iob_to_biluo(tags3) == ["B-PERSON", "L-PERSON", "U-PERSON"] |     assert iob_to_biluo(tags3) == ["B-PERSON", "L-PERSON", "U-PERSON"] | ||||||
|  |     # ensure it works with hyphens in the name | ||||||
|  |     tags4 = ("B-MULTI-PERSON", "I-MULTI-PERSON", "B-MULTI-PERSON") | ||||||
|  |     assert iob_to_biluo(tags4) == ["B-MULTI-PERSON", "L-MULTI-PERSON", "U-MULTI-PERSON"] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.issue(2800) | @pytest.mark.issue(2800) | ||||||
|  | @ -154,6 +157,19 @@ def test_issue3209(): | ||||||
|     assert ner2.move_names == move_names |     assert ner2.move_names == move_names | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def test_labels_from_BILUO(): | ||||||
|  |     """Test that labels are inferred correctly when there's a - in label. | ||||||
|  |     """ | ||||||
|  |     nlp = English() | ||||||
|  |     ner = nlp.add_pipe("ner") | ||||||
|  |     ner.add_label("LARGE-ANIMAL") | ||||||
|  |     nlp.initialize() | ||||||
|  |     move_names = ["O", "B-LARGE-ANIMAL", "I-LARGE-ANIMAL", "L-LARGE-ANIMAL", "U-LARGE-ANIMAL"] | ||||||
|  |     labels = {"LARGE-ANIMAL"} | ||||||
|  |     assert ner.move_names == move_names | ||||||
|  |     assert set(ner.labels) == labels | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| @pytest.mark.issue(4267) | @pytest.mark.issue(4267) | ||||||
| def test_issue4267(): | def test_issue4267(): | ||||||
|     """Test that running an entity_ruler after ner gives consistent results""" |     """Test that running an entity_ruler after ner gives consistent results""" | ||||||
|  | @ -298,7 +314,7 @@ def test_oracle_moves_missing_B(en_vocab): | ||||||
|         elif tag == "O": |         elif tag == "O": | ||||||
|             moves.add_action(move_types.index("O"), "") |             moves.add_action(move_types.index("O"), "") | ||||||
|         else: |         else: | ||||||
|             action, label = tag.split("-") |             action, label = split_bilu_label(tag) | ||||||
|             moves.add_action(move_types.index("B"), label) |             moves.add_action(move_types.index("B"), label) | ||||||
|             moves.add_action(move_types.index("I"), label) |             moves.add_action(move_types.index("I"), label) | ||||||
|             moves.add_action(move_types.index("L"), label) |             moves.add_action(move_types.index("L"), label) | ||||||
|  | @ -324,7 +340,7 @@ def test_oracle_moves_whitespace(en_vocab): | ||||||
|         elif tag == "O": |         elif tag == "O": | ||||||
|             moves.add_action(move_types.index("O"), "") |             moves.add_action(move_types.index("O"), "") | ||||||
|         else: |         else: | ||||||
|             action, label = tag.split("-") |             action, label = split_bilu_label(tag) | ||||||
|             moves.add_action(move_types.index(action), label) |             moves.add_action(move_types.index(action), label) | ||||||
|     moves.get_oracle_sequence(example) |     moves.get_oracle_sequence(example) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -5,6 +5,7 @@ import srsly | ||||||
| from spacy.tokens import Doc | from spacy.tokens import Doc | ||||||
| from spacy.vocab import Vocab | from spacy.vocab import Vocab | ||||||
| from spacy.util import make_tempdir  # noqa: F401 | from spacy.util import make_tempdir  # noqa: F401 | ||||||
|  | from spacy.training import split_bilu_label | ||||||
| from thinc.api import get_current_ops | from thinc.api import get_current_ops | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -40,7 +41,7 @@ def apply_transition_sequence(parser, doc, sequence): | ||||||
|     desired state.""" |     desired state.""" | ||||||
|     for action_name in sequence: |     for action_name in sequence: | ||||||
|         if "-" in action_name: |         if "-" in action_name: | ||||||
|             move, label = action_name.split("-") |             move, label = split_bilu_label(action_name) | ||||||
|             parser.add_label(label) |             parser.add_label(label) | ||||||
|     with parser.step_through(doc) as stepwise: |     with parser.step_through(doc) as stepwise: | ||||||
|         for transition in sequence: |         for transition in sequence: | ||||||
|  |  | ||||||
|  | @ -5,6 +5,7 @@ from .augment import dont_augment, orth_variants_augmenter  # noqa: F401 | ||||||
| from .iob_utils import iob_to_biluo, biluo_to_iob  # noqa: F401 | from .iob_utils import iob_to_biluo, biluo_to_iob  # noqa: F401 | ||||||
| from .iob_utils import offsets_to_biluo_tags, biluo_tags_to_offsets  # noqa: F401 | from .iob_utils import offsets_to_biluo_tags, biluo_tags_to_offsets  # noqa: F401 | ||||||
| from .iob_utils import biluo_tags_to_spans, tags_to_entities  # noqa: F401 | from .iob_utils import biluo_tags_to_spans, tags_to_entities  # noqa: F401 | ||||||
|  | from .iob_utils import split_bilu_label, remove_bilu_prefix  # noqa: F401 | ||||||
| from .gold_io import docs_to_json, read_json_file  # noqa: F401 | from .gold_io import docs_to_json, read_json_file  # noqa: F401 | ||||||
| from .batchers import minibatch_by_padded_size, minibatch_by_words  # noqa: F401 | from .batchers import minibatch_by_padded_size, minibatch_by_words  # noqa: F401 | ||||||
| from .loggers import console_logger  # noqa: F401 | from .loggers import console_logger  # noqa: F401 | ||||||
|  |  | ||||||
|  | @ -3,10 +3,10 @@ from typing import Optional | ||||||
| import random | import random | ||||||
| import itertools | import itertools | ||||||
| from functools import partial | from functools import partial | ||||||
| from pydantic import BaseModel, StrictStr |  | ||||||
| 
 | 
 | ||||||
| from ..util import registry | from ..util import registry | ||||||
| from .example import Example | from .example import Example | ||||||
|  | from .iob_utils import split_bilu_label | ||||||
| 
 | 
 | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
|     from ..language import Language  # noqa: F401 |     from ..language import Language  # noqa: F401 | ||||||
|  | @ -278,10 +278,8 @@ def make_whitespace_variant( | ||||||
|             ent_prev = doc_dict["entities"][position - 1] |             ent_prev = doc_dict["entities"][position - 1] | ||||||
|             ent_next = doc_dict["entities"][position] |             ent_next = doc_dict["entities"][position] | ||||||
|             if "-" in ent_prev and "-" in ent_next: |             if "-" in ent_prev and "-" in ent_next: | ||||||
|                 ent_iob_prev = ent_prev.split("-")[0] |                 ent_iob_prev, ent_type_prev = split_bilu_label(ent_prev) | ||||||
|                 ent_type_prev = ent_prev.split("-", 1)[1] |                 ent_iob_next, ent_type_next = split_bilu_label(ent_next) | ||||||
|                 ent_iob_next = ent_next.split("-")[0] |  | ||||||
|                 ent_type_next = ent_next.split("-", 1)[1] |  | ||||||
|                 if ( |                 if ( | ||||||
|                     ent_iob_prev in ("B", "I") |                     ent_iob_prev in ("B", "I") | ||||||
|                     and ent_iob_next in ("I", "L") |                     and ent_iob_next in ("I", "L") | ||||||
|  |  | ||||||
|  | @ -9,7 +9,7 @@ from ..tokens.span import Span | ||||||
| from ..attrs import IDS | from ..attrs import IDS | ||||||
| from .alignment import Alignment | from .alignment import Alignment | ||||||
| from .iob_utils import biluo_to_iob, offsets_to_biluo_tags, doc_to_biluo_tags | from .iob_utils import biluo_to_iob, offsets_to_biluo_tags, doc_to_biluo_tags | ||||||
| from .iob_utils import biluo_tags_to_spans | from .iob_utils import biluo_tags_to_spans, remove_bilu_prefix | ||||||
| from ..errors import Errors, Warnings | from ..errors import Errors, Warnings | ||||||
| from ..pipeline._parser_internals import nonproj | from ..pipeline._parser_internals import nonproj | ||||||
| from ..tokens.token cimport MISSING_DEP | from ..tokens.token cimport MISSING_DEP | ||||||
|  | @ -519,7 +519,7 @@ def _parse_ner_tags(biluo_or_offsets, vocab, words, spaces): | ||||||
|         else: |         else: | ||||||
|             ent_iobs.append(iob_tag.split("-")[0]) |             ent_iobs.append(iob_tag.split("-")[0]) | ||||||
|             if iob_tag.startswith("I") or iob_tag.startswith("B"): |             if iob_tag.startswith("I") or iob_tag.startswith("B"): | ||||||
|                 ent_types.append(iob_tag.split("-", 1)[1]) |                 ent_types.append(remove_bilu_prefix(iob_tag)) | ||||||
|             else: |             else: | ||||||
|                 ent_types.append("") |                 ent_types.append("") | ||||||
|     return ent_iobs, ent_types |     return ent_iobs, ent_types | ||||||
|  |  | ||||||
|  | @ -1,4 +1,4 @@ | ||||||
| from typing import List, Dict, Tuple, Iterable, Union, Iterator | from typing import List, Dict, Tuple, Iterable, Union, Iterator, cast | ||||||
| import warnings | import warnings | ||||||
| 
 | 
 | ||||||
| from ..errors import Errors, Warnings | from ..errors import Errors, Warnings | ||||||
|  | @ -218,6 +218,14 @@ def tags_to_entities(tags: Iterable[str]) -> List[Tuple[str, int, int]]: | ||||||
|     return entities |     return entities | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def split_bilu_label(label: str) -> Tuple[str, str]: | ||||||
|  |     return cast(Tuple[str, str], label.split("-", 1)) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def remove_bilu_prefix(label: str) -> str: | ||||||
|  |     return label.split("-", 1)[1] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| # Fallbacks to make backwards-compat easier | # Fallbacks to make backwards-compat easier | ||||||
| offsets_from_biluo_tags = biluo_tags_to_offsets | offsets_from_biluo_tags = biluo_tags_to_offsets | ||||||
| spans_from_biluo_tags = biluo_tags_to_spans | spans_from_biluo_tags = biluo_tags_to_spans | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user