mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	Merge branch 'master' into spacy.io
This commit is contained in:
		
						commit
						4361da2bba
					
				|  | @ -13,9 +13,17 @@ INPUT_DIM = 300  # dimension of pre-trained input vectors | |||
| DESC_WIDTH = 64  # dimension of output entity vectors | ||||
| 
 | ||||
| 
 | ||||
| def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ, | ||||
|               entity_def_output, entity_descr_output, | ||||
|               count_input, prior_prob_input, wikidata_input): | ||||
| def create_kb( | ||||
|     nlp, | ||||
|     max_entities_per_alias, | ||||
|     min_entity_freq, | ||||
|     min_occ, | ||||
|     entity_def_output, | ||||
|     entity_descr_output, | ||||
|     count_input, | ||||
|     prior_prob_input, | ||||
|     wikidata_input, | ||||
| ): | ||||
|     # Create the knowledge base from Wikidata entries | ||||
|     kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH) | ||||
| 
 | ||||
|  | @ -28,7 +36,9 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ, | |||
|         title_to_id, id_to_descr = wd.read_wikidata_entities_json(wikidata_input) | ||||
| 
 | ||||
|         # write the title-ID and ID-description mappings to file | ||||
|         _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr) | ||||
|         _write_entity_files( | ||||
|             entity_def_output, entity_descr_output, title_to_id, id_to_descr | ||||
|         ) | ||||
| 
 | ||||
|     else: | ||||
|         # read the mappings from file | ||||
|  | @ -54,8 +64,8 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ, | |||
|             frequency_list.append(freq) | ||||
|             filtered_title_to_id[title] = entity | ||||
| 
 | ||||
|     print("Kept", len(filtered_title_to_id.keys()), "out of", len(title_to_id.keys()), | ||||
|           "titles with filter frequency", min_entity_freq) | ||||
|     print(len(title_to_id.keys()), "original titles") | ||||
|     print("kept", len(filtered_title_to_id.keys()), " with frequency", min_entity_freq) | ||||
| 
 | ||||
|     print() | ||||
|     print(" * train entity encoder", datetime.datetime.now()) | ||||
|  | @ -70,14 +80,20 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ, | |||
| 
 | ||||
|     print() | ||||
|     print(" * adding", len(entity_list), "entities", datetime.datetime.now()) | ||||
|     kb.set_entities(entity_list=entity_list, prob_list=frequency_list, vector_list=embeddings) | ||||
|     kb.set_entities( | ||||
|         entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings | ||||
|     ) | ||||
| 
 | ||||
|     print() | ||||
|     print(" * adding aliases", datetime.datetime.now()) | ||||
|     print() | ||||
|     _add_aliases(kb, title_to_id=filtered_title_to_id, | ||||
|                  max_entities_per_alias=max_entities_per_alias, min_occ=min_occ, | ||||
|                  prior_prob_input=prior_prob_input) | ||||
|     _add_aliases( | ||||
|         kb, | ||||
|         title_to_id=filtered_title_to_id, | ||||
|         max_entities_per_alias=max_entities_per_alias, | ||||
|         min_occ=min_occ, | ||||
|         prior_prob_input=prior_prob_input, | ||||
|     ) | ||||
| 
 | ||||
|     print() | ||||
|     print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases()) | ||||
|  | @ -86,13 +102,15 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ, | |||
|     return kb | ||||
| 
 | ||||
| 
 | ||||
| def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr): | ||||
|     with open(entity_def_output, mode='w', encoding='utf8') as id_file: | ||||
| def _write_entity_files( | ||||
|     entity_def_output, entity_descr_output, title_to_id, id_to_descr | ||||
| ): | ||||
|     with entity_def_output.open("w", encoding="utf8") as id_file: | ||||
|         id_file.write("WP_title" + "|" + "WD_id" + "\n") | ||||
|         for title, qid in title_to_id.items(): | ||||
|             id_file.write(title + "|" + str(qid) + "\n") | ||||
| 
 | ||||
|     with open(entity_descr_output, mode='w', encoding='utf8') as descr_file: | ||||
|     with entity_descr_output.open("w", encoding="utf8") as descr_file: | ||||
|         descr_file.write("WD_id" + "|" + "description" + "\n") | ||||
|         for qid, descr in id_to_descr.items(): | ||||
|             descr_file.write(str(qid) + "|" + descr + "\n") | ||||
|  | @ -100,8 +118,8 @@ def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_ | |||
| 
 | ||||
| def get_entity_to_id(entity_def_output): | ||||
|     entity_to_id = dict() | ||||
|     with open(entity_def_output, 'r', encoding='utf8') as csvfile: | ||||
|         csvreader = csv.reader(csvfile, delimiter='|') | ||||
|     with entity_def_output.open("r", encoding="utf8") as csvfile: | ||||
|         csvreader = csv.reader(csvfile, delimiter="|") | ||||
|         # skip header | ||||
|         next(csvreader) | ||||
|         for row in csvreader: | ||||
|  | @ -111,8 +129,8 @@ def get_entity_to_id(entity_def_output): | |||
| 
 | ||||
| def get_id_to_description(entity_descr_output): | ||||
|     id_to_desc = dict() | ||||
|     with open(entity_descr_output, 'r', encoding='utf8') as csvfile: | ||||
|         csvreader = csv.reader(csvfile, delimiter='|') | ||||
|     with entity_descr_output.open("r", encoding="utf8") as csvfile: | ||||
|         csvreader = csv.reader(csvfile, delimiter="|") | ||||
|         # skip header | ||||
|         next(csvreader) | ||||
|         for row in csvreader: | ||||
|  | @ -125,7 +143,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in | |||
| 
 | ||||
|     # adding aliases with prior probabilities | ||||
|     # we can read this file sequentially, it's sorted by alias, and then by count | ||||
|     with open(prior_prob_input, mode='r', encoding='utf8') as prior_file: | ||||
|     with prior_prob_input.open("r", encoding="utf8") as prior_file: | ||||
|         # skip header | ||||
|         prior_file.readline() | ||||
|         line = prior_file.readline() | ||||
|  | @ -134,7 +152,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in | |||
|         counts = [] | ||||
|         entities = [] | ||||
|         while line: | ||||
|             splits = line.replace('\n', "").split(sep='|') | ||||
|             splits = line.replace("\n", "").split(sep="|") | ||||
|             new_alias = splits[0] | ||||
|             count = int(splits[1]) | ||||
|             entity = splits[2] | ||||
|  | @ -153,7 +171,11 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in | |||
| 
 | ||||
|                     if selected_entities: | ||||
|                         try: | ||||
|                             kb.add_alias(alias=previous_alias, entities=selected_entities, probabilities=prior_probs) | ||||
|                             kb.add_alias( | ||||
|                                 alias=previous_alias, | ||||
|                                 entities=selected_entities, | ||||
|                                 probabilities=prior_probs, | ||||
|                             ) | ||||
|                         except ValueError as e: | ||||
|                             print(e) | ||||
|                 total_count = 0 | ||||
|  | @ -168,4 +190,3 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in | |||
|             previous_alias = new_alias | ||||
| 
 | ||||
|             line = prior_file.readline() | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,7 +1,7 @@ | |||
| # coding: utf-8 | ||||
| from __future__ import unicode_literals | ||||
| 
 | ||||
| import os | ||||
| import random | ||||
| import re | ||||
| import bz2 | ||||
| import datetime | ||||
|  | @ -17,6 +17,10 @@ Gold-standard entities are stored in one file in standoff format (by character o | |||
| ENTITY_FILE = "gold_entities.csv" | ||||
| 
 | ||||
| 
 | ||||
| def now(): | ||||
|     return datetime.datetime.now() | ||||
| 
 | ||||
| 
 | ||||
| def create_training(wikipedia_input, entity_def_input, training_output): | ||||
|     wp_to_id = kb_creator.get_entity_to_id(entity_def_input) | ||||
|     _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None) | ||||
|  | @ -27,21 +31,23 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N | |||
|     Read the XML wikipedia data to parse out training data: | ||||
|     raw text data + positive instances | ||||
|     """ | ||||
|     title_regex = re.compile(r'(?<=<title>).*(?=</title>)') | ||||
|     id_regex = re.compile(r'(?<=<id>)\d*(?=</id>)') | ||||
|     title_regex = re.compile(r"(?<=<title>).*(?=</title>)") | ||||
|     id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)") | ||||
| 
 | ||||
|     read_ids = set() | ||||
|     entityfile_loc = training_output / ENTITY_FILE | ||||
|     with open(entityfile_loc, mode="w", encoding='utf8') as entityfile: | ||||
|     with entityfile_loc.open("w", encoding="utf8") as entityfile: | ||||
|         # write entity training header file | ||||
|         _write_training_entity(outputfile=entityfile, | ||||
|                                article_id="article_id", | ||||
|                                alias="alias", | ||||
|                                entity="WD_id", | ||||
|                                start="start", | ||||
|                                end="end") | ||||
|         _write_training_entity( | ||||
|             outputfile=entityfile, | ||||
|             article_id="article_id", | ||||
|             alias="alias", | ||||
|             entity="WD_id", | ||||
|             start="start", | ||||
|             end="end", | ||||
|         ) | ||||
| 
 | ||||
|         with bz2.open(wikipedia_input, mode='rb') as file: | ||||
|         with bz2.open(wikipedia_input, mode="rb") as file: | ||||
|             line = file.readline() | ||||
|             cnt = 0 | ||||
|             article_text = "" | ||||
|  | @ -51,7 +57,7 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N | |||
|             reading_revision = False | ||||
|             while line and (not limit or cnt < limit): | ||||
|                 if cnt % 1000000 == 0: | ||||
|                     print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump") | ||||
|                     print(now(), "processed", cnt, "lines of Wikipedia dump") | ||||
|                 clean_line = line.strip().decode("utf-8") | ||||
| 
 | ||||
|                 if clean_line == "<revision>": | ||||
|  | @ -69,12 +75,23 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N | |||
|                 elif clean_line == "</page>": | ||||
|                     if article_id: | ||||
|                         try: | ||||
|                             _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text.strip(), | ||||
|                                              training_output) | ||||
|                             _process_wp_text( | ||||
|                                 wp_to_id, | ||||
|                                 entityfile, | ||||
|                                 article_id, | ||||
|                                 article_title, | ||||
|                                 article_text.strip(), | ||||
|                                 training_output, | ||||
|                             ) | ||||
|                         except Exception as e: | ||||
|                             print("Error processing article", article_id, article_title, e) | ||||
|                             print( | ||||
|                                 "Error processing article", article_id, article_title, e | ||||
|                             ) | ||||
|                     else: | ||||
|                         print("Done processing a page, but couldn't find an article_id ?", article_title) | ||||
|                         print( | ||||
|                             "Done processing a page, but couldn't find an article_id ?", | ||||
|                             article_title, | ||||
|                         ) | ||||
|                     article_text = "" | ||||
|                     article_title = None | ||||
|                     article_id = None | ||||
|  | @ -98,7 +115,9 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N | |||
|                     if ids: | ||||
|                         article_id = ids[0] | ||||
|                         if article_id in read_ids: | ||||
|                             print("Found duplicate article ID", article_id, clean_line)  # This should never happen ... | ||||
|                             print( | ||||
|                                 "Found duplicate article ID", article_id, clean_line | ||||
|                             )  # This should never happen ... | ||||
|                         read_ids.add(article_id) | ||||
| 
 | ||||
|                 # read the title of this article (outside the revision portion of the document) | ||||
|  | @ -111,10 +130,12 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N | |||
|                 cnt += 1 | ||||
| 
 | ||||
| 
 | ||||
| text_regex = re.compile(r'(?<=<text xml:space=\"preserve\">).*(?=</text)') | ||||
| text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)") | ||||
| 
 | ||||
| 
 | ||||
| def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text, training_output): | ||||
| def _process_wp_text( | ||||
|     wp_to_id, entityfile, article_id, article_title, article_text, training_output | ||||
| ): | ||||
|     found_entities = False | ||||
| 
 | ||||
|     # ignore meta Wikipedia pages | ||||
|  | @ -141,11 +162,11 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te | |||
|     entity_buffer = "" | ||||
|     mention_buffer = "" | ||||
|     for index, letter in enumerate(clean_text): | ||||
|         if letter == '[': | ||||
|         if letter == "[": | ||||
|             open_read += 1 | ||||
|         elif letter == ']': | ||||
|         elif letter == "]": | ||||
|             open_read -= 1 | ||||
|         elif letter == '|': | ||||
|         elif letter == "|": | ||||
|             if reading_text: | ||||
|                 final_text += letter | ||||
|             # switch from reading entity to mention in the [[entity|mention]] pattern | ||||
|  | @ -163,7 +184,7 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te | |||
|             elif reading_text: | ||||
|                 final_text += letter | ||||
|             else: | ||||
|                 raise ValueError("Not sure at point", clean_text[index-2:index+2]) | ||||
|                 raise ValueError("Not sure at point", clean_text[index - 2 : index + 2]) | ||||
| 
 | ||||
|         if open_read > 2: | ||||
|             reading_special_case = True | ||||
|  | @ -175,7 +196,7 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te | |||
| 
 | ||||
|         # we just finished reading an entity | ||||
|         if open_read == 0 and not reading_text: | ||||
|             if '#' in entity_buffer or entity_buffer.startswith(':'): | ||||
|             if "#" in entity_buffer or entity_buffer.startswith(":"): | ||||
|                 reading_special_case = True | ||||
|             # Ignore cases with nested structures like File: handles etc | ||||
|             if not reading_special_case: | ||||
|  | @ -185,12 +206,14 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te | |||
|                 end = start + len(mention_buffer) | ||||
|                 qid = wp_to_id.get(entity_buffer, None) | ||||
|                 if qid: | ||||
|                     _write_training_entity(outputfile=entityfile, | ||||
|                                            article_id=article_id, | ||||
|                                            alias=mention_buffer, | ||||
|                                            entity=qid, | ||||
|                                            start=start, | ||||
|                                            end=end) | ||||
|                     _write_training_entity( | ||||
|                         outputfile=entityfile, | ||||
|                         article_id=article_id, | ||||
|                         alias=mention_buffer, | ||||
|                         entity=qid, | ||||
|                         start=start, | ||||
|                         end=end, | ||||
|                     ) | ||||
|                 found_entities = True | ||||
|                 final_text += mention_buffer | ||||
| 
 | ||||
|  | @ -203,29 +226,35 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te | |||
|             reading_special_case = False | ||||
| 
 | ||||
|     if found_entities: | ||||
|         _write_training_article(article_id=article_id, clean_text=final_text, training_output=training_output) | ||||
|         _write_training_article( | ||||
|             article_id=article_id, | ||||
|             clean_text=final_text, | ||||
|             training_output=training_output, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| info_regex = re.compile(r'{[^{]*?}') | ||||
| htlm_regex = re.compile(r'<!--[^-]*-->') | ||||
| category_regex = re.compile(r'\[\[Category:[^\[]*]]') | ||||
| file_regex = re.compile(r'\[\[File:[^[\]]+]]') | ||||
| ref_regex = re.compile(r'<ref.*?>')     # non-greedy | ||||
| ref_2_regex = re.compile(r'</ref.*?>')  # non-greedy | ||||
| info_regex = re.compile(r"{[^{]*?}") | ||||
| htlm_regex = re.compile(r"<!--[^-]*-->") | ||||
| category_regex = re.compile(r"\[\[Category:[^\[]*]]") | ||||
| file_regex = re.compile(r"\[\[File:[^[\]]+]]") | ||||
| ref_regex = re.compile(r"<ref.*?>")  # non-greedy | ||||
| ref_2_regex = re.compile(r"</ref.*?>")  # non-greedy | ||||
| 
 | ||||
| 
 | ||||
| def _get_clean_wp_text(article_text): | ||||
|     clean_text = article_text.strip() | ||||
| 
 | ||||
|     # remove bolding & italic markup | ||||
|     clean_text = clean_text.replace('\'\'\'', '') | ||||
|     clean_text = clean_text.replace('\'\'', '') | ||||
|     clean_text = clean_text.replace("'''", "") | ||||
|     clean_text = clean_text.replace("''", "") | ||||
| 
 | ||||
|     # remove nested {{info}} statements by removing the inner/smallest ones first and iterating | ||||
|     try_again = True | ||||
|     previous_length = len(clean_text) | ||||
|     while try_again: | ||||
|         clean_text = info_regex.sub('', clean_text)  # non-greedy match excluding a nested { | ||||
|         clean_text = info_regex.sub( | ||||
|             "", clean_text | ||||
|         )  # non-greedy match excluding a nested { | ||||
|         if len(clean_text) < previous_length: | ||||
|             try_again = True | ||||
|         else: | ||||
|  | @ -233,14 +262,14 @@ def _get_clean_wp_text(article_text): | |||
|         previous_length = len(clean_text) | ||||
| 
 | ||||
|     # remove HTML comments | ||||
|     clean_text = htlm_regex.sub('', clean_text) | ||||
|     clean_text = htlm_regex.sub("", clean_text) | ||||
| 
 | ||||
|     # remove Category and File statements | ||||
|     clean_text = category_regex.sub('', clean_text) | ||||
|     clean_text = file_regex.sub('', clean_text) | ||||
|     clean_text = category_regex.sub("", clean_text) | ||||
|     clean_text = file_regex.sub("", clean_text) | ||||
| 
 | ||||
|     # remove multiple = | ||||
|     while '==' in clean_text: | ||||
|     while "==" in clean_text: | ||||
|         clean_text = clean_text.replace("==", "=") | ||||
| 
 | ||||
|     clean_text = clean_text.replace(". =", ".") | ||||
|  | @ -249,43 +278,47 @@ def _get_clean_wp_text(article_text): | |||
|     clean_text = clean_text.replace(" =", "") | ||||
| 
 | ||||
|     # remove refs (non-greedy match) | ||||
|     clean_text = ref_regex.sub('', clean_text) | ||||
|     clean_text = ref_2_regex.sub('', clean_text) | ||||
|     clean_text = ref_regex.sub("", clean_text) | ||||
|     clean_text = ref_2_regex.sub("", clean_text) | ||||
| 
 | ||||
|     # remove additional wikiformatting | ||||
|     clean_text = re.sub(r'<blockquote>', '', clean_text) | ||||
|     clean_text = re.sub(r'</blockquote>', '', clean_text) | ||||
|     clean_text = re.sub(r"<blockquote>", "", clean_text) | ||||
|     clean_text = re.sub(r"</blockquote>", "", clean_text) | ||||
| 
 | ||||
|     # change special characters back to normal ones | ||||
|     clean_text = clean_text.replace(r'<', '<') | ||||
|     clean_text = clean_text.replace(r'>', '>') | ||||
|     clean_text = clean_text.replace(r'"', '"') | ||||
|     clean_text = clean_text.replace(r'&nbsp;', ' ') | ||||
|     clean_text = clean_text.replace(r'&', '&') | ||||
|     clean_text = clean_text.replace(r"<", "<") | ||||
|     clean_text = clean_text.replace(r">", ">") | ||||
|     clean_text = clean_text.replace(r""", '"') | ||||
|     clean_text = clean_text.replace(r"&nbsp;", " ") | ||||
|     clean_text = clean_text.replace(r"&", "&") | ||||
| 
 | ||||
