mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			349 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Cython
		
	
	
	
	
	
			
		
		
	
	
			349 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Cython
		
	
	
	
	
	
# cython: infer_types=True, profile=True, binding=True
 | 
						|
from collections import defaultdict
 | 
						|
from typing import Callable, Optional
 | 
						|
 | 
						|
from thinc.api import Config, Model
 | 
						|
 | 
						|
from ._parser_internals.transition_system import TransitionSystem
 | 
						|
 | 
						|
from ._parser_internals.arc_eager cimport ArcEager
 | 
						|
from .transition_parser cimport Parser
 | 
						|
 | 
						|
from ..language import Language
 | 
						|
from ..scorer import Scorer
 | 
						|
from ..training import remove_bilu_prefix
 | 
						|
from ..util import registry
 | 
						|
from ._parser_internals import nonproj
 | 
						|
from ._parser_internals.nonproj import DELIMITER
 | 
						|
from .functions import merge_subtokens
 | 
						|
 | 
						|
default_model_config = """
 | 
						|
[model]
 | 
						|
@architectures = "spacy.TransitionBasedParser.v2"
 | 
						|
state_type = "parser"
 | 
						|
extra_state_tokens = false
 | 
						|
hidden_width = 64
 | 
						|
maxout_pieces = 2
 | 
						|
use_upper = true
 | 
						|
 | 
						|
[model.tok2vec]
 | 
						|
@architectures = "spacy.HashEmbedCNN.v2"
 | 
						|
pretrained_vectors = null
 | 
						|
width = 96
 | 
						|
depth = 4
 | 
						|
embed_size = 2000
 | 
						|
window_size = 1
 | 
						|
maxout_pieces = 3
 | 
						|
subword_features = true
 | 
						|
"""
 | 
						|
DEFAULT_PARSER_MODEL = Config().from_str(default_model_config)["model"]
 | 
						|
 | 
						|
 | 
						|
@Language.factory(
 | 
						|
    "parser",
 | 
						|
    assigns=["token.dep", "token.head", "token.is_sent_start", "doc.sents"],
 | 
						|
    default_config={
 | 
						|
        "moves": None,
 | 
						|
        "update_with_oracle_cut_size": 100,
 | 
						|
        "learn_tokens": False,
 | 
						|
        "min_action_freq": 30,
 | 
						|
        "model": DEFAULT_PARSER_MODEL,
 | 
						|
        "scorer": {"@scorers": "spacy.parser_scorer.v1"},
 | 
						|
    },
 | 
						|
    default_score_weights={
 | 
						|
        "dep_uas": 0.5,
 | 
						|
        "dep_las": 0.5,
 | 
						|
        "dep_las_per_type": None,
 | 
						|
        "sents_p": None,
 | 
						|
        "sents_r": None,
 | 
						|
        "sents_f": 0.0,
 | 
						|
    },
 | 
						|
)
 | 
						|
