mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	* Support a cfg field in transition system * Make NER 'has gold' check use right alignment for span * Pass 'negative_samples_key' property into NER transition system * Add field for negative samples to NER transition system * Check neg_key in NER has_gold * Support negative examples in NER oracle * Test for negative examples in NER * Fix name of config variable in NER * Remove vestiges of old-style partial annotation * Remove obsolete tests * Add comment noting lack of support for negative samples in parser * Additions to "neg examples" PR (#8201) * add custom error and test for deprecated format * add test for unlearning an entity * add break also for Begin's cost * add negative_samples_key property on Parser * rename * extend docs & fix some older docs issues * add subclass constructors, clean up tests, fix docs * add flaky test with ValueError if gold parse was not found * remove ValueError if n_gold == 0 * fix docstring * Hack in environment variables to try out training * Remove hack * Remove NER hack, and support 'negative O' samples * Fix O oracle * Fix transition parser * Remove 'not O' from oracle * Fix NER oracle * check for spans in both gold.ents and gold.spans and raise if so, to prevent memory access violation * use set instead of list in consistency check Co-authored-by: svlandeg <sofie.vanlandeghem@gmail.com> Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
		
			
				
	
	
		
			253 lines
		
	
	
		
			9.7 KiB
		
	
	
	
		
			Cython
		
	
	
	
	
	
			
		
		
	
	
			253 lines
		
	
	
		
			9.7 KiB
		
	
	
	
		
			Cython
		
	
	
	
	
	
| # cython: infer_types=True, profile=True, binding=True
 | |
| from collections import defaultdict
 | |
| from typing import Optional, Iterable
 | |
| from thinc.api import Model, Config
 | |
| 
 | |
| from ._parser_internals.transition_system import TransitionSystem
 | |
| from .transition_parser cimport Parser
 | |
| from ._parser_internals.ner cimport BiluoPushDown
 | |
| 
 | |
| from ..language import Language
 | |
| from ..scorer import get_ner_prf, PRFScore
 | |
| from ..training import validate_examples
 | |
| 
 | |
| 
 | |
| default_model_config = """
 | |
| [model]
 | |
| @architectures = "spacy.TransitionBasedParser.v2"
 | |
| state_type = "ner"
 | |
| extra_state_tokens = false
 | |
| hidden_width = 64
 | |
| maxout_pieces = 2
 | |
| use_upper = true
 | |
| 
 | |
| [model.tok2vec]
 | |
| @architectures = "spacy.HashEmbedCNN.v2"
 | |
| pretrained_vectors = null
 | |
| width = 96
 | |
| depth = 4
 | |
| embed_size = 2000
 | |
| window_size = 1
 | |
| maxout_pieces = 3
 | |
| subword_features = true
 | |
| """
 | |
| DEFAULT_NER_MODEL = Config().from_str(default_model_config)["model"]
 | |
| 
 | |
| 
 | |
| @Language.factory(
 | |
|     "ner",
 | |
|     assigns=["doc.ents", "token.ent_iob", "token.ent_type"],
 | |
|     default_config={
 | |
|         "moves": None,
 | |
|         "update_with_oracle_cut_size": 100,
 | |
|         "model": DEFAULT_NER_MODEL,
 | |
|         "incorrect_spans_key": None
 | |
|     },
 | |
|     default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0, "ents_per_type": None},
 | |
| 
 | |
| )
 | |
| def make_ner(
 | |
|     nlp: Language,
 | |
|     name: str,
 | |
|     model: Model,
 | |
|     moves: Optional[TransitionSystem],
 | |
|     update_with_oracle_cut_size: int,
 | |
|     incorrect_spans_key: Optional[str]=None
 | |
| ):
 | |
