mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Improve train_ud script
This commit is contained in:
parent
363f09e68c
commit
4ff92184f1
|
@ -18,6 +18,7 @@ from spacy.pipeline import DependencyParser
|
||||||
from spacy.syntax.parser import get_templates
|
from spacy.syntax.parser import get_templates
|
||||||
from spacy.syntax.arc_eager import ArcEager
|
from spacy.syntax.arc_eager import ArcEager
|
||||||
from spacy.scorer import Scorer
|
from spacy.scorer import Scorer
|
||||||
|
from spacy.language_data.tag_map import TAG_MAP as DEFAULT_TAG_MAP
|
||||||
import spacy.attrs
|
import spacy.attrs
|
||||||
import io
|
import io
|
||||||
|
|
||||||
|
@ -61,9 +62,12 @@ def score_model(vocab, tagger, parser, gold_docs, verbose=False):
|
||||||
return scorer
|
return scorer
|
||||||
|
|
||||||
|
|
||||||
def main(train_loc, dev_loc, model_dir, tag_map_loc):
|
def main(train_loc, dev_loc, model_dir, tag_map_loc=None):
|
||||||
with open(tag_map_loc) as file_:
|
if tag_map_loc:
|
||||||
tag_map = json.loads(file_.read())
|
with open(tag_map_loc) as file_:
|
||||||
|
tag_map = json.loads(file_.read())
|
||||||
|
else:
|
||||||
|
tag_map = DEFAULT_TAG_MAP
|
||||||
train_sents = list(read_conllx(train_loc))
|
train_sents = list(read_conllx(train_loc))
|
||||||
train_sents = PseudoProjectivity.preprocess_training_data(train_sents)
|
train_sents = PseudoProjectivity.preprocess_training_data(train_sents)
|
||||||
|
|
||||||
|
@ -73,9 +77,10 @@ def main(train_loc, dev_loc, model_dir, tag_map_loc):
|
||||||
model_dir = pathlib.Path(model_dir)
|
model_dir = pathlib.Path(model_dir)
|
||||||
if not (model_dir / 'deps').exists():
|
if not (model_dir / 'deps').exists():
|
||||||
(model_dir / 'deps').mkdir()
|
(model_dir / 'deps').mkdir()
|
||||||
with (model_dir / 'deps' / 'config.json').open('w') as file_:
|
with (model_dir / 'deps' / 'config.json').open('wb') as file_:
|
||||||
json.dump({'pseudoprojective': True, 'labels': actions, 'features': features}, file_)
|
file_.write(
|
||||||
|
json.dumps(
|
||||||
|
{'pseudoprojective': True, 'labels': actions, 'features': features}).encode('utf8'))
|
||||||
vocab = Vocab(lex_attr_getters=Language.Defaults.lex_attr_getters, tag_map=tag_map)
|
vocab = Vocab(lex_attr_getters=Language.Defaults.lex_attr_getters, tag_map=tag_map)
|
||||||
# Populate vocab
|
# Populate vocab
|
||||||
for _, doc_sents in train_sents:
|
for _, doc_sents in train_sents:
|
||||||
|
@ -86,8 +91,9 @@ def main(train_loc, dev_loc, model_dir, tag_map_loc):
|
||||||
_ = vocab[dep]
|
_ = vocab[dep]
|
||||||
for tag in tags:
|
for tag in tags:
|
||||||
_ = vocab[tag]
|
_ = vocab[tag]
|
||||||
for tag in tags:
|
if tag_map:
|
||||||
assert tag in tag_map, repr(tag)
|
for tag in tags:
|
||||||
|
assert tag in tag_map, repr(tag)
|
||||||
tagger = Tagger(vocab, tag_map=tag_map)
|
tagger = Tagger(vocab, tag_map=tag_map)
|
||||||
parser = DependencyParser(vocab, actions=actions, features=features)
|
parser = DependencyParser(vocab, actions=actions, features=features)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user