mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Suggestions from code review, cleanup, typing
This commit is contained in:
		
							parent
							
								
									6999436270
								
							
						
					
					
						commit
						6087da9675
					
				| 
						 | 
					@ -2,7 +2,7 @@ from typing import Iterable, Tuple, Optional, Dict, Callable, Any, List
 | 
				
			||||||
import warnings
 | 
					import warnings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from thinc.types import Floats2d, Floats3d, Ints2d
 | 
					from thinc.types import Floats2d, Floats3d, Ints2d
 | 
				
			||||||
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
 | 
					from thinc.api import Model, Config, Optimizer
 | 
				
			||||||
from thinc.api import set_dropout_rate, to_categorical
 | 
					from thinc.api import set_dropout_rate, to_categorical
 | 
				
			||||||
from itertools import islice
 | 
					from itertools import islice
 | 
				
			||||||
from statistics import mean
 | 
					from statistics import mean
 | 
				
			||||||
| 
						 | 
					@ -11,7 +11,6 @@ from .trainable_pipe import TrainablePipe
 | 
				
			||||||
from ..language import Language
 | 
					from ..language import Language
 | 
				
			||||||
from ..training import Example, validate_examples, validate_get_examples
 | 
					from ..training import Example, validate_examples, validate_get_examples
 | 
				
			||||||
from ..errors import Errors
 | 
					from ..errors import Errors
 | 
				
			||||||
from ..scorer import Scorer
 | 
					 | 
				
			||||||
from ..tokens import Doc
 | 
					from ..tokens import Doc
 | 
				
			||||||
from ..vocab import Vocab
 | 
					from ..vocab import Vocab
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -118,7 +117,7 @@ class CoreferenceResolver(TrainablePipe):
 | 
				
			||||||
        self.span_cluster_prefix = span_cluster_prefix
 | 
					        self.span_cluster_prefix = span_cluster_prefix
 | 
				
			||||||
        self._rehearsal_model = None
 | 
					        self._rehearsal_model = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.cfg = {}
 | 
					        self.cfg: Dict[str, Any] = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
 | 
					    def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
 | 
				
			||||||
        """Apply the pipeline's model to a batch of docs, without modifying them.
 | 
					        """Apply the pipeline's model to a batch of docs, without modifying them.
 | 
				
			||||||