|     # remove multiple spaces | ||||
|     while '  ' in clean_text: | ||||
|         clean_text = clean_text.replace('  ', ' ') | ||||
|     while "  " in clean_text: | ||||
|         clean_text = clean_text.replace("  ", " ") | ||||
| 
 | ||||
|     return clean_text.strip() | ||||
| 
 | ||||
| 
 | ||||
| def _write_training_article(article_id, clean_text, training_output): | ||||
|     file_loc = training_output / str(article_id) + ".txt" | ||||
|     with open(file_loc, mode='w', encoding='utf8') as outputfile: | ||||
|     file_loc = training_output / "{}.txt".format(article_id) | ||||
|     with file_loc.open("w", encoding="utf8") as outputfile: | ||||
|         outputfile.write(clean_text) | ||||
| 
 | ||||
| 
 | ||||
| def _write_training_entity(outputfile, article_id, alias, entity, start, end): | ||||
|     outputfile.write(article_id + "|" + alias + "|" + entity + "|" + str(start) + "|" + str(end) + "\n") | ||||
|     line = "{}|{}|{}|{}|{}\n".format(article_id, alias, entity, start, end) | ||||
|     outputfile.write(line) | ||||
| 
 | ||||
| 
 | ||||
| def is_dev(article_id): | ||||
|     return article_id.endswith("3") | ||||
| 
 | ||||
| 
 | ||||
| def read_training(nlp, training_dir, dev, limit): | ||||
|     # This method provides training examples that correspond to the entity annotations found by the nlp object | ||||
| def read_training(nlp, training_dir, dev, limit, kb=None): | ||||
|     """ This method provides training examples that correspond to the entity annotations found by the nlp object. | ||||
|      When kb is provided (for training), it will include negative training examples by using the candidate generator, | ||||
|      and it will only keep positive training examples that can be found in the KB. | ||||
|      When kb=None (for testing), it will include all positive examples only.""" | ||||
|     entityfile_loc = training_dir / ENTITY_FILE | ||||
|     data = [] | ||||
| 
 | ||||
|  | @ -296,24 +329,30 @@ def read_training(nlp, training_dir, dev, limit): | |||
|     skip_articles = set() | ||||
|     total_entities = 0 | ||||
| 
 | ||||
|     with open(entityfile_loc, mode='r', encoding='utf8') as file: | ||||
|     with entityfile_loc.open("r", encoding="utf8") as file: | ||||
|         for line in file: | ||||
|             if not limit or len(data) < limit: | ||||
|                 fields = line.replace('\n', "").split(sep='|') | ||||
|                 fields = line.replace("\n", "").split(sep="|") | ||||
|                 article_id = fields[0] | ||||
|                 alias = fields[1] | ||||
|                 wp_title = fields[2] | ||||
|                 wd_id = fields[2] | ||||
|                 start = fields[3] | ||||
|                 end = fields[4] | ||||
| 
 | ||||
|                 if dev == is_dev(article_id) and article_id != "article_id" and article_id not in skip_articles: | ||||
|                 if ( | ||||
|                     dev == is_dev(article_id) | ||||
|                     and article_id != "article_id" | ||||
|                     and article_id not in skip_articles | ||||
|                 ): | ||||
|                     if not current_doc or (current_article_id != article_id): | ||||
|                         # parse the new article text | ||||
|                         file_name = article_id + ".txt" | ||||
|                         try: | ||||
|                             with open(os.path.join(training_dir, file_name), mode="r", encoding='utf8') as f: | ||||
|                             training_file = training_dir / file_name | ||||
|                             with training_file.open("r", encoding="utf8") as f: | ||||
|                                 text = f.read() | ||||
|                                 if len(text) < 30000:   # threshold for convenience / speed of processing | ||||
|                                 # threshold for convenience / speed of processing | ||||
|                                 if len(text) < 30000: | ||||
|                                     current_doc = nlp(text) | ||||
|                                     current_article_id = article_id | ||||
|                                     ents_by_offset = dict() | ||||
|  | @ -321,33 +360,69 @@ def read_training(nlp, training_dir, dev, limit): | |||
|                                         sent_length = len(ent.sent) | ||||
|                                         # custom filtering to avoid too long or too short sentences | ||||
|                                         if 5 < sent_length < 100: | ||||
|                                             ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent | ||||
|                                             offset = "{}_{}".format( | ||||
|                                                 ent.start_char, ent.end_char | ||||
|                                             ) | ||||
|                                             ents_by_offset[offset] = ent | ||||
|                                 else: | ||||
|                                     skip_articles.add(article_id) | ||||
|                                     current_doc = None | ||||
|                         except Exception as e: | ||||
|                             print("Problem parsing article", article_id, e) | ||||
|                             skip_articles.add(article_id) | ||||
|                             raise e | ||||
| 
 | ||||
|                     # repeat checking this condition in case an exception was thrown | ||||
|                     if current_doc and (current_article_id == article_id): | ||||
|                         found_ent = ents_by_offset.get(start + "_" + end,  None) | ||||
|                         offset = "{}_{}".format(start, end) | ||||
|                         found_ent = ents_by_offset.get(offset, None) | ||||
|                         if found_ent: | ||||
|                             if found_ent.text != alias: | ||||
|                                 skip_articles.add(article_id) | ||||
|                                 current_doc = None | ||||
|                             else: | ||||
|                                 sent = found_ent.sent.as_doc() | ||||
|                                 # currently feeding the gold data one entity per sentence at a time | ||||
| 
 | ||||
|                                 gold_start = int(start) - found_ent.sent.start_char | ||||
|                                 gold_end = int(end) - found_ent.sent.start_char | ||||
|                                 gold_entities = [(gold_start, gold_end, wp_title)] | ||||
|                                 gold = GoldParse(doc=sent, links=gold_entities) | ||||
|                                 data.append((sent, gold)) | ||||
|                                 total_entities += 1 | ||||
|                                 if len(data) % 2500 == 0: | ||||
|                                     print(" -read", total_entities, "entities") | ||||
| 
 | ||||
|                                 gold_entities = {} | ||||
|                                 found_useful = False | ||||
|                                 for ent in sent.ents: | ||||
|                                     entry = (ent.start_char, ent.end_char) | ||||
|                                     gold_entry = (gold_start, gold_end) | ||||
|                                     if entry == gold_entry: | ||||
|                                         # add both pos and neg examples (in random order) | ||||
|                                         # this will exclude examples not in the KB | ||||
|                                         if kb: | ||||
|                                             value_by_id = {} | ||||
|                                             candidates = kb.get_candidates(alias) | ||||
|                                             candidate_ids = [ | ||||
|                                                 c.entity_ for c in candidates | ||||
|                                             ] | ||||
|                                             random.shuffle(candidate_ids) | ||||
|                                             for kb_id in candidate_ids: | ||||
|                                                 found_useful = True | ||||
|                                                 if kb_id != wd_id: | ||||
|                                                     value_by_id[kb_id] = 0.0 | ||||
|                                                 else: | ||||
|                                                     value_by_id[kb_id] = 1.0 | ||||
|                                             gold_entities[entry] = value_by_id | ||||
|                                         # if no KB, keep all positive examples | ||||
|                                         else: | ||||
|                                             found_useful = True | ||||
|                                             value_by_id = {wd_id: 1.0} | ||||
| 
 | ||||
|                                             gold_entities[entry] = value_by_id | ||||
|                                     # currently feeding the gold data one entity per sentence at a time | ||||
|                                     # setting all other entities to empty gold dictionary | ||||
|                                     else: | ||||
|                                         gold_entities[entry] = {} | ||||
|                                 if found_useful: | ||||
|                                     gold = GoldParse(doc=sent, links=gold_entities) | ||||
|                                     data.append((sent, gold)) | ||||
|                                     total_entities += 1 | ||||
|                                     if len(data) % 2500 == 0: | ||||
|                                         print(" -read", total_entities, "entities") | ||||
| 
 | ||||
|     print(" -read", total_entities, "entities") | ||||
|     return data | ||||
|  |  | |||
|  | @ -14,22 +14,97 @@ Write these results to file for downstream KB and training data generation. | |||
| map_alias_to_link = dict() | ||||
| 
 | ||||
| # these will/should be matched ignoring case | ||||
| wiki_namespaces = ["b", "betawikiversity", "Book", "c", "Category", "Commons", | ||||
|                    "d", "dbdump", "download", "Draft", "Education", "Foundation", | ||||
|                    "Gadget", "Gadget definition", "gerrit", "File", "Help", "Image", "Incubator", | ||||
|                    "m", "mail", "mailarchive", "media", "MediaWiki", "MediaWiki talk", "Mediawikiwiki", | ||||
|                    "MediaZilla", "Meta", "Metawikipedia", "Module", | ||||
|                    "mw", "n", "nost", "oldwikisource", "outreach", "outreachwiki", "otrs", "OTRSwiki", | ||||
|                    "Portal", "phab", "Phabricator", "Project", "q", "quality", "rev", | ||||
|                    "s", "spcom", "Special", "species", "Strategy", "sulutil", "svn", | ||||
|                    "Talk", "Template", "Template talk", "Testwiki", "ticket", "TimedText", "Toollabs", "tools", | ||||
|                    "tswiki", "User", "User talk", "v", "voy", | ||||
|                    "w", "Wikibooks", "Wikidata", "wikiHow", "Wikinvest", "wikilivres", "Wikimedia", "Wikinews", | ||||
|                    "Wikipedia", "Wikipedia talk", "Wikiquote", "Wikisource", "Wikispecies", "Wikitech", | ||||
|                    "Wikiversity", "Wikivoyage", "wikt", "wiktionary", "wmf", "wmania", "WP"] | ||||
| wiki_namespaces = [ | ||||
|     "b", | ||||
|     "betawikiversity", | ||||
|     "Book", | ||||
|     "c", | ||||
|     "Category", | ||||
|     "Commons", | ||||
|     "d", | ||||
|     "dbdump", | ||||
|     "download", | ||||
|     "Draft", | ||||
|     "Education", | ||||
|     "Foundation", | ||||
|     "Gadget", | ||||
|     "Gadget definition", | ||||
|     "gerrit", | ||||
|     "File", | ||||
|     "Help", | ||||
|     "Image", | ||||
|     "Incubator", | ||||
|     "m", | ||||
|     "mail", | ||||
|     "mailarchive", | ||||
|     "media", | ||||
|     "MediaWiki", | ||||
|     "MediaWiki talk", | ||||
|     "Mediawikiwiki", | ||||
|     "MediaZilla", | ||||
|     "Meta", | ||||
|     "Metawikipedia", | ||||
|     "Module", | ||||
|     "mw", | ||||
|     "n", | ||||
|     "nost", | ||||
|     "oldwikisource", | ||||
|     "outreach", | ||||
|     "outreachwiki", | ||||
|     "otrs", | ||||
|     "OTRSwiki", | ||||
|     "Portal", | ||||
|     "phab", | ||||
|     "Phabricator", | ||||
|     "Project", | ||||
|     "q", | ||||
|     "quality", | ||||
|     "rev", | ||||
|     "s", | ||||
|     "spcom", | ||||
|     "Special", | ||||
|     "species", | ||||
|     "Strategy", | ||||
|     "sulutil", | ||||
|     "svn", | ||||
|     "Talk", | ||||
|     "Template", | ||||
|     "Template talk", | ||||
|     "Testwiki", | ||||
|     "ticket", | ||||
|     "TimedText", | ||||
|     "Toollabs", | ||||
|     "tools", | ||||
|     "tswiki", | ||||
|     "User", | ||||
|     "User talk", | ||||
|     "v", | ||||
|     "voy", | ||||
|     "w", | ||||
|     "Wikibooks", | ||||
|     "Wikidata", | ||||
|     "wikiHow", | ||||
|     "Wikinvest", | ||||
|     "wikilivres", | ||||
|     "Wikimedia", | ||||
|     "Wikinews", | ||||
|     "Wikipedia", | ||||
|     "Wikipedia talk", | ||||
|     "Wikiquote", | ||||
|     "Wikisource", | ||||
|     "Wikispecies", | ||||
|     "Wikitech", | ||||
|     "Wikiversity", | ||||
|     "Wikivoyage", | ||||
|     "wikt", | ||||
|     "wiktionary", | ||||
|     "wmf", | ||||
|     "wmania", | ||||
|     "WP", | ||||
| ] | ||||
| 
 | ||||
