Skeleton for span predictor component

This should be moved into its own file, but for now just stubbing out
the methods.
This commit is contained in:
Paul O'Leary McCann 2022-03-16 20:09:33 +09:00
parent 0275ae29de
commit 6855df0e66

View File

@ -117,7 +117,6 @@ class CoreferenceResolver(TrainablePipe):
self.span_mentions = span_mentions self.span_mentions = span_mentions
self.span_cluster_prefix = span_cluster_prefix self.span_cluster_prefix = span_cluster_prefix
self._rehearsal_model = None self._rehearsal_model = None
self.loss = CategoricalCrossentropy()
self.cfg = {} self.cfg = {}
@ -389,3 +388,91 @@ class CoreferenceResolver(TrainablePipe):
fname = f"coref_{field}" fname = f"coref_{field}"
out[fname] = mean([ss[fname] for ss in scores]) out[fname] = mean([ss[fname] for ss in scores])
return out return out
class SpanPredictor(TrainablePipe):
"""Pipeline component to resolve one-token spans to full spans.
Used in coreference resolution.
"""
def __init__(
self,
vocab: Vocab,
model: Model,
name: str = "span_predictor",
*,
input_prefix: str = "coref_head_clusters",
output_prefix: str = "coref_clusters",
) -> None:
self.vocab = vocab
self.model = model
self.name = name
self.input_prefix = input_prefix
self.output_prefix = output_prefix
self.cfg = {}
def predict(self, docs: Iterable[Doc]):
...
def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None:
...
def update(
self,
examples: Iterable[Example],
*,
drop: float = 0.0,
sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None,
) -> Dict[str, float]:
...
def rehearse(
self,
examples: Iterable[Example],
*,
drop: float = 0.0,
sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None,
) -> Dict[str, float]:
...
def add_label(self, label: str) -> int:
"""Technically this method should be implemented from TrainablePipe,
but it is not relevant for this component.
"""
raise NotImplementedError(
Errors.E931.format(
parent="SpanPredictor", method="add_label", name=self.name
)
)
def get_loss(
self,
examples: Iterable[Example],
#TODO add necessary args
):
...
def initialize(
self,
get_examples: Callable[[], Iterable[Example]],
*,
nlp: Optional[Language] = None,
) -> None:
validate_get_examples(get_examples, "CoreferenceResolver.initialize")
X = []
Y = []
for ex in islice(get_examples(), 2):
X.append(ex.predicted)
Y.append(ex.reference)
assert len(X) > 0, Errors.E923.format(name=self.name)
self.model.initialize(X=X, Y=Y)
def score(self, examples, **kwargs):
# TODO this will overlap significantly with coref, maybe factor into function
...