mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
Skeleton for span predictor component
This should be moved into its own file, but for now just stubbing out the methods.
This commit is contained in:
parent
0275ae29de
commit
6855df0e66
|
@ -117,7 +117,6 @@ class CoreferenceResolver(TrainablePipe):
|
|||
self.span_mentions = span_mentions
|
||||
self.span_cluster_prefix = span_cluster_prefix
|
||||
self._rehearsal_model = None
|
||||
self.loss = CategoricalCrossentropy()
|
||||
|
||||
self.cfg = {}
|
||||
|
||||
|
@ -389,3 +388,91 @@ class CoreferenceResolver(TrainablePipe):
|
|||
fname = f"coref_{field}"
|
||||
out[fname] = mean([ss[fname] for ss in scores])
|
||||
return out
|
||||
|
||||
class SpanPredictor(TrainablePipe):
|
||||
"""Pipeline component to resolve one-token spans to full spans.
|
||||
|
||||
Used in coreference resolution.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab: Vocab,
|
||||
model: Model,
|
||||
name: str = "span_predictor",
|
||||
*,
|
||||
input_prefix: str = "coref_head_clusters",
|
||||
output_prefix: str = "coref_clusters",
|
||||
) -> None:
|
||||
self.vocab = vocab
|
||||
self.model = model
|
||||
self.name = name
|
||||
self.input_prefix = input_prefix
|
||||
self.output_prefix = output_prefix
|
||||
|
||||
self.cfg = {}
|
||||
|
||||
def predict(self, docs: Iterable[Doc]):
|
||||
...
|
||||
|
||||
def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None:
|
||||
...
|
||||
|
||||
def update(
|
||||
self,
|
||||
examples: Iterable[Example],
|
||||
*,
|
||||
drop: float = 0.0,
|
||||
sgd: Optional[Optimizer] = None,
|
||||
losses: Optional[Dict[str, float]] = None,
|
||||
) -> Dict[str, float]:
|
||||
...
|
||||
|
||||
def rehearse(
|
||||
self,
|
||||
examples: Iterable[Example],
|
||||
*,
|
||||
drop: float = 0.0,
|
||||
sgd: Optional[Optimizer] = None,
|
||||
losses: Optional[Dict[str, float]] = None,
|
||||
) -> Dict[str, float]:
|
||||
...
|
||||
|
||||
def add_label(self, label: str) -> int:
|
||||
"""Technically this method should be implemented from TrainablePipe,
|
||||
but it is not relevant for this component.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
Errors.E931.format(
|
||||
parent="SpanPredictor", method="add_label", name=self.name
|
||||
)
|
||||
)
|
||||
|
||||
def get_loss(
|
||||
self,
|
||||
examples: Iterable[Example],
|
||||
#TODO add necessary args
|
||||
):
|
||||
...
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
get_examples: Callable[[], Iterable[Example]],
|
||||
*,
|
||||
nlp: Optional[Language] = None,
|
||||
) -> None:
|
||||
validate_get_examples(get_examples, "CoreferenceResolver.initialize")
|
||||
|
||||
X = []
|
||||
Y = []
|
||||
for ex in islice(get_examples(), 2):
|
||||
X.append(ex.predicted)
|
||||
Y.append(ex.reference)
|
||||
|
||||
assert len(X) > 0, Errors.E923.format(name=self.name)
|
||||
self.model.initialize(X=X, Y=Y)
|
||||
|
||||
def score(self, examples, **kwargs):
|
||||
# TODO this will overlap significantly with coref, maybe factor into function
|
||||
...
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user