mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
adding prior probability as feature in the model
This commit is contained in:
parent
1c80b85241
commit
c664f58246
|
@ -45,7 +45,7 @@ EPOCHS = 10
|
|||
DROPOUT = 0.2
|
||||
LEARN_RATE = 0.005
|
||||
L2 = 1e-6
|
||||
CONTEXT_WIDTH=128
|
||||
CONTEXT_WIDTH = 128
|
||||
|
||||
|
||||
def run_pipeline():
|
||||
|
@ -138,7 +138,9 @@ def run_pipeline():
|
|||
# STEP 6: create and train the entity linking pipe
|
||||
if train_pipe:
|
||||
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
|
||||
el_pipe = nlp_2.create_pipe(name='entity_linker', config={"context_width": CONTEXT_WIDTH})
|
||||
el_pipe = nlp_2.create_pipe(name='entity_linker',
|
||||
config={"context_width": CONTEXT_WIDTH,
|
||||
"pretrained_vectors": nlp_2.vocab.vectors.name})
|
||||
el_pipe.set_kb(kb_2)
|
||||
nlp_2.add_pipe(el_pipe, last=True)
|
||||
|
||||
|
@ -195,11 +197,11 @@ def run_pipeline():
|
|||
if batchnr > 0:
|
||||
with el_pipe.model.use_params(optimizer.averages):
|
||||
el_pipe.context_weight = 1
|
||||
el_pipe.prior_weight = 0
|
||||
el_pipe.prior_weight = 1
|
||||
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
|
||||
losses['entity_linker'] = losses['entity_linker'] / batchnr
|
||||
print("Epoch, train loss", itn, round(losses['entity_linker'], 2),
|
||||
" / dev acc context avg", round(dev_acc_context, 3))
|
||||
" / dev acc avg", round(dev_acc_context, 3))
|
||||
|
||||
# STEP 7: measure the performance of our trained pipe on an independent dev set
|
||||
if len(dev_data) and measure_performance:
|
||||
|
|
|
@ -666,15 +666,16 @@ def build_nel_encoder(embed_width, hidden_width, **cfg):
|
|||
entity_width = cfg.get("entity_width")
|
||||
|
||||
with Model.define_operators({">>": chain, "**": clone}):
|
||||
model = Affine(1, entity_width+context_width, drop_factor=0.0)\
|
||||
model = Affine(entity_width, entity_width+context_width+1)\
|
||||
>> Affine(1, entity_width, drop_factor=0.0)\
|
||||
>> logistic
|
||||
|
||||
# context encoder
|
||||
tok2vec = Tok2Vec(width=hidden_width, embed_size=embed_width, pretrained_vectors=pretrained_vectors,
|
||||
cnn_maxout_pieces=cnn_maxout_pieces, subword_features=False, conv_depth=conv_depth,
|
||||
cnn_maxout_pieces=cnn_maxout_pieces, subword_features=True, conv_depth=conv_depth,
|
||||
bilstm_depth=0) >> flatten_add_lengths >> Pooling(mean_pool)\
|
||||
>> Residual(zero_init(Maxout(hidden_width, hidden_width))) \
|
||||
>> zero_init(Affine(context_width, hidden_width, drop_factor=0.0))
|
||||
>> zero_init(Affine(context_width, hidden_width))
|
||||
|
||||
model.tok2vec = tok2vec
|
||||
|
||||
|
|
|
@ -1132,7 +1132,8 @@ class EntityLinker(Pipe):
|
|||
|
||||
context_docs = []
|
||||
entity_encodings = []
|
||||
labels = []
|
||||
cats = []
|
||||
priors = []
|
||||
|
||||
for doc, gold in zip(docs, golds):
|
||||
for entity in gold.links:
|
||||
|
@ -1143,27 +1144,33 @@ class EntityLinker(Pipe):
|
|||
nr_neg = 0
|
||||
for c in candidates:
|
||||
kb_id = c.entity_
|
||||
entity_encoding = c.entity_vector
|
||||
entity_encodings.append(entity_encoding)
|
||||
context_docs.append(doc)
|
||||
|
||||
if self.prior_weight > 0:
|
||||
priors.append([c.prior_prob])
|
||||
else:
|
||||
priors.append([0])
|
||||
|
||||
if kb_id == gold_kb:
|
||||
entity_encoding = c.entity_vector
|
||||
entity_encodings.append(entity_encoding)
|
||||
context_docs.append(doc)
|
||||
labels.append([1])
|
||||
else: # elif nr_neg < 1:
|
||||
cats.append([1])
|
||||
else:
|
||||
nr_neg += 1
|
||||
entity_encoding = c.entity_vector
|
||||
entity_encodings.append(entity_encoding)
|
||||
context_docs.append(doc)
|
||||
labels.append([0])
|
||||
cats.append([0])
|
||||
|
||||
if len(entity_encodings) > 0:
|
||||
assert len(priors) == len(entity_encodings) == len(context_docs) == len(cats)
|
||||
|
||||
context_encodings, bp_context = self.model.tok2vec.begin_update(context_docs, drop=drop)
|
||||
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
|
||||
|
||||
mention_encodings = [list(context_encodings[i]) + list(entity_encodings[i]) for i in range(len(entity_encodings))]
|
||||
mention_encodings = [list(context_encodings[i]) + list(entity_encodings[i]) + priors[i]
|
||||
for i in range(len(entity_encodings))]
|
||||
pred, bp_mention = self.model.begin_update(self.model.ops.asarray(mention_encodings, dtype="float32"), drop=drop)
|
||||
labels = self.model.ops.asarray(labels, dtype="float32")
|
||||
cats = self.model.ops.asarray(cats, dtype="float32")
|
||||
|
||||
loss, d_scores = self.get_loss(prediction=pred, golds=labels, docs=None)
|
||||
loss, d_scores = self.get_loss(prediction=pred, golds=cats, docs=None)
|
||||
mention_gradient = bp_mention(d_scores, sgd=sgd)
|
||||
|
||||
context_gradients = [list(x[0:self.context_width]) for x in mention_gradient]
|
||||
|
@ -1221,13 +1228,19 @@ class EntityLinker(Pipe):
|
|||
candidates = self.kb.get_candidates(ent.text)
|
||||
if candidates:
|
||||
random.shuffle(candidates)
|
||||
|
||||
# this will set the prior probabilities to 0 (just like in training) if their weight is 0
|
||||
prior_probs = xp.asarray([[c.prior_prob] for c in candidates])
|
||||
prior_probs *= self.prior_weight
|
||||
scores = prior_probs
|
||||
|
||||
entity_encodings = xp.asarray([c.entity_vector for c in candidates])
|
||||
mention_encodings = [list(context_encoding) + list(entity_encodings[i]) for i in range(len(entity_encodings))]
|
||||
predictions = self.model(self.model.ops.asarray(mention_encodings, dtype="float32"))
|
||||
scores = (prior_probs + predictions - (xp.dot(prior_probs.T, predictions)))
|
||||
if self.context_weight > 0:
|
||||
entity_encodings = xp.asarray([c.entity_vector for c in candidates])
|
||||
assert len(entity_encodings) == len(prior_probs)
|
||||
mention_encodings = [list(context_encoding) + list(entity_encodings[i])
|
||||
+ list(prior_probs[i])
|
||||
for i in range(len(entity_encodings))]
|
||||
scores = self.model(self.model.ops.asarray(mention_encodings, dtype="float32"))
|
||||
|
||||
# TODO: thresholding
|
||||
best_index = scores.argmax()
|
||||
|
|
Loading…
Reference in New Issue
Block a user