diff --git a/examples/pipeline/wiki_entity_linking/train_descriptions.py b/examples/pipeline/wiki_entity_linking/train_descriptions.py index 92859fd84..bf4bcbc3d 100644 --- a/examples/pipeline/wiki_entity_linking/train_descriptions.py +++ b/examples/pipeline/wiki_entity_linking/train_descriptions.py @@ -1,8 +1,6 @@ # coding: utf-8 from random import shuffle -from examples.pipeline.wiki_entity_linking import kb_creator - import numpy as np from spacy._ml import zero_init, create_default_optimizer diff --git a/examples/pipeline/wiki_entity_linking/training_set_creator.py b/examples/pipeline/wiki_entity_linking/training_set_creator.py index 38a86058d..fc620a1d3 100644 --- a/examples/pipeline/wiki_entity_linking/training_set_creator.py +++ b/examples/pipeline/wiki_entity_linking/training_set_creator.py @@ -19,17 +19,15 @@ Process Wikipedia interlinks to generate a training dataset for the EL algorithm ENTITY_FILE = "gold_entities.csv" -def create_training(kb, entity_def_input, training_output): - if not kb: - raise ValueError("kb should be defined") +def create_training(entity_def_input, training_output): wp_to_id = kb_creator._get_entity_to_id(entity_def_input) - _process_wikipedia_texts(kb, wp_to_id, training_output, limit=100000000) # TODO: full dataset + _process_wikipedia_texts(wp_to_id, training_output, limit=100000000) # TODO: full dataset 100000000 -def _process_wikipedia_texts(kb, wp_to_id, training_output, limit=None): +def _process_wikipedia_texts(wp_to_id, training_output, limit=None): """ Read the XML wikipedia data to parse out training data: - raw text data + positive and negative instances + raw text data + positive instances """ title_regex = re.compile(r'(?<=).*(?=)') @@ -43,8 +41,9 @@ def _process_wikipedia_texts(kb, wp_to_id, training_output, limit=None): _write_training_entity(outputfile=entityfile, article_id="article_id", alias="alias", - entity="entity", - correct="correct") + entity="WD_id", + start="start", + end="end") with bz2.open(wp.ENWIKI_DUMP, mode='rb') as file: line = file.readline() @@ -75,14 +74,11 @@ def _process_wikipedia_texts(kb, wp_to_id, training_output, limit=None): elif clean_line == "": if article_id: try: - _process_wp_text(kb, wp_to_id, entityfile, article_id, article_text.strip(), training_output) - # on a previous run, an error occurred after 46M lines and 2h + _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text.strip(), training_output) except Exception as e: print("Error processing article", article_id, article_title, e) else: - print("Done processing a page, but couldn't find an article_id ?") - print(article_title) - print(article_text) + print("Done processing a page, but couldn't find an article_id ?", article_title) article_text = "" article_title = None article_id = None @@ -122,7 +118,14 @@ def _process_wikipedia_texts(kb, wp_to_id, training_output, limit=None): text_regex = re.compile(r'(?<=).*(?=", entity) - candidates = kb.get_candidates(alias) + if open_read > 2: + reading_special_case = True - # as training data, we only store entities that are sufficiently ambiguous - if len(candidates) > 1: - _write_training_article(article_id=article_id, clean_text=clean_text, training_output=training_output) - # print("alias", alias) + if open_read == 2 and reading_text: + reading_text = False + reading_entity = True + reading_mention = False - # print all incorrect candidates - for c in candidates: - if entity != c.entity_: + # we just finished reading an entity + if open_read == 0 and not reading_text: + if '#' in entity_buffer or entity_buffer.startswith(':'): + reading_special_case = True + # Ignore cases with nested structures like File: handles etc + if not reading_special_case: + if not mention_buffer: + mention_buffer = entity_buffer + start = len(final_text) + end = start + len(mention_buffer) + qid = wp_to_id.get(entity_buffer, None) + if qid: _write_training_entity(outputfile=entityfile, article_id=article_id, - alias=alias, - entity=c.entity_, - correct="0") + alias=mention_buffer, + entity=qid, + start=start, + end=end) + found_entities = True + final_text += mention_buffer - # print the one correct candidate - _write_training_entity(outputfile=entityfile, - article_id=article_id, - alias=alias, - entity=entity, - correct="1") + entity_buffer = "" + mention_buffer = "" - # print("gold entity", entity) - # print() + reading_text = True + reading_entity = False + reading_mention = False + reading_special_case = False - # _run_ner_depr(nlp, clean_text, article_dict) - # print() + if found_entities: + _write_training_article(article_id=article_id, clean_text=final_text, training_output=training_output) info_regex = re.compile(r'{[^{]*?}') -interwiki_regex = re.compile(r'\[\[([^|]*?)]]') -interwiki_2_regex = re.compile(r'\[\[[^|]*?\|([^|]*?)]]') -htlm_regex = re.compile(r'<!--[^!]*-->') +htlm_regex = re.compile(r'<!--[^-]*-->') category_regex = re.compile(r'\[\[Category:[^\[]*]]') file_regex = re.compile(r'\[\[File:[^[\]]+]]') ref_regex = re.compile(r'<ref.*?>') # non-greedy @@ -215,12 +242,6 @@ def _get_clean_wp_text(article_text): try_again = False previous_length = len(clean_text) - # remove simple interwiki links (no alternative name) - clean_text = interwiki_regex.sub(r'\1', clean_text) - - # remove simple interwiki links by picking the alternative name - clean_text = interwiki_2_regex.sub(r'\1', clean_text) - # remove HTML comments clean_text = htlm_regex.sub('', clean_text) @@ -265,43 +286,34 @@ def _write_training_article(article_id, clean_text, training_output): outputfile.write(clean_text) -def _write_training_entity(outputfile, article_id, alias, entity, correct): - outputfile.write(article_id + "|" + alias + "|" + entity + "|" + correct + "\n") +def _write_training_entity(outputfile, article_id, alias, entity, start, end): + outputfile.write(article_id + "|" + alias + "|" + entity + "|" + str(start) + "|" + str(end) + "\n") -def read_training_entities(training_output, collect_correct=True, collect_incorrect=False): +def read_training_entities(training_output): entityfile_loc = training_output + "/" + ENTITY_FILE - incorrect_entries_per_article = dict() - correct_entries_per_article = dict() + entries_per_article = dict() + with open(entityfile_loc, mode='r', encoding='utf8') as file: for line in file: fields = line.replace('\n', "").split(sep='|') article_id = fields[0] alias = fields[1] - entity = fields[2] - correct = fields[3] + wp_title = fields[2] + start = fields[3] + end = fields[4] - if correct == "1" and collect_correct: - entry_dict = correct_entries_per_article.get(article_id, dict()) - if alias in entry_dict: - raise ValueError("Found alias", alias, "multiple times for article", article_id, "in", ENTITY_FILE) - entry_dict[alias] = entity - correct_entries_per_article[article_id] = entry_dict + entries_by_offset = entries_per_article.get(article_id, dict()) + entries_by_offset[start + "-" + end] = (alias, wp_title) - if correct == "0" and collect_incorrect: - entry_dict = incorrect_entries_per_article.get(article_id, dict()) - entities = entry_dict.get(alias, set()) - entities.add(entity) - entry_dict[alias] = entities - incorrect_entries_per_article[article_id] = entry_dict + entries_per_article[article_id] = entries_by_offset - return correct_entries_per_article, incorrect_entries_per_article + return entries_per_article def read_training(nlp, training_dir, dev, limit, to_print): - correct_entries, incorrect_entries = read_training_entities(training_output=training_dir, - collect_correct=True, - collect_incorrect=True) + # This method will provide training examples that correspond to the entity annotations found by the nlp object + entries_per_article = read_training_entities(training_output=training_dir) data = [] @@ -320,36 +332,33 @@ def read_training(nlp, training_dir, dev, limit, to_print): text = file.read() article_doc = nlp(text) + entries_by_offset = entries_per_article.get(article_id, dict()) + gold_entities = list() + for ent in article_doc.ents: + start = ent.start_char + end = ent.end_char - # process all positive and negative entities, collect all relevant mentions in this article - for mention, entity_pos in correct_entries[article_id].items(): - # find all matches in the doc for the mentions - # TODO: fix this - doesn't look like all entities are found - matcher = PhraseMatcher(nlp.vocab) - patterns = list(nlp.tokenizer.pipe([mention])) + entity_tuple = entries_by_offset.get(str(start) + "-" + str(end), None) + if entity_tuple: + alias, wp_title = entity_tuple + if ent.text != alias: + print("Non-matching entity in", article_id, start, end) + else: + gold_entities.append((start, end, wp_title)) - matcher.add("TerminologyList", None, *patterns) - matches = matcher(article_doc) - - # store gold entities - for match_id, start, end in matches: - gold_entities.append((start, end, entity_pos)) - - gold = GoldParse(doc=article_doc, links=gold_entities) - data.append((article_doc, gold)) + if gold_entities: + gold = GoldParse(doc=article_doc, links=gold_entities) + data.append((article_doc, gold)) cnt += 1 except Exception as e: print("Problem parsing article", article_id) print(e) + raise e if to_print: print() print("Processed", cnt, "training articles, dev=" + str(dev)) print() return data - - - - diff --git a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py index d5002e26f..faea93f53 100644 --- a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py +++ b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py @@ -30,8 +30,8 @@ TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/' MAX_CANDIDATES = 10 MIN_PAIR_OCC = 5 -DOC_CHAR_CUTOFF = 300 -EPOCHS = 2 +DOC_SENT_CUTOFF = 2 +EPOCHS = 10 DROPOUT = 0.1 @@ -46,14 +46,14 @@ def run_pipeline(): # one-time methods to create KB and write to file to_create_prior_probs = False to_create_entity_counts = False - to_create_kb = True + to_create_kb = False # TODO: entity_defs should also contain entities not in the KB # read KB back in from file - to_read_kb = True - to_test_kb = True + to_read_kb = False + to_test_kb = False # create training dataset - create_wp_training = False + create_wp_training = True # train the EL pipe train_pipe = False @@ -103,9 +103,6 @@ def run_pipeline(): # STEP 4 : read KB back in from file if to_read_kb: print("STEP 4: to_read_kb", datetime.datetime.now()) - # my_vocab = Vocab() - # my_vocab.from_disk(VOCAB_DIR) - # my_kb = KnowledgeBase(vocab=my_vocab, entity_vector_length=64) nlp_2 = spacy.load(NLP_1_DIR) kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH) kb_2.load_bulk(KB_FILE) @@ -121,13 +118,13 @@ def run_pipeline(): # STEP 5: create a training dataset from WP if create_wp_training: print("STEP 5: create training dataset", datetime.datetime.now()) - training_set_creator.create_training(kb=kb_2, entity_def_input=ENTITY_DEFS, training_output=TRAINING_DIR) + training_set_creator.create_training(entity_def_input=ENTITY_DEFS, training_output=TRAINING_DIR) # STEP 6: create the entity linking pipe if train_pipe: print("STEP 6: training Entity Linking pipe", datetime.datetime.now()) - train_limit = 10 - dev_limit = 5 + train_limit = 50 + dev_limit = 10 print("Training on", train_limit, "articles") print("Dev testing on", dev_limit, "articles") print() @@ -144,7 +141,7 @@ def run_pipeline(): limit=dev_limit, to_print=False) - el_pipe = nlp_2.create_pipe(name='entity_linker', config={"doc_cutoff": DOC_CHAR_CUTOFF}) + el_pipe = nlp_2.create_pipe(name='entity_linker', config={"doc_cutoff": DOC_SENT_CUTOFF}) el_pipe.set_kb(kb_2) nlp_2.add_pipe(el_pipe, last=True) @@ -199,10 +196,14 @@ def run_pipeline(): el_pipe.prior_weight = 0 dev_acc_1_0 = _measure_accuracy(dev_data, el_pipe) train_acc_1_0 = _measure_accuracy(train_data, el_pipe) - print("train/dev acc context:", round(train_acc_1_0, 2), round(dev_acc_1_0, 2)) print() + # reset for follow-up tests + el_pipe.context_weight = 1 + el_pipe.prior_weight = 1 + + if to_test_pipeline: print() print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now()) @@ -215,17 +216,10 @@ def run_pipeline(): print("STEP 9: testing NLP IO", datetime.datetime.now()) print() print("writing to", NLP_2_DIR) - print(" vocab len nlp_2", len(nlp_2.vocab)) - print(" vocab len kb_2", len(kb_2.vocab)) nlp_2.to_disk(NLP_2_DIR) print() print("reading from", NLP_2_DIR) nlp_3 = spacy.load(NLP_2_DIR) - print(" vocab len nlp_3", len(nlp_3.vocab)) - - for pipe_name, pipe in nlp_3.pipeline: - if pipe_name == "entity_linker": - print(" vocab len kb_3", len(pipe.kb.vocab)) print() print("running toy example with NLP 2") @@ -253,9 +247,10 @@ def _measure_accuracy(data, el_pipe): for ent in doc.ents: if ent.label_ == "PERSON": # TODO: expand to other types pred_entity = ent.kb_id_ - start = ent.start - end = ent.end + start = ent.start_char + end = ent.end_char gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None) + # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' if gold_entity is not None: if gold_entity == pred_entity: correct += 1 @@ -285,7 +280,8 @@ def run_el_toy_example(nlp): print() # Q4426480 is her husband, Q3568763 her tutor - text = "Ada Lovelace loved her husband William King dearly. " \ + text = "Ada Lovelace was the countess of Lovelace. She is known for her programming work on the analytical engine."\ + "Ada Lovelace loved her husband William King dearly. " \ "Ada Lovelace was tutored by her favorite physics tutor William King." doc = nlp(text) diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index e73ff6a0e..5d82da7ee 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -1074,6 +1074,9 @@ class EntityLinker(Pipe): @classmethod def Model(cls, **cfg): + if "entity_width" not in cfg: + raise ValueError("entity_width not found") + embed_width = cfg.get("embed_width", 300) hidden_width = cfg.get("hidden_width", 32) article_width = cfg.get("article_width", 128) @@ -1095,7 +1098,10 @@ class EntityLinker(Pipe): self.mention_encoder = True self.kb = None self.cfg = dict(cfg) - self.doc_cutoff = self.cfg.get("doc_cutoff", 150) + self.doc_cutoff = self.cfg.get("doc_cutoff", 5) + self.sgd_article = None + self.sgd_sent = None + self.sgd_mention = None def set_kb(self, kb): self.kb = kb @@ -1126,6 +1132,12 @@ class EntityLinker(Pipe): self.require_model() self.require_kb() + if losses is not None: + losses.setdefault(self.name, 0.0) + + if not docs or not golds: + return 0 + if len(docs) != len(golds): raise ValueError(Errors.E077.format(value="EL training", n_docs=len(docs), n_golds=len(golds))) @@ -1141,21 +1153,30 @@ class EntityLinker(Pipe): for doc, gold in zip(docs, golds): for entity in gold.links: start, end, gold_kb = entity - mention = doc[start:end] - sentence = mention.sent - first_par = doc[0:self.doc_cutoff].as_doc() + mention = doc.text[start:end] + sent_start = 0 + sent_end = len(doc) + first_par_end = len(doc) + for index, sent in enumerate(doc.sents): + if start >= sent.start_char and end <= sent.end_char: + sent_start = sent.start + sent_end = sent.end + if index == self.doc_cutoff-1: + first_par_end = sent.end + sentence = doc[sent_start:sent_end].as_doc() + first_par = doc[0:first_par_end].as_doc() - candidates = self.kb.get_candidates(mention.text) + candidates = self.kb.get_candidates(mention) for c in candidates: kb_id = c.entity_ - # TODO: currently only training on the positive instances + # Currently only training on the positive instances if kb_id == gold_kb: prior_prob = c.prior_prob entity_encoding = c.entity_vector entity_encodings.append(entity_encoding) article_docs.append(first_par) - sentence_docs.append(sentence.as_doc()) + sentence_docs.append(sentence) if len(entity_encodings) > 0: doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop) @@ -1168,11 +1189,6 @@ class EntityLinker(Pipe): entity_encodings = np.asarray(entity_encodings, dtype=np.float32) loss, d_scores = self.get_loss(scores=mention_encodings, golds=entity_encodings, docs=None) - # print("scores", mention_encodings) - # print("golds", entity_encodings) - # print("loss", loss) - # print("d_scores", d_scores) - mention_gradient = bp_mention(d_scores, sgd=self.sgd_mention) # gradient : concat (doc+sent) vs. desc @@ -1187,7 +1203,6 @@ class EntityLinker(Pipe): bp_sent(sent_gradients, sgd=self.sgd_sent) if losses is not None: - losses.setdefault(self.name, 0.0) losses[self.name] += loss return loss @@ -1230,16 +1245,25 @@ class EntityLinker(Pipe): self.require_model() self.require_kb() - if isinstance(docs, Doc): - docs = [docs] - final_entities = list() final_kb_ids = list() - for i, article_doc in enumerate(docs): - if len(article_doc) > 0: - doc_encoding = self.article_encoder([article_doc]) - for ent in article_doc.ents: + if not docs: + return final_entities, final_kb_ids + + if isinstance(docs, Doc): + docs = [docs] + + for i, doc in enumerate(docs): + if len(doc) > 0: + first_par_end = len(doc) + for index, sent in enumerate(doc.sents): + if index == self.doc_cutoff-1: + first_par_end = sent.end + first_par = doc[0:first_par_end].as_doc() + + doc_encoding = self.article_encoder([first_par]) + for ent in doc.ents: sent_doc = ent.sent.as_doc() if len(sent_doc) > 0: sent_encoding = self.sent_encoder([sent_doc]) @@ -1254,7 +1278,7 @@ class EntityLinker(Pipe): prior_prob = c.prior_prob * self.prior_weight kb_id = c.entity_ entity_encoding = c.entity_vector - sim = cosine(np.asarray([entity_encoding]), mention_enc_t) * self.context_weight + sim = float(cosine(np.asarray([entity_encoding]), mention_enc_t)) * self.context_weight score = prior_prob + sim - (prior_prob*sim) # put weights on the different factors ? scores.append(score) @@ -1271,36 +1295,7 @@ class EntityLinker(Pipe): for token in entity: token.ent_kb_id_ = kb_id - def to_bytes(self, exclude=tuple(), **kwargs): - """Serialize the pipe to a bytestring. - - exclude (list): String names of serialization fields to exclude. - RETURNS (bytes): The serialized object. - """ - serialize = OrderedDict() - serialize["cfg"] = lambda: srsly.json_dumps(self.cfg) - serialize["kb"] = self.kb.to_bytes # TODO - if self.mention_encoder not in (True, False, None): - serialize["article_encoder"] = self.article_encoder.to_bytes - serialize["sent_encoder"] = self.sent_encoder.to_bytes - serialize["mention_encoder"] = self.mention_encoder.to_bytes - exclude = util.get_serialization_exclude(serialize, exclude, kwargs) - return util.to_bytes(serialize, exclude) - - def from_bytes(self, bytes_data, exclude=tuple(), **kwargs): - """Load the pipe from a bytestring.""" - deserialize = OrderedDict() - deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b)) - deserialize["kb"] = lambda b: self.kb.from_bytes(b) # TODO - deserialize["article_encoder"] = lambda b: self.article_encoder.from_bytes(b) - deserialize["sent_encoder"] = lambda b: self.sent_encoder.from_bytes(b) - deserialize["mention_encoder"] = lambda b: self.mention_encoder.from_bytes(b) - exclude = util.get_serialization_exclude(deserialize, exclude, kwargs) - util.from_bytes(bytes_data, deserialize, exclude) - return self - def to_disk(self, path, exclude=tuple(), **kwargs): - """Serialize the pipe to disk.""" serialize = OrderedDict() serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg) serialize["kb"] = lambda p: self.kb.dump(p) @@ -1312,7 +1307,6 @@ class EntityLinker(Pipe): util.to_disk(path, serialize, exclude) def from_disk(self, path, exclude=tuple(), **kwargs): - """Load the pipe from disk.""" def load_article_encoder(p): if self.article_encoder is True: self.article_encoder, _, _ = self.Model(**self.cfg)