spaCy/examples/training/train_parser.py

76 lines
2.2 KiB
Python
Raw Normal View History

2016-10-16 18:05:55 +03:00
from __future__ import unicode_literals, print_function
import json
import pathlib
import random
import spacy
from spacy.pipeline import DependencyParser
from spacy.gold import GoldParse
from spacy.tokens import Doc
def train_parser(nlp, train_data, left_labels, right_labels):
parser = DependencyParser(
nlp.vocab,
left_labels=left_labels,
right_labels=right_labels)
2016-10-16 18:05:55 +03:00
for itn in range(1000):
random.shuffle(train_data)
loss = 0
for words, heads, deps in train_data:
2016-10-16 22:41:14 +03:00
doc = Doc(nlp.vocab, words=words)
2016-10-16 18:05:55 +03:00
gold = GoldParse(doc, heads=heads, deps=deps)
loss += parser.update(doc, gold)
parser.model.end_training()
return parser
def main(model_dir=None):
if model_dir is not None:
model_dir = pathlib.Path(model_dir)
2016-10-16 18:05:55 +03:00
if not model_dir.exists():
model_dir.mkdir()
assert model_dir.is_dir()
2016-10-16 18:05:55 +03:00
nlp = spacy.load('en', tagger=False, parser=False, entity=False, vectors=False)
train_data = [
(
['They', 'trade', 'mortgage', '-', 'backed', 'securities', '.'],
[1, 1, 4, 4, 5, 1, 1],
['nsubj', 'ROOT', 'compound', 'punct', 'nmod', 'dobj', 'punct']
),
(
['I', 'like', 'London', 'and', 'Berlin', '.'],
[1, 1, 1, 2, 2, 1],
['nsubj', 'ROOT', 'dobj', 'cc', 'conj', 'punct']
)
]
left_labels = set()
right_labels = set()
for _, heads, deps in train_data:
for i, (head, dep) in enumerate(zip(heads, deps)):
if i < head:
left_labels.add(dep)
elif i > head:
right_labels.add(dep)
parser = train_parser(nlp, train_data, sorted(left_labels), sorted(right_labels))
2016-10-16 22:41:14 +03:00
doc = Doc(nlp.vocab, words=['I', 'like', 'securities', '.'])
parser(doc)
2016-10-16 18:05:55 +03:00
for word in doc:
print(word.text, word.dep_, word.head.text)
if model_dir is not None:
with (model_dir / 'config.json').open('w') as file_:
2016-10-16 18:05:55 +03:00
json.dump(parser.cfg, file_)
parser.model.dump(str(model_dir / 'model'))
if __name__ == '__main__':
main()
# I nsubj like
# like ROOT like
# securities dobj like
# . cc securities