| # find the links | ||||
| link_regex = re.compile(r'\[\[[^\[\]]*\]\]') | ||||
| link_regex = re.compile(r"\[\[[^\[\]]*\]\]") | ||||
| 
 | ||||
| # match on interwiki links, e.g. `en:` or `:fr:` | ||||
| ns_regex = r":?" + "[a-z][a-z]" + ":" | ||||
|  | @ -41,18 +116,22 @@ for ns in wiki_namespaces: | |||
| ns_regex = re.compile(ns_regex, re.IGNORECASE) | ||||
| 
 | ||||
| 
 | ||||
| def read_wikipedia_prior_probs(wikipedia_input, prior_prob_output): | ||||
| def now(): | ||||
|     return datetime.datetime.now() | ||||
| 
 | ||||
| 
 | ||||
| def read_prior_probs(wikipedia_input, prior_prob_output): | ||||
|     """ | ||||
|     Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities. | ||||
|     The full file takes about 2h to parse 1100M lines. | ||||
|     It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from. | ||||
|     """ | ||||
|     with bz2.open(wikipedia_input, mode='rb') as file: | ||||
|     with bz2.open(wikipedia_input, mode="rb") as file: | ||||
|         line = file.readline() | ||||
|         cnt = 0 | ||||
|         while line: | ||||
|             if cnt % 5000000 == 0: | ||||
|                 print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump") | ||||
|                 print(now(), "processed", cnt, "lines of Wikipedia dump") | ||||
|             clean_line = line.strip().decode("utf-8") | ||||
| 
 | ||||
|             aliases, entities, normalizations = get_wp_links(clean_line) | ||||
|  | @ -64,10 +143,11 @@ def read_wikipedia_prior_probs(wikipedia_input, prior_prob_output): | |||
|             cnt += 1 | ||||
| 
 | ||||
|     # write all aliases and their entities and count occurrences to file | ||||
|     with open(prior_prob_output, mode='w', encoding='utf8') as outputfile: | ||||
|     with prior_prob_output.open("w", encoding="utf8") as outputfile: | ||||
|         outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n") | ||||
|         for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]): | ||||
|             for entity, count in sorted(alias_dict.items(), key=lambda x: x[1], reverse=True): | ||||
|             s_dict = sorted(alias_dict.items(), key=lambda x: x[1], reverse=True) | ||||
|             for entity, count in s_dict: | ||||
|                 outputfile.write(alias + "|" + str(count) + "|" + entity + "\n") | ||||
| 
 | ||||
| 
 | ||||
|  | @ -140,13 +220,13 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False): | |||
|     entity_to_count = dict() | ||||
|     total_count = 0 | ||||
| 
 | ||||
|     with open(prior_prob_input, mode='r', encoding='utf8') as prior_file: | ||||
|     with prior_prob_input.open("r", encoding="utf8") as prior_file: | ||||
|         # skip header | ||||
|         prior_file.readline() | ||||
|         line = prior_file.readline() | ||||
| 
 | ||||
|         while line: | ||||
|             splits = line.replace('\n', "").split(sep='|') | ||||
|             splits = line.replace("\n", "").split(sep="|") | ||||
|             # alias = splits[0] | ||||
|             count = int(splits[1]) | ||||
|             entity = splits[2] | ||||
|  | @ -158,7 +238,7 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False): | |||
| 
 | ||||
|             line = prior_file.readline() | ||||
| 
 | ||||
|     with open(count_output, mode='w', encoding='utf8') as entity_file: | ||||
|     with count_output.open("w", encoding="utf8") as entity_file: | ||||
|         entity_file.write("entity" + "|" + "count" + "\n") | ||||
|         for entity, count in entity_to_count.items(): | ||||
|             entity_file.write(entity + "|" + str(count) + "\n") | ||||
|  | @ -171,12 +251,11 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False): | |||
| 
 | ||||
| def get_all_frequencies(count_input): | ||||
|     entity_to_count = dict() | ||||
|     with open(count_input, 'r', encoding='utf8') as csvfile: | ||||
|         csvreader = csv.reader(csvfile, delimiter='|') | ||||
|     with count_input.open("r", encoding="utf8") as csvfile: | ||||
|         csvreader = csv.reader(csvfile, delimiter="|") | ||||
|         # skip header | ||||
|         next(csvreader) | ||||
|         for row in csvreader: | ||||
|             entity_to_count[row[0]] = int(row[1]) | ||||
| 
 | ||||
|     return entity_to_count | ||||
| 
 | ||||
|  |  | |||
|  | @ -14,15 +14,15 @@ def create_kb(vocab): | |||
|     # adding entities | ||||
|     entity_0 = "Q1004791_Douglas" | ||||
|     print("adding entity", entity_0) | ||||
|     kb.add_entity(entity=entity_0, prob=0.5, entity_vector=[0]) | ||||
|     kb.add_entity(entity=entity_0, freq=0.5, entity_vector=[0]) | ||||
| 
 | ||||
|     entity_1 = "Q42_Douglas_Adams" | ||||
|     print("adding entity", entity_1) | ||||
|     kb.add_entity(entity=entity_1, prob=0.5, entity_vector=[1]) | ||||
|     kb.add_entity(entity=entity_1, freq=0.5, entity_vector=[1]) | ||||
| 
 | ||||
|     entity_2 = "Q5301561_Douglas_Haig" | ||||
|     print("adding entity", entity_2) | ||||
|     kb.add_entity(entity=entity_2, prob=0.5, entity_vector=[2]) | ||||
|     kb.add_entity(entity=entity_2, freq=0.5, entity_vector=[2]) | ||||
| 
 | ||||
|     # adding aliases | ||||
|     print() | ||||
|  |  | |||
|  | @ -1,11 +1,14 @@ | |||
| # coding: utf-8 | ||||
| from __future__ import unicode_literals | ||||
| 
 | ||||
| import os | ||||
| from os import path | ||||
| import random | ||||
| import datetime | ||||
| from pathlib import Path | ||||
| 
 | ||||
| from bin.wiki_entity_linking import training_set_creator, kb_creator, wikipedia_processor as wp | ||||
| from bin.wiki_entity_linking import wikipedia_processor as wp | ||||
| from bin.wiki_entity_linking import training_set_creator, kb_creator | ||||
| from bin.wiki_entity_linking.kb_creator import DESC_WIDTH | ||||
| 
 | ||||
