Minor NEL type fixes (#10860)

* Fix TODO about typing

Fix was simple: just request an array2f.

* Add type ignore

Maxout has a more restrictive type than the residual layer expects (only
Floats2d vs any Floats).

* Various cleanup

This moves a lot of lines around but doesn't change any functionality.
Details:

1. use `continue` to reduce indentation
2. move sentence doc building inside conditional since it's otherwise
   unused
3. reduces some temporary assignments
This commit is contained in:
Paul O'Leary McCann 2022-06-01 07:41:28 +09:00 committed by GitHub
parent 56d4055d96
commit dca2e8c644
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 56 additions and 56 deletions

View File

@ -23,7 +23,7 @@ def build_nel_encoder(
((tok2vec >> list2ragged()) & build_span_maker()) ((tok2vec >> list2ragged()) & build_span_maker())
>> extract_spans() >> extract_spans()
>> reduce_mean() >> reduce_mean()
>> residual(Maxout(nO=token_width, nI=token_width, nP=2, dropout=0.0)) >> residual(Maxout(nO=token_width, nI=token_width, nP=2, dropout=0.0)) # type: ignore
>> output_layer >> output_layer
) )
model.set_ref("output_layer", output_layer) model.set_ref("output_layer", output_layer)

View File

@ -355,7 +355,7 @@ class EntityLinker(TrainablePipe):
keep_ents.append(eidx) keep_ents.append(eidx)
eidx += 1 eidx += 1
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32") entity_encodings = self.model.ops.asarray2f(entity_encodings, dtype="float32")
selected_encodings = sentence_encodings[keep_ents] selected_encodings = sentence_encodings[keep_ents]
# if there are no matches, short circuit # if there are no matches, short circuit
@ -368,13 +368,12 @@ class EntityLinker(TrainablePipe):
method="get_loss", msg="gold entities do not match up" method="get_loss", msg="gold entities do not match up"
) )
raise RuntimeError(err) raise RuntimeError(err)
# TODO: fix typing issue here gradients = self.distance.get_grad(selected_encodings, entity_encodings)
gradients = self.distance.get_grad(selected_encodings, entity_encodings) # type: ignore
# to match the input size, we need to give a zero gradient for items not in the kb # to match the input size, we need to give a zero gradient for items not in the kb
out = self.model.ops.alloc2f(*sentence_encodings.shape) out = self.model.ops.alloc2f(*sentence_encodings.shape)
out[keep_ents] = gradients out[keep_ents] = gradients
loss = self.distance.get_loss(selected_encodings, entity_encodings) # type: ignore loss = self.distance.get_loss(selected_encodings, entity_encodings)
loss = loss / len(entity_encodings) loss = loss / len(entity_encodings)
return float(loss), out return float(loss), out
@ -391,18 +390,21 @@ class EntityLinker(TrainablePipe):
self.validate_kb() self.validate_kb()
entity_count = 0 entity_count = 0
final_kb_ids: List[str] = [] final_kb_ids: List[str] = []
xp = self.model.ops.xp
if not docs: if not docs:
return final_kb_ids return final_kb_ids
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
if len(doc) == 0:
continue
sentences = [s for s in doc.sents] sentences = [s for s in doc.sents]
if len(doc) > 0: # Looping through each entity (TODO: rewrite)
# Looping through each entity (TODO: rewrite) for ent in doc.ents:
for ent in doc.ents: sent_index = sentences.index(ent.sent)
sent = ent.sent assert sent_index >= 0
sent_index = sentences.index(sent)
assert sent_index >= 0 if self.incl_context:
# get n_neighbour sentences, clipped to the length of the document # get n_neighbour sentences, clipped to the length of the document
start_sentence = max(0, sent_index - self.n_sents) start_sentence = max(0, sent_index - self.n_sents)
end_sentence = min(len(sentences) - 1, sent_index + self.n_sents) end_sentence = min(len(sentences) - 1, sent_index + self.n_sents)
@ -410,55 +412,53 @@ class EntityLinker(TrainablePipe):
end_token = sentences[end_sentence].end end_token = sentences[end_sentence].end
sent_doc = doc[start_token:end_token].as_doc() sent_doc = doc[start_token:end_token].as_doc()
# currently, the context is the same for each entity in a sentence (should be refined) # currently, the context is the same for each entity in a sentence (should be refined)
xp = self.model.ops.xp sentence_encoding = self.model.predict([sent_doc])[0]
if self.incl_context: sentence_encoding_t = sentence_encoding.T
sentence_encoding = self.model.predict([sent_doc])[0] sentence_norm = xp.linalg.norm(sentence_encoding_t)
sentence_encoding_t = sentence_encoding.T entity_count += 1
sentence_norm = xp.linalg.norm(sentence_encoding_t) if ent.label_ in self.labels_discard:
entity_count += 1 # ignoring this entity - setting to NIL
if ent.label_ in self.labels_discard: final_kb_ids.append(self.NIL)
# ignoring this entity - setting to NIL else:
candidates = list(self.get_candidates(self.kb, ent))
if not candidates:
# no prediction possible for this entity - setting to NIL
final_kb_ids.append(self.NIL) final_kb_ids.append(self.NIL)
elif len(candidates) == 1:
# shortcut for efficiency reasons: take the 1 candidate
# TODO: thresholding
final_kb_ids.append(candidates[0].entity_)
else: else:
candidates = list(self.get_candidates(self.kb, ent)) random.shuffle(candidates)
if not candidates: # set all prior probabilities to 0 if incl_prior=False
# no prediction possible for this entity - setting to NIL prior_probs = xp.asarray([c.prior_prob for c in candidates])
final_kb_ids.append(self.NIL) if not self.incl_prior:
elif len(candidates) == 1: prior_probs = xp.asarray([0.0 for _ in candidates])
# shortcut for efficiency reasons: take the 1 candidate scores = prior_probs
# TODO: thresholding # add in similarity from the context
final_kb_ids.append(candidates[0].entity_) if self.incl_context:
else: entity_encodings = xp.asarray(
random.shuffle(candidates) [c.entity_vector for c in candidates]
# set all prior probabilities to 0 if incl_prior=False )
prior_probs = xp.asarray([c.prior_prob for c in candidates]) entity_norm = xp.linalg.norm(entity_encodings, axis=1)
if not self.incl_prior: if len(entity_encodings) != len(prior_probs):
prior_probs = xp.asarray([0.0 for _ in candidates]) raise RuntimeError(
scores = prior_probs Errors.E147.format(
# add in similarity from the context method="predict",
if self.incl_context: msg="vectors not of equal length",
entity_encodings = xp.asarray(
[c.entity_vector for c in candidates]
)
entity_norm = xp.linalg.norm(entity_encodings, axis=1)
if len(entity_encodings) != len(prior_probs):
raise RuntimeError(
Errors.E147.format(
method="predict",
msg="vectors not of equal length",
)
) )
# cosine similarity
sims = xp.dot(entity_encodings, sentence_encoding_t) / (
sentence_norm * entity_norm
) )
if sims.shape != prior_probs.shape: # cosine similarity
raise ValueError(Errors.E161) sims = xp.dot(entity_encodings, sentence_encoding_t) / (
scores = prior_probs + sims - (prior_probs * sims) sentence_norm * entity_norm
# TODO: thresholding )
best_index = scores.argmax().item() if sims.shape != prior_probs.shape:
best_candidate = candidates[best_index] raise ValueError(Errors.E161)
final_kb_ids.append(best_candidate.entity_) scores = prior_probs + sims - (prior_probs * sims)
# TODO: thresholding
best_index = scores.argmax().item()
best_candidate = candidates[best_index]
final_kb_ids.append(best_candidate.entity_)
if not (len(final_kb_ids) == entity_count): if not (len(final_kb_ids) == entity_count):
err = Errors.E147.format( err = Errors.E147.format(
method="predict", msg="result variables not of equal length" method="predict", msg="result variables not of equal length"