mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-27 10:26:35 +03:00
write entity linking pipe to file and keep vocab consistent between kb and nlp
This commit is contained in:
parent
b12001f368
commit
78dd3e11da
|
@ -40,8 +40,8 @@ def create_kb(nlp, max_entities_per_alias, min_occ,
|
||||||
title_list = list(title_to_id.keys())
|
title_list = list(title_to_id.keys())
|
||||||
|
|
||||||
# TODO: remove this filter (just for quicker testing of code)
|
# TODO: remove this filter (just for quicker testing of code)
|
||||||
# title_list = title_list[0:34200]
|
title_list = title_list[0:342]
|
||||||
# title_to_id = {t: title_to_id[t] for t in title_list}
|
title_to_id = {t: title_to_id[t] for t in title_list}
|
||||||
|
|
||||||
entity_list = [title_to_id[x] for x in title_list]
|
entity_list = [title_to_id[x] for x in title_list]
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ import random
|
||||||
from spacy.util import minibatch, compounding
|
from spacy.util import minibatch, compounding
|
||||||
|
|
||||||
from examples.pipeline.wiki_entity_linking import wikipedia_processor as wp, kb_creator, training_set_creator, run_el
|
from examples.pipeline.wiki_entity_linking import wikipedia_processor as wp, kb_creator, training_set_creator, run_el
|
||||||
|
from examples.pipeline.wiki_entity_linking.kb_creator import DESC_WIDTH
|
||||||
|
|
||||||
import spacy
|
import spacy
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
|
@ -22,41 +23,48 @@ ENTITY_DEFS = 'C:/Users/Sofie/Documents/data/wikipedia/entity_defs.csv'
|
||||||
ENTITY_DESCR = 'C:/Users/Sofie/Documents/data/wikipedia/entity_descriptions.csv'
|
ENTITY_DESCR = 'C:/Users/Sofie/Documents/data/wikipedia/entity_descriptions.csv'
|
||||||
|
|
||||||
KB_FILE = 'C:/Users/Sofie/Documents/data/wikipedia/kb'
|
KB_FILE = 'C:/Users/Sofie/Documents/data/wikipedia/kb'
|
||||||
VOCAB_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/vocab'
|
NLP_1_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/nlp_1'
|
||||||
|
NLP_2_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/nlp_2'
|
||||||
|
|
||||||
TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
|
TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
|
||||||
|
|
||||||
MAX_CANDIDATES = 10
|
MAX_CANDIDATES = 10
|
||||||
MIN_PAIR_OCC = 5
|
MIN_PAIR_OCC = 5
|
||||||
DOC_CHAR_CUTOFF = 300
|
DOC_CHAR_CUTOFF = 300
|
||||||
EPOCHS = 10
|
EPOCHS = 2
|
||||||
DROPOUT = 0.1
|
DROPOUT = 0.1
|
||||||
|
|
||||||
|
|
||||||
def run_pipeline():
|
def run_pipeline():
|
||||||
print("START", datetime.datetime.now())
|
print("START", datetime.datetime.now())
|
||||||
print()
|
print()
|
||||||
nlp = spacy.load('en_core_web_lg')
|
nlp_1 = spacy.load('en_core_web_lg')
|
||||||
my_kb = None
|
nlp_2 = None
|
||||||
|
kb_1 = None
|
||||||
|
kb_2 = None
|
||||||
|
|
||||||
# one-time methods to create KB and write to file
|
# one-time methods to create KB and write to file
|
||||||
to_create_prior_probs = False
|
to_create_prior_probs = False
|
||||||
to_create_entity_counts = False
|
to_create_entity_counts = False
|
||||||
to_create_kb = False
|
to_create_kb = True
|
||||||
|
|
||||||
# read KB back in from file
|
# read KB back in from file
|
||||||
to_read_kb = True
|
to_read_kb = True
|
||||||
to_test_kb = False
|
to_test_kb = True
|
||||||
|
|
||||||
# create training dataset
|
# create training dataset
|
||||||
create_wp_training = False
|
create_wp_training = False
|
||||||
|
|
||||||
# train the EL pipe
|
# train the EL pipe
|
||||||
train_pipe = True
|
train_pipe = True
|
||||||
|
measure_performance = False
|
||||||
|
|
||||||
# test the EL pipe on a simple example
|
# test the EL pipe on a simple example
|
||||||
to_test_pipeline = True
|
to_test_pipeline = True
|
||||||
|
|
||||||
|
# write the NLP object, read back in and test again
|
||||||
|
test_nlp_io = True
|
||||||
|
|
||||||
# STEP 1 : create prior probabilities from WP
|
# STEP 1 : create prior probabilities from WP
|
||||||
# run only once !
|
# run only once !
|
||||||
if to_create_prior_probs:
|
if to_create_prior_probs:
|
||||||
|
@ -75,7 +83,7 @@ def run_pipeline():
|
||||||
# run only once !
|
# run only once !
|
||||||
if to_create_kb:
|
if to_create_kb:
|
||||||
print("STEP 3a: to_create_kb", datetime.datetime.now())
|
print("STEP 3a: to_create_kb", datetime.datetime.now())
|
||||||
my_kb = kb_creator.create_kb(nlp,
|
kb_1 = kb_creator.create_kb(nlp_1,
|
||||||
max_entities_per_alias=MAX_CANDIDATES,
|
max_entities_per_alias=MAX_CANDIDATES,
|
||||||
min_occ=MIN_PAIR_OCC,
|
min_occ=MIN_PAIR_OCC,
|
||||||
entity_def_output=ENTITY_DEFS,
|
entity_def_output=ENTITY_DEFS,
|
||||||
|
@ -83,63 +91,66 @@ def run_pipeline():
|
||||||
count_input=ENTITY_COUNTS,
|
count_input=ENTITY_COUNTS,
|
||||||
prior_prob_input=PRIOR_PROB,
|
prior_prob_input=PRIOR_PROB,
|
||||||
to_print=False)
|
to_print=False)
|
||||||
print("kb entities:", my_kb.get_size_entities())
|
print("kb entities:", kb_1.get_size_entities())
|
||||||
print("kb aliases:", my_kb.get_size_aliases())
|
print("kb aliases:", kb_1.get_size_aliases())
|
||||||
print()
|
print()
|
||||||
|
|
||||||
print("STEP 3b: write KB", datetime.datetime.now())
|
print("STEP 3b: write KB and NLP", datetime.datetime.now())
|
||||||
my_kb.dump(KB_FILE)
|
kb_1.dump(KB_FILE)
|
||||||
nlp.vocab.to_disk(VOCAB_DIR)
|
nlp_1.to_disk(NLP_1_DIR)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# STEP 4 : read KB back in from file
|
# STEP 4 : read KB back in from file
|
||||||
if to_read_kb:
|
if to_read_kb:
|
||||||
print("STEP 4: to_read_kb", datetime.datetime.now())
|
print("STEP 4: to_read_kb", datetime.datetime.now())
|
||||||
my_vocab = Vocab()
|
# my_vocab = Vocab()
|
||||||
my_vocab.from_disk(VOCAB_DIR)
|
# my_vocab.from_disk(VOCAB_DIR)
|
||||||
my_kb = KnowledgeBase(vocab=my_vocab, entity_vector_length=64) # TODO entity vectors
|
# my_kb = KnowledgeBase(vocab=my_vocab, entity_vector_length=64)
|
||||||
my_kb.load_bulk(KB_FILE)
|
nlp_2 = spacy.load(NLP_1_DIR)
|
||||||
print("kb entities:", my_kb.get_size_entities())
|
kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH)
|
||||||
print("kb aliases:", my_kb.get_size_aliases())
|
kb_2.load_bulk(KB_FILE)
|
||||||
|
print("kb entities:", kb_2.get_size_entities())
|
||||||
|
print("kb aliases:", kb_2.get_size_aliases())
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# test KB
|
# test KB
|
||||||
if to_test_kb:
|
if to_test_kb:
|
||||||
run_el.run_kb_toy_example(kb=my_kb)
|
run_el.run_kb_toy_example(kb=kb_2)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# STEP 5: create a training dataset from WP
|
# STEP 5: create a training dataset from WP
|
||||||
if create_wp_training:
|
if create_wp_training:
|
||||||
print("STEP 5: create training dataset", datetime.datetime.now())
|
print("STEP 5: create training dataset", datetime.datetime.now())
|
||||||
training_set_creator.create_training(kb=my_kb, entity_def_input=ENTITY_DEFS, training_output=TRAINING_DIR)
|
training_set_creator.create_training(kb=kb_2, entity_def_input=ENTITY_DEFS, training_output=TRAINING_DIR)
|
||||||
|
|
||||||
# STEP 6: create the entity linking pipe
|
# STEP 6: create the entity linking pipe
|
||||||
if train_pipe:
|
if train_pipe:
|
||||||
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
|
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
|
||||||
train_limit = 5000
|
train_limit = 10
|
||||||
dev_limit = 1000
|
dev_limit = 5
|
||||||
print("Training on", train_limit, "articles")
|
print("Training on", train_limit, "articles")
|
||||||
print("Dev testing on", dev_limit, "articles")
|
print("Dev testing on", dev_limit, "articles")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
train_data = training_set_creator.read_training(nlp=nlp,
|
train_data = training_set_creator.read_training(nlp=nlp_2,
|
||||||
training_dir=TRAINING_DIR,
|
training_dir=TRAINING_DIR,
|
||||||
dev=False,
|
dev=False,
|
||||||
limit=train_limit,
|
limit=train_limit,
|
||||||
to_print=False)
|
to_print=False)
|
||||||
|
|
||||||
dev_data = training_set_creator.read_training(nlp=nlp,
|
dev_data = training_set_creator.read_training(nlp=nlp_2,
|
||||||
training_dir=TRAINING_DIR,
|
training_dir=TRAINING_DIR,
|
||||||
dev=True,
|
dev=True,
|
||||||
limit=dev_limit,
|
limit=dev_limit,
|
||||||
to_print=False)
|
to_print=False)
|
||||||
|
|
||||||
el_pipe = nlp.create_pipe(name='entity_linker', config={"kb": my_kb, "doc_cutoff": DOC_CHAR_CUTOFF})
|
el_pipe = nlp_2.create_pipe(name='entity_linker', config={"doc_cutoff": DOC_CHAR_CUTOFF})
|
||||||
nlp.add_pipe(el_pipe, last=True)
|
el_pipe.set_kb(kb_2)
|
||||||
|
nlp_2.add_pipe(el_pipe, last=True)
|
||||||
|
|
||||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "entity_linker"]
|
other_pipes = [pipe for pipe in nlp_2.pipe_names if pipe != "entity_linker"]
|
||||||
with nlp.disable_pipes(*other_pipes): # only train Entity Linking
|
with nlp_2.disable_pipes(*other_pipes): # only train Entity Linking
|
||||||
nlp.begin_training()
|
nlp_2.begin_training()
|
||||||
|
|
||||||
for itn in range(EPOCHS):
|
for itn in range(EPOCHS):
|
||||||
random.shuffle(train_data)
|
random.shuffle(train_data)
|
||||||
|
@ -147,11 +158,11 @@ def run_pipeline():
|
||||||
batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001))
|
batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001))
|
||||||
batchnr = 0
|
batchnr = 0
|
||||||
|
|
||||||
with nlp.disable_pipes(*other_pipes):
|
with nlp_2.disable_pipes(*other_pipes):
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
try:
|
try:
|
||||||
docs, golds = zip(*batch)
|
docs, golds = zip(*batch)
|
||||||
nlp.update(
|
nlp_2.update(
|
||||||
docs,
|
docs,
|
||||||
golds,
|
golds,
|
||||||
drop=DROPOUT,
|
drop=DROPOUT,
|
||||||
|
@ -164,6 +175,7 @@ def run_pipeline():
|
||||||
losses['entity_linker'] = losses['entity_linker'] / batchnr
|
losses['entity_linker'] = losses['entity_linker'] / batchnr
|
||||||
print("Epoch, train loss", itn, round(losses['entity_linker'], 2))
|
print("Epoch, train loss", itn, round(losses['entity_linker'], 2))
|
||||||
|
|
||||||
|
if measure_performance:
|
||||||
print()
|
print()
|
||||||
print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now())
|
print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now())
|
||||||
print()
|
print()
|
||||||
|
@ -195,9 +207,30 @@ def run_pipeline():
|
||||||
print()
|
print()
|
||||||
print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now())
|
print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now())
|
||||||
print()
|
print()
|
||||||
run_el_toy_example(kb=my_kb, nlp=nlp)
|
run_el_toy_example(nlp=nlp_2)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
if test_nlp_io:
|
||||||
|
print()
|
||||||
|
print("STEP 9: testing NLP IO", datetime.datetime.now())
|
||||||
|
print()
|
||||||
|
print("writing to", NLP_2_DIR)
|
||||||
|
print(" vocab len nlp_2", len(nlp_2.vocab))
|
||||||
|
print(" vocab len kb_2", len(kb_2.vocab))
|
||||||
|
nlp_2.to_disk(NLP_2_DIR)
|
||||||
|
print()
|
||||||
|
print("reading from", NLP_2_DIR)
|
||||||
|
nlp_3 = spacy.load(NLP_2_DIR)
|
||||||
|
print(" vocab len nlp_3", len(nlp_3.vocab))
|
||||||
|
|
||||||
|
for pipe_name, pipe in nlp_3.pipeline:
|
||||||
|
if pipe_name == "entity_linker":
|
||||||
|
print(" vocab len kb_3", len(pipe.kb.vocab))
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("running toy example with NLP 2")
|
||||||
|
run_el_toy_example(nlp=nlp_3)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print("STOP", datetime.datetime.now())
|
print("STOP", datetime.datetime.now())
|
||||||
|
|
||||||
|
@ -239,7 +272,7 @@ def _measure_accuracy(data, el_pipe):
|
||||||
return acc
|
return acc
|
||||||
|
|
||||||
|
|
||||||
def run_el_toy_example(nlp, kb):
|
def run_el_toy_example(nlp):
|
||||||
text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \
|
text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \
|
||||||
"Douglas reminds us to always bring our towel. " \
|
"Douglas reminds us to always bring our towel. " \
|
||||||
"The main character in Doug's novel is the man Arthur Dent, " \
|
"The main character in Doug's novel is the man Arthur Dent, " \
|
||||||
|
|
|
@ -2,6 +2,8 @@
|
||||||
# cython: profile=True
|
# cython: profile=True
|
||||||
# coding: utf8
|
# coding: utf8
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from pathlib import Path, WindowsPath
|
||||||
|
|
||||||
from cpython.exc cimport PyErr_CheckSignals
|
from cpython.exc cimport PyErr_CheckSignals
|
||||||
|
|
||||||
from spacy import util
|
from spacy import util
|
||||||
|
@ -389,6 +391,8 @@ cdef class Writer:
|
||||||
def __init__(self, object loc):
|
def __init__(self, object loc):
|
||||||
if path.exists(loc):
|
if path.exists(loc):
|
||||||
assert not path.isdir(loc), "%s is directory." % loc
|
assert not path.isdir(loc), "%s is directory." % loc
|
||||||
|
if isinstance(loc, Path):
|
||||||
|
loc = bytes(loc)
|
||||||
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
|
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
|
||||||
self._fp = fopen(<char*>bytes_loc, 'wb')
|
self._fp = fopen(<char*>bytes_loc, 'wb')
|
||||||
assert self._fp != NULL
|
assert self._fp != NULL
|
||||||
|
@ -431,6 +435,8 @@ cdef class Reader:
|
||||||
def __init__(self, object loc):
|
def __init__(self, object loc):
|
||||||
assert path.exists(loc)
|
assert path.exists(loc)
|
||||||
assert not path.isdir(loc)
|
assert not path.isdir(loc)
|
||||||
|
if isinstance(loc, Path):
|
||||||
|
loc = bytes(loc)
|
||||||
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
|
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
|
||||||
self._fp = fopen(<char*>bytes_loc, 'rb')
|
self._fp = fopen(<char*>bytes_loc, 'rb')
|
||||||
if not self._fp:
|
if not self._fp:
|
||||||
|
|
|
@ -11,6 +11,7 @@ from copy import copy, deepcopy
|
||||||
from thinc.neural import Model
|
from thinc.neural import Model
|
||||||
import srsly
|
import srsly
|
||||||
|
|
||||||
|
from spacy.kb import KnowledgeBase
|
||||||
from .tokenizer import Tokenizer
|
from .tokenizer import Tokenizer
|
||||||
from .vocab import Vocab
|
from .vocab import Vocab
|
||||||
from .lemmatizer import Lemmatizer
|
from .lemmatizer import Lemmatizer
|
||||||
|
@ -809,6 +810,14 @@ class Language(object):
|
||||||
# Convert to list here in case exclude is (default) tuple
|
# Convert to list here in case exclude is (default) tuple
|
||||||
exclude = list(exclude) + ["vocab"]
|
exclude = list(exclude) + ["vocab"]
|
||||||
util.from_disk(path, deserializers, exclude)
|
util.from_disk(path, deserializers, exclude)
|
||||||
|
|
||||||
|
# download the KB for the entity linking component - requires the vocab
|
||||||
|
for pipe_name, pipe in self.pipeline:
|
||||||
|
if pipe_name == "entity_linker":
|
||||||
|
kb = KnowledgeBase(vocab=self.vocab, entity_vector_length=pipe.cfg["entity_width"])
|
||||||
|
kb.load_bulk(path / pipe_name / "kb")
|
||||||
|
pipe.set_kb(kb)
|
||||||
|
|
||||||
self._path = path
|
self._path = path
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@ from thinc.misc import LayerNorm
|
||||||
from thinc.neural.util import to_categorical
|
from thinc.neural.util import to_categorical
|
||||||
from thinc.neural.util import get_array_module
|
from thinc.neural.util import get_array_module
|
||||||
|
|
||||||
|
from spacy.kb import KnowledgeBase
|
||||||
from ..tokens.doc cimport Doc
|
from ..tokens.doc cimport Doc
|
||||||
from ..syntax.nn_parser cimport Parser
|
from ..syntax.nn_parser cimport Parser
|
||||||
from ..syntax.ner cimport BiluoPushDown
|
from ..syntax.ner cimport BiluoPushDown
|
||||||
|
@ -1077,7 +1078,7 @@ class EntityLinker(Pipe):
|
||||||
hidden_width = cfg.get("hidden_width", 32)
|
hidden_width = cfg.get("hidden_width", 32)
|
||||||
article_width = cfg.get("article_width", 128)
|
article_width = cfg.get("article_width", 128)
|
||||||
sent_width = cfg.get("sent_width", 64)
|
sent_width = cfg.get("sent_width", 64)
|
||||||
entity_width = cfg["kb"].entity_vector_length
|
entity_width = cfg.get("entity_width") # no default because this needs to correspond with the KB
|
||||||
|
|
||||||
article_encoder = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=article_width, **cfg)
|
article_encoder = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=article_width, **cfg)
|
||||||
sent_encoder = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=sent_width, **cfg)
|
sent_encoder = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=sent_width, **cfg)
|
||||||
|
@ -1089,25 +1090,31 @@ class EntityLinker(Pipe):
|
||||||
return article_encoder, sent_encoder, mention_encoder
|
return article_encoder, sent_encoder, mention_encoder
|
||||||
|
|
||||||
def __init__(self, **cfg):
|
def __init__(self, **cfg):
|
||||||
|
self.article_encoder = True
|
||||||
|
self.sent_encoder = True
|
||||||
self.mention_encoder = True
|
self.mention_encoder = True
|
||||||
|
self.kb = None
|
||||||
self.cfg = dict(cfg)
|
self.cfg = dict(cfg)
|
||||||
self.kb = self.cfg["kb"]
|
self.doc_cutoff = self.cfg.get("doc_cutoff", 150)
|
||||||
self.doc_cutoff = self.cfg["doc_cutoff"]
|
|
||||||
|
|
||||||
def use_avg_params(self):
|
|
||||||
# Modify the pipe's encoders/models, to use their average parameter values.
|
|
||||||
# TODO: this doesn't work yet because there's no exit method
|
|
||||||
self.article_encoder.use_params(self.sgd_article.averages)
|
|
||||||
self.sent_encoder.use_params(self.sgd_sent.averages)
|
|
||||||
self.mention_encoder.use_params(self.sgd_mention.averages)
|
|
||||||
|
|
||||||
|
def set_kb(self, kb):
|
||||||
|
self.kb = kb
|
||||||
|
|
||||||
def require_model(self):
|
def require_model(self):
|
||||||
# Raise an error if the component's model is not initialized.
|
# Raise an error if the component's model is not initialized.
|
||||||
if getattr(self, "mention_encoder", None) in (None, True, False):
|
if getattr(self, "mention_encoder", None) in (None, True, False):
|
||||||
raise ValueError(Errors.E109.format(name=self.name))
|
raise ValueError(Errors.E109.format(name=self.name))
|
||||||
|
|
||||||
|
def require_kb(self):
|
||||||
|
# Raise an error if the knowledge base is not initialized.
|
||||||
|
if getattr(self, "kb", None) in (None, True, False):
|
||||||
|
# TODO: custom error
|
||||||
|
raise ValueError(Errors.E109.format(name=self.name))
|
||||||
|
|
||||||
def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None, **kwargs):
|
def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None, **kwargs):
|
||||||
|
self.require_kb()
|
||||||
|
self.cfg["entity_width"] = self.kb.entity_vector_length
|
||||||
|
|
||||||
if self.mention_encoder is True:
|
if self.mention_encoder is True:
|
||||||
self.article_encoder, self.sent_encoder, self.mention_encoder = self.Model(**self.cfg)
|
self.article_encoder, self.sent_encoder, self.mention_encoder = self.Model(**self.cfg)
|
||||||
self.sgd_article = create_default_optimizer(self.article_encoder.ops)
|
self.sgd_article = create_default_optimizer(self.article_encoder.ops)
|
||||||
|
@ -1117,6 +1124,7 @@ class EntityLinker(Pipe):
|
||||||
|
|
||||||
def update(self, docs, golds, state=None, drop=0.0, sgd=None, losses=None):
|
def update(self, docs, golds, state=None, drop=0.0, sgd=None, losses=None):
|
||||||
self.require_model()
|
self.require_model()
|
||||||
|
self.require_kb()
|
||||||
|
|
||||||
if len(docs) != len(golds):
|
if len(docs) != len(golds):
|
||||||
raise ValueError(Errors.E077.format(value="EL training", n_docs=len(docs),
|
raise ValueError(Errors.E077.format(value="EL training", n_docs=len(docs),
|
||||||
|
@ -1220,6 +1228,7 @@ class EntityLinker(Pipe):
|
||||||
|
|
||||||
def predict(self, docs):
|
def predict(self, docs):
|
||||||
self.require_model()
|
self.require_model()
|
||||||
|
self.require_kb()
|
||||||
|
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
|
@ -1228,9 +1237,11 @@ class EntityLinker(Pipe):
|
||||||
final_kb_ids = list()
|
final_kb_ids = list()
|
||||||
|
|
||||||
for i, article_doc in enumerate(docs):
|
for i, article_doc in enumerate(docs):
|
||||||
|
if len(article_doc) > 0:
|
||||||
doc_encoding = self.article_encoder([article_doc])
|
doc_encoding = self.article_encoder([article_doc])
|
||||||
for ent in article_doc.ents:
|
for ent in article_doc.ents:
|
||||||
sent_doc = ent.sent.as_doc()
|
sent_doc = ent.sent.as_doc()
|
||||||
|
if len(sent_doc) > 0:
|
||||||
sent_encoding = self.sent_encoder([sent_doc])
|
sent_encoding = self.sent_encoder([sent_doc])
|
||||||
concat_encoding = [list(doc_encoding[0]) + list(sent_encoding[0])]
|
concat_encoding = [list(doc_encoding[0]) + list(sent_encoding[0])]
|
||||||
mention_encoding = self.mention_encoder(np.asarray([concat_encoding[0]]))
|
mention_encoding = self.mention_encoder(np.asarray([concat_encoding[0]]))
|
||||||
|
@ -1260,6 +1271,80 @@ class EntityLinker(Pipe):
|
||||||
for token in entity:
|
for token in entity:
|
||||||
token.ent_kb_id_ = kb_id
|
token.ent_kb_id_ = kb_id
|
||||||
|
|
||||||
|
def to_bytes(self, exclude=tuple(), **kwargs):
|
||||||
|
"""Serialize the pipe to a bytestring.
|
||||||
|
|
||||||
|
exclude (list): String names of serialization fields to exclude.
|
||||||
|
RETURNS (bytes): The serialized object.
|
||||||
|
"""
|
||||||
|
serialize = OrderedDict()
|
||||||
|
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
|
||||||
|
serialize["kb"] = self.kb.to_bytes # TODO
|
||||||
|
if self.mention_encoder not in (True, False, None):
|
||||||
|
serialize["article_encoder"] = self.article_encoder.to_bytes
|
||||||
|
serialize["sent_encoder"] = self.sent_encoder.to_bytes
|
||||||
|
serialize["mention_encoder"] = self.mention_encoder.to_bytes
|
||||||
|
exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
|
||||||
|
return util.to_bytes(serialize, exclude)
|
||||||
|
|
||||||
|
def from_bytes(self, bytes_data, exclude=tuple(), **kwargs):
|
||||||
|
"""Load the pipe from a bytestring."""
|
||||||
|
deserialize = OrderedDict()
|
||||||
|
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
|
||||||
|
deserialize["kb"] = lambda b: self.kb.from_bytes(b) # TODO
|
||||||
|
deserialize["article_encoder"] = lambda b: self.article_encoder.from_bytes(b)
|
||||||
|
deserialize["sent_encoder"] = lambda b: self.sent_encoder.from_bytes(b)
|
||||||
|
deserialize["mention_encoder"] = lambda b: self.mention_encoder.from_bytes(b)
|
||||||
|
exclude = util.get_serialization_exclude(deserialize, exclude, kwargs)
|
||||||
|
util.from_bytes(bytes_data, deserialize, exclude)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def to_disk(self, path, exclude=tuple(), **kwargs):
|
||||||
|
"""Serialize the pipe to disk."""
|
||||||
|
serialize = OrderedDict()
|
||||||
|
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
|
||||||
|
serialize["kb"] = lambda p: self.kb.dump(p)
|
||||||
|
if self.mention_encoder not in (None, True, False):
|
||||||
|
serialize["article_encoder"] = lambda p: p.open("wb").write(self.article_encoder.to_bytes())
|
||||||
|
serialize["sent_encoder"] = lambda p: p.open("wb").write(self.sent_encoder.to_bytes())
|
||||||
|
serialize["mention_encoder"] = lambda p: p.open("wb").write(self.mention_encoder.to_bytes())
|
||||||
|
exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
|
||||||
|
util.to_disk(path, serialize, exclude)
|
||||||
|
|
||||||
|
def from_disk(self, path, exclude=tuple(), **kwargs):
|
||||||
|
"""Load the pipe from disk."""
|
||||||
|
def load_article_encoder(p):
|
||||||
|
if self.article_encoder is True:
|
||||||
|
self.article_encoder, _, _ = self.Model(**self.cfg)
|
||||||
|
self.article_encoder.from_bytes(p.open("rb").read())
|
||||||
|
|
||||||
|
def load_sent_encoder(p):
|
||||||
|
if self.sent_encoder is True:
|
||||||
|
_, self.sent_encoder, _ = self.Model(**self.cfg)
|
||||||
|
self.sent_encoder.from_bytes(p.open("rb").read())
|
||||||
|
|
||||||
|
def load_mention_encoder(p):
|
||||||
|
if self.mention_encoder is True:
|
||||||
|
_, _, self.mention_encoder = self.Model(**self.cfg)
|
||||||
|
self.mention_encoder.from_bytes(p.open("rb").read())
|
||||||
|
|
||||||
|
deserialize = OrderedDict()
|
||||||
|
deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p))
|
||||||
|
deserialize["article_encoder"] = load_article_encoder
|
||||||
|
deserialize["sent_encoder"] = load_sent_encoder
|
||||||
|
deserialize["mention_encoder"] = load_mention_encoder
|
||||||
|
exclude = util.get_serialization_exclude(deserialize, exclude, kwargs)
|
||||||
|
util.from_disk(path, deserialize, exclude)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def rehearse(self, docs, sgd=None, losses=None, **config):
|
||||||
|
# TODO
|
||||||
|
pass
|
||||||
|
|
||||||
|
def add_label(self, label):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Sentencizer(object):
|
class Sentencizer(object):
|
||||||
"""Segment the Doc into sentences using a rule-based strategy.
|
"""Segment the Doc into sentences using a rule-based strategy.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user