| import spacy | ||||
|  | @ -17,23 +20,26 @@ Demonstrate how to build a knowledge base from WikiData and run an Entity Linkin | |||
| """ | ||||
| 
 | ||||
| ROOT_DIR = Path("C:/Users/Sofie/Documents/data/") | ||||
| OUTPUT_DIR = ROOT_DIR / 'wikipedia' | ||||
| TRAINING_DIR = OUTPUT_DIR / 'training_data_nel' | ||||
| OUTPUT_DIR = ROOT_DIR / "wikipedia" | ||||
| TRAINING_DIR = OUTPUT_DIR / "training_data_nel" | ||||
| 
 | ||||
| PRIOR_PROB = OUTPUT_DIR / 'prior_prob.csv' | ||||
| ENTITY_COUNTS = OUTPUT_DIR / 'entity_freq.csv' | ||||
| ENTITY_DEFS = OUTPUT_DIR / 'entity_defs.csv' | ||||
| ENTITY_DESCR = OUTPUT_DIR / 'entity_descriptions.csv' | ||||
| PRIOR_PROB = OUTPUT_DIR / "prior_prob.csv" | ||||
| ENTITY_COUNTS = OUTPUT_DIR / "entity_freq.csv" | ||||
| ENTITY_DEFS = OUTPUT_DIR / "entity_defs.csv" | ||||
| ENTITY_DESCR = OUTPUT_DIR / "entity_descriptions.csv" | ||||
| 
 | ||||
| KB_FILE = OUTPUT_DIR / 'kb_1' / 'kb' | ||||
| NLP_1_DIR = OUTPUT_DIR / 'nlp_1' | ||||
| NLP_2_DIR = OUTPUT_DIR / 'nlp_2' | ||||
| KB_DIR = OUTPUT_DIR / "kb_1" | ||||
| KB_FILE = "kb" | ||||
| NLP_1_DIR = OUTPUT_DIR / "nlp_1" | ||||
| NLP_2_DIR = OUTPUT_DIR / "nlp_2" | ||||
| 
 | ||||
| # get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/ | ||||
| WIKIDATA_JSON = ROOT_DIR / 'wikidata' / 'wikidata-20190304-all.json.bz2' | ||||
| WIKIDATA_JSON = ROOT_DIR / "wikidata" / "wikidata-20190304-all.json.bz2" | ||||
| 
 | ||||
| # get enwiki-latest-pages-articles-multistream.xml.bz2 from https://dumps.wikimedia.org/enwiki/latest/ | ||||
| ENWIKI_DUMP = ROOT_DIR / 'wikipedia' / 'enwiki-20190320-pages-articles-multistream.xml.bz2' | ||||
| ENWIKI_DUMP = ( | ||||
|     ROOT_DIR / "wikipedia" / "enwiki-20190320-pages-articles-multistream.xml.bz2" | ||||
| ) | ||||
| 
 | ||||
| # KB construction parameters | ||||
| MAX_CANDIDATES = 10 | ||||
|  | @ -48,11 +54,15 @@ L2 = 1e-6 | |||
| CONTEXT_WIDTH = 128 | ||||
| 
 | ||||
| 
 | ||||
| def now(): | ||||
|     return datetime.datetime.now() | ||||
| 
 | ||||
| 
 | ||||
| def run_pipeline(): | ||||
|     # set the appropriate booleans to define which parts of the pipeline should be re(run) | ||||
|     print("START", datetime.datetime.now()) | ||||
|     print("START", now()) | ||||
|     print() | ||||
|     nlp_1 = spacy.load('en_core_web_lg') | ||||
|     nlp_1 = spacy.load("en_core_web_lg") | ||||
|     nlp_2 = None | ||||
|     kb_2 = None | ||||
| 
 | ||||
|  | @ -82,43 +92,48 @@ def run_pipeline(): | |||
| 
 | ||||
|     # STEP 1 : create prior probabilities from WP (run only once) | ||||
|     if to_create_prior_probs: | ||||
|         print("STEP 1: to_create_prior_probs", datetime.datetime.now()) | ||||
|         wp.read_wikipedia_prior_probs(wikipedia_input=ENWIKI_DUMP, prior_prob_output=PRIOR_PROB) | ||||
|         print("STEP 1: to_create_prior_probs", now()) | ||||
|         wp.read_prior_probs(ENWIKI_DUMP, PRIOR_PROB) | ||||
|         print() | ||||
| 
 | ||||
|     # STEP 2 : deduce entity frequencies from WP (run only once) | ||||
|     if to_create_entity_counts: | ||||
|         print("STEP 2: to_create_entity_counts", datetime.datetime.now()) | ||||
|         wp.write_entity_counts(prior_prob_input=PRIOR_PROB, count_output=ENTITY_COUNTS, to_print=False) | ||||
|         print("STEP 2: to_create_entity_counts", now()) | ||||
|         wp.write_entity_counts(PRIOR_PROB, ENTITY_COUNTS, to_print=False) | ||||
|         print() | ||||
| 
 | ||||
|     # STEP 3 : create KB and write to file (run only once) | ||||
|     if to_create_kb: | ||||
|         print("STEP 3a: to_create_kb", datetime.datetime.now()) | ||||
|         kb_1 = kb_creator.create_kb(nlp_1, | ||||
|                                     max_entities_per_alias=MAX_CANDIDATES, | ||||
|                                     min_entity_freq=MIN_ENTITY_FREQ, | ||||
|                                     min_occ=MIN_PAIR_OCC, | ||||
|                                     entity_def_output=ENTITY_DEFS, | ||||
|                                     entity_descr_output=ENTITY_DESCR, | ||||
|                                     count_input=ENTITY_COUNTS, | ||||
|                                     prior_prob_input=PRIOR_PROB, | ||||
|                                     wikidata_input=WIKIDATA_JSON) | ||||
|         print("STEP 3a: to_create_kb", now()) | ||||
|         kb_1 = kb_creator.create_kb( | ||||
|             nlp=nlp_1, | ||||
|             max_entities_per_alias=MAX_CANDIDATES, | ||||
|             min_entity_freq=MIN_ENTITY_FREQ, | ||||
|             min_occ=MIN_PAIR_OCC, | ||||
|             entity_def_output=ENTITY_DEFS, | ||||
|             entity_descr_output=ENTITY_DESCR, | ||||
|             count_input=ENTITY_COUNTS, | ||||
|             prior_prob_input=PRIOR_PROB, | ||||
|             wikidata_input=WIKIDATA_JSON, | ||||
|         ) | ||||
|         print("kb entities:", kb_1.get_size_entities()) | ||||
|         print("kb aliases:", kb_1.get_size_aliases()) | ||||
|         print() | ||||
| 
 | ||||
|         print("STEP 3b: write KB and NLP", datetime.datetime.now()) | ||||
|         kb_1.dump(KB_FILE) | ||||
|         print("STEP 3b: write KB and NLP", now()) | ||||
| 
 | ||||
|         if not path.exists(KB_DIR): | ||||
|             os.makedirs(KB_DIR) | ||||
|         kb_1.dump(KB_DIR / KB_FILE) | ||||
|         nlp_1.to_disk(NLP_1_DIR) | ||||
|         print() | ||||
| 
 | ||||
|     # STEP 4 : read KB back in from file | ||||
|     if to_read_kb: | ||||
|         print("STEP 4: to_read_kb", datetime.datetime.now()) | ||||
|         print("STEP 4: to_read_kb", now()) | ||||
|         nlp_2 = spacy.load(NLP_1_DIR) | ||||
|         kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH) | ||||
|         kb_2.load_bulk(KB_FILE) | ||||
|         kb_2.load_bulk(KB_DIR / KB_FILE) | ||||
|         print("kb entities:", kb_2.get_size_entities()) | ||||
|         print("kb aliases:", kb_2.get_size_aliases()) | ||||
|         print() | ||||
|  | @ -130,20 +145,26 @@ def run_pipeline(): | |||
| 
 | ||||
|     # STEP 5: create a training dataset from WP | ||||
|     if create_wp_training: | ||||
|         print("STEP 5: create training dataset", datetime.datetime.now()) | ||||
|         training_set_creator.create_training(wikipedia_input=ENWIKI_DUMP, | ||||
|                                              entity_def_input=ENTITY_DEFS, | ||||
|                                              training_output=TRAINING_DIR) | ||||
|         print("STEP 5: create training dataset", now()) | ||||
|         training_set_creator.create_training( | ||||
|             wikipedia_input=ENWIKI_DUMP, | ||||
|             entity_def_input=ENTITY_DEFS, | ||||
|             training_output=TRAINING_DIR, | ||||
|         ) | ||||
| 
 | ||||
|     # STEP 6: create and train the entity linking pipe | ||||
|     if train_pipe: | ||||
|         print("STEP 6: training Entity Linking pipe", datetime.datetime.now()) | ||||
|         print("STEP 6: training Entity Linking pipe", now()) | ||||
|         type_to_int = {label: i for i, label in enumerate(nlp_2.entity.labels)} | ||||
|         print(" -analysing", len(type_to_int), "different entity types") | ||||
|         el_pipe = nlp_2.create_pipe(name='entity_linker', | ||||
|                                     config={"context_width": CONTEXT_WIDTH, | ||||
|                                             "pretrained_vectors": nlp_2.vocab.vectors.name, | ||||
|                                             "type_to_int": type_to_int}) | ||||
|         el_pipe = nlp_2.create_pipe( | ||||
|             name="entity_linker", | ||||
|             config={ | ||||
|                 "context_width": CONTEXT_WIDTH, | ||||
|                 "pretrained_vectors": nlp_2.vocab.vectors.name, | ||||
|                 "type_to_int": type_to_int, | ||||
|             }, | ||||
|         ) | ||||
|         el_pipe.set_kb(kb_2) | ||||
|         nlp_2.add_pipe(el_pipe, last=True) | ||||
| 
 | ||||
|  | @ -157,18 +178,22 @@ def run_pipeline(): | |||
|         train_limit = 5000 | ||||
|         dev_limit = 5000 | ||||
| 
 | ||||
|         train_data = training_set_creator.read_training(nlp=nlp_2, | ||||
|                                                         training_dir=TRAINING_DIR, | ||||
|                                                         dev=False, | ||||
|                                                         limit=train_limit) | ||||
|         # for training, get pos & neg instances that correspond to entries in the kb | ||||
|         train_data = training_set_creator.read_training( | ||||
|             nlp=nlp_2, | ||||
|             training_dir=TRAINING_DIR, | ||||
|             dev=False, | ||||
|             limit=train_limit, | ||||
|             kb=el_pipe.kb, | ||||
|         ) | ||||
| 
 | ||||
|         print("Training on", len(train_data), "articles") | ||||
|         print() | ||||
| 
 | ||||
|         dev_data = training_set_creator.read_training(nlp=nlp_2, | ||||
|                                                       training_dir=TRAINING_DIR, | ||||
|                                                       dev=True, | ||||
|                                                       limit=dev_limit) | ||||
|         # for testing, get all pos instances, whether or not they are in the kb | ||||
|         dev_data = training_set_creator.read_training( | ||||
|             nlp=nlp_2, training_dir=TRAINING_DIR, dev=True, limit=dev_limit, kb=None | ||||
|         ) | ||||
| 
 | ||||
|         print("Dev testing on", len(dev_data), "articles") | ||||
|         print() | ||||
|  | @ -187,8 +212,8 @@ def run_pipeline(): | |||
|                         try: | ||||
|                             docs, golds = zip(*batch) | ||||
|                             nlp_2.update( | ||||
|                                 docs, | ||||
|                                 golds, | ||||
|                                 docs=docs, | ||||
|                                 golds=golds, | ||||
|                                 sgd=optimizer, | ||||
|                                 drop=DROPOUT, | ||||
|                                 losses=losses, | ||||
|  | @ -200,48 +225,61 @@ def run_pipeline(): | |||
|                 if batchnr > 0: | ||||
|                     el_pipe.cfg["context_weight"] = 1 | ||||
|                     el_pipe.cfg["prior_weight"] = 1 | ||||
|                     dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe) | ||||
|                     losses['entity_linker'] = losses['entity_linker'] / batchnr | ||||
|                     print("Epoch, train loss", itn, round(losses['entity_linker'], 2), | ||||
|                           " / dev acc avg", round(dev_acc_context, 3)) | ||||
|                     dev_acc_context, _ = _measure_acc(dev_data, el_pipe) | ||||
|                     losses["entity_linker"] = losses["entity_linker"] / batchnr | ||||
|                     print( | ||||
|                         "Epoch, train loss", | ||||
|                         itn, | ||||
|                         round(losses["entity_linker"], 2), | ||||
|                         " / dev acc avg", | ||||
|                         round(dev_acc_context, 3), | ||||
|                     ) | ||||
| 
 | ||||
|         # STEP 7: measure the performance of our trained pipe on an independent dev set | ||||
|         if len(dev_data) and measure_performance: | ||||
|             print() | ||||
|             print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now()) | ||||
|             print("STEP 7: performance measurement of Entity Linking pipe", now()) | ||||
|             print() | ||||
| 
 | ||||
|             counts, acc_r, acc_r_label, acc_p, acc_p_label, acc_o, acc_o_label = _measure_baselines(dev_data, kb_2) | ||||
|             counts, acc_r, acc_r_d, acc_p, acc_p_d, acc_o, acc_o_d = _measure_baselines( | ||||
|                 dev_data, kb_2 | ||||
|             ) | ||||
|             print("dev counts:", sorted(counts.items(), key=lambda x: x[0])) | ||||
|             print("dev acc oracle:", round(acc_o, 3), [(x, round(y, 3)) for x, y in acc_o_label.items()]) | ||||
|             print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_label.items()]) | ||||
|             print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_label.items()]) | ||||
| 
 | ||||
|             oracle_by_label = [(x, round(y, 3)) for x, y in acc_o_d.items()] | ||||
|             print("dev acc oracle:", round(acc_o, 3), oracle_by_label) | ||||
| 
 | ||||
|             random_by_label = [(x, round(y, 3)) for x, y in acc_r_d.items()] | ||||
|             print("dev acc random:", round(acc_r, 3), random_by_label) | ||||
| 
 | ||||
|             prior_by_label = [(x, round(y, 3)) for x, y in acc_p_d.items()] | ||||
|             print("dev acc prior:", round(acc_p, 3), prior_by_label) | ||||
| 
 | ||||
|             # using only context | ||||
|             el_pipe.cfg["context_weight"] = 1 | ||||
|             el_pipe.cfg["prior_weight"] = 0 | ||||
|             dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe) | ||||
|             print("dev acc context avg:", round(dev_acc_context, 3), | ||||
|                   [(x, round(y, 3)) for x, y in dev_acc_context_dict.items()]) | ||||
|             dev_acc_context, dev_acc_cont_d = _measure_acc(dev_data, el_pipe) | ||||
|             context_by_label = [(x, round(y, 3)) for x, y in dev_acc_cont_d.items()] | ||||
|             print("dev acc context avg:", round(dev_acc_context, 3), context_by_label) | ||||
| 
 | ||||
|             # measuring combined accuracy (prior + context) | ||||
|             el_pipe.cfg["context_weight"] = 1 | ||||
|             el_pipe.cfg["prior_weight"] = 1 | ||||
|             dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe, error_analysis=False) | ||||
|             print("dev acc combo avg:", round(dev_acc_combo, 3), | ||||
|                   [(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()]) | ||||
|             dev_acc_combo, dev_acc_combo_d = _measure_acc(dev_data, el_pipe) | ||||
|             combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_d.items()] | ||||
|             print("dev acc combo avg:", round(dev_acc_combo, 3), combo_by_label) | ||||
| 
 | ||||
|         # STEP 8: apply the EL pipe on a toy example | ||||
|         if to_test_pipeline: | ||||
|             print() | ||||
|             print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now()) | ||||
|             print("STEP 8: applying Entity Linking to toy example", now()) | ||||
|             print() | ||||
|             run_el_toy_example(nlp=nlp_2) | ||||
| 
 | ||||
|         # STEP 9: write the NLP pipeline (including entity linker) to file | ||||
|         if to_write_nlp: | ||||
|             print() | ||||
|             print("STEP 9: testing NLP IO", datetime.datetime.now()) | ||||
|             print("STEP 9: testing NLP IO", now()) | ||||
|             print() | ||||
|             print("writing to", NLP_2_DIR) | ||||
|             nlp_2.to_disk(NLP_2_DIR) | ||||
|  | @ -262,23 +300,22 @@ def run_pipeline(): | |||
|         el_pipe = nlp_3.get_pipe("entity_linker") | ||||
| 
 | ||||
|         dev_limit = 5000 | ||||
|         dev_data = training_set_creator.read_training(nlp=nlp_2, | ||||
|                                                       training_dir=TRAINING_DIR, | ||||
|                                                       dev=True, | ||||
|                                                       limit=dev_limit) | ||||
|         dev_data = training_set_creator.read_training( | ||||
|             nlp=nlp_2, training_dir=TRAINING_DIR, dev=True, limit=dev_limit, kb=None | ||||
|         ) | ||||
| 
 | ||||
|         print("Dev testing from file on", len(dev_data), "articles") | ||||
|         print() | ||||
| 
 | ||||
|         dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe=el_pipe, error_analysis=False) | ||||
|         print("dev acc combo avg:", round(dev_acc_combo, 3), | ||||
|               [(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()]) | ||||
|         dev_acc_combo, dev_acc_combo_dict = _measure_acc(dev_data, el_pipe) | ||||
|         combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()] | ||||
|         print("dev acc combo avg:", round(dev_acc_combo, 3), combo_by_label) | ||||
| 
 | ||||
|     print() | ||||
|     print("STOP", datetime.datetime.now()) | ||||
|     print("STOP", now()) | ||||
| 
 | ||||
| 
 | ||||
| def _measure_accuracy(data, el_pipe=None, error_analysis=False): | ||||
| def _measure_acc(data, el_pipe=None, error_analysis=False): | ||||
|     # If the docs in the data require further processing with an entity linker, set el_pipe | ||||
|     correct_by_label = dict() | ||||
|     incorrect_by_label = dict() | ||||
|  | @ -291,16 +328,21 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False): | |||
|     for doc, gold in zip(docs, golds): | ||||
|         try: | ||||
|             correct_entries_per_article = dict() | ||||
|             for entity in gold.links: | ||||
|                 start, end, gold_kb = entity | ||||
|                 correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb | ||||
|             for entity, kb_dict in gold.links.items(): | ||||
|                 start, end = entity | ||||
|                 # only evaluating on positive examples | ||||
|                 for gold_kb, value in kb_dict.items(): | ||||
|                     if value: | ||||
|                         offset = _offset(start, end) | ||||
|                         correct_entries_per_article[offset] = gold_kb | ||||
| 
 | ||||
|             for ent in doc.ents: | ||||
|                 ent_label = ent.label_ | ||||
|                 pred_entity = ent.kb_id_ | ||||
|                 start = ent.start_char | ||||
|                 end = ent.end_char | ||||
|                 gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None) | ||||
|                 offset = _offset(start, end) | ||||
|                 gold_entity = correct_entries_per_article.get(offset, None) | ||||
|                 # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' | ||||
|                 if gold_entity is not None: | ||||
|                     if gold_entity == pred_entity: | ||||
|  | @ -311,28 +353,33 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False): | |||
|                         incorrect_by_label[ent_label] = incorrect + 1 | ||||
|                         if error_analysis: | ||||
|                             print(ent.text, "in", doc) | ||||
|                             print("Predicted",  pred_entity, "should have been", gold_entity) | ||||
|                             print( | ||||
|                                 "Predicted", | ||||
|                                 pred_entity, | ||||
|                                 "should have been", | ||||
|                                 gold_entity, | ||||
|                             ) | ||||
|                             print() | ||||
| 
 | ||||
|         except Exception as e: | ||||
|             print("Error assessing accuracy", e) | ||||
| 
 | ||||
|     acc, acc_by_label = calculate_acc(correct_by_label,  incorrect_by_label) | ||||
|     acc, acc_by_label = calculate_acc(correct_by_label, incorrect_by_label) | ||||
|     return acc, acc_by_label | ||||
| 
 | ||||
| 
 | ||||
| def _measure_baselines(data, kb): | ||||
|     # Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound | ||||
|     counts_by_label = dict() | ||||
|     counts_d = dict() | ||||
| 
 | ||||
|     random_correct_by_label = dict() | ||||
|     random_incorrect_by_label = dict() | ||||
|     random_correct_d = dict() | ||||
|     random_incorrect_d = dict() | ||||
| 
 | ||||
|     oracle_correct_by_label = dict() | ||||
|     oracle_incorrect_by_label = dict() | ||||
|     oracle_correct_d = dict() | ||||
|     oracle_incorrect_d = dict() | ||||
| 
 | ||||
|     prior_correct_by_label = dict() | ||||
|     prior_incorrect_by_label = dict() | ||||
|     prior_correct_d = dict() | ||||
|     prior_incorrect_d = dict() | ||||
| 
 | ||||
|     docs = [d for d, g in data if len(d) > 0] | ||||
|     golds = [g for d, g in data if len(d) > 0] | ||||
|  | @ -340,19 +387,24 @@ def _measure_baselines(data, kb): | |||
|     for doc, gold in zip(docs, golds): | ||||
|         try: | ||||
|             correct_entries_per_article = dict() | ||||
|             for entity in gold.links: | ||||
|                 start, end, gold_kb = entity | ||||
|                 correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb | ||||
|             for entity, kb_dict in gold.links.items(): | ||||
|                 start, end = entity | ||||
|                 for gold_kb, value in kb_dict.items(): | ||||
|                     # only evaluating on positive examples | ||||
|                     if value: | ||||
|                         offset = _offset(start, end) | ||||
|                         correct_entries_per_article[offset] = gold_kb | ||||
| 
 | ||||
|             for ent in doc.ents: | ||||
|                 ent_label = ent.label_ | ||||
|                 label = ent.label_ | ||||
|                 start = ent.start_char | ||||
|                 end = ent.end_char | ||||
|                 gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None) | ||||
|                 offset = _offset(start, end) | ||||
|                 gold_entity = correct_entries_per_article.get(offset, None) | ||||
| 
 | ||||
|                 # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' | ||||
|                 if gold_entity is not None: | ||||
|                     counts_by_label[ent_label] = counts_by_label.get(ent_label, 0) + 1 | ||||
|                     counts_d[label] = counts_d.get(label, 0) + 1 | ||||
|                     candidates = kb.get_candidates(ent.text) | ||||
|                     oracle_candidate = "" | ||||
|                     best_candidate = "" | ||||
|  | @ -370,28 +422,40 @@ def _measure_baselines(data, kb): | |||
|                         random_candidate = random.choice(candidates).entity_ | ||||
| 
 | ||||
|                     if gold_entity == best_candidate: | ||||
|                         prior_correct_by_label[ent_label] = prior_correct_by_label.get(ent_label, 0) + 1 | ||||
|                         prior_correct_d[label] = prior_correct_d.get(label, 0) + 1 | ||||
|                     else: | ||||
|                         prior_incorrect_by_label[ent_label] = prior_incorrect_by_label.get(ent_label, 0) + 1 | ||||
|                         prior_incorrect_d[label] = prior_incorrect_d.get(label, 0) + 1 | ||||
| 
 | ||||
|                     if gold_entity == random_candidate: | ||||
|                         random_correct_by_label[ent_label] = random_correct_by_label.get(ent_label, 0) + 1 | ||||
|                         random_correct_d[label] = random_correct_d.get(label, 0) + 1 | ||||
|                     else: | ||||
|                         random_incorrect_by_label[ent_label] = random_incorrect_by_label.get(ent_label, 0) + 1 | ||||
|                         random_incorrect_d[label] = random_incorrect_d.get(label, 0) + 1 | ||||
| 
 | ||||
|                     if gold_entity == oracle_candidate: | ||||
|                         oracle_correct_by_label[ent_label] = oracle_correct_by_label.get(ent_label, 0) + 1 | ||||
|                         oracle_correct_d[label] = oracle_correct_d.get(label, 0) + 1 | ||||
|                     else: | ||||
|                         oracle_incorrect_by_label[ent_label] = oracle_incorrect_by_label.get(ent_label, 0) + 1 | ||||
|                         oracle_incorrect_d[label] = oracle_incorrect_d.get(label, 0) + 1 | ||||
| 
 | ||||
|         except Exception as e: | ||||
|             print("Error assessing accuracy", e) | ||||
| 
 | ||||
|     acc_prior, acc_prior_by_label = calculate_acc(prior_correct_by_label, prior_incorrect_by_label) | ||||
|     acc_rand, acc_rand_by_label = calculate_acc(random_correct_by_label, random_incorrect_by_label) | ||||
|     acc_oracle, acc_oracle_by_label = calculate_acc(oracle_correct_by_label, oracle_incorrect_by_label) | ||||
|     acc_prior, acc_prior_d = calculate_acc(prior_correct_d, prior_incorrect_d) | ||||
|     acc_rand, acc_rand_d = calculate_acc(random_correct_d, random_incorrect_d) | ||||
|     acc_oracle, acc_oracle_d = calculate_acc(oracle_correct_d, oracle_incorrect_d) | ||||
| 
 | ||||
|     return counts_by_label, acc_rand, acc_rand_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label | ||||
|     return ( | ||||
|         counts_d, | ||||
|         acc_rand, | ||||
|         acc_rand_d, | ||||
|         acc_prior, | ||||
|         acc_prior_d, | ||||
|         acc_oracle, | ||||
|         acc_oracle_d, | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| def _offset(start, end): | ||||
|     return "{}_{}".format(start, end) | ||||
| 
 | ||||
| 
 | ||||
| def calculate_acc(correct_by_label, incorrect_by_label): | ||||
|  | @ -422,15 +486,23 @@ def check_kb(kb): | |||
| 
 | ||||
|         print("generating candidates for " + mention + " :") | ||||
|         for c in candidates: | ||||
|             print(" ", c.prior_prob, c.alias_, "-->", c.entity_ + " (freq=" + str(c.entity_freq) + ")") | ||||
|             print( | ||||
|                 " ", | ||||
|                 c.prior_prob, | ||||
|                 c.alias_, | ||||
|                 "-->", | ||||
|                 c.entity_ + " (freq=" + str(c.entity_freq) + ")", | ||||
|             ) | ||||
|         print() | ||||
| 
 | ||||
| 
 | ||||
| def run_el_toy_example(nlp): | ||||
|     text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \ | ||||
|            "Douglas reminds us to always bring our towel, even in China or Brazil. " \ | ||||
|            "The main character in Doug's novel is the man Arthur Dent, " \ | ||||
|            "but Douglas doesn't write about George Washington or Homer Simpson." | ||||
|     text = ( | ||||
|         "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " | ||||
|         "Douglas reminds us to always bring our towel, even in China or Brazil. " | ||||
|         "The main character in Doug's novel is the man Arthur Dent, " | ||||
|         "but Dougledydoug doesn't write about George Washington or Homer Simpson." | ||||
|     ) | ||||
|     doc = nlp(text) | ||||
|     print(text) | ||||
|     for ent in doc.ents: | ||||
|  |  | |||
|  | @ -663,15 +663,14 @@ def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=False, | |||
| 
 | ||||
| 
 | ||||
| def build_nel_encoder(embed_width, hidden_width, ner_types, **cfg): | ||||
|     # TODO proper error | ||||
|     if "entity_width" not in cfg: | ||||
|         raise ValueError("entity_width not found") | ||||
|         raise ValueError(Errors.E144.format(param="entity_width")) | ||||
|     if "context_width" not in cfg: | ||||
|         raise ValueError("context_width not found") | ||||
|         raise ValueError(Errors.E144.format(param="context_width")) | ||||
| 
 | ||||
|     conv_depth = cfg.get("conv_depth", 2) | ||||
|     cnn_maxout_pieces = cfg.get("cnn_maxout_pieces", 3) | ||||
|     pretrained_vectors = cfg.get("pretrained_vectors")  # self.nlp.vocab.vectors.name | ||||
|     pretrained_vectors = cfg.get("pretrained_vectors", None) | ||||
|     context_width = cfg.get("context_width") | ||||
|     entity_width = cfg.get("entity_width") | ||||
| 
 | ||||
|  |  | |||
|  | @ -406,7 +406,15 @@ class Errors(object): | |||
|     E141 = ("Entity vectors should be of length {required} instead of the provided {found}.") | ||||
|     E142 = ("Unsupported loss_function '{loss_func}'. Use either 'L2' or 'cosine'") | ||||
|     E143 = ("Labels for component '{name}' not initialized. Did you forget to call add_label()?") | ||||
| 
 | ||||
|     E144 = ("Could not find parameter `{param}` when building the entity linker model.") | ||||
|     E145 = ("Error reading `{param}` from input file.") | ||||
|     E146 = ("Could not access `{path}`.") | ||||
|     E147 = ("Unexpected error in the {method} functionality of the EntityLinker: {msg}. " | ||||
|             "This is likely a bug in spaCy, so feel free to open an issue.") | ||||
|     E148 = ("Expected {ents} KB identifiers but got {ids}. Make sure that each entity in `doc.ents` " | ||||
|             "is assigned to a KB identifier.") | ||||
|     E149 = ("Error deserializing model. Check that the config used to create the " | ||||
|             "component matches the model being loaded.") | ||||
| 
 | ||||
| @add_codes | ||||
| class TempErrors(object): | ||||
|  |  | |||
|  | @ -31,7 +31,7 @@ cdef class GoldParse: | |||
|     cdef public list ents | ||||
|     cdef public dict brackets | ||||
|     cdef public object cats | ||||
|     cdef public list links | ||||
|     cdef public dict links | ||||
| 
 | ||||
|     cdef readonly list cand_to_gold | ||||
|     cdef readonly list gold_to_cand | ||||
|  |  | |||
|  | @ -468,8 +468,11 @@ cdef class GoldParse: | |||
|             examples of a label to have the value 0.0. Labels not in the | ||||
|             dictionary are treated as missing - the gradient for those labels | ||||
|             will be zero. | ||||
|         links (iterable): A sequence of `(start_char, end_char, kb_id)` tuples, | ||||
|             representing the external ID of an entity in a knowledge base. | ||||
|         links (dict): A dict with `(start_char, end_char)` keys, | ||||
|             and the values being dicts with kb_id:value entries, | ||||
|             representing the external IDs in a knowledge base (KB) | ||||
|             mapped to either 1.0 or 0.0, indicating positive and | ||||
|             negative examples respectively. | ||||
|         RETURNS (GoldParse): The newly constructed object. | ||||
|         """ | ||||
|         if words is None: | ||||
|  |  | |||
							
								
								
									
										12
									
								
								spacy/kb.pxd
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								spacy/kb.pxd
									
									
									
									
									
								
							|  | @ -79,7 +79,7 @@ cdef class KnowledgeBase: | |||
