mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-19 04:32:32 +03:00
Formatting
This commit is contained in:
parent
6855df0e66
commit
1a79d18796
|
@ -60,6 +60,7 @@ DEFAULT_MODEL = Config().from_str(default_config)["model"]
|
|||
|
||||
DEFAULT_CLUSTERS_PREFIX = "coref_clusters"
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"coref",
|
||||
assigns=["doc.spans"],
|
||||
|
@ -133,7 +134,7 @@ class CoreferenceResolver(TrainablePipe):
|
|||
# each item in scores includes scores and a mapping from scores to mentions
|
||||
ant_idxs = idxs
|
||||
|
||||
#TODO batching
|
||||
# TODO batching
|
||||
xp = self.model.ops.xp
|
||||
|
||||
starts = xp.arange(0, len(docs[0]))
|
||||
|
@ -242,7 +243,7 @@ class CoreferenceResolver(TrainablePipe):
|
|||
if self._rehearsal_model is None:
|
||||
return losses
|
||||
validate_examples(examples, "CoreferenceResolver.rehearse")
|
||||
#TODO test this whole function
|
||||
# TODO test this whole function
|
||||
docs = [eg.predicted for eg in examples]
|
||||
if not any(len(doc) for doc in docs):
|
||||
# Handle cases where there are no tokens in any docs.
|
||||
|
@ -256,7 +257,7 @@ class CoreferenceResolver(TrainablePipe):
|
|||
if sgd is not None:
|
||||
self.finish_update(sgd)
|
||||
if losses is not None:
|
||||
losses[self.name] += (gradient ** 2).sum()
|
||||
losses[self.name] += (gradient**2).sum()
|
||||
return losses
|
||||
|
||||
def add_label(self, label: str) -> int:
|
||||
|
@ -290,13 +291,13 @@ class CoreferenceResolver(TrainablePipe):
|
|||
offset = 0
|
||||
gradients = []
|
||||
total_loss = 0
|
||||
#TODO change this
|
||||
# TODO change this
|
||||
# 1. do not handle batching (add it back later)
|
||||
# 2. don't do index conversion (no mentions, just word indices)
|
||||
# 3. convert words to spans (if necessary) in gold and predictions
|
||||
|
||||
|
||||
# massage score matrix to be shaped correctly
|
||||
score_matrix = [ (score_matrix, None) ]
|
||||
score_matrix = [(score_matrix, None)]
|
||||
for example, (cscores, cidx) in zip(examples, score_matrix):
|
||||
|
||||
ll = cscores.shape[0]
|
||||
|
@ -306,7 +307,7 @@ class CoreferenceResolver(TrainablePipe):
|
|||
span_idxs = create_head_span_idxs(ops, len(example.predicted))
|
||||
gscores = create_gold_scores(span_idxs, clusters)
|
||||
gscores = ops.asarray2f(gscores)
|
||||
#top_gscores = xp.take_along_axis(gscores, cidx, axis=1)
|
||||
# top_gscores = xp.take_along_axis(gscores, cidx, axis=1)
|
||||
top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1)
|
||||
# now add the placeholder
|
||||
gold_placeholder = ~top_gscores.any(axis=1).T
|
||||
|
@ -322,7 +323,7 @@ class CoreferenceResolver(TrainablePipe):
|
|||
log_norm = ops.softmax(cscores, axis=1)
|
||||
grad = log_norm - log_marg
|
||||
gradients.append((grad, cidx))
|
||||
total_loss += float((grad ** 2).sum())
|
||||
total_loss += float((grad**2).sum())
|
||||
|
||||
offset = hi
|
||||
|
||||
|
@ -389,6 +390,7 @@ class CoreferenceResolver(TrainablePipe):
|
|||
out[fname] = mean([ss[fname] for ss in scores])
|
||||
return out
|
||||
|
||||
|
||||
class SpanPredictor(TrainablePipe):
|
||||
"""Pipeline component to resolve one-token spans to full spans.
|
||||
|
||||
|
@ -419,21 +421,21 @@ class SpanPredictor(TrainablePipe):
|
|||
...
|
||||
|
||||
def update(
|
||||
self,
|
||||
examples: Iterable[Example],
|
||||
*,
|
||||
drop: float = 0.0,
|
||||
sgd: Optional[Optimizer] = None,
|
||||
self,
|
||||
examples: Iterable[Example],
|
||||
*,
|
||||
drop: float = 0.0,
|
||||
sgd: Optional[Optimizer] = None,
|
||||
losses: Optional[Dict[str, float]] = None,
|
||||
) -> Dict[str, float]:
|
||||
...
|
||||
|
||||
def rehearse(
|
||||
self,
|
||||
examples: Iterable[Example],
|
||||
*,
|
||||
drop: float = 0.0,
|
||||
sgd: Optional[Optimizer] = None,
|
||||
self,
|
||||
examples: Iterable[Example],
|
||||
*,
|
||||
drop: float = 0.0,
|
||||
sgd: Optional[Optimizer] = None,
|
||||
losses: Optional[Dict[str, float]] = None,
|
||||
) -> Dict[str, float]:
|
||||
...
|
||||
|
@ -451,7 +453,7 @@ class SpanPredictor(TrainablePipe):
|
|||
def get_loss(
|
||||
self,
|
||||
examples: Iterable[Example],
|
||||
#TODO add necessary args
|
||||
# TODO add necessary args
|
||||
):
|
||||
...
|
||||
|
||||
|
@ -475,4 +477,3 @@ class SpanPredictor(TrainablePipe):
|
|||
def score(self, examples, **kwargs):
|
||||
# TODO this will overlap significantly with coref, maybe factor into function
|
||||
...
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user