def make_parser(
 | 
						|
    nlp: Language,
 | 
						|
    name: str,
 | 
						|
    model: Model,
 | 
						|
    moves: Optional[TransitionSystem],
 | 
						|
    update_with_oracle_cut_size: int,
 | 
						|
    learn_tokens: bool,
 | 
						|
    min_action_freq: int,
 | 
						|
    scorer: Optional[Callable],
 | 
						|
):
 | 
						|
    """Create a transition-based DependencyParser component. The dependency parser
 | 
						|
    jointly learns sentence segmentation and labelled dependency parsing, and can
 | 
						|
    optionally learn to merge tokens that had been over-segmented by the tokenizer.
 | 
						|
 | 
						|
    The parser uses a variant of the non-monotonic arc-eager transition-system
 | 
						|
    described by Honnibal and Johnson (2014), with the addition of a "break"
 | 
						|
    transition to perform the sentence segmentation. Nivre's pseudo-projective
 | 
						|
    dependency transformation is used to allow the parser to predict
 | 
						|
    non-projective parses.
 | 
						|
 | 
						|
    The parser is trained using an imitation learning objective. The parser follows
 | 
						|
    the actions predicted by the current weights, and at each state, determines
 | 
						|
    which actions are compatible with the optimal parse that could be reached
 | 
						|
    from the current state. The weights such that the scores assigned to the
 | 
						|
    set of optimal actions is increased, while scores assigned to other
 | 
						|
    actions are decreased. Note that more than one action may be optimal for
 | 
						|
    a given state.
 | 
						|
 | 
						|
    model (Model): The model for the transition-based parser. The model needs
 | 
						|
        to have a specific substructure of named components --- see the
 | 
						|
        spacy.ml.tb_framework.TransitionModel for details.
 | 
						|
    moves (Optional[TransitionSystem]): This defines how the parse-state is created,
 | 
						|
        updated and evaluated. If 'moves' is None, a new instance is
 | 
						|
        created with `self.TransitionSystem()`. Defaults to `None`.
 | 
						|
    update_with_oracle_cut_size (int): During training, cut long sequences into
 | 
						|
        shorter segments by creating intermediate states based on the gold-standard
 | 
						|
        history. The model is not very sensitive to this parameter, so you usually
 | 
						|
        won't need to change it. 100 is a good default.
 | 
						|
    learn_tokens (bool): Whether to learn to merge subtokens that are split
 | 
						|
        relative to the gold standard. Experimental.
 | 
						|
    min_action_freq (int): The minimum frequency of labelled actions to retain.
 | 
						|
        Rarer labelled actions have their label backed-off to "dep". While this
 | 
						|
        primarily affects the label accuracy, it can also affect the attachment
 | 
						|
        structure, as the labels are used to represent the pseudo-projectivity
 | 
						|
        transformation.
 | 
						|
    scorer (Optional[Callable]): The scoring method.
 | 
						|
    """
 | 
						|
    return DependencyParser(
 | 
						|
        nlp.vocab,
 | 
						|
        model,
 | 
						|
        name,
 | 
						|
        moves=moves,
 | 
						|
        update_with_oracle_cut_size=update_with_oracle_cut_size,
 | 
						|
        multitasks=[],
 | 
						|
        learn_tokens=learn_tokens,
 | 
						|
        min_action_freq=min_action_freq,
 | 
						|
        beam_width=1,
 | 
						|
        beam_density=0.0,
 | 
						|
        beam_update_prob=0.0,
 | 
						|
        # At some point in the future we can try to implement support for
 | 
						|
        # partial annotations, perhaps only in the beam objective.
 | 
						|
        incorrect_spans_key=None,
 | 
						|
        scorer=scorer,
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
@Language.factory(
 | 
						|
    "beam_parser",
 | 
						|
    assigns=["token.dep", "token.head", "token.is_sent_start", "doc.sents"],
 | 
						|
    default_config={
 | 
						|
        "beam_width": 8,
 | 
						|
        "beam_density": 0.01,
 | 
						|
        "beam_update_prob": 0.5,
 | 
						|
        "moves": None,
 | 
						|
        "update_with_oracle_cut_size": 100,
 | 
						|
        "learn_tokens": False,
 | 
						|
        "min_action_freq": 30,
 | 
						|
        "model": DEFAULT_PARSER_MODEL,
 | 
						|
        "scorer": {"@scorers": "spacy.parser_scorer.v1"},
 | 
						|
    },
 | 
						|
    default_score_weights={
 | 
						|
        "dep_uas": 0.5,
 | 
						|
        "dep_las": 0.5,
 | 
						|
        "dep_las_per_type": None,
 | 
						|
        "sents_p": None,
 | 
						|
        "sents_r": None,
 | 
						|
        "sents_f": 0.0,
 | 
						|
    },
 | 
						|
)
 | 
						|
def make_beam_parser(
 | 
						|
    nlp: Language,
 | 
						|
    name: str,
 | 
						|
    model: Model,
 | 
						|
    moves: Optional[TransitionSystem],
 | 
						|
    update_with_oracle_cut_size: int,
 | 
						|
    learn_tokens: bool,
 | 
						|
    min_action_freq: int,
 | 
						|
    beam_width: int,
 | 
						|
    beam_density: float,
 | 
						|
    beam_update_prob: float,
 | 
						|
    scorer: Optional[Callable],
 | 
						|
):
 | 
						|
    """Create a transition-based DependencyParser component that uses beam-search.
 | 
						|
    The dependency parser jointly learns sentence segmentation and labelled
 | 
						|
    dependency parsing, and can optionally learn to merge tokens that had been
 | 
						|
    over-segmented by the tokenizer.
 | 
						|
 | 
						|
    The parser uses a variant of the non-monotonic arc-eager transition-system
 | 
						|
    described by Honnibal and Johnson (2014), with the addition of a "break"
 | 
						|
    transition to perform the sentence segmentation. Nivre's pseudo-projective
 | 
						|
    dependency transformation is used to allow the parser to predict
 | 
						|
    non-projective parses.
 | 
						|
 | 
						|
    The parser is trained using a global objective. That is, it learns to assign
 | 
						|
    probabilities to whole parses.
 | 
						|
 | 
						|
    model (Model): The model for the transition-based parser. The model needs
 | 
						|
        to have a specific substructure of named components --- see the
 | 
						|
        spacy.ml.tb_framework.TransitionModel for details.
 | 
						|
    moves (Optional[TransitionSystem]): This defines how the parse-state is created,
 | 
						|
        updated and evaluated. If 'moves' is None, a new instance is
 | 
						|
        created with `self.TransitionSystem()`. Defaults to `None`.
 | 
						|
    update_with_oracle_cut_size (int): During training, cut long sequences into
 | 
						|
        shorter segments by creating intermediate states based on the gold-standard
 | 
						|
        history. The model is not very sensitive to this parameter, so you usually
 | 
						|
        won't need to change it. 100 is a good default.
 | 
						|
    beam_width (int): The number of candidate analyses to maintain.
 | 
						|
    beam_density (float): The minimum ratio between the scores of the first and
 | 
						|
        last candidates in the beam. This allows the parser to avoid exploring
 | 
						|
        candidates that are too far behind. This is mostly intended to improve
 | 
						|
        efficiency, but it can also improve accuracy as deeper search is not
 | 
						|
        always better.
 | 
						|
    beam_update_prob (float): The chance of making a beam update, instead of a
 | 
						|
        greedy update. Greedy updates are an approximation for the beam updates,
 | 
						|
        and are faster to compute.
 | 
						|
    learn_tokens (bool): Whether to learn to merge subtokens that are split
 | 
						|
        relative to the gold standard. Experimental.
 | 
						|
    min_action_freq (int): The minimum frequency of labelled actions to retain.
 | 
						|
        Rarer labelled actions have their label backed-off to "dep". While this
 | 
						|
        primarily affects the label accuracy, it can also affect the attachment
 | 
						|
        structure, as the labels are used to represent the pseudo-projectivity
 | 
						|
        transformation.
 | 
						|
    """
 | 
						|
    return DependencyParser(
 | 
						|
        nlp.vocab,
 | 
						|
        model,
 | 
						|
        name,
 | 
						|
        moves=moves,
 | 
						|
        update_with_oracle_cut_size=update_with_oracle_cut_size,
 | 
						|
        beam_width=beam_width,
 | 
						|
        beam_density=beam_density,
 | 
						|
        beam_update_prob=beam_update_prob,
 | 
						|
        multitasks=[],
 | 
						|
        learn_tokens=learn_tokens,
 | 
						|
        min_action_freq=min_action_freq,
 | 
						|
        # At some point in the future we can try to implement support for
 | 
						|
        # partial annotations, perhaps only in the beam objective.
 | 
						|
        incorrect_spans_key=None,
 | 
						|
        scorer=scorer,
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def parser_score(examples, **kwargs):
 | 
						|
    """Score a batch of examples.
 | 
						|
 | 
						|
    examples (Iterable[Example]): The examples to score.
 | 
						|
    RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_spans
 | 
						|
        and Scorer.score_deps.
 | 
						|
 | 
						|
    DOCS: https://spacy.io/api/dependencyparser#score
 | 
						|
    """
 | 
						|
    def has_sents(doc):
 | 
						|
        return doc.has_annotation("SENT_START")
 | 
						|
 | 
						|
    def dep_getter(token, attr):
 | 
						|
        dep = getattr(token, attr)
 | 
						|
        dep = token.vocab.strings.as_string(dep).lower()
 | 
						|
        return dep
 | 
						|
    results = {}
 | 
						|
    results.update(Scorer.score_spans(examples, "sents", has_annotation=has_sents, **kwargs))
 | 
						|
    kwargs.setdefault("getter", dep_getter)
 | 
						|
    kwargs.setdefault("ignore_labels", ("p", "punct"))
 | 
						|
    results.update(Scorer.score_deps(examples, "dep", **kwargs))
 | 
						|
    del results["sents_per_type"]
 | 
						|
    return results
 | 
						|
 | 
						|
 | 
						|
@registry.scorers("spacy.parser_scorer.v1")
 | 
						|
def make_parser_scorer():
 | 
						|
    return parser_score
 | 
						|
 | 
						|
 | 
						|
cdef class DependencyParser(Parser):
 | 
						|
    """Pipeline component for dependency parsing.
 | 
						|
 | 
						|
    DOCS: https://spacy.io/api/dependencyparser
 | 
						|
    """
 | 
						|
    TransitionSystem = ArcEager
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        vocab,
 | 
						|
        model,
 | 
						|
        name="parser",
 | 
						|
        moves=None,
 | 
						|
        *,
 | 
						|
        update_with_oracle_cut_size=100,
 | 
						|
        min_action_freq=30,
 | 
						|
        learn_tokens=False,
 | 
						|
        beam_width=1,
 | 
						|
        beam_density=0.0,
 | 
						|
        beam_update_prob=0.0,
 | 
						|
        multitasks=tuple(),
 | 
						|
        incorrect_spans_key=None,
 | 
						|
        scorer=parser_score,
 | 
						|
    ):
 | 
						|
        """Create a DependencyParser.
 | 
						|
        """
 | 
						|
        super().__init__(
 | 
						|
            vocab,
 | 
						|
            model,
 | 
						|
            name,
 | 
						|
            moves,
 | 
						|
            update_with_oracle_cut_size=update_with_oracle_cut_size,
 | 
						|
            min_action_freq=min_action_freq,
 | 
						|
            learn_tokens=learn_tokens,
 | 
						|
            beam_width=beam_width,
 | 
						|
            beam_density=beam_density,
 | 
						|
            beam_update_prob=beam_update_prob,
 | 
						|
            multitasks=multitasks,
 | 
						|
            incorrect_spans_key=incorrect_spans_key,
 | 
						|
            scorer=scorer,
 | 
						|
        )
 | 
						|
 | 
						|
    @property
 | 
						|
    def postprocesses(self):
 | 
						|
        output = [nonproj.deprojectivize]
 | 
						|
        if self.cfg.get("learn_tokens") is True:
 | 
						|
            output.append(merge_subtokens)
 | 
						|
        return tuple(output)
 | 
						|
 | 
						|
    def add_multitask_objective(self, mt_component):
 | 
						|
        self._multitasks.append(mt_component)
 | 
						|
 | 
						|
    def init_multitask_objectives(self, get_examples, nlp=None, **cfg):
 | 
						|
        # TODO: transfer self.model.get_ref("tok2vec") to the multitask's model ?
 | 
						|
        for labeller in self._multitasks:
 | 
						|
            labeller.model.set_dim("nO", len(self.labels))
 | 
						|
            if labeller.model.has_ref("output_layer"):
 | 
						|
                labeller.model.get_ref("output_layer").set_dim("nO", len(self.labels))
 | 
						|
            labeller.initialize(get_examples, nlp=nlp)
 | 
						|
 | 
						|
    @property
 | 
						|
    def labels(self):
 | 
						|
        labels = set()
 | 
						|
        # Get the labels from the model by looking at the available moves
 | 
						|
        for move in self.move_names:
 | 
						|
            if "-" in move:
 | 
						|
                label = remove_bilu_prefix(move)
 | 
						|
                if DELIMITER in label:
 | 
						|
                    label = label.split(DELIMITER)[1]
 | 
						|
                labels.add(label)
 | 
						|
        return tuple(sorted(labels))
 | 
						|
 | 
						|
    def scored_parses(self, beams):
 | 
						|
        """Return two dictionaries with scores for each beam/doc that was processed:
 | 
						|
        one containing (i, head) keys, and another containing (i, label) keys.
 | 
						|
        """
 | 
						|
        head_scores = []
 | 
						|
        label_scores = []
 | 
						|
        for beam in beams:
 | 
						|
            score_head_dict = defaultdict(float)
 | 
						|
            score_label_dict = defaultdict(float)
 | 
						|
            for score, parses in self.moves.get_beam_parses(beam):
 | 
						|
                for head, i, label in parses:
 | 
						|
                    score_head_dict[(i, head)] += score
 | 
						|
                    score_label_dict[(i, label)] += score
 | 
						|
            head_scores.append(score_head_dict)
 | 
						|
            label_scores.append(score_label_dict)
 | 
						|
        return head_scores, label_scores
 | 
						|
 | 
						|
    def _ensure_labels_are_added(self, docs):
 | 
						|
        # This gives the parser a chance to add labels it's missing for a batch
 | 
						|
        # of documents. However, this isn't desirable for the dependency parser,
 | 
						|
        # because we instead have a label frequency cut-off and back off rare
 | 
						|
        # labels to 'dep'.
 | 
						|
        pass
 |