mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	sentence encoder only (removing article/mention encoder)
This commit is contained in:
		
							parent
							
								
									6332af40de
								
							
						
					
					
						commit
						ffae7d3555
					
				|  | @ -294,7 +294,6 @@ def read_training(nlp, training_dir, dev, limit): | |||
|     # we assume the data is written sequentially | ||||
|     current_article_id = None | ||||
|     current_doc = None | ||||
|     gold_entities = list() | ||||
|     ents_by_offset = dict() | ||||
|     skip_articles = set() | ||||
|     total_entities = 0 | ||||
|  | @ -302,8 +301,6 @@ def read_training(nlp, training_dir, dev, limit): | |||
|     with open(entityfile_loc, mode='r', encoding='utf8') as file: | ||||
|         for line in file: | ||||
|             if not limit or len(data) < limit: | ||||
|                 if len(data) > 0 and len(data) % 50 == 0: | ||||
|                     print("Read", total_entities, "entities in", len(data), "articles") | ||||
|                 fields = line.replace('\n', "").split(sep='|') | ||||
|                 article_id = fields[0] | ||||
|                 alias = fields[1] | ||||
|  | @ -313,34 +310,42 @@ def read_training(nlp, training_dir, dev, limit): | |||
| 
 | ||||
|                 if dev == is_dev(article_id) and article_id != "article_id" and article_id not in skip_articles: | ||||
|                     if not current_doc or (current_article_id != article_id): | ||||
|                         # store the data from the previous article | ||||
|                         if gold_entities and current_doc: | ||||
|                             gold = GoldParse(doc=current_doc, links=gold_entities) | ||||
|                             data.append((current_doc, gold)) | ||||
|                             total_entities += len(gold_entities) | ||||
| 
 | ||||
|                         # parse the new article text | ||||
|                         file_name = article_id + ".txt" | ||||
|                         try: | ||||
|                             with open(os.path.join(training_dir, file_name), mode="r", encoding='utf8') as f: | ||||
|                                 text = f.read() | ||||
|                                 current_doc = nlp(text) | ||||
|                                 for ent in current_doc.ents: | ||||
|                                     ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent.text | ||||
|                                 if len(text) < 30000:   # threshold for convenience / speed of processing | ||||
|                                     current_doc = nlp(text) | ||||
|                                     current_article_id = article_id | ||||
|                                     ents_by_offset = dict() | ||||
|                                     for ent in current_doc.ents: | ||||
|                                         ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent | ||||
|                                 else: | ||||
|                                     skip_articles.add(current_article_id) | ||||
|                                     current_doc = None | ||||
|                         except Exception as e: | ||||
|                             print("Problem parsing article", article_id, e) | ||||
| 
 | ||||
|                         current_article_id = article_id | ||||
|                         gold_entities = list() | ||||
| 
 | ||||
|                     # repeat checking this condition in case an exception was thrown | ||||
|                     if current_doc and (current_article_id == article_id): | ||||
|                         found_ent = ents_by_offset.get(start + "_" + end,  None) | ||||
|                         if found_ent: | ||||
|                             if found_ent != alias: | ||||
|                             if found_ent.text != alias: | ||||
|                                 skip_articles.add(current_article_id) | ||||
|                                 current_doc = None | ||||
|                             else: | ||||
|                                 gold_entities.append((int(start), int(end), wp_title)) | ||||
|                                 sent = found_ent.sent.as_doc() | ||||
|                                 # currently feeding the gold data one entity per sentence at a time | ||||
|                                 gold_start = int(start) - found_ent.sent.start_char | ||||
|                                 gold_end = int(end) - found_ent.sent.start_char | ||||
|                                 gold_entities = list() | ||||
|                                 gold_entities.append((gold_start, gold_end, wp_title)) | ||||
|                                 gold = GoldParse(doc=current_doc, links=gold_entities) | ||||
|                                 data.append((sent, gold)) | ||||
|                                 total_entities += 1 | ||||
|                                 if len(data) % 500 == 0: | ||||
|                                     print(" -read", total_entities, "entities") | ||||
| 
 | ||||
|     print("Read", total_entities, "entities in", len(data), "articles") | ||||
|     print(" -read", total_entities, "entities") | ||||
|     return data | ||||
|  |  | |||
|  | @ -9,7 +9,6 @@ from examples.pipeline.wiki_entity_linking import wikipedia_processor as wp, kb_ | |||
| from examples.pipeline.wiki_entity_linking.kb_creator import DESC_WIDTH | ||||
| 
 | ||||
| import spacy | ||||
| from spacy.vocab import Vocab | ||||
| from spacy.kb import KnowledgeBase | ||||
| import datetime | ||||
| 
 | ||||
|  | @ -64,8 +63,8 @@ def run_pipeline(): | |||
|     to_test_pipeline = True | ||||
| 
 | ||||
