2018-12-02 06:26:26 +03:00
|
|
|
"""This example shows how to add a multi-task objective that is trained
|
2018-01-21 21:46:37 +03:00
|
|
|
alongside the entity recognizer. This is an alternative to adding features
|
|
|
|
to the model.
|
|
|
|
|
|
|
|
The multi-task idea is to train an auxiliary model to predict some attribute,
|
|
|
|
with weights shared between the auxiliary model and the main model. In this
|
|
|
|
example, we're predicting the position of the word in the document.
|
|
|
|
|
|
|
|
The model that predicts the position of the word encourages the convolutional
|
|
|
|
layers to include the position information in their representation. The
|
|
|
|
information is then available to the main model, as a feature.
|
|
|
|
|
|
|
|
The overall idea is that we might know something about what sort of features
|
|
|
|
we'd like the CNN to extract. The multi-task objectives can encourage the
|
|
|
|
extraction of this type of feature. The multi-task objective is only used
|
|
|
|
during training. We discard the auxiliary model before run-time.
|
|
|
|
|
|
|
|
The specific example here is not necessarily a good idea --- but it shows
|
|
|
|
how an arbitrary objective function for some word can be used.
|
|
|
|
|
2019-10-27 23:58:50 +03:00
|
|
|
Developed and tested for spaCy 2.0.6. Updated for v2.2.2
|
2018-12-02 06:26:26 +03:00
|
|
|
"""
|
2018-01-21 21:46:37 +03:00
|
|
|
import random
|
|
|
|
import plac
|
|
|
|
import spacy
|
|
|
|
import os.path
|
2019-10-27 18:01:32 +03:00
|
|
|
from spacy.tokens import Doc
|
2019-10-27 23:58:50 +03:00
|
|
|
from spacy.gold import read_json_file, GoldParse
|
2019-10-27 18:01:32 +03:00
|
|
|
|
2018-01-21 21:46:37 +03:00
|
|
|
random.seed(0)
|
|
|
|
|
|
|
|
PWD = os.path.dirname(__file__)
|
|
|
|
|
2019-11-11 19:35:27 +03:00
|
|
|
TRAIN_DATA = list(read_json_file(os.path.join(PWD, "training-data.json")))
|
2018-01-21 21:46:37 +03:00
|
|
|
|
|
|
|
|
2019-11-11 19:35:27 +03:00
|
|
|
def get_position_label(i, token_annotation):
|
2018-12-02 06:26:26 +03:00
|
|
|
"""Return labels indicating the position of the word in the document.
|
|
|
|
"""
|
2019-11-11 19:35:27 +03:00
|
|
|
if len(token_annotation.words) < 20:
|
2018-12-02 06:26:26 +03:00
|
|
|
return "short-doc"
|
2018-01-21 21:46:37 +03:00
|
|
|
elif i == 0:
|
2018-12-02 06:26:26 +03:00
|
|
|
return "first-word"
|
2018-01-21 21:46:37 +03:00
|
|
|
elif i < 10:
|
2018-12-02 06:26:26 +03:00
|
|
|
return "early-word"
|
2018-01-21 21:46:37 +03:00
|
|
|
elif i < 20:
|
2018-12-02 06:26:26 +03:00
|
|
|
return "mid-word"
|
2019-11-11 19:35:27 +03:00
|
|
|
elif i == len(token_annotation.words) - 1:
|
2018-12-02 06:26:26 +03:00
|
|
|
return "last-word"
|
2018-01-21 21:46:37 +03:00
|
|
|
else:
|
2018-12-02 06:26:26 +03:00
|
|
|
return "late-word"
|
2018-01-21 21:46:37 +03:00
|
|
|
|
|
|
|
|
|
|
|
def main(n_iter=10):
|
2018-12-02 06:26:26 +03:00
|
|
|
nlp = spacy.blank("en")
|
|
|
|
ner = nlp.create_pipe("ner")
|
2018-01-21 21:46:37 +03:00
|
|
|
ner.add_multitask_objective(get_position_label)
|
|
|
|
nlp.add_pipe(ner)
|
2019-10-27 23:58:50 +03:00
|
|
|
print(nlp.pipeline)
|
2018-01-21 21:46:37 +03:00
|
|
|
|
2019-10-27 23:58:50 +03:00
|
|
|
print("Create data", len(TRAIN_DATA))
|
2019-11-11 19:35:27 +03:00
|
|
|
optimizer = nlp.begin_training(get_examples=lambda: TRAIN_DATA)
|
2018-01-21 21:46:37 +03:00
|
|
|
for itn in range(n_iter):
|
|
|
|
random.shuffle(TRAIN_DATA)
|
|
|
|
losses = {}
|
2019-11-11 19:35:27 +03:00
|
|
|
for example in TRAIN_DATA:
|
|
|
|
for token_annotation in example.token_annotations:
|
|
|
|
doc = Doc(nlp.vocab, words=token_annotation.words)
|
|
|
|
gold = GoldParse.from_annotation(doc, example.doc_annotation, token_annotation)
|
|
|
|
|
2019-10-27 18:01:32 +03:00
|
|
|
nlp.update(
|
2019-11-11 19:35:27 +03:00
|
|
|
examples=[(doc, gold)], # 1 example
|
2019-10-27 18:01:32 +03:00
|
|
|
drop=0.2, # dropout - make it harder to memorise data
|
|
|
|
sgd=optimizer, # callable to update weights
|
|
|
|
losses=losses,
|
|
|
|
)
|
2018-12-02 06:26:26 +03:00
|
|
|
print(losses.get("nn_labeller", 0.0), losses["ner"])
|
2018-01-21 21:46:37 +03:00
|
|
|
|
|
|
|
# test the trained model
|
2019-11-11 19:35:27 +03:00
|
|
|
for example in TRAIN_DATA:
|
|
|
|
if example.text is not None:
|
|
|
|
doc = nlp(example.text)
|
2019-10-27 23:58:50 +03:00
|
|
|
print("Entities", [(ent.text, ent.label_) for ent in doc.ents])
|
|
|
|
print("Tokens", [(t.text, t.ent_type_, t.ent_iob) for t in doc])
|
2018-01-21 21:46:37 +03:00
|
|
|
|
|
|
|
|
2018-12-02 06:26:26 +03:00
|
|
|
if __name__ == "__main__":
|
2018-01-21 21:46:37 +03:00
|
|
|
plac.call(main)
|