|         return new_index | ||||
| 
 | ||||
| 
 | ||||
|     cdef inline int64_t c_add_entity(self, hash_t entity_hash, float prob, | ||||
|     cdef inline int64_t c_add_entity(self, hash_t entity_hash, float freq, | ||||
|                                      int32_t vector_index, int feats_row) nogil: | ||||
|         """Add an entry to the vector of entries. | ||||
|         After calling this method, make sure to update also the _entry_index using the return value""" | ||||
|  | @ -92,7 +92,7 @@ cdef class KnowledgeBase: | |||
|         entry.entity_hash = entity_hash | ||||
|         entry.vector_index = vector_index | ||||
|         entry.feats_row = feats_row | ||||
|         entry.prob = prob | ||||
|         entry.freq = freq | ||||
| 
 | ||||
|         self._entries.push_back(entry) | ||||
|         return new_index | ||||
|  | @ -125,7 +125,7 @@ cdef class KnowledgeBase: | |||
|         entry.entity_hash = dummy_hash | ||||
|         entry.vector_index = dummy_value | ||||
|         entry.feats_row = dummy_value | ||||
|         entry.prob = dummy_value | ||||
|         entry.freq = dummy_value | ||||
| 
 | ||||
|         # Avoid struct initializer to enable nogil | ||||
|         cdef vector[int64_t] dummy_entry_indices | ||||
|  | @ -141,7 +141,7 @@ cdef class KnowledgeBase: | |||
|         self._aliases_table.push_back(alias) | ||||
| 
 | ||||
|     cpdef load_bulk(self, loc) | ||||
|     cpdef set_entities(self, entity_list, prob_list, vector_list) | ||||
|     cpdef set_entities(self, entity_list, freq_list, vector_list) | ||||
| 
 | ||||
| 
 | ||||
| cdef class Writer: | ||||
|  | @ -149,7 +149,7 @@ cdef class Writer: | |||
| 
 | ||||
|     cdef int write_header(self, int64_t nr_entries, int64_t entity_vector_length) except -1 | ||||
|     cdef int write_vector_element(self, float element) except -1 | ||||
|     cdef int write_entry(self, hash_t entry_hash, float entry_prob, int32_t vector_index) except -1 | ||||
|     cdef int write_entry(self, hash_t entry_hash, float entry_freq, int32_t vector_index) except -1 | ||||
| 
 | ||||
|     cdef int write_alias_length(self, int64_t alias_length) except -1 | ||||
|     cdef int write_alias_header(self, hash_t alias_hash, int64_t candidate_length) except -1 | ||||
|  | @ -162,7 +162,7 @@ cdef class Reader: | |||
| 
 | ||||
|     cdef int read_header(self, int64_t* nr_entries, int64_t* entity_vector_length) except -1 | ||||
|     cdef int read_vector_element(self, float* element) except -1 | ||||
|     cdef int read_entry(self, hash_t* entity_hash, float* prob, int32_t* vector_index) except -1 | ||||
|     cdef int read_entry(self, hash_t* entity_hash, float* freq, int32_t* vector_index) except -1 | ||||
| 
 | ||||
|     cdef int read_alias_length(self, int64_t* alias_length) except -1 | ||||
|     cdef int read_alias_header(self, hash_t* alias_hash, int64_t* candidate_length) except -1 | ||||
|  |  | |||
							
								
								
									
										86
									
								
								spacy/kb.pyx
									
									
									
									
									
								
							
							
						
						
									
										86
									
								
								spacy/kb.pyx
									
									
									
									
									
								
							|  | @ -94,7 +94,7 @@ cdef class KnowledgeBase: | |||
|     def get_alias_strings(self): | ||||
|         return [self.vocab.strings[x] for x in self._alias_index] | ||||
| 
 | ||||
|     def add_entity(self, unicode entity, float prob, vector[float] entity_vector): | ||||
|     def add_entity(self, unicode entity, float freq, vector[float] entity_vector): | ||||
|         """ | ||||
|         Add an entity to the KB, optionally specifying its log probability based on corpus frequency | ||||
|         Return the hash of the entity ID/name at the end. | ||||
|  | @ -113,15 +113,15 @@ cdef class KnowledgeBase: | |||
|         vector_index = self.c_add_vector(entity_vector=entity_vector) | ||||
| 
 | ||||
|         new_index = self.c_add_entity(entity_hash=entity_hash, | ||||
|                                       prob=prob, | ||||
|                                       freq=freq, | ||||
|                                       vector_index=vector_index, | ||||
|                                       feats_row=-1)  # Features table currently not implemented | ||||
|         self._entry_index[entity_hash] = new_index | ||||
| 
 | ||||
|         return entity_hash | ||||
| 
 | ||||
|     cpdef set_entities(self, entity_list, prob_list, vector_list): | ||||
|         if len(entity_list) != len(prob_list) or len(entity_list) != len(vector_list): | ||||
|     cpdef set_entities(self, entity_list, freq_list, vector_list): | ||||
|         if len(entity_list) != len(freq_list) or len(entity_list) != len(vector_list): | ||||
|             raise ValueError(Errors.E140) | ||||
| 
 | ||||
|         nr_entities = len(entity_list) | ||||
|  | @ -137,7 +137,7 @@ cdef class KnowledgeBase: | |||
| 
 | ||||
|             entity_hash = self.vocab.strings.add(entity_list[i]) | ||||
|             entry.entity_hash = entity_hash | ||||
|             entry.prob = prob_list[i] | ||||
|             entry.freq = freq_list[i] | ||||
| 
 | ||||
|             vector_index = self.c_add_vector(entity_vector=vector_list[i]) | ||||
|             entry.vector_index = vector_index | ||||
|  | @ -196,13 +196,42 @@ cdef class KnowledgeBase: | |||
| 
 | ||||
|         return [Candidate(kb=self, | ||||
|                           entity_hash=self._entries[entry_index].entity_hash, | ||||
|                           entity_freq=self._entries[entry_index].prob, | ||||
|                           entity_freq=self._entries[entry_index].freq, | ||||
|                           entity_vector=self._vectors_table[self._entries[entry_index].vector_index], | ||||
|                           alias_hash=alias_hash, | ||||
|                           prior_prob=prob) | ||||
|                 for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs) | ||||
|                           prior_prob=prior_prob) | ||||
|                 for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs) | ||||
|                 if entry_index != 0] | ||||
| 
 | ||||
|     def get_vector(self, unicode entity): | ||||
|         cdef hash_t entity_hash = self.vocab.strings[entity] | ||||
| 
 | ||||
|         # Return an empty list if this entity is unknown in this KB | ||||
|         if entity_hash not in self._entry_index: | ||||
|             return [0] * self.entity_vector_length | ||||
|         entry_index = self._entry_index[entity_hash] | ||||
| 
 | ||||
|         return self._vectors_table[self._entries[entry_index].vector_index] | ||||
| 
 | ||||
|     def get_prior_prob(self, unicode entity, unicode alias): | ||||
|         """ Return the prior probability of a given alias being linked to a given entity, | ||||
|         or return 0.0 when this combination is not known in the knowledge base""" | ||||
|         cdef hash_t alias_hash = self.vocab.strings[alias] | ||||
|         cdef hash_t entity_hash = self.vocab.strings[entity] | ||||
| 
 | ||||
|         if entity_hash not in self._entry_index or alias_hash not in self._alias_index: | ||||
|             return 0.0 | ||||
| 
 | ||||
|         alias_index = <int64_t>self._alias_index.get(alias_hash) | ||||
|         entry_index = self._entry_index[entity_hash] | ||||
| 
 | ||||
|         alias_entry = self._aliases_table[alias_index] | ||||
|         for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs): | ||||
|             if self._entries[entry_index].entity_hash == entity_hash: | ||||
|                 return prior_prob | ||||
| 
 | ||||
|         return 0.0 | ||||
| 
 | ||||
| 
 | ||||
|     def dump(self, loc): | ||||
|         cdef Writer writer = Writer(loc) | ||||
|  | @ -222,7 +251,7 @@ cdef class KnowledgeBase: | |||
|             entry = self._entries[entry_index] | ||||
|             assert entry.entity_hash == entry_hash | ||||
|             assert entry_index == i | ||||
|             writer.write_entry(entry.entity_hash, entry.prob, entry.vector_index) | ||||
|             writer.write_entry(entry.entity_hash, entry.freq, entry.vector_index) | ||||
|             i = i+1 | ||||
| 
 | ||||
|         writer.write_alias_length(self.get_size_aliases()) | ||||
|  | @ -248,7 +277,7 @@ cdef class KnowledgeBase: | |||
|         cdef hash_t entity_hash | ||||
|         cdef hash_t alias_hash | ||||
|         cdef int64_t entry_index | ||||
|         cdef float prob | ||||
|         cdef float freq, prob | ||||
|         cdef int32_t vector_index | ||||
|         cdef KBEntryC entry | ||||
|         cdef AliasC alias | ||||
|  | @ -284,10 +313,10 @@ cdef class KnowledgeBase: | |||
|         # index 0 is a dummy object not stored in the _entry_index and can be ignored. | ||||
|         i = 1 | ||||
|         while i <= nr_entities: | ||||
|             reader.read_entry(&entity_hash, &prob, &vector_index) | ||||
|             reader.read_entry(&entity_hash, &freq, &vector_index) | ||||
| 
 | ||||
|             entry.entity_hash = entity_hash | ||||
|             entry.prob = prob | ||||
|             entry.freq = freq | ||||
|             entry.vector_index = vector_index | ||||
|             entry.feats_row = -1    # Features table currently not implemented | ||||
| 
 | ||||
|  | @ -343,7 +372,8 @@ cdef class Writer: | |||
|             loc = bytes(loc) | ||||
|         cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc | ||||
|         self._fp = fopen(<char*>bytes_loc, 'wb') | ||||
|         assert self._fp != NULL | ||||
|         if not self._fp: | ||||
|             raise IOError(Errors.E146.format(path=loc)) | ||||
|         fseek(self._fp, 0, 0) | ||||
| 
 | ||||
|     def close(self): | ||||
|  | @ -357,9 +387,9 @@ cdef class Writer: | |||
|     cdef int write_vector_element(self, float element) except -1: | ||||
|         self._write(&element, sizeof(element)) | ||||
| 
 | ||||
|     cdef int write_entry(self, hash_t entry_hash, float entry_prob, int32_t vector_index) except -1: | ||||
|     cdef int write_entry(self, hash_t entry_hash, float entry_freq, int32_t vector_index) except -1: | ||||
|         self._write(&entry_hash, sizeof(entry_hash)) | ||||
|         self._write(&entry_prob, sizeof(entry_prob)) | ||||
|         self._write(&entry_freq, sizeof(entry_freq)) | ||||
|         self._write(&vector_index, sizeof(vector_index)) | ||||
|         # Features table currently not implemented and not written to file | ||||
| 
 | ||||
|  | @ -399,39 +429,39 @@ cdef class Reader: | |||
|         if status < 1: | ||||
|             if feof(self._fp): | ||||
|                 return 0  # end of file | ||||
|             raise IOError("error reading header from input file") | ||||
|             raise IOError(Errors.E145.format(param="header")) | ||||
| 
 | ||||
|         status = self._read(entity_vector_length, sizeof(int64_t)) | ||||
|         if status < 1: | ||||
|             if feof(self._fp): | ||||
|                 return 0  # end of file | ||||
|             raise IOError("error reading header from input file") | ||||
|             raise IOError(Errors.E145.format(param="vector length")) | ||||
| 
 | ||||
|     cdef int read_vector_element(self, float* element) except -1: | ||||
|         status = self._read(element, sizeof(float)) | ||||
|         if status < 1: | ||||
|             if feof(self._fp): | ||||
|                 return 0  # end of file | ||||
|             raise IOError("error reading entity vector from input file") | ||||
|             raise IOError(Errors.E145.format(param="vector element")) | ||||
| 
 | ||||
|     cdef int read_entry(self, hash_t* entity_hash, float* prob, int32_t* vector_index) except -1: | ||||
|     cdef int read_entry(self, hash_t* entity_hash, float* freq, int32_t* vector_index) except -1: | ||||
|         status = self._read(entity_hash, sizeof(hash_t)) | ||||
|         if status < 1: | ||||
|             if feof(self._fp): | ||||
|                 return 0  # end of file | ||||
|             raise IOError("error reading entity hash from input file") | ||||
|             raise IOError(Errors.E145.format(param="entity hash")) | ||||
| 
 | ||||
|         status = self._read(prob, sizeof(float)) | ||||
|         status = self._read(freq, sizeof(float)) | ||||
|         if status < 1: | ||||
|             if feof(self._fp): | ||||
|                 return 0  # end of file | ||||
|             raise IOError("error reading entity prob from input file") | ||||
|             raise IOError(Errors.E145.format(param="entity freq")) | ||||
| 
 | ||||
|         status = self._read(vector_index, sizeof(int32_t)) | ||||
|         if status < 1: | ||||
|             if feof(self._fp): | ||||
|                 return 0  # end of file | ||||
|             raise IOError("error reading entity vector from input file") | ||||
|             raise IOError(Errors.E145.format(param="vector index")) | ||||
| 
 | ||||
|         if feof(self._fp): | ||||
|             return 0 | ||||
|  | @ -443,33 +473,33 @@ cdef class Reader: | |||
|         if status < 1: | ||||
|             if feof(self._fp): | ||||
|                 return 0  # end of file | ||||
|             raise IOError("error reading alias length from input file") | ||||
|             raise IOError(Errors.E145.format(param="alias length")) | ||||
| 
 | ||||
|     cdef int read_alias_header(self, hash_t* alias_hash, int64_t* candidate_length) except -1: | ||||
|         status = self._read(alias_hash, sizeof(hash_t)) | ||||
|         if status < 1: | ||||
|             if feof(self._fp): | ||||
|                 return 0  # end of file | ||||
|             raise IOError("error reading alias hash from input file") | ||||
|             raise IOError(Errors.E145.format(param="alias hash")) | ||||
| 
 | ||||
|         status = self._read(candidate_length, sizeof(int64_t)) | ||||
|         if status < 1: | ||||
|             if feof(self._fp): | ||||
|                 return 0  # end of file | ||||
|             raise IOError("error reading candidate length from input file") | ||||
|             raise IOError(Errors.E145.format(param="candidate length")) | ||||
| 
 | ||||
|     cdef int read_alias(self, int64_t* entry_index, float* prob) except -1: | ||||
|         status = self._read(entry_index, sizeof(int64_t)) | ||||
|         if status < 1: | ||||
|             if feof(self._fp): | ||||
|                 return 0  # end of file | ||||
|             raise IOError("error reading entry index for alias from input file") | ||||
|             raise IOError(Errors.E145.format(param="entry index")) | ||||
| 
 | ||||
|         status = self._read(prob, sizeof(float)) | ||||
|         if status < 1: | ||||
|             if feof(self._fp): | ||||
|                 return 0  # end of file | ||||
|             raise IOError("error reading prob for entity/alias from input file") | ||||
|             raise IOError(Errors.E145.format(param="prior probability")) | ||||
| 
 | ||||
|     cdef int _read(self, void* value, size_t size) except -1: | ||||
|         status = fread(value, size, 1, self._fp) | ||||
|  |  | |||
|  | @ -12,10 +12,6 @@ from ...language import Language | |||
| from ...attrs import LANG, NORM | ||||
| from ...util import update_exc, add_lookups | ||||
| 
 | ||||
| # Borrowing french syntax parser because both languages use | ||||
| # universal dependencies for tagging/parsing. | ||||
| # Read here for more: | ||||
| # https://github.com/explosion/spaCy/pull/1882#issuecomment-361409573 | ||||
| from .syntax_iterators import SYNTAX_ITERATORS | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -14,7 +14,6 @@ from thinc.neural.util import to_categorical | |||
| from thinc.neural.util import get_array_module | ||||
| 
 | ||||
| from spacy.kb import KnowledgeBase | ||||
| from ..cli.pretrain import get_cossim_loss | ||||
| from .functions import merge_subtokens | ||||
| from ..tokens.doc cimport Doc | ||||
| from ..syntax.nn_parser cimport Parser | ||||
|  | @ -168,7 +167,10 @@ class Pipe(object): | |||
|                 self.cfg["pretrained_vectors"] = self.vocab.vectors.name | ||||
|             if self.model is True: | ||||
|                 self.model = self.Model(**self.cfg) | ||||
|             self.model.from_bytes(b) | ||||
|             try: | ||||
|                 self.model.from_bytes(b) | ||||
|             except AttributeError: | ||||
|                 raise ValueError(Errors.E149) | ||||
| 
 | ||||
|         deserialize = OrderedDict() | ||||
|         deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b)) | ||||
|  | @ -197,7 +199,10 @@ class Pipe(object): | |||
|                 self.cfg["pretrained_vectors"] = self.vocab.vectors.name | ||||
|             if self.model is True: | ||||
|                 self.model = self.Model(**self.cfg) | ||||
|             self.model.from_bytes(p.open("rb").read()) | ||||
|             try: | ||||
|                 self.model.from_bytes(p.open("rb").read()) | ||||
|             except AttributeError: | ||||
|                 raise ValueError(Errors.E149) | ||||
| 
 | ||||
|         deserialize = OrderedDict() | ||||
|         deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p)) | ||||
|  | @ -563,7 +568,10 @@ class Tagger(Pipe): | |||
|                     "token_vector_width", | ||||
|                     self.cfg.get("token_vector_width", 96)) | ||||
|                 self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg) | ||||
|             self.model.from_bytes(b) | ||||
|             try: | ||||
|                 self.model.from_bytes(b) | ||||
|             except AttributeError: | ||||
|                 raise ValueError(Errors.E149) | ||||
| 
 | ||||
|         def load_tag_map(b): | ||||
|             tag_map = srsly.msgpack_loads(b) | ||||
|  | @ -601,7 +609,10 @@ class Tagger(Pipe): | |||
|             if self.model is True: | ||||
|                 self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg) | ||||
|             with p.open("rb") as file_: | ||||
|                 self.model.from_bytes(file_.read()) | ||||
|                 try: | ||||
|                     self.model.from_bytes(file_.read()) | ||||
|                 except AttributeError: | ||||
|                     raise ValueError(Errors.E149) | ||||
| 
 | ||||
|         def load_tag_map(p): | ||||
|             tag_map = srsly.read_msgpack(p) | ||||
|  | @ -1077,6 +1088,7 @@ class EntityLinker(Pipe): | |||
|     DOCS: TODO | ||||
|     """ | ||||
|     name = 'entity_linker' | ||||
|     NIL = "NIL"  # string used to refer to a non-existing link | ||||
| 
 | ||||
|     @classmethod | ||||
|     def Model(cls, **cfg): | ||||
|  | @ -1093,6 +1105,8 @@ class EntityLinker(Pipe): | |||
|         self.kb = None | ||||
|         self.cfg = dict(cfg) | ||||
|         self.sgd_context = None | ||||
|         if not self.cfg.get("context_width"): | ||||
|             self.cfg["context_width"] = 128 | ||||
| 
 | ||||
|     def set_kb(self, kb): | ||||
|         self.kb = kb | ||||
|  | @ -1140,7 +1154,7 @@ class EntityLinker(Pipe): | |||
| 
 | ||||
|         context_docs = [] | ||||
|         entity_encodings = [] | ||||
|         cats = [] | ||||
| 
 | ||||
|         priors = [] | ||||
|         type_vectors = [] | ||||
| 
 | ||||
|  | @ -1149,50 +1163,44 @@ class EntityLinker(Pipe): | |||
|         for doc, gold in zip(docs, golds): | ||||
|             ents_by_offset = dict() | ||||
|             for ent in doc.ents: | ||||
|                 ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent | ||||
|             for entity in gold.links: | ||||
|                 start, end, gold_kb = entity | ||||
|                 ents_by_offset["{}_{}".format(ent.start_char, ent.end_char)] = ent | ||||
|             for entity, kb_dict in gold.links.items(): | ||||
|                 start, end = entity | ||||
|                 mention = doc.text[start:end] | ||||
|                 for kb_id, value in kb_dict.items(): | ||||
|                     entity_encoding = self.kb.get_vector(kb_id) | ||||
|                     prior_prob = self.kb.get_prior_prob(kb_id, mention) | ||||
| 
 | ||||
|                 gold_ent = ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] | ||||
|                 assert gold_ent is not None | ||||
|                 type_vector = [0 for i in range(len(type_to_int))] | ||||
|                 if len(type_to_int) > 0: | ||||
|                     type_vector[type_to_int[gold_ent.label_]] = 1 | ||||
|                     gold_ent = ents_by_offset["{}_{}".format(start, end)] | ||||
|                     if gold_ent is None: | ||||
|                         raise RuntimeError(Errors.E147.format(method="update", msg="gold entity not found")) | ||||
| 
 | ||||