| 
						 | 
					@ -154,6 +153,7 @@ class CoreferenceResolver(TrainablePipe):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DOCS: https://spacy.io/api/coref#set_annotations (TODO)
 | 
					        DOCS: https://spacy.io/api/coref#set_annotations (TODO)
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					        docs = list(docs)
 | 
				
			||||||
        if len(docs) != len(clusters_by_doc):
 | 
					        if len(docs) != len(clusters_by_doc):
 | 
				
			||||||
            raise ValueError(
 | 
					            raise ValueError(
 | 
				
			||||||
                "Found coref clusters incompatible with the "
 | 
					                "Found coref clusters incompatible with the "
 | 
				
			||||||
| 
						 | 
					@ -219,49 +219,8 @@ class CoreferenceResolver(TrainablePipe):
 | 
				
			||||||
        losses[self.name] += total_loss
 | 
					        losses[self.name] += total_loss
 | 
				
			||||||
        return losses
 | 
					        return losses
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def rehearse(
 | 
					    def rehearse(self, examples, *, sgd=None, losses=None, **config):
 | 
				
			||||||
        self,
 | 
					        raise NotImplementedError
 | 
				
			||||||
        examples: Iterable[Example],
 | 
					 | 
				
			||||||
        *,
 | 
					 | 
				
			||||||
        drop: float = 0.0,
 | 
					 | 
				
			||||||
        sgd: Optional[Optimizer] = None,
 | 
					 | 
				
			||||||
        losses: Optional[Dict[str, float]] = None,
 | 
					 | 
				
			||||||
    ) -> Dict[str, float]:
 | 
					 | 
				
			||||||
        """Perform a "rehearsal" update from a batch of data. Rehearsal updates
 | 
					 | 
				
			||||||
        teach the current model to make predictions similar to an initial model,
 | 
					 | 
				
			||||||
        to try to address the "catastrophic forgetting" problem. This feature is
 | 
					 | 
				
			||||||
        experimental.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        examples (Iterable[Example]): A batch of Example objects.
 | 
					 | 
				
			||||||
        drop (float): The dropout rate.
 | 
					 | 
				
			||||||
        sgd (thinc.api.Optimizer): The optimizer.
 | 
					 | 
				
			||||||
        losses (Dict[str, float]): Optional record of the loss during training.
 | 
					 | 
				
			||||||
            Updated using the component name as the key.
 | 
					 | 
				
			||||||
        RETURNS (Dict[str, float]): The updated losses dictionary.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        DOCS: https://spacy.io/api/coref#rehearse (TODO)
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        if losses is not None:
 | 
					 | 
				
			||||||
            losses.setdefault(self.name, 0.0)
 | 
					 | 
				
			||||||
        if self._rehearsal_model is None:
 | 
					 | 
				
			||||||
            return losses
 | 
					 | 
				
			||||||
        validate_examples(examples, "CoreferenceResolver.rehearse")
 | 
					 | 
				
			||||||
        # TODO test this whole function
 | 
					 | 
				
			||||||
        docs = [eg.predicted for eg in examples]
 | 
					 | 
				
			||||||
        if not any(len(doc) for doc in docs):
 | 
					 | 
				
			||||||
            # Handle cases where there are no tokens in any docs.
 | 
					 | 
				
			||||||
            return losses
 | 
					 | 
				
			||||||
        set_dropout_rate(self.model, drop)
 | 
					 | 
				
			||||||
        scores, bp_scores = self.model.begin_update(docs)
 | 
					 | 
				
			||||||
        # TODO below
 | 
					 | 
				
			||||||
        target = self._rehearsal_model(examples)
 | 
					 | 
				
			||||||
        gradient = scores - target
 | 
					 | 
				
			||||||
        bp_scores(gradient)
 | 
					 | 
				
			||||||
        if sgd is not None:
 | 
					 | 
				
			||||||
            self.finish_update(sgd)
 | 
					 | 
				
			||||||
        if losses is not None:
 | 
					 | 
				
			||||||
            losses[self.name] += (gradient**2).sum()
 | 
					 | 
				
			||||||
        return losses
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def add_label(self, label: str) -> int:
 | 
					    def add_label(self, label: str) -> int:
 | 
				
			||||||
        """Technically this method should be implemented from TrainablePipe,
 | 
					        """Technically this method should be implemented from TrainablePipe,
 | 
				
			||||||
| 
						 | 
					@ -276,7 +235,7 @@ class CoreferenceResolver(TrainablePipe):
 | 
				
			||||||
    def get_loss(
 | 
					    def get_loss(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        examples: Iterable[Example],
 | 
					        examples: Iterable[Example],
 | 
				
			||||||
        score_matrix: List[Tuple[Floats2d, Ints2d]],
 | 
					        score_matrix: Floats2d,
 | 
				
			||||||
        mention_idx: Ints2d,
 | 
					        mention_idx: Ints2d,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        """Find the loss and gradient of loss for the batch of documents and
 | 
					        """Find the loss and gradient of loss for the batch of documents and
 | 
				
			||||||
| 
						 | 
					@ -293,13 +252,13 @@ class CoreferenceResolver(TrainablePipe):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # TODO if there is more than one example, give an error
 | 
					        # TODO if there is more than one example, give an error
 | 
				
			||||||
        # (or actually rework this to take multiple things)
 | 
					        # (or actually rework this to take multiple things)
 | 
				
			||||||
        example = examples[0]
 | 
					        example = list(examples)[0]
 | 
				
			||||||
        cscores = score_matrix
 | 
					 | 
				
			||||||
        cidx = mention_idx
 | 
					        cidx = mention_idx
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        clusters = get_clusters_from_doc(example.reference)
 | 
					        clusters = get_clusters_from_doc(example.reference)
 | 
				
			||||||
        span_idxs = create_head_span_idxs(ops, len(example.predicted))
 | 
					        span_idxs = create_head_span_idxs(ops, len(example.predicted))
 | 
				
			||||||
        gscores = create_gold_scores(span_idxs, clusters)
 | 
					        gscores = create_gold_scores(span_idxs, clusters)
 | 
				
			||||||
 | 
					        # TODO fix type here. This is bools but asarray2f wants ints.
 | 
				
			||||||
        gscores = ops.asarray2f(gscores)
 | 
					        gscores = ops.asarray2f(gscores)
 | 
				
			||||||
        # top_gscores = xp.take_along_axis(gscores, cidx, axis=1)
 | 
					        # top_gscores = xp.take_along_axis(gscores, cidx, axis=1)
 | 
				
			||||||
        top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1)
 | 
					        top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1)
 | 
				
			||||||
| 
						 | 
					@ -313,8 +272,8 @@ class CoreferenceResolver(TrainablePipe):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        with warnings.catch_warnings():
 | 
					        with warnings.catch_warnings():
 | 
				
			||||||
            warnings.filterwarnings("ignore", category=RuntimeWarning)
 | 
					            warnings.filterwarnings("ignore", category=RuntimeWarning)
 | 
				
			||||||
            log_marg = ops.softmax(cscores + ops.xp.log(top_gscores), axis=1)
 | 
					            log_marg = ops.softmax(score_matrix + ops.xp.log(top_gscores), axis=1)
 | 
				
			||||||
        log_norm = ops.softmax(cscores, axis=1)
 | 
					        log_norm = ops.softmax(score_matrix, axis=1)
 | 
				
			||||||
        grad = log_norm - log_marg
 | 
					        grad = log_norm - log_marg
 | 
				
			||||||
        # gradients.append((grad, cidx))
 | 
					        # gradients.append((grad, cidx))
 | 
				
			||||||
        loss = float((grad**2).sum())
 | 
					        loss = float((grad**2).sum())
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user