|     """Create a transition-based EntityRecognizer component. The entity recognizer
 | |
|     identifies non-overlapping labelled spans of tokens.
 | |
| 
 | |
|     The transition-based algorithm used encodes certain assumptions that are
 | |
|     effective for "traditional" named entity recognition tasks, but may not be
 | |
|     a good fit for every span identification problem. Specifically, the loss
 | |
|     function optimizes for whole entity accuracy, so if your inter-annotator
 | |
|     agreement on boundary tokens is low, the component will likely perform poorly
 | |
|     on your problem. The transition-based algorithm also assumes that the most
 | |
|     decisive information about your entities will be close to their initial tokens.
 | |
|     If your entities are long and characterised by tokens in their middle, the
 | |
|     component will likely do poorly on your task.
 | |
| 
 | |
|     model (Model): The model for the transition-based parser. The model needs
 | |
|         to have a specific substructure of named components --- see the
 | |
|         spacy.ml.tb_framework.TransitionModel for details.
 | |
|     moves (Optional[TransitionSystem]): This defines how the parse-state is created,
 | |
|         updated and evaluated. If 'moves' is None, a new instance is
 | |
|         created with `self.TransitionSystem()`. Defaults to `None`.
 | |
|     update_with_oracle_cut_size (int): During training, cut long sequences into
 | |
|         shorter segments by creating intermediate states based on the gold-standard
 | |
|         history. The model is not very sensitive to this parameter, so you usually
 | |
|         won't need to change it. 100 is a good default.
 | |
|     incorrect_spans_key (Optional[str]): Identifies spans that are known
 | |
|         to be incorrect entity annotations. The incorrect entity annotations
 | |
|         can be stored in the span group, under this key.
 | |
|     """
 | |
|     return EntityRecognizer(
 | |
|         nlp.vocab,
 | |
|         model,
 | |
|         name,
 | |
|         moves=moves,
 | |
|         update_with_oracle_cut_size=update_with_oracle_cut_size,
 | |
|         incorrect_spans_key=incorrect_spans_key,
 | |
|         multitasks=[],
 | |
|         beam_width=1,
 | |
|         beam_density=0.0,
 | |
|         beam_update_prob=0.0,
 | |
|     )
 | |
| 
 | |
| @Language.factory(
 | |
|     "beam_ner",
 | |
|     assigns=["doc.ents", "token.ent_iob", "token.ent_type"],
 | |
|     default_config={
 | |
|         "moves": None,
 | |
|         "update_with_oracle_cut_size": 100,
 | |
|         "model": DEFAULT_NER_MODEL,
 | |
|         "beam_density": 0.01,
 | |
|         "beam_update_prob": 0.5,
 | |
|         "beam_width": 32,
 | |
|         "incorrect_spans_key": None
 | |
|     },
 | |
|     default_score_weights={"ents_f": 1.0, "ents_p": 0.0, "ents_r": 0.0, "ents_per_type": None},
 | |
| )
 | |
| def make_beam_ner(
 | |
|     nlp: Language,
 | |
|     name: str,
 | |
|     model: Model,
 | |
|     moves: Optional[TransitionSystem],
 | |
|     update_with_oracle_cut_size: int,
 | |
|     beam_width: int,
 | |
|     beam_density: float,
 | |
|     beam_update_prob: float,
 | |
|     incorrect_spans_key: Optional[str]=None
 | |
| ):
 | |
|     """Create a transition-based EntityRecognizer component that uses beam-search.
 | |
|     The entity recognizer identifies non-overlapping labelled spans of tokens.
 | |
| 
 | |
|     The transition-based algorithm used encodes certain assumptions that are
 | |
|     effective for "traditional" named entity recognition tasks, but may not be
 | |
|     a good fit for every span identification problem. Specifically, the loss
 | |
|     function optimizes for whole entity accuracy, so if your inter-annotator
 | |
|     agreement on boundary tokens is low, the component will likely perform poorly
 | |
|     on your problem. The transition-based algorithm also assumes that the most
 | |
|     decisive information about your entities will be close to their initial tokens.
 | |
|     If your entities are long and characterised by tokens in their middle, the
 | |
|     component will likely do poorly on your task.
 | |
| 
 | |
|     model (Model): The model for the transition-based parser. The model needs
 | |
|         to have a specific substructure of named components --- see the
 | |
|         spacy.ml.tb_framework.TransitionModel for details.
 | |
|     moves (Optional[TransitionSystem]): This defines how the parse-state is created,
 | |
|         updated and evaluated. If 'moves' is None, a new instance is
 | |
|         created with `self.TransitionSystem()`. Defaults to `None`.
 | |
|     update_with_oracle_cut_size (int): During training, cut long sequences into
 | |
|         shorter segments by creating intermediate states based on the gold-standard
 | |
|         history. The model is not very sensitive to this parameter, so you usually
 | |
|         won't need to change it. 100 is a good default.
 | |
|     beam_width (int): The number of candidate analyses to maintain.
 | |
|     beam_density (float): The minimum ratio between the scores of the first and
 | |
|         last candidates in the beam. This allows the parser to avoid exploring
 | |
|         candidates that are too far behind. This is mostly intended to improve
 | |
|         efficiency, but it can also improve accuracy as deeper search is not
 | |
|         always better.
 | |
|     beam_update_prob (float): The chance of making a beam update, instead of a
 | |
|         greedy update. Greedy updates are an approximation for the beam updates,
 | |
|         and are faster to compute.
 | |
|     incorrect_spans_key (Optional[str]): Optional key into span groups of
 | |
|         entities known to be non-entities.
 | |
|     """
 | |