|                 candidates = self.kb.get_candidates(mention) | ||||
|                 random.shuffle(candidates) | ||||
|                 nr_neg = 0 | ||||
|                 for c in candidates: | ||||
|                     kb_id = c.entity_ | ||||
|                     entity_encoding = c.entity_vector | ||||
|                     type_vector = [0 for i in range(len(type_to_int))] | ||||
|                     if len(type_to_int) > 0: | ||||
|                         type_vector[type_to_int[gold_ent.label_]] = 1 | ||||
| 
 | ||||
|                     # store data | ||||
|                     entity_encodings.append(entity_encoding) | ||||
|                     context_docs.append(doc) | ||||
|                     type_vectors.append(type_vector) | ||||
| 
 | ||||
|                     if self.cfg.get("prior_weight", 1) > 0: | ||||
|                         priors.append([c.prior_prob]) | ||||
|                         priors.append([prior_prob]) | ||||
|                     else: | ||||
|                         priors.append([0]) | ||||
| 
 | ||||
|                     if kb_id == gold_kb: | ||||
|                         cats.append([1]) | ||||
|                     else: | ||||
|                         nr_neg += 1 | ||||
|                         cats.append([0]) | ||||
| 
 | ||||
|         if len(entity_encodings) > 0: | ||||
|             assert len(priors) == len(entity_encodings) == len(context_docs) == len(cats) == len(type_vectors) | ||||
|             if not (len(priors) == len(entity_encodings) == len(context_docs) == len(type_vectors)): | ||||
|                 raise RuntimeError(Errors.E147.format(method="update", msg="vector lengths not equal")) | ||||
| 
 | ||||
|             context_encodings, bp_context = self.model.tok2vec.begin_update(context_docs, drop=drop) | ||||
|             entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32") | ||||
| 
 | ||||
|             context_encodings, bp_context = self.model.tok2vec.begin_update(context_docs, drop=drop) | ||||
|             mention_encodings = [list(context_encodings[i]) + list(entity_encodings[i]) + priors[i] + type_vectors[i] | ||||
|                                  for i in range(len(entity_encodings))] | ||||
|             pred, bp_mention = self.model.begin_update(self.model.ops.asarray(mention_encodings, dtype="float32"), drop=drop) | ||||
|             cats = self.model.ops.asarray(cats, dtype="float32") | ||||
| 
 | ||||
|             loss, d_scores = self.get_loss(prediction=pred, golds=cats, docs=None) | ||||
|             loss, d_scores = self.get_loss(scores=pred, golds=golds, docs=docs) | ||||
|             mention_gradient = bp_mention(d_scores, sgd=sgd) | ||||
| 
 | ||||
|             context_gradients = [list(x[0:self.cfg.get("context_width")]) for x in mention_gradient] | ||||
|  | @ -1203,39 +1211,45 @@ class EntityLinker(Pipe): | |||
|             return loss | ||||
|         return 0 | ||||
| 
 | ||||
|     def get_loss(self, docs, golds, prediction): | ||||
|         d_scores = (prediction - golds) | ||||
|     def get_loss(self, docs, golds, scores): | ||||
|         cats = [] | ||||
|         for gold in golds: | ||||
|             for entity, kb_dict in gold.links.items(): | ||||
|                 for kb_id, value in kb_dict.items(): | ||||
|                     cats.append([value]) | ||||
| 
 | ||||
|         cats = self.model.ops.asarray(cats, dtype="float32") | ||||
|         if len(scores) != len(cats): | ||||
|             raise RuntimeError(Errors.E147.format(method="get_loss", msg="gold entities do not match up")) | ||||
| 
 | ||||
|         d_scores = (scores - cats) | ||||
|         loss = (d_scores ** 2).sum() | ||||
|         loss = loss / len(golds) | ||||
|         loss = loss / len(cats) | ||||
|         return loss, d_scores | ||||
| 
 | ||||
|     def get_loss_old(self, docs, golds, scores): | ||||
|         # this loss function assumes we're only using positive examples | ||||
|         loss, gradients = get_cossim_loss(yh=scores, y=golds) | ||||
|         loss = loss / len(golds) | ||||
|         return loss, gradients | ||||
| 
 | ||||
