mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
* Update the CoNLL train script, to get working on other languages
This commit is contained in:
parent
6c633f2edc
commit
a676d66807
|
@ -5,7 +5,7 @@ from __future__ import unicode_literals
|
||||||
import os
|
import os
|
||||||
from os import path
|
from os import path
|
||||||
import shutil
|
import shutil
|
||||||
import codecs
|
import io
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
import gzip
|
import gzip
|
||||||
|
@ -56,12 +56,20 @@ def _parse_line(line):
|
||||||
if len(pieces) == 4:
|
if len(pieces) == 4:
|
||||||
word, pos, head_idx, label = pieces
|
word, pos, head_idx, label = pieces
|
||||||
head_idx = int(head_idx)
|
head_idx = int(head_idx)
|
||||||
|
elif len(pieces) == 15:
|
||||||
|
id_ = int(pieces[0].split('_')[-1])
|
||||||
|
word = pieces[1]
|
||||||
|
pos = pieces[4]
|
||||||
|
head_idx = int(pieces[8])-1
|
||||||
|
label = pieces[10]
|
||||||
else:
|
else:
|
||||||
id_ = int(pieces[0])
|
id_ = int(pieces[0].split('_')[-1])
|
||||||
word = pieces[1]
|
word = pieces[1]
|
||||||
pos = pieces[4]
|
pos = pieces[4]
|
||||||
head_idx = int(pieces[6])-1
|
head_idx = int(pieces[6])-1
|
||||||
label = pieces[7]
|
label = pieces[7]
|
||||||
|
if head_idx == 0:
|
||||||
|
label = 'ROOT'
|
||||||
return word, pos, head_idx, label
|
return word, pos, head_idx, label
|
||||||
|
|
||||||
|
|
||||||
|
@ -69,8 +77,8 @@ def score_model(scorer, nlp, raw_text, annot_tuples, verbose=False):
|
||||||
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
tokens = nlp.tokenizer.tokens_from_list(annot_tuples[1])
|
||||||
nlp.tagger(tokens)
|
nlp.tagger(tokens)
|
||||||
nlp.parser(tokens)
|
nlp.parser(tokens)
|
||||||
gold = GoldParse(tokens, annot_tuples)
|
gold = GoldParse(tokens, annot_tuples, make_projective=False)
|
||||||
scorer.score(tokens, gold, verbose=verbose)
|
scorer.score(tokens, gold, verbose=verbose, punct_labels=('--', 'p', 'punct'))
|
||||||
|
|
||||||
|
|
||||||
def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0,
|
||||||
|
@ -122,11 +130,11 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic', seed=0
|
||||||
|
|
||||||
|
|
||||||
def main(train_loc, dev_loc, model_dir):
|
def main(train_loc, dev_loc, model_dir):
|
||||||
with codecs.open(train_loc, 'r', 'utf8') as file_:
|
with io.open(train_loc, 'r', encoding='utf8') as file_:
|
||||||
train_sents = read_conll(file_)
|
train_sents = read_conll(file_)
|
||||||
train(English, train_sents, model_dir)
|
#train(English, train_sents, model_dir)
|
||||||
nlp = English(data_dir=model_dir)
|
nlp = English(data_dir=model_dir)
|
||||||
dev_sents = read_conll(open(dev_loc))
|
dev_sents = read_conll(io.open(dev_loc, 'r', encoding='utf8'))
|
||||||
scorer = Scorer()
|
scorer = Scorer()
|
||||||
for _, sents in dev_sents:
|
for _, sents in dev_sents:
|
||||||
for annot_tuples, _ in sents:
|
for annot_tuples, _ in sents:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user