|     # write the NLP object, read back in and test again | ||||
|     to_write_nlp = True | ||||
|     to_read_nlp = True | ||||
|     to_write_nlp = False | ||||
|     to_read_nlp = False | ||||
| 
 | ||||
|     # STEP 1 : create prior probabilities from WP | ||||
|     # run only once ! | ||||
|  | @ -134,8 +133,8 @@ def run_pipeline(): | |||
| 
 | ||||
|     if train_pipe: | ||||
|         print("STEP 6: training Entity Linking pipe", datetime.datetime.now()) | ||||
|         train_limit = 5 | ||||
|         dev_limit = 2 | ||||
|         train_limit = 25000 | ||||
|         dev_limit = 1000 | ||||
| 
 | ||||
|         train_data = training_set_creator.read_training(nlp=nlp_2, | ||||
|                                                         training_dir=TRAINING_DIR, | ||||
|  | @ -345,7 +344,11 @@ def calculate_acc(correct_by_label, incorrect_by_label): | |||
|     acc_by_label = dict() | ||||
|     total_correct = 0 | ||||
|     total_incorrect = 0 | ||||
|     for label, correct in correct_by_label.items(): | ||||
|     all_keys = set() | ||||
|     all_keys.update(correct_by_label.keys()) | ||||
|     all_keys.update(incorrect_by_label.keys()) | ||||
|     for label in sorted(all_keys): | ||||
|         correct = correct_by_label.get(label, 0) | ||||
|         incorrect = incorrect_by_label.get(label, 0) | ||||
|         total_correct += correct | ||||
|         total_incorrect += incorrect | ||||
|  |  | |||
|  | @ -1079,36 +1079,39 @@ class EntityLinker(Pipe): | |||
| 
 | ||||
|         embed_width = cfg.get("embed_width", 300) | ||||
|         hidden_width = cfg.get("hidden_width", 32) | ||||
|         article_width = cfg.get("article_width", 128) | ||||
|         sent_width = cfg.get("sent_width", 64) | ||||
|         entity_width = cfg.get("entity_width")  # no default because this needs to correspond with the KB | ||||
|         sent_width = entity_width | ||||
| 
 | ||||
|         article_encoder = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=article_width, **cfg) | ||||
|         sent_encoder = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=sent_width, **cfg) | ||||
|         model = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=sent_width, **cfg) | ||||
| 
 | ||||
|         # dimension of the mention encoder needs to match the dimension of the entity encoder | ||||
|         mention_width = article_width + sent_width | ||||
|         mention_encoder = Affine(entity_width, mention_width, drop_factor=0.0) | ||||
|         # article_width = cfg.get("article_width", 128) | ||||
|         # sent_width = cfg.get("sent_width", 64) | ||||
|         # article_encoder = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=article_width, **cfg) | ||||
|         # mention_width = article_width + sent_width | ||||
|         # mention_encoder = Affine(entity_width, mention_width, drop_factor=0.0) | ||||
|         # return article_encoder, sent_encoder, mention_encoder | ||||
| 
 | ||||
|         return article_encoder, sent_encoder, mention_encoder | ||||
|         return model | ||||
| 
 | ||||
|     def __init__(self, **cfg): | ||||
|         self.article_encoder = True | ||||
|         self.sent_encoder = True | ||||
|         self.mention_encoder = True | ||||
|         # self.article_encoder = True | ||||
|         # self.sent_encoder = True | ||||
|         # self.mention_encoder = True | ||||
|         self.model = True | ||||
|         self.kb = None | ||||
|         self.cfg = dict(cfg) | ||||
|         self.doc_cutoff = self.cfg.get("doc_cutoff", 5) | ||||
|         self.sgd_article = None | ||||
|         self.sgd_sent = None | ||||
|         self.sgd_mention = None | ||||
|         # self.sgd_article = None | ||||
|         # self.sgd_sent = None | ||||
|         # self.sgd_mention = None | ||||
| 
 | ||||
|     def set_kb(self, kb): | ||||
|         self.kb = kb | ||||
| 
 | ||||
|     def require_model(self): | ||||
|         # Raise an error if the component's model is not initialized. | ||||
|         if getattr(self, "mention_encoder", None) in (None, True, False): | ||||
|         if getattr(self, "model", None) in (None, True, False): | ||||
|             raise ValueError(Errors.E109.format(name=self.name)) | ||||
| 
 | ||||
|     def require_kb(self): | ||||
|  | @ -1121,12 +1124,19 @@ class EntityLinker(Pipe): | |||
|         self.require_kb() | ||||
|         self.cfg["entity_width"] = self.kb.entity_vector_length | ||||
| 
 | ||||