|     def __call__(self, doc): | ||||
|         entities, kb_ids = self.predict([doc]) | ||||
|         self.set_annotations([doc], entities, kb_ids) | ||||
|         kb_ids, tensors = self.predict([doc]) | ||||
|         self.set_annotations([doc], kb_ids, tensors=tensors) | ||||
|         return doc | ||||
| 
 | ||||
|     def pipe(self, stream, batch_size=128, n_threads=-1): | ||||
|         for docs in util.minibatch(stream, size=batch_size): | ||||
|             docs = list(docs) | ||||
|             entities, kb_ids = self.predict(docs) | ||||
|             self.set_annotations(docs, entities, kb_ids) | ||||
|             kb_ids, tensors = self.predict(docs) | ||||
|             self.set_annotations(docs, kb_ids, tensors=tensors) | ||||
|             yield from docs | ||||
| 
 | ||||
|     def predict(self, docs): | ||||
|         """ Return the KB IDs for each entity in each doc, including NIL if there is no prediction """ | ||||
|         self.require_model() | ||||
|         self.require_kb() | ||||
| 
 | ||||
|         final_entities = [] | ||||
|         entity_count = 0 | ||||
|         final_kb_ids = [] | ||||
|         final_tensors = [] | ||||
| 
 | ||||
|         if not docs: | ||||
|             return final_entities, final_kb_ids | ||||
|             return final_kb_ids, final_tensors | ||||
| 
 | ||||
|         if isinstance(docs, Doc): | ||||
|             docs = [docs] | ||||
|  | @ -1247,14 +1261,19 @@ class EntityLinker(Pipe): | |||
| 
 | ||||
|         for i, doc in enumerate(docs): | ||||
|             if len(doc) > 0: | ||||
|                 # currently, the context is the same for each entity in a sentence (should be refined) | ||||
|                 context_encoding = context_encodings[i] | ||||
|                 for ent in doc.ents: | ||||
|                     entity_count += 1 | ||||
|                     type_vector = [0 for i in range(len(type_to_int))] | ||||
|                     if len(type_to_int) > 0: | ||||
|                         type_vector[type_to_int[ent.label_]] = 1 | ||||
| 
 | ||||
|                     candidates = self.kb.get_candidates(ent.text) | ||||
|                     if candidates: | ||||
|                     if not candidates: | ||||
|                         final_kb_ids.append(self.NIL)  # no prediction possible for this entity | ||||
|                         final_tensors.append(context_encoding) | ||||
|                     else: | ||||
|                         random.shuffle(candidates) | ||||
| 
 | ||||
|                         # this will set the prior probabilities to 0 (just like in training) if their weight is 0 | ||||
|  | @ -1264,7 +1283,9 @@ class EntityLinker(Pipe): | |||
| 
 | ||||
|                         if self.cfg.get("context_weight", 1) > 0: | ||||
|                             entity_encodings = xp.asarray([c.entity_vector for c in candidates]) | ||||
|                             assert len(entity_encodings) == len(prior_probs) | ||||
|                             if len(entity_encodings) != len(prior_probs): | ||||
|                                 raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length")) | ||||
| 
 | ||||
|                             mention_encodings = [list(context_encoding) + list(entity_encodings[i]) | ||||
|                                                  + list(prior_probs[i]) + type_vector | ||||
|                                                  for i in range(len(entity_encodings))] | ||||
|  | @ -1273,15 +1294,26 @@ class EntityLinker(Pipe): | |||
|                         # TODO: thresholding | ||||
|                         best_index = scores.argmax() | ||||
|                         best_candidate = candidates[best_index] | ||||
|                         final_entities.append(ent) | ||||
|                         final_kb_ids.append(best_candidate.entity_) | ||||
|                         final_tensors.append(context_encoding) | ||||
| 
 | ||||
|         return final_entities, final_kb_ids | ||||
|         if not (len(final_tensors) == len(final_kb_ids) == entity_count): | ||||
|             raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length")) | ||||
| 
 | ||||
|     def set_annotations(self, docs, entities, kb_ids=None): | ||||
|         for entity, kb_id in zip(entities, kb_ids): | ||||
|             for token in entity: | ||||
|                 token.ent_kb_id_ = kb_id | ||||
|         return final_kb_ids, final_tensors | ||||
| 
 | ||||
|     def set_annotations(self, docs, kb_ids, tensors=None): | ||||
|         count_ents = len([ent for doc in docs for ent in doc.ents]) | ||||
|         if count_ents != len(kb_ids): | ||||
|             raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids))) | ||||
| 
 | ||||
|         i=0 | ||||
|         for doc in docs: | ||||
|             for ent in doc.ents: | ||||
|                 kb_id = kb_ids[i] | ||||
|                 i += 1 | ||||
|                 for token in ent: | ||||
|                     token.ent_kb_id_ = kb_id | ||||
| 
 | ||||
|     def to_disk(self, path, exclude=tuple(), **kwargs): | ||||
|         serialize = OrderedDict() | ||||
|  | @ -1295,9 +1327,12 @@ class EntityLinker(Pipe): | |||
| 
 | ||||
|     def from_disk(self, path, exclude=tuple(), **kwargs): | ||||
|         def load_model(p): | ||||
|              if self.model is True: | ||||
|             if self.model is True: | ||||
|                 self.model = self.Model(**self.cfg) | ||||
|              self.model.from_bytes(p.open("rb").read()) | ||||
|             try:  | ||||
|                 self.model.from_bytes(p.open("rb").read()) | ||||
|             except AttributeError: | ||||
|                 raise ValueError(Errors.E149) | ||||
| 
 | ||||
|         def load_kb(p): | ||||
|             kb = KnowledgeBase(vocab=self.vocab, entity_vector_length=self.cfg["entity_width"]) | ||||
|  |  | |||
|  | @ -93,7 +93,7 @@ cdef struct KBEntryC: | |||
|     int32_t feats_row | ||||
| 
 | ||||
|     # log probability of entity, based on corpus frequency | ||||
|     float prob | ||||
|     float freq | ||||
| 
 | ||||
| 
 | ||||
| # Each alias struct stores a list of Entry pointers with their prior probabilities | ||||
|  |  | |||
|  | @ -631,7 +631,10 @@ cdef class Parser: | |||
|                 cfg = {} | ||||
|             with (path / 'model').open('rb') as file_: | ||||
|                 bytes_data = file_.read() | ||||
|             self.model.from_bytes(bytes_data) | ||||
|             try: | ||||
|                 self.model.from_bytes(bytes_data) | ||||
|             except AttributeError: | ||||
|                 raise ValueError(Errors.E149) | ||||
|             self.cfg.update(cfg) | ||||
|         return self | ||||
| 
 | ||||
|  | @ -663,6 +666,9 @@ cdef class Parser: | |||
|             else: | ||||
|                 cfg = {} | ||||
|             if 'model' in msg: | ||||
|                 self.model.from_bytes(msg['model']) | ||||
|                 try: | ||||
|                     self.model.from_bytes(msg['model']) | ||||
|                 except AttributeError: | ||||
|                     raise ValueError(Errors.E149) | ||||
|             self.cfg.update(cfg) | ||||
|         return self | ||||
|  |  | |||
|  | @ -13,22 +13,38 @@ def nlp(): | |||
|     return English() | ||||
| 
 | ||||
| 
 | ||||
| def assert_almost_equal(a, b): | ||||
|     delta = 0.0001 | ||||
|     assert a - delta <= b <= a + delta | ||||
| 
 | ||||
| 
 | ||||
| def test_kb_valid_entities(nlp): | ||||
|     """Test the valid construction of a KB with 3 entities and two aliases""" | ||||
|     mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) | ||||
|     mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3) | ||||
| 
 | ||||
|     # adding entities | ||||
|     mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1]) | ||||
|     mykb.add_entity(entity='Q2', prob=0.5, entity_vector=[2]) | ||||
|     mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3]) | ||||
|     mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[8, 4, 3]) | ||||
|     mykb.add_entity(entity="Q2", freq=0.5, entity_vector=[2, 1, 0]) | ||||
|     mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[-1, -6, 5]) | ||||
| 
 | ||||
|     # adding aliases | ||||
|     mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.2]) | ||||
|     mykb.add_alias(alias='adam', entities=['Q2'], probabilities=[0.9]) | ||||
|     mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.2]) | ||||
|     mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9]) | ||||
| 
 | ||||
|     # test the size of the corresponding KB | ||||
|     assert(mykb.get_size_entities() == 3) | ||||
|     assert(mykb.get_size_aliases() == 2) | ||||
|     assert mykb.get_size_entities() == 3 | ||||
|     assert mykb.get_size_aliases() == 2 | ||||
| 
 | ||||
|     # test retrieval of the entity vectors | ||||
|     assert mykb.get_vector("Q1") == [8, 4, 3] | ||||
|     assert mykb.get_vector("Q2") == [2, 1, 0] | ||||
|     assert mykb.get_vector("Q3") == [-1, -6, 5] | ||||
| 
 | ||||
|     # test retrieval of prior probabilities | ||||
|     assert_almost_equal(mykb.get_prior_prob(entity="Q2", alias="douglas"), 0.8) | ||||
|     assert_almost_equal(mykb.get_prior_prob(entity="Q3", alias="douglas"), 0.2) | ||||
|     assert_almost_equal(mykb.get_prior_prob(entity="Q342", alias="douglas"), 0.0) | ||||
|     assert_almost_equal(mykb.get_prior_prob(entity="Q3", alias="douglassssss"), 0.0) | ||||
| 
 | ||||
| 
 | ||||
| def test_kb_invalid_entities(nlp): | ||||
|  | @ -36,13 +52,15 @@ def test_kb_invalid_entities(nlp): | |||
|     mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) | ||||
| 
 | ||||
|     # adding entities | ||||
|     mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1]) | ||||
|     mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2]) | ||||
|     mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3]) | ||||
|     mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1]) | ||||
|     mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2]) | ||||
|     mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[3]) | ||||
| 
 | ||||
|     # adding aliases - should fail because one of the given IDs is not valid | ||||
|     with pytest.raises(ValueError): | ||||
|         mykb.add_alias(alias='douglas', entities=['Q2', 'Q342'], probabilities=[0.8, 0.2]) | ||||
|         mykb.add_alias( | ||||
|             alias="douglas", entities=["Q2", "Q342"], probabilities=[0.8, 0.2] | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| def test_kb_invalid_probabilities(nlp): | ||||
|  | @ -50,13 +68,13 @@ def test_kb_invalid_probabilities(nlp): | |||
|     mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) | ||||
| 
 | ||||
|     # adding entities | ||||
|     mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1]) | ||||
|     mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2]) | ||||
|     mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3]) | ||||
|     mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1]) | ||||
|     mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2]) | ||||
|     mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[3]) | ||||
| 
 | ||||
|     # adding aliases - should fail because the sum of the probabilities exceeds 1 | ||||
|     with pytest.raises(ValueError): | ||||
|         mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.4]) | ||||
|         mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.4]) | ||||
| 
 | ||||
| 
 | ||||
| def test_kb_invalid_combination(nlp): | ||||
|  | @ -64,13 +82,15 @@ def test_kb_invalid_combination(nlp): | |||
|     mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) | ||||
| 
 | ||||
|     # adding entities | ||||
|     mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1]) | ||||
|     mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2]) | ||||
|     mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3]) | ||||
|     mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1]) | ||||
|     mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2]) | ||||
|     mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[3]) | ||||
| 
 | ||||
|     # adding aliases - should fail because the entities and probabilities vectors are not of equal length | ||||
|     with pytest.raises(ValueError): | ||||
|         mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.3, 0.4, 0.1]) | ||||
|         mykb.add_alias( | ||||
|             alias="douglas", entities=["Q2", "Q3"], probabilities=[0.3, 0.4, 0.1] | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| def test_kb_invalid_entity_vector(nlp): | ||||
|  | @ -78,11 +98,11 @@ def test_kb_invalid_entity_vector(nlp): | |||
|     mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3) | ||||
| 
 | ||||
|     # adding entities | ||||
|     mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1, 2, 3]) | ||||
|     mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1, 2, 3]) | ||||
| 
 | ||||
|     # this should fail because the kb's expected entity vector length is 3 | ||||
|     with pytest.raises(ValueError): | ||||
|         mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2]) | ||||
|         mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2]) | ||||
| 
 | ||||
| 
 | ||||
| def test_candidate_generation(nlp): | ||||
|  | @ -90,18 +110,24 @@ def test_candidate_generation(nlp): | |||
|     mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) | ||||
| 
 | ||||
|     # adding entities | ||||
|     mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1]) | ||||
|     mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2]) | ||||
|     mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3]) | ||||
|     mykb.add_entity(entity="Q1", freq=0.7, entity_vector=[1]) | ||||
|     mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2]) | ||||
|     mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[3]) | ||||
| 
 | ||||
|     # adding aliases | ||||
|     mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.2]) | ||||
|     mykb.add_alias(alias='adam', entities=['Q2'], probabilities=[0.9]) | ||||
|     mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1]) | ||||
|     mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9]) | ||||
| 
 | ||||
|     # test the size of the relevant candidates | ||||
|     assert(len(mykb.get_candidates('douglas')) == 2) | ||||
|     assert(len(mykb.get_candidates('adam')) == 1) | ||||
|     assert(len(mykb.get_candidates('shrubbery')) == 0) | ||||
|     assert len(mykb.get_candidates("douglas")) == 2 | ||||
|     assert len(mykb.get_candidates("adam")) == 1 | ||||
|     assert len(mykb.get_candidates("shrubbery")) == 0 | ||||
| 
 | ||||
|     # test the content of the candidates | ||||
|     assert mykb.get_candidates("adam")[0].entity_ == "Q2" | ||||
|     assert mykb.get_candidates("adam")[0].alias_ == "adam" | ||||
|     assert_almost_equal(mykb.get_candidates("adam")[0].entity_freq, 0.2) | ||||
|     assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9) | ||||
| 
 | ||||
| 
 | ||||
| def test_preserving_links_asdoc(nlp): | ||||
|  | @ -109,24 +135,26 @@ def test_preserving_links_asdoc(nlp): | |||
|     mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) | ||||
| 
 | ||||
|     # adding entities | ||||
|     mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1]) | ||||
|     mykb.add_entity(entity='Q2', prob=0.8, entity_vector=[1]) | ||||
|     mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1]) | ||||
|     mykb.add_entity(entity="Q2", freq=0.8, entity_vector=[1]) | ||||
| 
 | ||||
|     # adding aliases | ||||
|     mykb.add_alias(alias='Boston', entities=['Q1'], probabilities=[0.7]) | ||||
|     mykb.add_alias(alias='Denver', entities=['Q2'], probabilities=[0.6]) | ||||
|     mykb.add_alias(alias="Boston", entities=["Q1"], probabilities=[0.7]) | ||||
|     mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6]) | ||||
| 
 | ||||
|     # set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained) | ||||
|     sentencizer = nlp.create_pipe("sentencizer") | ||||
|     nlp.add_pipe(sentencizer) | ||||
| 
 | ||||
|     ruler = EntityRuler(nlp) | ||||
|     patterns = [{"label": "GPE", "pattern": "Boston"}, | ||||
|                 {"label": "GPE", "pattern": "Denver"}] | ||||
|     patterns = [ | ||||
|         {"label": "GPE", "pattern": "Boston"}, | ||||
|         {"label": "GPE", "pattern": "Denver"}, | ||||
|     ] | ||||
|     ruler.add_patterns(patterns) | ||||
|     nlp.add_pipe(ruler) | ||||
| 
 | ||||
