mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 09:56:28 +03:00
Refactor conllu script, fix interface, generalize
This commit is contained in:
parent
551c93fe01
commit
9e960d24fc
|
@ -13,7 +13,7 @@ import json
|
||||||
import spacy
|
import spacy
|
||||||
import spacy.util
|
import spacy.util
|
||||||
from spacy.tokens import Token, Doc
|
from spacy.tokens import Token, Doc
|
||||||
from spacy.gold import GoldParse, minibatch
|
from spacy.gold import GoldParse
|
||||||
from spacy.syntax.nonproj import projectivize
|
from spacy.syntax.nonproj import projectivize
|
||||||
from collections import defaultdict, Counter
|
from collections import defaultdict, Counter
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
|
@ -24,7 +24,7 @@ import random
|
||||||
import numpy.random
|
import numpy.random
|
||||||
import cytoolz
|
import cytoolz
|
||||||
|
|
||||||
from spacy._align import align
|
import conll17_ud_eval
|
||||||
|
|
||||||
random.seed(0)
|
random.seed(0)
|
||||||
numpy.random.seed(0)
|
numpy.random.seed(0)
|
||||||
|
@ -43,7 +43,8 @@ def minibatch_by_words(items, size=5000):
|
||||||
try:
|
try:
|
||||||
doc, gold = next(items)
|
doc, gold = next(items)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
yield batch
|
if batch:
|
||||||
|
yield batch
|
||||||
return
|
return
|
||||||
batch_size -= len(doc)
|
batch_size -= len(doc)
|
||||||
batch.append((doc, gold))
|
batch.append((doc, gold))
|
||||||
|
@ -56,9 +57,9 @@ def minibatch_by_words(items, size=5000):
|
||||||
# Data reading #
|
# Data reading #
|
||||||
################
|
################
|
||||||
|
|
||||||
|
space_re = re.compile('\s+')
|
||||||
def split_text(text):
|
def split_text(text):
|
||||||
return [par.strip().replace('\n', ' ')
|
return [space_re.sub(' ', par.strip()) for par in text.split('\n\n')]
|
||||||
for par in text.split('\n\n')]
|
|
||||||
|
|
||||||
|
|
||||||
def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False,
|
def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False,
|
||||||
|
@ -132,7 +133,10 @@ def read_conllu(file_):
|
||||||
doc.append(sent)
|
doc.append(sent)
|
||||||
sent = []
|
sent = []
|
||||||
else:
|
else:
|
||||||
sent.append(line.strip().split())
|
sent.append(list(line.strip().split('\t')))
|
||||||
|
if len(sent[-1]) != 10:
|
||||||
|
print(repr(line))
|
||||||
|
raise ValueError
|
||||||
if sent:
|
if sent:
|
||||||
doc.append(sent)
|
doc.append(sent)
|
||||||
if doc:
|
if doc:
|
||||||
|
@ -176,50 +180,21 @@ def golds_to_gold_tuples(docs, golds):
|
||||||
# Evaluation #
|
# Evaluation #
|
||||||
##############
|
##############
|
||||||
|
|
||||||
def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False,
|
def evaluate(nlp, text_loc, gold_loc, sys_loc, limit=None):
|
||||||
joint_sbd=True, limit=None):
|
with text_loc.open('r', encoding='utf8') as text_file:
|
||||||
with open(text_loc) as text_file:
|
texts = split_text(text_file.read())
|
||||||
with open(conllu_loc) as conllu_file:
|
docs = list(nlp.pipe(texts))
|
||||||
docs, golds = read_data(nlp, conllu_file, text_file,
|
with sys_loc.open('w', encoding='utf8') as out_file:
|
||||||
oracle_segments=oracle_segments, limit=limit)
|
write_conllu(docs, out_file)
|
||||||
if joint_sbd:
|
with gold_loc.open('r', encoding='utf8') as gold_file:
|
||||||
pass
|
gold_ud = conll17_ud_eval.load_conllu(gold_file)
|
||||||
else:
|
with sys_loc.open('r', encoding='utf8') as sys_file:
|
||||||
sbd = nlp.create_pipe('sentencizer')
|
sys_ud = conll17_ud_eval.load_conllu(sys_file)
|
||||||
for doc in docs:
|
scores = conll17_ud_eval.evaluate(gold_ud, sys_ud)
|
||||||
doc = sbd(doc)
|
return scores
|
||||||
for sent in doc.sents:
|
|
||||||
sent[0].is_sent_start = True
|
|
||||||
for word in sent[1:]:
|
|
||||||
word.is_sent_start = False
|
|
||||||
scorer = nlp.evaluate(zip(docs, golds))
|
|
||||||
return docs, scorer
|
|
||||||
|
|
||||||
|
|
||||||
def print_progress(itn, losses, scorer):
|
def write_conllu(docs, file_):
|
||||||
scores = {}
|
|
||||||
for col in ['dep_loss', 'tag_loss', 'uas', 'tags_acc', 'token_acc',
|
|
||||||
'ents_p', 'ents_r', 'ents_f', 'cpu_wps', 'gpu_wps']:
|
|
||||||
scores[col] = 0.0
|
|
||||||
scores['dep_loss'] = losses.get('parser', 0.0)
|
|
||||||
scores['ner_loss'] = losses.get('ner', 0.0)
|
|
||||||
scores['tag_loss'] = losses.get('tagger', 0.0)
|
|
||||||
scores.update(scorer.scores)
|
|
||||||
tpl = '\t'.join((
|
|
||||||
'{:d}',
|
|
||||||
'{dep_loss:.3f}',
|
|
||||||
'{ner_loss:.3f}',
|
|
||||||
'{uas:.3f}',
|
|
||||||
'{ents_p:.3f}',
|
|
||||||
'{ents_r:.3f}',
|
|
||||||
'{ents_f:.3f}',
|
|
||||||
'{tags_acc:.3f}',
|
|
||||||
'{token_acc:.3f}',
|
|
||||||
))
|
|
||||||
print(tpl.format(itn, **scores))
|
|
||||||
|
|
||||||
|
|
||||||
def print_conllu(docs, file_):
|
|
||||||
merger = Matcher(docs[0].vocab)
|
merger = Matcher(docs[0].vocab)
|
||||||
merger.add('SUBTOK', None, [{'DEP': 'subtok', 'op': '+'}])
|
merger.add('SUBTOK', None, [{'DEP': 'subtok', 'op': '+'}])
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
|
@ -236,6 +211,31 @@ def print_conllu(docs, file_):
|
||||||
file_.write(token._.get_conllu_lines(k) + '\n')
|
file_.write(token._.get_conllu_lines(k) + '\n')
|
||||||
file_.write('\n')
|
file_.write('\n')
|
||||||
|
|
||||||
|
|
||||||
|
def print_progress(itn, losses, ud_scores):
|
||||||
|
fields = {
|
||||||
|
'dep_loss': losses.get('parser', 0.0),
|
||||||
|
'tag_loss': losses.get('tagger', 0.0),
|
||||||
|
'words': ud_scores['Words'].f1 * 100,
|
||||||
|
'sents': ud_scores['Sentences'].f1 * 100,
|
||||||
|
'tags': ud_scores['XPOS'].f1 * 100,
|
||||||
|
'uas': ud_scores['UAS'].f1 * 100,
|
||||||
|
'las': ud_scores['LAS'].f1 * 100,
|
||||||
|
}
|
||||||
|
header = ['Epoch', 'Loss', 'LAS', 'UAS', 'TAG', 'SENT', 'WORD']
|
||||||
|
if itn == 0:
|
||||||
|
print('\t'.join(header))
|
||||||
|
tpl = '\t'.join((
|
||||||
|
'{:d}',
|
||||||
|
'{dep_loss:.1f}',
|
||||||
|
'{las:.1f}',
|
||||||
|
'{uas:.1f}',
|
||||||
|
'{tags:.1f}',
|
||||||
|
'{sents:.1f}',
|
||||||
|
'{words:.1f}',
|
||||||
|
))
|
||||||
|
print(tpl.format(itn, **fields))
|
||||||
|
|
||||||
#def get_sent_conllu(sent, sent_id):
|
#def get_sent_conllu(sent, sent_id):
|
||||||
# lines = ["# sent_id = {sent_id}".format(sent_id=sent_id)]
|
# lines = ["# sent_id = {sent_id}".format(sent_id=sent_id)]
|
||||||
|
|
||||||
|
@ -275,7 +275,6 @@ def load_nlp(corpus, config):
|
||||||
return nlp
|
return nlp
|
||||||
|
|
||||||
def initialize_pipeline(nlp, docs, golds, config):
|
def initialize_pipeline(nlp, docs, golds, config):
|
||||||
print("Create parser")
|
|
||||||
nlp.add_pipe(nlp.create_pipe('parser'))
|
nlp.add_pipe(nlp.create_pipe('parser'))
|
||||||
if config.multitask_tag:
|
if config.multitask_tag:
|
||||||
nlp.parser.add_multitask_objective('tag')
|
nlp.parser.add_multitask_objective('tag')
|
||||||
|
@ -347,14 +346,16 @@ class TreebankPaths(object):
|
||||||
|
|
||||||
@plac.annotations(
|
@plac.annotations(
|
||||||
ud_dir=("Path to Universal Dependencies corpus", "positional", None, Path),
|
ud_dir=("Path to Universal Dependencies corpus", "positional", None, Path),
|
||||||
config=("Path to json formatted config file", "positional", None, Config.load),
|
|
||||||
corpus=("UD corpus to train and evaluate on, e.g. en, es_ancora, etc",
|
corpus=("UD corpus to train and evaluate on, e.g. en, es_ancora, etc",
|
||||||
"positional", None, str),
|
"positional", None, str),
|
||||||
parses_loc=("Path to write the development parses", "positional", None, Path),
|
parses_dir=("Directory to write the development parses", "positional", None, Path),
|
||||||
|
config=("Path to json formatted config file", "positional", None, Config.load),
|
||||||
limit=("Size limit", "option", "n", int)
|
limit=("Size limit", "option", "n", int)
|
||||||
)
|
)
|
||||||
def main(ud_dir, corpus, config, parses_loc='/tmp/dev.conllu', limit=10):
|
def main(ud_dir, parses_dir, config, corpus, limit=0):
|
||||||
paths = TreebankPaths(ud_dir, corpus)
|
paths = TreebankPaths(ud_dir, corpus)
|
||||||
|
if not (parses_dir / corpus).exists():
|
||||||
|
(parses_dir / corpus).mkdir()
|
||||||
print("Train and evaluate", corpus, "using lang", paths.lang)
|
print("Train and evaluate", corpus, "using lang", paths.lang)
|
||||||
nlp = load_nlp(paths.lang, config)
|
nlp = load_nlp(paths.lang, config)
|
||||||
|
|
||||||
|
@ -362,6 +363,7 @@ def main(ud_dir, corpus, config, parses_loc='/tmp/dev.conllu', limit=10):
|
||||||
max_doc_length=config.max_doc_length, limit=limit)
|
max_doc_length=config.max_doc_length, limit=limit)
|
||||||
|
|
||||||
optimizer = initialize_pipeline(nlp, docs, golds, config)
|
optimizer = initialize_pipeline(nlp, docs, golds, config)
|
||||||
|
|
||||||
for i in range(config.nr_epoch):
|
for i in range(config.nr_epoch):
|
||||||
docs = [nlp.make_doc(doc.text) for doc in docs]
|
docs = [nlp.make_doc(doc.text) for doc in docs]
|
||||||
batches = minibatch_by_words(list(zip(docs, golds)), size=config.batch_size)
|
batches = minibatch_by_words(list(zip(docs, golds)), size=config.batch_size)
|
||||||
|
@ -374,11 +376,10 @@ def main(ud_dir, corpus, config, parses_loc='/tmp/dev.conllu', limit=10):
|
||||||
nlp.update(batch_docs, batch_gold, sgd=optimizer,
|
nlp.update(batch_docs, batch_gold, sgd=optimizer,
|
||||||
drop=config.dropout, losses=losses)
|
drop=config.dropout, losses=losses)
|
||||||
|
|
||||||
|
out_path = parses_dir / corpus / 'epoch-{i}.conllu'.format(i=i)
|
||||||
with nlp.use_params(optimizer.averages):
|
with nlp.use_params(optimizer.averages):
|
||||||
dev_docs, scorer = parse_dev_data(nlp, paths.dev.text, paths.dev.conllu)
|
scores = evaluate(nlp, paths.dev.text, paths.dev.conllu, out_path)
|
||||||
print_progress(i, losses, scorer)
|
print_progress(i, losses, scores)
|
||||||
with open(parses_loc, 'w') as file_:
|
|
||||||
print_conllu(dev_docs, file_)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue
Block a user