|         if self.mention_encoder is True: | ||||
|             self.article_encoder, self.sent_encoder, self.mention_encoder = self.Model(**self.cfg) | ||||
|         self.sgd_article = create_default_optimizer(self.article_encoder.ops) | ||||
|         self.sgd_sent = create_default_optimizer(self.sent_encoder.ops) | ||||
|         self.sgd_mention = create_default_optimizer(self.mention_encoder.ops) | ||||
|         return self.sgd_article | ||||
|         if self.model is True: | ||||
|             self.model = self.Model(**self.cfg) | ||||
| 
 | ||||
|         if sgd is None: | ||||
|             sgd = self.create_optimizer() | ||||
|         return sgd | ||||
| 
 | ||||
|         # if self.mention_encoder is True: | ||||
|         #    self.article_encoder, self.sent_encoder, self.mention_encoder = self.Model(**self.cfg) | ||||
|         # self.sgd_article = create_default_optimizer(self.article_encoder.ops) | ||||
|         # self.sgd_sent = create_default_optimizer(self.sent_encoder.ops) | ||||
|         # self.sgd_mention = create_default_optimizer(self.mention_encoder.ops) | ||||
|         # return self.sgd_article | ||||
| 
 | ||||
|     def update(self, docs, golds, state=None, drop=0.0, sgd=None, losses=None): | ||||
|         self.require_model() | ||||
|  | @ -1146,7 +1156,7 @@ class EntityLinker(Pipe): | |||
|             docs = [docs] | ||||
|             golds = [golds] | ||||
| 
 | ||||
|         article_docs = list() | ||||
|         # article_docs = list() | ||||
|         sentence_docs = list() | ||||
|         entity_encodings = list() | ||||
| 
 | ||||
|  | @ -1173,34 +1183,32 @@ class EntityLinker(Pipe): | |||
|                     if kb_id == gold_kb: | ||||
|                         prior_prob = c.prior_prob | ||||
|                         entity_encoding = c.entity_vector | ||||
| 
 | ||||
|                         entity_encodings.append(entity_encoding) | ||||
|                         article_docs.append(first_par) | ||||
|                         # article_docs.append(first_par) | ||||
|                         sentence_docs.append(sentence) | ||||
| 
 | ||||
|         if len(entity_encodings) > 0: | ||||
|             doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop) | ||||
|             sent_encodings, bp_sent = self.sent_encoder.begin_update(sentence_docs, drop=drop) | ||||
|             # doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop) | ||||
|             # sent_encodings, bp_sent = self.sent_encoder.begin_update(sentence_docs, drop=drop) | ||||
| 
 | ||||
|             concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in | ||||
|                                 range(len(article_docs))] | ||||
|             mention_encodings, bp_mention = self.mention_encoder.begin_update(np.asarray(concat_encodings), drop=drop) | ||||
|             # concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in range(len(article_docs))] | ||||
|             # mention_encodings, bp_mention = self.mention_encoder.begin_update(np.asarray(concat_encodings), drop=drop) | ||||
| 
 | ||||
|             sent_encodings, bp_sent = self.model.begin_update(sentence_docs, drop=drop) | ||||
|             entity_encodings = np.asarray(entity_encodings, dtype=np.float32) | ||||
| 
 | ||||
|             loss, d_scores = self.get_loss(scores=mention_encodings, golds=entity_encodings, docs=None) | ||||
|             mention_gradient = bp_mention(d_scores, sgd=self.sgd_mention) | ||||
|             loss, d_scores = self.get_loss(scores=sent_encodings, golds=entity_encodings, docs=None) | ||||
|             bp_sent(d_scores, sgd=sgd) | ||||
| 
 | ||||
|             # gradient : concat (doc+sent) vs. desc | ||||
|             sent_start = self.article_encoder.nO | ||||
|             sent_gradients = list() | ||||
|             doc_gradients = list() | ||||
|             for x in mention_gradient: | ||||
|                 doc_gradients.append(list(x[0:sent_start])) | ||||
|                 sent_gradients.append(list(x[sent_start:])) | ||||
| 
 | ||||
|             bp_doc(doc_gradients, sgd=self.sgd_article) | ||||
|             bp_sent(sent_gradients, sgd=self.sgd_sent) | ||||
|             # sent_start = self.article_encoder.nO | ||||
|             # sent_gradients = list() | ||||
|             # doc_gradients = list() | ||||
|             # for x in mention_gradient: | ||||
|                 # doc_gradients.append(list(x[0:sent_start])) | ||||
|                 # sent_gradients.append(list(x[sent_start:])) | ||||
|             # bp_doc(doc_gradients, sgd=self.sgd_article) | ||||
|             # bp_sent(sent_gradients, sgd=self.sgd_sent) | ||||
| 
 | ||||