|     el_pipe = nlp.create_pipe(name='entity_linker', config={"context_width": 64}) | ||||
|     el_pipe = nlp.create_pipe(name="entity_linker", config={"context_width": 64}) | ||||
|     el_pipe.set_kb(mykb) | ||||
|     el_pipe.begin_training() | ||||
|     el_pipe.context_weight = 0 | ||||
|  |  | |||
							
								
								
									
										112
									
								
								spacy/tests/regression/test_issue3962.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										112
									
								
								spacy/tests/regression/test_issue3962.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,112 @@ | |||
| # coding: utf8 | ||||
| from __future__ import unicode_literals | ||||
| 
 | ||||
| import pytest | ||||
| 
 | ||||
| from ..util import get_doc | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def doc(en_tokenizer): | ||||
|     text = "He jests at scars, that never felt a wound." | ||||
|     heads = [1, 6, -1, -1, 3, 2, 1, 0, 1, -2, -3] | ||||
|     deps = [ | ||||
|         "nsubj", | ||||
|         "ccomp", | ||||
|         "prep", | ||||
|         "pobj", | ||||
|         "punct", | ||||
|         "nsubj", | ||||
|         "neg", | ||||
|         "ROOT", | ||||
|         "det", | ||||
|         "dobj", | ||||
|         "punct", | ||||
|     ] | ||||
|     tokens = en_tokenizer(text) | ||||
|     return get_doc(tokens.vocab, words=[t.text for t in tokens], heads=heads, deps=deps) | ||||
| 
 | ||||
| 
 | ||||
| def test_issue3962(doc): | ||||
|     """ Ensure that as_doc does not result in out-of-bound access of tokens. | ||||
|     This is achieved by setting the head to itself if it would lie out of the span otherwise.""" | ||||
|     span2 = doc[1:5]  # "jests at scars ," | ||||
|     doc2 = span2.as_doc() | ||||
|     doc2_json = doc2.to_json() | ||||
|     assert doc2_json | ||||
| 
 | ||||
|     assert doc2[0].head.text == "jests"  # head set to itself, being the new artificial root | ||||
|     assert doc2[0].dep_ == "dep" | ||||
|     assert doc2[1].head.text == "jests" | ||||
|     assert doc2[1].dep_ == "prep" | ||||
|     assert doc2[2].head.text == "at" | ||||
|     assert doc2[2].dep_ == "pobj" | ||||
|     assert doc2[3].head.text == "jests"  # head set to the new artificial root | ||||
|     assert doc2[3].dep_ == "dep" | ||||
| 
 | ||||
|     # We should still have 1 sentence | ||||
|     assert len(list(doc2.sents)) == 1 | ||||
| 
 | ||||
|     span3 = doc[6:9]  # "never felt a" | ||||
|     doc3 = span3.as_doc() | ||||
|     doc3_json = doc3.to_json() | ||||
|     assert doc3_json | ||||
| 
 | ||||
|     assert doc3[0].head.text == "felt" | ||||
|     assert doc3[0].dep_ == "neg" | ||||
|     assert doc3[1].head.text == "felt" | ||||
|     assert doc3[1].dep_ == "ROOT" | ||||
|     assert doc3[2].head.text == "felt"  # head set to ancestor | ||||
|     assert doc3[2].dep_ == "dep" | ||||
| 
 | ||||
|     # We should still have 1 sentence as "a" can be attached to "felt" instead of "wound" | ||||
|     assert len(list(doc3.sents)) == 1 | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def two_sent_doc(en_tokenizer): | ||||
|     text = "He jests at scars. They never felt a wound." | ||||
|     heads = [1, 0, -1, -1, -3, 2, 1, 0, 1, -2, -3] | ||||
|     deps = [ | ||||
|         "nsubj", | ||||
|         "ROOT", | ||||
|         "prep", | ||||
|         "pobj", | ||||
|         "punct", | ||||
|         "nsubj", | ||||
|         "neg", | ||||
|         "ROOT", | ||||
|         "det", | ||||
|         "dobj", | ||||
|         "punct", | ||||
|     ] | ||||
|     tokens = en_tokenizer(text) | ||||
|     return get_doc(tokens.vocab, words=[t.text for t in tokens], heads=heads, deps=deps) | ||||
| 
 | ||||
| 
 | ||||
| def test_issue3962_long(two_sent_doc): | ||||
|     """ Ensure that as_doc does not result in out-of-bound access of tokens. | ||||
|     This is achieved by setting the head to itself if it would lie out of the span otherwise.""" | ||||
|     span2 = two_sent_doc[1:7]  # "jests at scars. They never" | ||||
|     doc2 = span2.as_doc() | ||||
|     doc2_json = doc2.to_json() | ||||
|     assert doc2_json | ||||
| 
 | ||||
|     assert doc2[0].head.text == "jests"  # head set to itself, being the new artificial root (in sentence 1) | ||||
|     assert doc2[0].dep_ == "ROOT" | ||||
|     assert doc2[1].head.text == "jests" | ||||
|     assert doc2[1].dep_ == "prep" | ||||
|     assert doc2[2].head.text == "at" | ||||
|     assert doc2[2].dep_ == "pobj" | ||||
|     assert doc2[3].head.text == "jests" | ||||
|     assert doc2[3].dep_ == "punct" | ||||
|     assert doc2[4].head.text == "They"  # head set to itself, being the new artificial root (in sentence 2) | ||||
|     assert doc2[4].dep_ == "dep" | ||||
|     assert doc2[4].head.text == "They"  # head set to the new artificial head (in sentence 2) | ||||
|     assert doc2[4].dep_ == "dep" | ||||
| 
 | ||||
|     # We should still have 2 sentences | ||||
|     sents = list(doc2.sents) | ||||
|     assert len(sents) == 2 | ||||
|     assert sents[0].text == "jests at scars ." | ||||
|     assert sents[1].text == "They never" | ||||
							
								
								
									
										28
									
								
								spacy/tests/regression/test_issue4002.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								spacy/tests/regression/test_issue4002.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,28 @@ | |||
| # coding: utf8 | ||||
| from __future__ import unicode_literals | ||||
| 
 | ||||
| import pytest | ||||
| from spacy.matcher import PhraseMatcher | ||||
| from spacy.tokens import Doc | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.xfail | ||||
| def test_issue4002(en_vocab): | ||||
|     """Test that the PhraseMatcher can match on overwritten NORM attributes. | ||||
|     """ | ||||
|     matcher = PhraseMatcher(en_vocab, attr="NORM") | ||||
|     pattern1 = Doc(en_vocab, words=["c", "d"]) | ||||
|     assert [t.norm_ for t in pattern1] == ["c", "d"] | ||||
|     matcher.add("TEST", None, pattern1) | ||||
|     doc = Doc(en_vocab, words=["a", "b", "c", "d"]) | ||||
|     assert [t.norm_ for t in doc] == ["a", "b", "c", "d"] | ||||
|     matches = matcher(doc) | ||||
|     assert len(matches) == 1 | ||||
|     matcher = PhraseMatcher(en_vocab, attr="NORM") | ||||
|     pattern2 = Doc(en_vocab, words=["1", "2"]) | ||||
|     pattern2[0].norm_ = "c" | ||||
|     pattern2[1].norm_ = "d" | ||||
|     assert [t.norm_ for t in pattern2] == ["c", "d"] | ||||
|     matcher.add("TEST", None, pattern2) | ||||
|     matches = matcher(doc) | ||||
|     assert len(matches) == 1 | ||||
|  | @ -30,10 +30,10 @@ def test_serialize_kb_disk(en_vocab): | |||
| def _get_dummy_kb(vocab): | ||||
|     kb = KnowledgeBase(vocab=vocab, entity_vector_length=3) | ||||
| 
 | ||||
|     kb.add_entity(entity='Q53', prob=0.33, entity_vector=[0, 5, 3]) | ||||
|     kb.add_entity(entity='Q17', prob=0.2, entity_vector=[7, 1, 0]) | ||||
|     kb.add_entity(entity='Q007', prob=0.7, entity_vector=[0, 0, 7]) | ||||
|     kb.add_entity(entity='Q44', prob=0.4, entity_vector=[4, 4, 4]) | ||||
|     kb.add_entity(entity='Q53', freq=0.33, entity_vector=[0, 5, 3]) | ||||
|     kb.add_entity(entity='Q17', freq=0.2, entity_vector=[7, 1, 0]) | ||||
|     kb.add_entity(entity='Q007', freq=0.7, entity_vector=[0, 0, 7]) | ||||
|     kb.add_entity(entity='Q44', freq=0.4, entity_vector=[4, 4, 4]) | ||||
| 
 | ||||
|     kb.add_alias(alias='double07', entities=['Q17', 'Q007'], probabilities=[0.1, 0.9]) | ||||
|     kb.add_alias(alias='guy', entities=['Q53', 'Q007', 'Q17', 'Q44'], probabilities=[0.3, 0.3, 0.2, 0.1]) | ||||
|  |  | |||
|  | @ -348,7 +348,7 @@ cdef class Tokenizer: | |||
|         """Add a special-case tokenization rule. | ||||
| 
 | ||||
|         string (unicode): The string to specially tokenize. | ||||
|         token_attrs (iterable): A sequence of dicts, where each dict describes | ||||
|         substrings (iterable): A sequence of dicts, where each dict describes | ||||
|             a token and its attributes. The `ORTH` fields of the attributes | ||||
|             must exactly match the string when they are concatenated. | ||||
| 
 | ||||
|  |  | |||
|  | @ -794,7 +794,7 @@ cdef class Doc: | |||
|                 if array[i, col] != 0: | ||||
|                     self.vocab.morphology.assign_tag(&tokens[i], array[i, col]) | ||||
|         # Now load the data | ||||
|         for i in range(self.length): | ||||
|         for i in range(length): | ||||
|             token = &self.c[i] | ||||
|             for j in range(n_attrs): | ||||
|                 if attr_ids[j] != TAG: | ||||
|  | @ -804,7 +804,7 @@ cdef class Doc: | |||
|         self.is_tagged = bool(self.is_tagged or TAG in attrs or POS in attrs) | ||||
|         # If document is parsed, set children | ||||
|         if self.is_parsed: | ||||
|             set_children_from_heads(self.c, self.length) | ||||
|             set_children_from_heads(self.c, length) | ||||
|         return self | ||||
| 
 | ||||
|     def get_lca_matrix(self): | ||||
|  |  | |||
|  | @ -17,6 +17,7 @@ from ..attrs cimport attr_id_t | |||
| from ..parts_of_speech cimport univ_pos_t | ||||
| from ..attrs cimport * | ||||
| from ..lexeme cimport Lexeme | ||||
| from ..symbols cimport dep | ||||
| 
 | ||||
| from ..util import normalize_slice | ||||
| from ..compat import is_config, basestring_ | ||||
|  | @ -206,7 +207,6 @@ cdef class Span: | |||
| 
 | ||||
|         DOCS: https://spacy.io/api/span#as_doc | ||||
|         """ | ||||
|         # TODO: Fix! | ||||
|         words = [t.text for t in self] | ||||
|         spaces = [bool(t.whitespace_) for t in self] | ||||
|         cdef Doc doc = Doc(self.doc.vocab, words=words, spaces=spaces) | ||||
|  | @ -220,7 +220,9 @@ cdef class Span: | |||
|         else: | ||||
|             array_head.append(SENT_START) | ||||
|         array = self.doc.to_array(array_head) | ||||
|         doc.from_array(array_head, array[self.start : self.end]) | ||||
|         array = array[self.start : self.end] | ||||
|         self._fix_dep_copy(array_head, array) | ||||
|         doc.from_array(array_head, array) | ||||
|         doc.noun_chunks_iterator = self.doc.noun_chunks_iterator | ||||
|         doc.user_hooks = self.doc.user_hooks | ||||
|         doc.user_span_hooks = self.doc.user_span_hooks | ||||
|  | @ -235,6 +237,44 @@ cdef class Span: | |||
|                     doc.cats[cat_label] = value | ||||
|         return doc | ||||
| 
 | ||||
|     def _fix_dep_copy(self, attrs, array): | ||||
|         """ Rewire dependency links to make sure their heads fall into the span | ||||
|         while still keeping the correct number of sentences. """ | ||||
|         cdef int length = len(array) | ||||
|         cdef attr_t value | ||||
|         cdef int i, head_col, ancestor_i | ||||
|         old_to_new_root = dict() | ||||
|         if HEAD in attrs: | ||||
|             head_col = attrs.index(HEAD) | ||||
|             for i in range(length): | ||||
|                 # if the HEAD refers to a token outside this span, find a more appropriate ancestor | ||||
|                 token = self[i] | ||||
|                 ancestor_i = token.head.i - self.start   # span offset | ||||
|                 if ancestor_i not in range(length): | ||||
|                     if DEP in attrs: | ||||
|                         array[i, attrs.index(DEP)] = dep | ||||
| 
 | ||||
|                     # try finding an ancestor within this span | ||||
|                     ancestors = token.ancestors | ||||
|                     for ancestor in ancestors: | ||||
|                         ancestor_i = ancestor.i - self.start | ||||
|                         if ancestor_i in range(length): | ||||
|                             array[i, head_col] = ancestor_i - i | ||||
| 
 | ||||
|                 # if there is no appropriate ancestor, define a new artificial root | ||||
|                 value = array[i, head_col] | ||||
|                 if (i+value) not in range(length): | ||||
|                     new_root = old_to_new_root.get(ancestor_i, None) | ||||
|                     if new_root is not None: | ||||
|                         # take the same artificial root as a previous token from the same sentence | ||||
|                         array[i, head_col] = new_root - i | ||||
|                     else: | ||||
|                         # set this token as the new artificial root | ||||
|                         array[i, head_col] = 0 | ||||
|                         old_to_new_root[ancestor_i] = i | ||||
| 
 | ||||
|         return array | ||||
| 
 | ||||
|     def merge(self, *args, **attributes): | ||||
|         """Retokenize the document, such that the span is merged into a single | ||||
|         token. | ||||
|  | @ -500,7 +540,7 @@ cdef class Span: | |||
|         if "root" in self.doc.user_span_hooks: | ||||
|             return self.doc.user_span_hooks["root"](self) | ||||
|         # This should probably be called 'head', and the other one called | ||||
|         # 'gov'. But we went with 'head' elsehwhere, and now we're stuck =/ | ||||
|         # 'gov'. But we went with 'head' elsewhere, and now we're stuck =/ | ||||
|         cdef int i | ||||
|         # First, we scan through the Span, and check whether there's a word | ||||
|         # with head==0, i.e. a sentence root. If so, we can return it. The | ||||
|  |  | |||
|  | @ -45,10 +45,11 @@ Whether the provided syntactic annotations form a projective dependency tree. | |||
| 
 | ||||
| | Name                              | Type | Description                                                                                                                                              | | ||||
| | --------------------------------- | ---- | -------------------------------------------------------------------------------------------------------------------------------------------------------- | | ||||
| | `words`                           | list | The words.                                                                                                                                               | | ||||
| | `tags`                            | list | The part-of-speech tag annotations.                                                                                                                      | | ||||
| | `heads`                           | list | The syntactic head annotations.                                                                                                                          | | ||||
| | `labels`                          | list | The syntactic relation-type annotations.                                                                                                                 | | ||||
| | `ents`                            | list | The named entity annotations.                                                                                                                            | | ||||
| | `ner`                             | list | The named entity annotations as BILUO tags.                                                                                                              | | ||||
| | `cand_to_gold`                    | list | The alignment from candidate tokenization to gold tokenization.                                                                                          | | ||||
| | `gold_to_cand`                    | list | The alignment from gold tokenization to candidate tokenization.                                                                                          | | ||||
| | `cats` <Tag variant="new">2</Tag> | list | Entries in the list should be either a label, or a `(start, end, label)` triple. The tuple form is used for categories applied to spans of the document. | | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user