mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 09:26:27 +03:00
Minor NEL type fixes (#10860)
* Fix TODO about typing Fix was simple: just request an array2f. * Add type ignore Maxout has a more restrictive type than the residual layer expects (only Floats2d vs any Floats). * Various cleanup This moves a lot of lines around but doesn't change any functionality. Details: 1. use `continue` to reduce indentation 2. move sentence doc building inside conditional since it's otherwise unused 3. reduces some temporary assignments
This commit is contained in:
parent
56d4055d96
commit
dca2e8c644
|
@ -23,7 +23,7 @@ def build_nel_encoder(
|
||||||
((tok2vec >> list2ragged()) & build_span_maker())
|
((tok2vec >> list2ragged()) & build_span_maker())
|
||||||
>> extract_spans()
|
>> extract_spans()
|
||||||
>> reduce_mean()
|
>> reduce_mean()
|
||||||
>> residual(Maxout(nO=token_width, nI=token_width, nP=2, dropout=0.0))
|
>> residual(Maxout(nO=token_width, nI=token_width, nP=2, dropout=0.0)) # type: ignore
|
||||||
>> output_layer
|
>> output_layer
|
||||||
)
|
)
|
||||||
model.set_ref("output_layer", output_layer)
|
model.set_ref("output_layer", output_layer)
|
||||||
|
|
|
@ -355,7 +355,7 @@ class EntityLinker(TrainablePipe):
|
||||||
keep_ents.append(eidx)
|
keep_ents.append(eidx)
|
||||||
|
|
||||||
eidx += 1
|
eidx += 1
|
||||||
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
|
entity_encodings = self.model.ops.asarray2f(entity_encodings, dtype="float32")
|
||||||
selected_encodings = sentence_encodings[keep_ents]
|
selected_encodings = sentence_encodings[keep_ents]
|
||||||
|
|
||||||
# if there are no matches, short circuit
|
# if there are no matches, short circuit
|
||||||
|
@ -368,13 +368,12 @@ class EntityLinker(TrainablePipe):
|
||||||
method="get_loss", msg="gold entities do not match up"
|
method="get_loss", msg="gold entities do not match up"
|
||||||
)
|
)
|
||||||
raise RuntimeError(err)
|
raise RuntimeError(err)
|
||||||
# TODO: fix typing issue here
|
gradients = self.distance.get_grad(selected_encodings, entity_encodings)
|
||||||
gradients = self.distance.get_grad(selected_encodings, entity_encodings) # type: ignore
|
|
||||||
# to match the input size, we need to give a zero gradient for items not in the kb
|
# to match the input size, we need to give a zero gradient for items not in the kb
|
||||||
out = self.model.ops.alloc2f(*sentence_encodings.shape)
|
out = self.model.ops.alloc2f(*sentence_encodings.shape)
|
||||||
out[keep_ents] = gradients
|
out[keep_ents] = gradients
|
||||||
|
|
||||||
loss = self.distance.get_loss(selected_encodings, entity_encodings) # type: ignore
|
loss = self.distance.get_loss(selected_encodings, entity_encodings)
|
||||||
loss = loss / len(entity_encodings)
|
loss = loss / len(entity_encodings)
|
||||||
return float(loss), out
|
return float(loss), out
|
||||||
|
|
||||||
|
@ -391,18 +390,21 @@ class EntityLinker(TrainablePipe):
|
||||||
self.validate_kb()
|
self.validate_kb()
|
||||||
entity_count = 0
|
entity_count = 0
|
||||||
final_kb_ids: List[str] = []
|
final_kb_ids: List[str] = []
|
||||||
|
xp = self.model.ops.xp
|
||||||
if not docs:
|
if not docs:
|
||||||
return final_kb_ids
|
return final_kb_ids
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
|
if len(doc) == 0:
|
||||||
|
continue
|
||||||
sentences = [s for s in doc.sents]
|
sentences = [s for s in doc.sents]
|
||||||
if len(doc) > 0:
|
# Looping through each entity (TODO: rewrite)
|
||||||
# Looping through each entity (TODO: rewrite)
|
for ent in doc.ents:
|
||||||
for ent in doc.ents:
|
sent_index = sentences.index(ent.sent)
|
||||||
sent = ent.sent
|
assert sent_index >= 0
|
||||||
sent_index = sentences.index(sent)
|
|
||||||
assert sent_index >= 0
|
if self.incl_context:
|
||||||
# get n_neighbour sentences, clipped to the length of the document
|
# get n_neighbour sentences, clipped to the length of the document
|
||||||
start_sentence = max(0, sent_index - self.n_sents)
|
start_sentence = max(0, sent_index - self.n_sents)
|
||||||
end_sentence = min(len(sentences) - 1, sent_index + self.n_sents)
|
end_sentence = min(len(sentences) - 1, sent_index + self.n_sents)
|
||||||
|
@ -410,55 +412,53 @@ class EntityLinker(TrainablePipe):
|
||||||
end_token = sentences[end_sentence].end
|
end_token = sentences[end_sentence].end
|
||||||
sent_doc = doc[start_token:end_token].as_doc()
|
sent_doc = doc[start_token:end_token].as_doc()
|
||||||
# currently, the context is the same for each entity in a sentence (should be refined)
|
# currently, the context is the same for each entity in a sentence (should be refined)
|
||||||
xp = self.model.ops.xp
|
sentence_encoding = self.model.predict([sent_doc])[0]
|
||||||
if self.incl_context:
|
sentence_encoding_t = sentence_encoding.T
|
||||||
sentence_encoding = self.model.predict([sent_doc])[0]
|
sentence_norm = xp.linalg.norm(sentence_encoding_t)
|
||||||
sentence_encoding_t = sentence_encoding.T
|
entity_count += 1
|
||||||
sentence_norm = xp.linalg.norm(sentence_encoding_t)
|
if ent.label_ in self.labels_discard:
|
||||||
entity_count += 1
|
# ignoring this entity - setting to NIL
|
||||||
if ent.label_ in self.labels_discard:
|
final_kb_ids.append(self.NIL)
|
||||||
# ignoring this entity - setting to NIL
|
else:
|
||||||
|
candidates = list(self.get_candidates(self.kb, ent))
|
||||||
|
if not candidates:
|
||||||
|
# no prediction possible for this entity - setting to NIL
|
||||||
final_kb_ids.append(self.NIL)
|
final_kb_ids.append(self.NIL)
|
||||||
|
elif len(candidates) == 1:
|
||||||
|
# shortcut for efficiency reasons: take the 1 candidate
|
||||||
|
# TODO: thresholding
|
||||||
|
final_kb_ids.append(candidates[0].entity_)
|
||||||
else:
|
else:
|
||||||
candidates = list(self.get_candidates(self.kb, ent))
|
random.shuffle(candidates)
|
||||||
if not candidates:
|
# set all prior probabilities to 0 if incl_prior=False
|
||||||
# no prediction possible for this entity - setting to NIL
|
prior_probs = xp.asarray([c.prior_prob for c in candidates])
|
||||||
final_kb_ids.append(self.NIL)
|
if not self.incl_prior:
|
||||||
elif len(candidates) == 1:
|
prior_probs = xp.asarray([0.0 for _ in candidates])
|
||||||
# shortcut for efficiency reasons: take the 1 candidate
|
scores = prior_probs
|
||||||
# TODO: thresholding
|
# add in similarity from the context
|
||||||
final_kb_ids.append(candidates[0].entity_)
|
if self.incl_context:
|
||||||
else:
|
entity_encodings = xp.asarray(
|
||||||
random.shuffle(candidates)
|
[c.entity_vector for c in candidates]
|
||||||
# set all prior probabilities to 0 if incl_prior=False
|
)
|
||||||
prior_probs = xp.asarray([c.prior_prob for c in candidates])
|
entity_norm = xp.linalg.norm(entity_encodings, axis=1)
|
||||||
if not self.incl_prior:
|
if len(entity_encodings) != len(prior_probs):
|
||||||
prior_probs = xp.asarray([0.0 for _ in candidates])
|
raise RuntimeError(
|
||||||
scores = prior_probs
|
Errors.E147.format(
|
||||||
# add in similarity from the context
|
method="predict",
|
||||||
if self.incl_context:
|
msg="vectors not of equal length",
|
||||||
entity_encodings = xp.asarray(
|
|
||||||
[c.entity_vector for c in candidates]
|
|
||||||
)
|
|
||||||
entity_norm = xp.linalg.norm(entity_encodings, axis=1)
|
|
||||||
if len(entity_encodings) != len(prior_probs):
|
|
||||||
raise RuntimeError(
|
|
||||||
Errors.E147.format(
|
|
||||||
method="predict",
|
|
||||||
msg="vectors not of equal length",
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
# cosine similarity
|
|
||||||
sims = xp.dot(entity_encodings, sentence_encoding_t) / (
|
|
||||||
sentence_norm * entity_norm
|
|
||||||
)
|
)
|
||||||
if sims.shape != prior_probs.shape:
|
# cosine similarity
|
||||||
raise ValueError(Errors.E161)
|
sims = xp.dot(entity_encodings, sentence_encoding_t) / (
|
||||||
scores = prior_probs + sims - (prior_probs * sims)
|
sentence_norm * entity_norm
|
||||||
# TODO: thresholding
|
)
|
||||||
best_index = scores.argmax().item()
|
if sims.shape != prior_probs.shape:
|
||||||
best_candidate = candidates[best_index]
|
raise ValueError(Errors.E161)
|
||||||
final_kb_ids.append(best_candidate.entity_)
|
scores = prior_probs + sims - (prior_probs * sims)
|
||||||
|
# TODO: thresholding
|
||||||
|
best_index = scores.argmax().item()
|
||||||
|
best_candidate = candidates[best_index]
|
||||||
|
final_kb_ids.append(best_candidate.entity_)
|
||||||
if not (len(final_kb_ids) == entity_count):
|
if not (len(final_kb_ids) == entity_count):
|
||||||
err = Errors.E147.format(
|
err = Errors.E147.format(
|
||||||
method="predict", msg="result variables not of equal length"
|
method="predict", msg="result variables not of equal length"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user