mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
improve speed of prediction loop
This commit is contained in:
parent
bee23cd8af
commit
1de61f68d6
|
@ -76,7 +76,7 @@ def run_pipeline():
|
||||||
|
|
||||||
# write the NLP object, read back in and test again
|
# write the NLP object, read back in and test again
|
||||||
to_write_nlp = False
|
to_write_nlp = False
|
||||||
to_read_nlp = True
|
to_read_nlp = False
|
||||||
test_from_file = True
|
test_from_file = True
|
||||||
|
|
||||||
# STEP 1 : create prior probabilities from WP (run only once)
|
# STEP 1 : create prior probabilities from WP (run only once)
|
||||||
|
@ -252,22 +252,27 @@ def run_pipeline():
|
||||||
print("reading from", NLP_2_DIR)
|
print("reading from", NLP_2_DIR)
|
||||||
nlp_3 = spacy.load(NLP_2_DIR)
|
nlp_3 = spacy.load(NLP_2_DIR)
|
||||||
|
|
||||||
if test_from_file:
|
print("running toy example with NLP 3")
|
||||||
dev_limit = 5000
|
run_el_toy_example(nlp=nlp_3)
|
||||||
dev_data = training_set_creator.read_training(nlp=nlp_3,
|
|
||||||
training_dir=TRAINING_DIR,
|
|
||||||
dev=True,
|
|
||||||
limit=dev_limit)
|
|
||||||
|
|
||||||
print("Dev testing from file on", len(dev_data), "articles")
|
# testing performance with an NLP model from file
|
||||||
print()
|
if test_from_file:
|
||||||
|
nlp_2 = spacy.load(NLP_1_DIR)
|
||||||
|
nlp_3 = spacy.load(NLP_2_DIR)
|
||||||
|
el_pipe = nlp_3.get_pipe("entity_linker")
|
||||||
|
|
||||||
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data)
|
dev_limit = 10000
|
||||||
print("dev acc combo avg:", round(dev_acc_combo, 3),
|
dev_data = training_set_creator.read_training(nlp=nlp_2,
|
||||||
[(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
|
training_dir=TRAINING_DIR,
|
||||||
else:
|
dev=True,
|
||||||
print("running toy example with NLP 3")
|
limit=dev_limit)
|
||||||
run_el_toy_example(nlp=nlp_3)
|
|
||||||
|
print("Dev testing from file on", len(dev_data), "articles")
|
||||||
|
print()
|
||||||
|
|
||||||
|
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe=el_pipe)
|
||||||
|
print("dev acc combo avg:", round(dev_acc_combo, 3),
|
||||||
|
[(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print("STOP", datetime.datetime.now())
|
print("STOP", datetime.datetime.now())
|
||||||
|
@ -280,7 +285,9 @@ def _measure_accuracy(data, el_pipe=None):
|
||||||
|
|
||||||
docs = [d for d, g in data if len(d) > 0]
|
docs = [d for d, g in data if len(d) > 0]
|
||||||
if el_pipe is not None:
|
if el_pipe is not None:
|
||||||
docs = el_pipe.pipe(docs)
|
print("applying el_pipe", datetime.datetime.now())
|
||||||
|
docs = list(el_pipe.pipe(docs, batch_size=10000000000))
|
||||||
|
print("done applying el_pipe", datetime.datetime.now())
|
||||||
golds = [g for d, g in data if len(d) > 0]
|
golds = [g for d, g in data if len(d) > 0]
|
||||||
|
|
||||||
for doc, gold in zip(docs, golds):
|
for doc, gold in zip(docs, golds):
|
||||||
|
|
|
@ -3,8 +3,6 @@
|
||||||
# coding: utf8
|
# coding: utf8
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import srsly
|
import srsly
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
@ -12,6 +10,7 @@ from thinc.api import chain
|
||||||
from thinc.v2v import Affine, Maxout, Softmax
|
from thinc.v2v import Affine, Maxout, Softmax
|
||||||
from thinc.misc import LayerNorm
|
from thinc.misc import LayerNorm
|
||||||
from thinc.neural.util import to_categorical
|
from thinc.neural.util import to_categorical
|
||||||
|
from thinc.neural.util import get_array_module
|
||||||
|
|
||||||
from ..cli.pretrain import get_cossim_loss
|
from ..cli.pretrain import get_cossim_loss
|
||||||
from .functions import merge_subtokens
|
from .functions import merge_subtokens
|
||||||
|
@ -1151,7 +1150,7 @@ class EntityLinker(Pipe):
|
||||||
|
|
||||||
if len(entity_encodings) > 0:
|
if len(entity_encodings) > 0:
|
||||||
context_encodings, bp_context = self.model.begin_update(context_docs, drop=drop)
|
context_encodings, bp_context = self.model.begin_update(context_docs, drop=drop)
|
||||||
entity_encodings = np.asarray(entity_encodings, dtype=np.float32)
|
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
|
||||||
|
|
||||||
loss, d_scores = self.get_loss(scores=context_encodings, golds=entity_encodings, docs=None)
|
loss, d_scores = self.get_loss(scores=context_encodings, golds=entity_encodings, docs=None)
|
||||||
bp_context(d_scores, sgd=sgd)
|
bp_context(d_scores, sgd=sgd)
|
||||||
|
@ -1192,24 +1191,30 @@ class EntityLinker(Pipe):
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
|
|
||||||
|
context_encodings = self.model(docs)
|
||||||
|
xp = get_array_module(context_encodings)
|
||||||
|
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
if len(doc) > 0:
|
if len(doc) > 0:
|
||||||
context_encoding = self.model([doc])
|
context_encoding = context_encodings[i]
|
||||||
context_enc_t = np.transpose(context_encoding)
|
context_enc_t = context_encoding.T
|
||||||
|
norm_1 = xp.linalg.norm(context_enc_t)
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
candidates = self.kb.get_candidates(ent.text)
|
candidates = self.kb.get_candidates(ent.text)
|
||||||
if candidates:
|
if candidates:
|
||||||
scores = []
|
prior_probs = xp.asarray([c.prior_prob for c in candidates])
|
||||||
for c in candidates:
|
prior_probs *= self.prior_weight
|
||||||
prior_prob = c.prior_prob * self.prior_weight
|
|
||||||
kb_id = c.entity_
|
entity_encodings = xp.asarray([c.entity_vector for c in candidates])
|
||||||
entity_encoding = c.entity_vector
|
norm_2 = xp.linalg.norm(entity_encodings, axis=1)
|
||||||
sim = float(cosine(np.asarray([entity_encoding]), context_enc_t)) * self.context_weight
|
|
||||||
score = prior_prob + sim - (prior_prob*sim)
|
# cosine similarity
|
||||||
scores.append(score)
|
sims = xp.dot(entity_encodings, context_enc_t) / (norm_1 * norm_2)
|
||||||
|
sims *= self.context_weight
|
||||||
|
scores = prior_probs + sims - (prior_probs*sims)
|
||||||
|
best_index = scores.argmax()
|
||||||
|
|
||||||
# TODO: thresholding
|
# TODO: thresholding
|
||||||
best_index = scores.index(max(scores))
|
|
||||||
best_candidate = candidates[best_index]
|
best_candidate = candidates[best_index]
|
||||||
final_entities.append(ent)
|
final_entities.append(ent)
|
||||||
final_kb_ids.append(best_candidate.entity_)
|
final_kb_ids.append(best_candidate.entity_)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user