mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
introduce goldparse.links
This commit is contained in:
parent
a5c061f506
commit
0486ccabfd
|
@ -303,8 +303,7 @@ def read_training(nlp, training_dir, id_to_descr, doc_cutoff, dev, limit, to_pri
|
||||||
collect_correct=True,
|
collect_correct=True,
|
||||||
collect_incorrect=True)
|
collect_incorrect=True)
|
||||||
|
|
||||||
docs = list()
|
data = []
|
||||||
golds = list()
|
|
||||||
|
|
||||||
cnt = 0
|
cnt = 0
|
||||||
next_entity_nr = 1
|
next_entity_nr = 1
|
||||||
|
@ -323,7 +322,7 @@ def read_training(nlp, training_dir, id_to_descr, doc_cutoff, dev, limit, to_pri
|
||||||
article_doc = nlp(text)
|
article_doc = nlp(text)
|
||||||
truncated_text = text[0:min(doc_cutoff, len(text))]
|
truncated_text = text[0:min(doc_cutoff, len(text))]
|
||||||
|
|
||||||
gold_entities = dict()
|
gold_entities = list()
|
||||||
|
|
||||||
# process all positive and negative entities, collect all relevant mentions in this article
|
# process all positive and negative entities, collect all relevant mentions in this article
|
||||||
for mention, entity_pos in correct_entries[article_id].items():
|
for mention, entity_pos in correct_entries[article_id].items():
|
||||||
|
@ -337,11 +336,10 @@ def read_training(nlp, training_dir, id_to_descr, doc_cutoff, dev, limit, to_pri
|
||||||
|
|
||||||
# store gold entities
|
# store gold entities
|
||||||
for match_id, start, end in matches:
|
for match_id, start, end in matches:
|
||||||
gold_entities[(start, end, entity_pos)] = 1.0
|
gold_entities.append((start, end, entity_pos))
|
||||||
|
|
||||||
gold = GoldParse(doc=article_doc, cats=gold_entities)
|
gold = GoldParse(doc=article_doc, links=gold_entities)
|
||||||
docs.append(article_doc)
|
data.append((article_doc, gold))
|
||||||
golds.append(gold)
|
|
||||||
|
|
||||||
cnt += 1
|
cnt += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -352,7 +350,7 @@ def read_training(nlp, training_dir, id_to_descr, doc_cutoff, dev, limit, to_pri
|
||||||
print()
|
print()
|
||||||
print("Processed", cnt, "training articles, dev=" + str(dev))
|
print("Processed", cnt, "training articles, dev=" + str(dev))
|
||||||
print()
|
print()
|
||||||
return docs, golds
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
from spacy.util import minibatch, compounding
|
||||||
|
|
||||||
from examples.pipeline.wiki_entity_linking import wikipedia_processor as wp, kb_creator, training_set_creator, run_el
|
from examples.pipeline.wiki_entity_linking import wikipedia_processor as wp, kb_creator, training_set_creator, run_el
|
||||||
from examples.pipeline.wiki_entity_linking.train_el import EL_Model
|
from examples.pipeline.wiki_entity_linking.train_el import EL_Model
|
||||||
|
|
||||||
|
@ -26,6 +30,8 @@ TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
|
||||||
MAX_CANDIDATES = 10
|
MAX_CANDIDATES = 10
|
||||||
MIN_PAIR_OCC = 5
|
MIN_PAIR_OCC = 5
|
||||||
DOC_CHAR_CUTOFF = 300
|
DOC_CHAR_CUTOFF = 300
|
||||||
|
EPOCHS = 5
|
||||||
|
DROPOUT = 0.1
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print("START", datetime.datetime.now())
|
print("START", datetime.datetime.now())
|
||||||
|
@ -115,7 +121,7 @@ if __name__ == "__main__":
|
||||||
if train_pipe:
|
if train_pipe:
|
||||||
id_to_descr = kb_creator._get_id_to_description(ENTITY_DESCR)
|
id_to_descr = kb_creator._get_id_to_description(ENTITY_DESCR)
|
||||||
|
|
||||||
docs, golds = training_set_creator.read_training(nlp=nlp,
|
train_data = training_set_creator.read_training(nlp=nlp,
|
||||||
training_dir=TRAINING_DIR,
|
training_dir=TRAINING_DIR,
|
||||||
id_to_descr=id_to_descr,
|
id_to_descr=id_to_descr,
|
||||||
doc_cutoff=DOC_CHAR_CUTOFF,
|
doc_cutoff=DOC_CHAR_CUTOFF,
|
||||||
|
@ -123,12 +129,6 @@ if __name__ == "__main__":
|
||||||
limit=10,
|
limit=10,
|
||||||
to_print=False)
|
to_print=False)
|
||||||
|
|
||||||
# for doc, gold in zip(docs, golds):
|
|
||||||
# print("doc", doc)
|
|
||||||
# for entity, label in gold.cats.items():
|
|
||||||
# print("entity", entity, label)
|
|
||||||
# print()
|
|
||||||
|
|
||||||
el_pipe = nlp.create_pipe(name='entity_linker', config={"kb": my_kb})
|
el_pipe = nlp.create_pipe(name='entity_linker', config={"kb": my_kb})
|
||||||
nlp.add_pipe(el_pipe, last=True)
|
nlp.add_pipe(el_pipe, last=True)
|
||||||
|
|
||||||
|
@ -136,6 +136,20 @@ if __name__ == "__main__":
|
||||||
with nlp.disable_pipes(*other_pipes): # only train Entity Linking
|
with nlp.disable_pipes(*other_pipes): # only train Entity Linking
|
||||||
nlp.begin_training()
|
nlp.begin_training()
|
||||||
|
|
||||||
|
for itn in range(EPOCHS):
|
||||||
|
random.shuffle(train_data)
|
||||||
|
losses = {}
|
||||||
|
batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001))
|
||||||
|
for batch in batches:
|
||||||
|
docs, golds = zip(*batch)
|
||||||
|
nlp.update(
|
||||||
|
docs,
|
||||||
|
golds,
|
||||||
|
drop=DROPOUT,
|
||||||
|
losses=losses,
|
||||||
|
)
|
||||||
|
print("Losses", losses)
|
||||||
|
|
||||||
### BELOW CODE IS DEPRECATED ###
|
### BELOW CODE IS DEPRECATED ###
|
||||||
|
|
||||||
# STEP 6: apply the EL algorithm on the training dataset - TODO deprecated - code moved to pipes.pyx
|
# STEP 6: apply the EL algorithm on the training dataset - TODO deprecated - code moved to pipes.pyx
|
||||||
|
|
|
@ -31,6 +31,7 @@ cdef class GoldParse:
|
||||||
cdef public list ents
|
cdef public list ents
|
||||||
cdef public dict brackets
|
cdef public dict brackets
|
||||||
cdef public object cats
|
cdef public object cats
|
||||||
|
cdef public list links
|
||||||
|
|
||||||
cdef readonly list cand_to_gold
|
cdef readonly list cand_to_gold
|
||||||
cdef readonly list gold_to_cand
|
cdef readonly list gold_to_cand
|
||||||
|
|
|
@ -427,7 +427,7 @@ cdef class GoldParse:
|
||||||
|
|
||||||
def __init__(self, doc, annot_tuples=None, words=None, tags=None,
|
def __init__(self, doc, annot_tuples=None, words=None, tags=None,
|
||||||
heads=None, deps=None, entities=None, make_projective=False,
|
heads=None, deps=None, entities=None, make_projective=False,
|
||||||
cats=None, **_):
|
cats=None, links=None, **_):
|
||||||
"""Create a GoldParse.
|
"""Create a GoldParse.
|
||||||
|
|
||||||
doc (Doc): The document the annotations refer to.
|
doc (Doc): The document the annotations refer to.
|
||||||
|
@ -450,6 +450,8 @@ cdef class GoldParse:
|
||||||
examples of a label to have the value 0.0. Labels not in the
|
examples of a label to have the value 0.0. Labels not in the
|
||||||
dictionary are treated as missing - the gradient for those labels
|
dictionary are treated as missing - the gradient for those labels
|
||||||
will be zero.
|
will be zero.
|
||||||
|
links (iterable): A sequence of `(start_char, end_char, kb_id)` tuples,
|
||||||
|
representing the external ID of an entity in a knowledge base.
|
||||||
RETURNS (GoldParse): The newly constructed object.
|
RETURNS (GoldParse): The newly constructed object.
|
||||||
"""
|
"""
|
||||||
if words is None:
|
if words is None:
|
||||||
|
@ -485,6 +487,7 @@ cdef class GoldParse:
|
||||||
self.c.ner = <Transition*>self.mem.alloc(len(doc), sizeof(Transition))
|
self.c.ner = <Transition*>self.mem.alloc(len(doc), sizeof(Transition))
|
||||||
|
|
||||||
self.cats = {} if cats is None else dict(cats)
|
self.cats = {} if cats is None else dict(cats)
|
||||||
|
self.links = links
|
||||||
self.words = [None] * len(doc)
|
self.words = [None] * len(doc)
|
||||||
self.tags = [None] * len(doc)
|
self.tags = [None] * len(doc)
|
||||||
self.heads = [None] * len(doc)
|
self.heads = [None] * len(doc)
|
||||||
|
|
|
@ -1115,48 +1115,61 @@ class EntityLinker(Pipe):
|
||||||
self.sgd_mention = create_default_optimizer(self.mention_encoder.ops)
|
self.sgd_mention = create_default_optimizer(self.mention_encoder.ops)
|
||||||
|
|
||||||
def update(self, docs, golds, state=None, drop=0.0, sgd=None, losses=None):
|
def update(self, docs, golds, state=None, drop=0.0, sgd=None, losses=None):
|
||||||
""" docs should be a tuple of (entity_docs, article_docs, sentence_docs) TODO """
|
|
||||||
self.require_model()
|
self.require_model()
|
||||||
|
|
||||||
if len(docs) != len(golds):
|
if len(docs) != len(golds):
|
||||||
raise ValueError(Errors.E077.format(value="loss", n_docs=len(docs),
|
raise ValueError(Errors.E077.format(value="EL training", n_docs=len(docs),
|
||||||
n_golds=len(golds)))
|
n_golds=len(golds)))
|
||||||
|
|
||||||
entity_docs, article_docs, sentence_docs = docs
|
if isinstance(docs, Doc):
|
||||||
assert len(entity_docs) == len(article_docs) == len(sentence_docs)
|
docs = [docs]
|
||||||
|
golds = [golds]
|
||||||
|
|
||||||
if isinstance(entity_docs, Doc):
|
for doc, gold in zip(docs, golds):
|
||||||
entity_docs = [entity_docs]
|
print("doc", doc)
|
||||||
article_docs = [article_docs]
|
for entity in gold.links:
|
||||||
sentence_docs = [sentence_docs]
|
start, end, gold_kb = entity
|
||||||
|
print("entity", entity)
|
||||||
|
mention = doc[start:end].text
|
||||||
|
print("mention", mention)
|
||||||
|
candidates = self.kb.get_candidates(mention)
|
||||||
|
for c in candidates:
|
||||||
|
prior_prob = c.prior_prob
|
||||||
|
kb_id = c.entity_
|
||||||
|
print("candidate", kb_id, prior_prob)
|
||||||
|
entity_encoding = c.entity_vector
|
||||||
|
print()
|
||||||
|
|
||||||
entity_encodings = None #TODO
|
print()
|
||||||
doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop)
|
|
||||||
sent_encodings, bp_sent = self.sent_encoder.begin_update(sentence_docs, drop=drop)
|
|
||||||
|
|
||||||
concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in
|
# entity_encodings = None #TODO
|
||||||
range(len(article_docs))]
|
# doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop)
|
||||||
mention_encodings, bp_cont = self.mention_encoder.begin_update(np.asarray(concat_encodings), drop=self.DROP)
|
# sent_encodings, bp_sent = self.sent_encoder.begin_update(sentence_docs, drop=drop)
|
||||||
|
#
|
||||||
loss, d_scores = self.get_loss(scores=mention_encodings, golds=entity_encodings, docs=None)
|
# concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in
|
||||||
|
# range(len(article_docs))]
|
||||||
mention_gradient = bp_cont(d_scores, sgd=self.sgd_cont)
|
# mention_encodings, bp_cont = self.mention_encoder.begin_update(np.asarray(concat_encodings), drop=self.DROP)
|
||||||
|
#
|
||||||
# gradient : concat (doc+sent) vs. desc
|
# loss, d_scores = self.get_loss(scores=mention_encodings, golds=entity_encodings, docs=None)
|
||||||
sent_start = self.article_encoder.nO
|
#
|
||||||
sent_gradients = list()
|
# mention_gradient = bp_cont(d_scores, sgd=self.sgd_cont)
|
||||||
doc_gradients = list()
|
#
|
||||||
for x in mention_gradient:
|
# # gradient : concat (doc+sent) vs. desc
|
||||||
doc_gradients.append(list(x[0:sent_start]))
|
# sent_start = self.article_encoder.nO
|
||||||
sent_gradients.append(list(x[sent_start:]))
|
# sent_gradients = list()
|
||||||
|
# doc_gradients = list()
|
||||||
bp_doc(doc_gradients, sgd=self.sgd_article)
|
# for x in mention_gradient:
|
||||||
bp_sent(sent_gradients, sgd=self.sgd_sent)
|
# doc_gradients.append(list(x[0:sent_start]))
|
||||||
|
# sent_gradients.append(list(x[sent_start:]))
|
||||||
if losses is not None:
|
#
|
||||||
losses.setdefault(self.name, 0.0)
|
# bp_doc(doc_gradients, sgd=self.sgd_article)
|
||||||
losses[self.name] += loss
|
# bp_sent(sent_gradients, sgd=self.sgd_sent)
|
||||||
return loss
|
#
|
||||||
|
# if losses is not None:
|
||||||
|
# losses.setdefault(self.name, 0.0)
|
||||||
|
# losses[self.name] += loss
|
||||||
|
# return loss
|
||||||
|
return None
|
||||||
|
|
||||||
def get_loss(self, docs, golds, scores):
|
def get_loss(self, docs, golds, scores):
|
||||||
loss, gradients = get_cossim_loss(scores, golds)
|
loss, gradients = get_cossim_loss(scores, golds)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user