From 6855df0e66655766a407f0728271bad701e91f8e Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Wed, 16 Mar 2022 20:09:33 +0900 Subject: [PATCH] Skeleton for span predictor component This should be moved into its own file, but for now just stubbing out the methods. --- spacy/pipeline/coref.py | 89 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 88 insertions(+), 1 deletion(-) diff --git a/spacy/pipeline/coref.py b/spacy/pipeline/coref.py index 97aa33cf2..20fdcac38 100644 --- a/spacy/pipeline/coref.py +++ b/spacy/pipeline/coref.py @@ -117,7 +117,6 @@ class CoreferenceResolver(TrainablePipe): self.span_mentions = span_mentions self.span_cluster_prefix = span_cluster_prefix self._rehearsal_model = None - self.loss = CategoricalCrossentropy() self.cfg = {} @@ -389,3 +388,91 @@ class CoreferenceResolver(TrainablePipe): fname = f"coref_{field}" out[fname] = mean([ss[fname] for ss in scores]) return out + +class SpanPredictor(TrainablePipe): + """Pipeline component to resolve one-token spans to full spans. + + Used in coreference resolution. + """ + + def __init__( + self, + vocab: Vocab, + model: Model, + name: str = "span_predictor", + *, + input_prefix: str = "coref_head_clusters", + output_prefix: str = "coref_clusters", + ) -> None: + self.vocab = vocab + self.model = model + self.name = name + self.input_prefix = input_prefix + self.output_prefix = output_prefix + + self.cfg = {} + + def predict(self, docs: Iterable[Doc]): + ... + + def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None: + ... + + def update( + self, + examples: Iterable[Example], + *, + drop: float = 0.0, + sgd: Optional[Optimizer] = None, + losses: Optional[Dict[str, float]] = None, + ) -> Dict[str, float]: + ... + + def rehearse( + self, + examples: Iterable[Example], + *, + drop: float = 0.0, + sgd: Optional[Optimizer] = None, + losses: Optional[Dict[str, float]] = None, + ) -> Dict[str, float]: + ... + + def add_label(self, label: str) -> int: + """Technically this method should be implemented from TrainablePipe, + but it is not relevant for this component. + """ + raise NotImplementedError( + Errors.E931.format( + parent="SpanPredictor", method="add_label", name=self.name + ) + ) + + def get_loss( + self, + examples: Iterable[Example], + #TODO add necessary args + ): + ... + + def initialize( + self, + get_examples: Callable[[], Iterable[Example]], + *, + nlp: Optional[Language] = None, + ) -> None: + validate_get_examples(get_examples, "CoreferenceResolver.initialize") + + X = [] + Y = [] + for ex in islice(get_examples(), 2): + X.append(ex.predicted) + Y.append(ex.reference) + + assert len(X) > 0, Errors.E923.format(name=self.name) + self.model.initialize(X=X, Y=Y) + + def score(self, examples, **kwargs): + # TODO this will overlap significantly with coref, maybe factor into function + ... +