mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-19 20:52:23 +03:00
Formatting
This commit is contained in:
parent
6855df0e66
commit
1a79d18796
|
@ -60,6 +60,7 @@ DEFAULT_MODEL = Config().from_str(default_config)["model"]
|
||||||
|
|
||||||
DEFAULT_CLUSTERS_PREFIX = "coref_clusters"
|
DEFAULT_CLUSTERS_PREFIX = "coref_clusters"
|
||||||
|
|
||||||
|
|
||||||
@Language.factory(
|
@Language.factory(
|
||||||
"coref",
|
"coref",
|
||||||
assigns=["doc.spans"],
|
assigns=["doc.spans"],
|
||||||
|
@ -133,7 +134,7 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
# each item in scores includes scores and a mapping from scores to mentions
|
# each item in scores includes scores and a mapping from scores to mentions
|
||||||
ant_idxs = idxs
|
ant_idxs = idxs
|
||||||
|
|
||||||
#TODO batching
|
# TODO batching
|
||||||
xp = self.model.ops.xp
|
xp = self.model.ops.xp
|
||||||
|
|
||||||
starts = xp.arange(0, len(docs[0]))
|
starts = xp.arange(0, len(docs[0]))
|
||||||
|
@ -242,7 +243,7 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
if self._rehearsal_model is None:
|
if self._rehearsal_model is None:
|
||||||
return losses
|
return losses
|
||||||
validate_examples(examples, "CoreferenceResolver.rehearse")
|
validate_examples(examples, "CoreferenceResolver.rehearse")
|
||||||
#TODO test this whole function
|
# TODO test this whole function
|
||||||
docs = [eg.predicted for eg in examples]
|
docs = [eg.predicted for eg in examples]
|
||||||
if not any(len(doc) for doc in docs):
|
if not any(len(doc) for doc in docs):
|
||||||
# Handle cases where there are no tokens in any docs.
|
# Handle cases where there are no tokens in any docs.
|
||||||
|
@ -256,7 +257,7 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
if sgd is not None:
|
if sgd is not None:
|
||||||
self.finish_update(sgd)
|
self.finish_update(sgd)
|
||||||
if losses is not None:
|
if losses is not None:
|
||||||
losses[self.name] += (gradient ** 2).sum()
|
losses[self.name] += (gradient**2).sum()
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def add_label(self, label: str) -> int:
|
def add_label(self, label: str) -> int:
|
||||||
|
@ -290,13 +291,13 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
offset = 0
|
offset = 0
|
||||||
gradients = []
|
gradients = []
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
#TODO change this
|
# TODO change this
|
||||||
# 1. do not handle batching (add it back later)
|
# 1. do not handle batching (add it back later)
|
||||||
# 2. don't do index conversion (no mentions, just word indices)
|
# 2. don't do index conversion (no mentions, just word indices)
|
||||||
# 3. convert words to spans (if necessary) in gold and predictions
|
# 3. convert words to spans (if necessary) in gold and predictions
|
||||||
|
|
||||||
# massage score matrix to be shaped correctly
|
# massage score matrix to be shaped correctly
|
||||||
score_matrix = [ (score_matrix, None) ]
|
score_matrix = [(score_matrix, None)]
|
||||||
for example, (cscores, cidx) in zip(examples, score_matrix):
|
for example, (cscores, cidx) in zip(examples, score_matrix):
|
||||||
|
|
||||||
ll = cscores.shape[0]
|
ll = cscores.shape[0]
|
||||||
|
@ -306,7 +307,7 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
span_idxs = create_head_span_idxs(ops, len(example.predicted))
|
span_idxs = create_head_span_idxs(ops, len(example.predicted))
|
||||||
gscores = create_gold_scores(span_idxs, clusters)
|
gscores = create_gold_scores(span_idxs, clusters)
|
||||||
gscores = ops.asarray2f(gscores)
|
gscores = ops.asarray2f(gscores)
|
||||||
#top_gscores = xp.take_along_axis(gscores, cidx, axis=1)
|
# top_gscores = xp.take_along_axis(gscores, cidx, axis=1)
|
||||||
top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1)
|
top_gscores = xp.take_along_axis(gscores, mention_idx, axis=1)
|
||||||
# now add the placeholder
|
# now add the placeholder
|
||||||
gold_placeholder = ~top_gscores.any(axis=1).T
|
gold_placeholder = ~top_gscores.any(axis=1).T
|
||||||
|
@ -322,7 +323,7 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
log_norm = ops.softmax(cscores, axis=1)
|
log_norm = ops.softmax(cscores, axis=1)
|
||||||
grad = log_norm - log_marg
|
grad = log_norm - log_marg
|
||||||
gradients.append((grad, cidx))
|
gradients.append((grad, cidx))
|
||||||
total_loss += float((grad ** 2).sum())
|
total_loss += float((grad**2).sum())
|
||||||
|
|
||||||
offset = hi
|
offset = hi
|
||||||
|
|
||||||
|
@ -389,6 +390,7 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
out[fname] = mean([ss[fname] for ss in scores])
|
out[fname] = mean([ss[fname] for ss in scores])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class SpanPredictor(TrainablePipe):
|
class SpanPredictor(TrainablePipe):
|
||||||
"""Pipeline component to resolve one-token spans to full spans.
|
"""Pipeline component to resolve one-token spans to full spans.
|
||||||
|
|
||||||
|
@ -451,7 +453,7 @@ class SpanPredictor(TrainablePipe):
|
||||||
def get_loss(
|
def get_loss(
|
||||||
self,
|
self,
|
||||||
examples: Iterable[Example],
|
examples: Iterable[Example],
|
||||||
#TODO add necessary args
|
# TODO add necessary args
|
||||||
):
|
):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -475,4 +477,3 @@ class SpanPredictor(TrainablePipe):
|
||||||
def score(self, examples, **kwargs):
|
def score(self, examples, **kwargs):
|
||||||
# TODO this will overlap significantly with coref, maybe factor into function
|
# TODO this will overlap significantly with coref, maybe factor into function
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user