mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
Skeleton for span predictor component
This should be moved into its own file, but for now just stubbing out the methods.
This commit is contained in:
parent
0275ae29de
commit
6855df0e66
|
@ -117,7 +117,6 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
self.span_mentions = span_mentions
|
self.span_mentions = span_mentions
|
||||||
self.span_cluster_prefix = span_cluster_prefix
|
self.span_cluster_prefix = span_cluster_prefix
|
||||||
self._rehearsal_model = None
|
self._rehearsal_model = None
|
||||||
self.loss = CategoricalCrossentropy()
|
|
||||||
|
|
||||||
self.cfg = {}
|
self.cfg = {}
|
||||||
|
|
||||||
|
@ -389,3 +388,91 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
fname = f"coref_{field}"
|
fname = f"coref_{field}"
|
||||||
out[fname] = mean([ss[fname] for ss in scores])
|
out[fname] = mean([ss[fname] for ss in scores])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class SpanPredictor(TrainablePipe):
|
||||||
|
"""Pipeline component to resolve one-token spans to full spans.
|
||||||
|
|
||||||
|
Used in coreference resolution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab: Vocab,
|
||||||
|
model: Model,
|
||||||
|
name: str = "span_predictor",
|
||||||
|
*,
|
||||||
|
input_prefix: str = "coref_head_clusters",
|
||||||
|
output_prefix: str = "coref_clusters",
|
||||||
|
) -> None:
|
||||||
|
self.vocab = vocab
|
||||||
|
self.model = model
|
||||||
|
self.name = name
|
||||||
|
self.input_prefix = input_prefix
|
||||||
|
self.output_prefix = output_prefix
|
||||||
|
|
||||||
|
self.cfg = {}
|
||||||
|
|
||||||
|
def predict(self, docs: Iterable[Doc]):
|
||||||
|
...
|
||||||
|
|
||||||
|
def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
examples: Iterable[Example],
|
||||||
|
*,
|
||||||
|
drop: float = 0.0,
|
||||||
|
sgd: Optional[Optimizer] = None,
|
||||||
|
losses: Optional[Dict[str, float]] = None,
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def rehearse(
|
||||||
|
self,
|
||||||
|
examples: Iterable[Example],
|
||||||
|
*,
|
||||||
|
drop: float = 0.0,
|
||||||
|
sgd: Optional[Optimizer] = None,
|
||||||
|
losses: Optional[Dict[str, float]] = None,
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def add_label(self, label: str) -> int:
|
||||||
|
"""Technically this method should be implemented from TrainablePipe,
|
||||||
|
but it is not relevant for this component.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
Errors.E931.format(
|
||||||
|
parent="SpanPredictor", method="add_label", name=self.name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_loss(
|
||||||
|
self,
|
||||||
|
examples: Iterable[Example],
|
||||||
|
#TODO add necessary args
|
||||||
|
):
|
||||||
|
...
|
||||||
|
|
||||||
|
def initialize(
|
||||||
|
self,
|
||||||
|
get_examples: Callable[[], Iterable[Example]],
|
||||||
|
*,
|
||||||
|
nlp: Optional[Language] = None,
|
||||||
|
) -> None:
|
||||||
|
validate_get_examples(get_examples, "CoreferenceResolver.initialize")
|
||||||
|
|
||||||
|
X = []
|
||||||
|
Y = []
|
||||||
|
for ex in islice(get_examples(), 2):
|
||||||
|
X.append(ex.predicted)
|
||||||
|
Y.append(ex.reference)
|
||||||
|
|
||||||
|
assert len(X) > 0, Errors.E923.format(name=self.name)
|
||||||
|
self.model.initialize(X=X, Y=Y)
|
||||||
|
|
||||||
|
def score(self, examples, **kwargs):
|
||||||
|
# TODO this will overlap significantly with coref, maybe factor into function
|
||||||
|
...
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user