|             if losses is not None: | ||||
|                 losses[self.name] += loss | ||||
|  | @ -1262,14 +1270,17 @@ class EntityLinker(Pipe): | |||
|                         first_par_end = sent.end | ||||
|                 first_par = doc[0:first_par_end].as_doc() | ||||
| 
 | ||||
|                 doc_encoding = self.article_encoder([first_par]) | ||||
|                 # doc_encoding = self.article_encoder([first_par]) | ||||
|                 for ent in doc.ents: | ||||
|                     sent_doc = ent.sent.as_doc() | ||||
|                     if len(sent_doc) > 0: | ||||
|                         sent_encoding = self.sent_encoder([sent_doc]) | ||||
|                         concat_encoding = [list(doc_encoding[0]) + list(sent_encoding[0])] | ||||
|                         mention_encoding = self.mention_encoder(np.asarray([concat_encoding[0]])) | ||||
|                         mention_enc_t = np.transpose(mention_encoding) | ||||
|                         # sent_encoding = self.sent_encoder([sent_doc]) | ||||
|                         # concat_encoding = [list(doc_encoding[0]) + list(sent_encoding[0])] | ||||
|                         # mention_encoding = self.mention_encoder(np.asarray([concat_encoding[0]])) | ||||
|                         # mention_enc_t = np.transpose(mention_encoding) | ||||
| 
 | ||||
|                         sent_encoding = self.model([sent_doc]) | ||||
|                         sent_enc_t = np.transpose(sent_encoding) | ||||
| 
 | ||||
|                         candidates = self.kb.get_candidates(ent.text) | ||||
|                         if candidates: | ||||
|  | @ -1278,7 +1289,7 @@ class EntityLinker(Pipe): | |||
|                                 prior_prob = c.prior_prob * self.prior_weight | ||||
|                                 kb_id = c.entity_ | ||||
|                                 entity_encoding = c.entity_vector | ||||
|                                 sim = float(cosine(np.asarray([entity_encoding]), mention_enc_t)) * self.context_weight | ||||
|                                 sim = float(cosine(np.asarray([entity_encoding]), sent_enc_t)) * self.context_weight | ||||
|                                 score = prior_prob + sim - (prior_prob*sim)  # put weights on the different factors ? | ||||
|                                 scores.append(score) | ||||
| 
 | ||||
|  | @ -1299,34 +1310,20 @@ class EntityLinker(Pipe): | |||
|         serialize = OrderedDict() | ||||
|         serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg) | ||||
|         serialize["kb"] = lambda p: self.kb.dump(p) | ||||
|         if self.mention_encoder not in (None, True, False): | ||||
|             serialize["article_encoder"] = lambda p: p.open("wb").write(self.article_encoder.to_bytes()) | ||||
|             serialize["sent_encoder"] = lambda p: p.open("wb").write(self.sent_encoder.to_bytes()) | ||||
|             serialize["mention_encoder"] = lambda p: p.open("wb").write(self.mention_encoder.to_bytes()) | ||||
|         if self.model not in (None, True, False): | ||||
|             serialize["model"] = lambda p: p.open("wb").write(self.model.to_bytes()) | ||||
|         exclude = util.get_serialization_exclude(serialize, exclude, kwargs) | ||||
|         util.to_disk(path, serialize, exclude) | ||||
| 
 | ||||
|     def from_disk(self, path, exclude=tuple(), **kwargs): | ||||
|         def load_article_encoder(p): | ||||
|             if self.article_encoder is True: | ||||
|                 self.article_encoder, _, _ = self.Model(**self.cfg) | ||||
|             self.article_encoder.from_bytes(p.open("rb").read()) | ||||
| 
 | ||||
|         def load_sent_encoder(p): | ||||
|             if self.sent_encoder is True: | ||||
|                 _, self.sent_encoder, _ = self.Model(**self.cfg) | ||||
|             self.sent_encoder.from_bytes(p.open("rb").read()) | ||||
| 
 | ||||
|         def load_mention_encoder(p): | ||||
|              if self.mention_encoder is True: | ||||
|                 _, _, self.mention_encoder = self.Model(**self.cfg) | ||||
|              self.mention_encoder.from_bytes(p.open("rb").read()) | ||||
|         def load_model(p): | ||||
|              if self.model is True: | ||||
|                 self.model = self.Model(**self.cfg) | ||||
|              self.model.from_bytes(p.open("rb").read()) | ||||
| 
 | ||||
|         deserialize = OrderedDict() | ||||
|         deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p)) | ||||
|         deserialize["article_encoder"] = load_article_encoder | ||||
|         deserialize["sent_encoder"] = load_sent_encoder | ||||
|         deserialize["mention_encoder"] = load_mention_encoder | ||||
|         deserialize["model"] = load_model | ||||
|         exclude = util.get_serialization_exclude(deserialize, exclude, kwargs) | ||||
|         util.from_disk(path, deserialize, exclude) | ||||
|         return self | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user