mirror of
https://github.com/explosion/spaCy.git
synced 2025-03-03 10:55:52 +03:00
Add example for training parser
This commit is contained in:
parent
195d998a12
commit
4574fe87c6
81
examples/training/train_parser.py
Normal file
81
examples/training/train_parser.py
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
from __future__ import unicode_literals, print_function
|
||||||
|
import json
|
||||||
|
import pathlib
|
||||||
|
import random
|
||||||
|
|
||||||
|
import spacy
|
||||||
|
from spacy.pipeline import DependencyParser
|
||||||
|
from spacy.gold import GoldParse
|
||||||
|
from spacy.tokens import Doc
|
||||||
|
|
||||||
|
|
||||||
|
def train_parser(nlp, train_data, left_labels, right_labels):
|
||||||
|
parser = DependencyParser.blank(
|
||||||
|
nlp.vocab,
|
||||||
|
left_labels=left_labels,
|
||||||
|
right_labels=right_labels,
|
||||||
|
features=nlp.defaults.parser_features)
|
||||||
|
for itn in range(1000):
|
||||||
|
random.shuffle(train_data)
|
||||||
|
loss = 0
|
||||||
|
for words, heads, deps in train_data:
|
||||||
|
doc = nlp.make_doc(words)
|
||||||
|
gold = GoldParse(doc, heads=heads, deps=deps)
|
||||||
|
loss += parser.update(doc, gold)
|
||||||
|
parser.model.end_training()
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main(model_dir=None):
|
||||||
|
if model_dir is not None:
|
||||||
|
model_dir = pathlb.Path(model_dir)
|
||||||
|
if not model_dir.exists():
|
||||||
|
model_dir.mkdir()
|
||||||
|
assert model_dir.isdir()
|
||||||
|
|
||||||
|
nlp = spacy.load('en', tagger=False, parser=False, entity=False, vectors=False)
|
||||||
|
nlp.make_doc = lambda words: Doc(nlp.vocab, zip(words, [True]*len(words)))
|
||||||
|
|
||||||
|
train_data = [
|
||||||
|
(
|
||||||
|
['They', 'trade', 'mortgage', '-', 'backed', 'securities', '.'],
|
||||||
|
[1, 1, 4, 4, 5, 1, 1],
|
||||||
|
['nsubj', 'ROOT', 'compound', 'punct', 'nmod', 'dobj', 'punct']
|
||||||
|
),
|
||||||
|
(
|
||||||
|
['I', 'like', 'London', 'and', 'Berlin', '.'],
|
||||||
|
[1, 1, 1, 2, 2, 1],
|
||||||
|
['nsubj', 'ROOT', 'dobj', 'cc', 'conj', 'punct']
|
||||||
|
)
|
||||||
|
]
|
||||||
|
left_labels = set()
|
||||||
|
right_labels = set()
|
||||||
|
for _, heads, deps in train_data:
|
||||||
|
for i, (head, dep) in enumerate(zip(heads, deps)):
|
||||||
|
if i < head:
|
||||||
|
left_labels.add(dep)
|
||||||
|
elif i > head:
|
||||||
|
right_labels.add(dep)
|
||||||
|
parser = train_parser(nlp, train_data, sorted(left_labels), sorted(right_labels))
|
||||||
|
|
||||||
|
doc = nlp.make_doc(['I', 'like', 'securities', '.'])
|
||||||
|
with parser.step_through(doc) as state:
|
||||||
|
while not state.is_final:
|
||||||
|
action = state.predict()
|
||||||
|
state.transition(action)
|
||||||
|
#parser(doc)
|
||||||
|
for word in doc:
|
||||||
|
print(word.text, word.dep_, word.head.text)
|
||||||
|
|
||||||
|
if model_dir is not None:
|
||||||
|
with (model_dir / 'config.json').open('wb') as file_:
|
||||||
|
json.dump(parser.cfg, file_)
|
||||||
|
parser.model.dump(str(model_dir / 'model'))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
# I nsubj like
|
||||||
|
# like ROOT like
|
||||||
|
# securities dobj like
|
||||||
|
# . cc securities
|
Loading…
Reference in New Issue
Block a user