|     return EntityRecognizer(
 | |
|         nlp.vocab,
 | |
|         model,
 | |
|         name,
 | |
|         moves=moves,
 | |
|         update_with_oracle_cut_size=update_with_oracle_cut_size,
 | |
|         multitasks=[],
 | |
|         beam_width=beam_width,
 | |
|         beam_density=beam_density,
 | |
|         beam_update_prob=beam_update_prob,
 | |
|         incorrect_spans_key=incorrect_spans_key
 | |
|     )
 | |
| 
 | |
| 
 | |
| cdef class EntityRecognizer(Parser):
 | |
|     """Pipeline component for named entity recognition.
 | |
| 
 | |
|     DOCS: https://spacy.io/api/entityrecognizer
 | |
|     """
 | |
|     TransitionSystem = BiluoPushDown
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         vocab,
 | |
|         model,
 | |
|         name="ner",
 | |
|         moves=None,
 | |
|         *,
 | |
|         update_with_oracle_cut_size=100,
 | |
|         beam_width=1,
 | |
|         beam_density=0.0,
 | |
|         beam_update_prob=0.0,
 | |
|         multitasks=tuple(),
 | |
|         incorrect_spans_key=None,
 | |
|     ):
 | |
|         """Create an EntityRecognizer.
 | |
|         """
 | |
|         super().__init__(
 | |
|             vocab,
 | |
|             model,
 | |
|             name,
 | |
|             moves,
 | |
|             update_with_oracle_cut_size=update_with_oracle_cut_size,
 | |
|             min_action_freq=1,   # not relevant for NER
 | |
|             learn_tokens=False,  # not relevant for NER
 | |
|             beam_width=beam_width,
 | |
|             beam_density=beam_density,
 | |
|             beam_update_prob=beam_update_prob,
 | |
|             multitasks=multitasks,
 | |
|             incorrect_spans_key=incorrect_spans_key,
 | |
|         )
 | |
| 
 | |
|     def add_multitask_objective(self, mt_component):
 | |
|         """Register another component as a multi-task objective. Experimental."""
 | |
|         self._multitasks.append(mt_component)
 | |
| 
 | |
|     def init_multitask_objectives(self, get_examples, nlp=None, **cfg):
 | |
|         """Setup multi-task objective components. Experimental and internal."""
 | |
|         # TODO: transfer self.model.get_ref("tok2vec") to the multitask's model ?
 | |
|         for labeller in self._multitasks:
 | |
|             labeller.model.set_dim("nO", len(self.labels))
 | |
|             if labeller.model.has_ref("output_layer"):
 | |
|                 labeller.model.get_ref("output_layer").set_dim("nO", len(self.labels))
 | |
|             labeller.initialize(get_examples, nlp=nlp)
 | |
| 
 | |
|     @property
 | |
|     def labels(self):
 | |
|         # Get the labels from the model by looking at the available moves, e.g.
 | |
|         # B-PERSON, I-PERSON, L-PERSON, U-PERSON
 | |
|         labels = set(move.split("-")[1] for move in self.move_names
 | |
|                      if move[0] in ("B", "I", "L", "U"))
 | |
|         return tuple(sorted(labels))
 | |
| 
 | |
|     def score(self, examples, **kwargs):
 | |
|         """Score a batch of examples.
 | |
| 
 | |
|         examples (Iterable[Example]): The examples to score.
 | |
|         RETURNS (Dict[str, Any]): The NER precision, recall and f-scores.
 | |
| 
 | |
|         DOCS: https://spacy.io/api/entityrecognizer#score
 | |
|         """
 | |
|         validate_examples(examples, "EntityRecognizer.score")
 | |
|         return get_ner_prf(examples)
 | |
| 
 | |
|     def scored_ents(self, beams):
 | |
|         """Return a dictionary of (start, end, label) tuples with corresponding scores
 | |
|         for each beam/doc that was processed.
 | |
|         """
 | |
|         entity_scores = []
 | |
|         for beam in beams:
 | |
|             score_dict = defaultdict(float)
 | |
|             for score, ents in self.moves.get_beam_parses(beam):
 | |
|                 for start, end, label in ents:
 | |
|                     score_dict[(start, end, label)] += score
 | |
|             entity_scores.append(score_dict)
 | |
|         return entity_scores
 |