mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-27 02:16:32 +03:00
Refactor conllu script
This commit is contained in:
parent
c388833ca6
commit
44e496a82e
|
@ -4,8 +4,12 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
import plac
|
import plac
|
||||||
import tqdm
|
import tqdm
|
||||||
|
import attr
|
||||||
|
from pathlib import Path
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
import json
|
||||||
|
|
||||||
import spacy
|
import spacy
|
||||||
import spacy.util
|
import spacy.util
|
||||||
from spacy.tokens import Token, Doc
|
from spacy.tokens import Token, Doc
|
||||||
|
@ -40,32 +44,9 @@ def minibatch_by_words(items, size=5000):
|
||||||
batch.append((doc, gold))
|
batch.append((doc, gold))
|
||||||
yield batch
|
yield batch
|
||||||
|
|
||||||
|
################
|
||||||
def get_token_acc(docs, golds):
|
# Data reading #
|
||||||
'''Quick function to evaluate tokenization accuracy.'''
|
################
|
||||||
miss = 0
|
|
||||||
hit = 0
|
|
||||||
for doc, gold in zip(docs, golds):
|
|
||||||
for i in range(len(doc)):
|
|
||||||
token = doc[i]
|
|
||||||
align = gold.words[i]
|
|
||||||
if align == None:
|
|
||||||
miss += 1
|
|
||||||
else:
|
|
||||||
hit += 1
|
|
||||||
return miss, hit
|
|
||||||
|
|
||||||
|
|
||||||
def golds_to_gold_tuples(docs, golds):
|
|
||||||
'''Get out the annoying 'tuples' format used by begin_training, given the
|
|
||||||
GoldParse objects.'''
|
|
||||||
tuples = []
|
|
||||||
for doc, gold in zip(docs, golds):
|
|
||||||
text = doc.text
|
|
||||||
ids, words, tags, heads, labels, iob = zip(*gold.orig_annot)
|
|
||||||
sents = [((ids, words, tags, heads, labels, iob), [])]
|
|
||||||
tuples.append((text, sents))
|
|
||||||
return tuples
|
|
||||||
|
|
||||||
def split_text(text):
|
def split_text(text):
|
||||||
return [par.strip().replace('\n', ' ')
|
return [par.strip().replace('\n', ' ')
|
||||||
|
@ -127,34 +108,6 @@ def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False,
|
||||||
return docs, golds
|
return docs, golds
|
||||||
|
|
||||||
|
|
||||||
def _make_gold(nlp, text, sent_annots):
|
|
||||||
# Flatten the conll annotations, and adjust the head indices
|
|
||||||
flat = defaultdict(list)
|
|
||||||
for sent in sent_annots:
|
|
||||||
flat['heads'].extend(len(flat['words'])+head for head in sent['heads'])
|
|
||||||
for field in ['words', 'tags', 'deps', 'entities', 'spaces']:
|
|
||||||
flat[field].extend(sent[field])
|
|
||||||
# Construct text if necessary
|
|
||||||
assert len(flat['words']) == len(flat['spaces'])
|
|
||||||
if text is None:
|
|
||||||
text = ''.join(word+' '*space for word, space in zip(flat['words'], flat['spaces']))
|
|
||||||
doc = nlp.make_doc(text)
|
|
||||||
flat.pop('spaces')
|
|
||||||
gold = GoldParse(doc, **flat)
|
|
||||||
#for annot in gold.orig_annot:
|
|
||||||
# print(annot)
|
|
||||||
#for i in range(len(doc)):
|
|
||||||
# print(doc[i].text, gold.words[i], gold.labels[i], gold.heads[i])
|
|
||||||
return doc, gold
|
|
||||||
|
|
||||||
|
|
||||||
def refresh_docs(docs):
|
|
||||||
vocab = docs[0].vocab
|
|
||||||
return [Doc(vocab, words=[t.text for t in doc],
|
|
||||||
spaces=[t.whitespace_ for t in doc])
|
|
||||||
for doc in docs]
|
|
||||||
|
|
||||||
|
|
||||||
def read_conllu(file_):
|
def read_conllu(file_):
|
||||||
docs = []
|
docs = []
|
||||||
sent = []
|
sent = []
|
||||||
|
@ -179,6 +132,52 @@ def read_conllu(file_):
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
def _make_gold(nlp, text, sent_annots):
|
||||||
|
# Flatten the conll annotations, and adjust the head indices
|
||||||
|
flat = defaultdict(list)
|
||||||
|
for sent in sent_annots:
|
||||||
|
flat['heads'].extend(len(flat['words'])+head for head in sent['heads'])
|
||||||
|
for field in ['words', 'tags', 'deps', 'entities', 'spaces']:
|
||||||
|
flat[field].extend(sent[field])
|
||||||
|
# Construct text if necessary
|
||||||
|
assert len(flat['words']) == len(flat['spaces'])
|
||||||
|
if text is None:
|
||||||
|
text = ''.join(word+' '*space for word, space in zip(flat['words'], flat['spaces']))
|
||||||
|
doc = nlp.make_doc(text)
|
||||||
|
flat.pop('spaces')
|
||||||
|
gold = GoldParse(doc, **flat)
|
||||||
|
#for annot in gold.orig_annot:
|
||||||
|
# print(annot)
|
||||||
|
#for i in range(len(doc)):
|
||||||
|
# print(doc[i].text, gold.words[i], gold.labels[i], gold.heads[i])
|
||||||
|
return doc, gold
|
||||||
|
|
||||||
|
#############################
|
||||||
|
# Data transforms for spaCy #
|
||||||
|
#############################
|
||||||
|
|
||||||
|
def golds_to_gold_tuples(docs, golds):
|
||||||
|
'''Get out the annoying 'tuples' format used by begin_training, given the
|
||||||
|
GoldParse objects.'''
|
||||||
|
tuples = []
|
||||||
|
for doc, gold in zip(docs, golds):
|
||||||
|
text = doc.text
|
||||||
|
ids, words, tags, heads, labels, iob = zip(*gold.orig_annot)
|
||||||
|
sents = [((ids, words, tags, heads, labels, iob), [])]
|
||||||
|
tuples.append((text, sents))
|
||||||
|
return tuples
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_docs(docs):
|
||||||
|
vocab = docs[0].vocab
|
||||||
|
return [Doc(vocab, words=[t.text for t in doc],
|
||||||
|
spaces=[t.whitespace_ for t in doc])
|
||||||
|
for doc in docs]
|
||||||
|
|
||||||
|
##############
|
||||||
|
# Evaluation #
|
||||||
|
##############
|
||||||
|
|
||||||
def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False,
|
def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False,
|
||||||
joint_sbd=True, limit=None):
|
joint_sbd=True, limit=None):
|
||||||
with open(text_loc) as text_file:
|
with open(text_loc) as text_file:
|
||||||
|
@ -265,25 +264,24 @@ Token.set_extension('begins_fused', default=False)
|
||||||
Token.set_extension('inside_fused', default=False)
|
Token.set_extension('inside_fused', default=False)
|
||||||
|
|
||||||
|
|
||||||
def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
|
##################
|
||||||
output_loc):
|
# Initialization #
|
||||||
if lang == 'en':
|
##################
|
||||||
|
|
||||||
|
|
||||||
|
def load_nlp(corpus, config):
|
||||||
|
lang = corpus.split('_')[0]
|
||||||
nlp = spacy.blank(lang)
|
nlp = spacy.blank(lang)
|
||||||
vec_nlp = spacy.util.load_model('spacy/data/en_core_web_lg/en_core_web_lg-2.0.0')
|
if config.vectors:
|
||||||
nlp.vocab.vectors = vec_nlp.vocab.vectors
|
nlp.vocab.from_disk(config.vectors / 'vocab')
|
||||||
for lex in vec_nlp.vocab:
|
return nlp
|
||||||
_ = nlp.vocab[lex.orth_]
|
|
||||||
vec_nlp = None
|
def initialize_pipeline(nlp, docs, golds, config):
|
||||||
else:
|
|
||||||
nlp = spacy.load(lang)
|
|
||||||
with open(conllu_train_loc) as conllu_file:
|
|
||||||
with open(text_train_loc) as text_file:
|
|
||||||
docs, golds = read_data(nlp, conllu_file, text_file,
|
|
||||||
oracle_segments=False, raw_text=True,
|
|
||||||
max_doc_length=10, limit=None)
|
|
||||||
print("Create parser")
|
print("Create parser")
|
||||||
nlp.add_pipe(nlp.create_pipe('parser'))
|
nlp.add_pipe(nlp.create_pipe('parser'))
|
||||||
|
if config.multitask_tag:
|
||||||
nlp.parser.add_multitask_objective('tag')
|
nlp.parser.add_multitask_objective('tag')
|
||||||
|
if config.multitask_sent:
|
||||||
nlp.parser.add_multitask_objective('sent_start')
|
nlp.parser.add_multitask_objective('sent_start')
|
||||||
nlp.parser.moves.add_action(2, 'subtok')
|
nlp.parser.moves.add_action(2, 'subtok')
|
||||||
nlp.add_pipe(nlp.create_pipe('tagger'))
|
nlp.add_pipe(nlp.create_pipe('tagger'))
|
||||||
|
@ -291,7 +289,6 @@ def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
|
||||||
for tag in gold.tags:
|
for tag in gold.tags:
|
||||||
if tag is not None:
|
if tag is not None:
|
||||||
nlp.tagger.add_label(tag)
|
nlp.tagger.add_label(tag)
|
||||||
optimizer = nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds))
|
|
||||||
# Replace labels that didn't make the frequency cutoff
|
# Replace labels that didn't make the frequency cutoff
|
||||||
actions = set(nlp.parser.labels)
|
actions = set(nlp.parser.labels)
|
||||||
label_set = set([act.split('-')[1] for act in actions if '-' in act])
|
label_set = set([act.split('-')[1] for act in actions if '-' in act])
|
||||||
|
@ -299,38 +296,92 @@ def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
|
||||||
for i, label in enumerate(gold.labels):
|
for i, label in enumerate(gold.labels):
|
||||||
if label is not None and label not in label_set:
|
if label is not None and label not in label_set:
|
||||||
gold.labels[i] = label.split('||')[0]
|
gold.labels[i] = label.split('||')[0]
|
||||||
|
return nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds))
|
||||||
|
|
||||||
|
|
||||||
|
########################
|
||||||
|
# Command line helpers #
|
||||||
|
########################
|
||||||
|
|
||||||
|
@attr.s
|
||||||
|
class Config(object):
|
||||||
|
vectors = attr.ib(default=None)
|
||||||
|
max_doc_length = attr.ib(default=10)
|
||||||
|
multitask_tag = attr.ib(default=True)
|
||||||
|
multitask_sent = attr.ib(default=True)
|
||||||
|
nr_epoch = attr.ib(default=30)
|
||||||
|
batch_size = attr.ib(default=1000)
|
||||||
|
dropout = attr.ib(default=0.2)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, loc):
|
||||||
|
with Path(loc).open('r', encoding='utf8') as file_:
|
||||||
|
cfg = json.load(file_)
|
||||||
|
return cls(**cfg)
|
||||||
|
|
||||||
|
|
||||||
|
class Dataset(object):
|
||||||
|
def __init__(self, path, section):
|
||||||
|
self.path = path
|
||||||
|
self.section = section
|
||||||
|
self.conllu = None
|
||||||
|
self.text = None
|
||||||
|
for file_path in self.path.iterdir():
|
||||||
|
name = file_path.parts[-1]
|
||||||
|
if section in name and name.endswith('conllu'):
|
||||||
|
self.conllu = file_path
|
||||||
|
elif section in name and name.endswith('txt'):
|
||||||
|
self.text = file_path
|
||||||
|
if self.conllu is None:
|
||||||
|
msg = "Could not find .txt file in {path} for {section}"
|
||||||
|
raise IOError(msg.format(section=section, path=path))
|
||||||
|
if self.text is None:
|
||||||
|
msg = "Could not find .txt file in {path} for {section}"
|
||||||
|
self.lang = self.conllu.parts[-1].split('-')[0].split('_')[0]
|
||||||
|
|
||||||
|
|
||||||
|
class TreebankPaths(object):
|
||||||
|
def __init__(self, ud_path, treebank, **cfg):
|
||||||
|
self.train = Dataset(ud_path / treebank, 'train')
|
||||||
|
self.dev = Dataset(ud_path / treebank, 'dev')
|
||||||
|
self.lang = self.train.lang
|
||||||
|
|
||||||
|
|
||||||
|
@plac.annotations(
|
||||||
|
ud_dir=("Path to Universal Dependencies corpus", "positional", None, Path),
|
||||||
|
config=("Path to json formatted config file", "positional", None, Config.load),
|
||||||
|
corpus=("UD corpus to train and evaluate on, e.g. en, es_ancora, etc",
|
||||||
|
"positional", None, str),
|
||||||
|
parses=("Path to write the development parses", "positional", None, Path)
|
||||||
|
)
|
||||||
|
def main(ud_dir, corpus, config, parses='/tmp/dev.conllu'):
|
||||||
|
paths = TreebankPaths(ud_dir, corpus)
|
||||||
|
nlp = load_nlp(paths.lang, config)
|
||||||
|
|
||||||
|
docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(),
|
||||||
|
config)
|
||||||
|
|
||||||
|
optimizer = initialize_pipeline(nlp, docs, golds, config)
|
||||||
n_train_words = sum(len(doc) for doc in docs)
|
n_train_words = sum(len(doc) for doc in docs)
|
||||||
print(n_train_words)
|
print("Begin training (%d words)" % n_train_words)
|
||||||
print("Begin training")
|
for i in range(config.nr_epoch):
|
||||||
# Batch size starts at 1 and grows, so that we make updates quickly
|
|
||||||
# at the beginning of training.
|
|
||||||
batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 1),
|
|
||||||
spacy.util.env_opt('batch_to', 8),
|
|
||||||
spacy.util.env_opt('batch_compound', 1.001))
|
|
||||||
for i in range(30):
|
|
||||||
docs = refresh_docs(docs)
|
docs = refresh_docs(docs)
|
||||||
batches = minibatch_by_words(list(zip(docs, golds)), size=1000)
|
batches = minibatch_by_words(list(zip(docs, golds)), size=config.batch_size)
|
||||||
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
|
||||||
losses = {}
|
losses = {}
|
||||||
for batch in batches:
|
for batch in tqdm.tqdm(batches, total=n_train_words//config.batch_size):
|
||||||
if not batch:
|
if not batch:
|
||||||
continue
|
continue
|
||||||
batch_docs, batch_gold = zip(*batch)
|
batch_docs, batch_gold = zip(*batch)
|
||||||
|
|
||||||
nlp.update(batch_docs, batch_gold, sgd=optimizer,
|
nlp.update(batch_docs, batch_gold, sgd=optimizer,
|
||||||
drop=0.2, losses=losses)
|
drop=config.dropout, losses=losses)
|
||||||
pbar.update(sum(len(doc) for doc in batch_docs))
|
|
||||||
|
|
||||||
with nlp.use_params(optimizer.averages):
|
with nlp.use_params(optimizer.averages):
|
||||||
dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc,
|
dev_docs, scorer = parse_dev_data(nlp, paths.dev.text, paths.dev.conllu,
|
||||||
oracle_segments=False, joint_sbd=True)
|
**attr.asdict(config))
|
||||||
print_progress(i, losses, scorer)
|
print_progress(i, losses, scorer)
|
||||||
with open(output_loc, 'w') as file_:
|
with open(output_loc, 'w') as file_:
|
||||||
print_conllu(dev_docs, file_)
|
print_conllu(dev_docs, file_)
|
||||||
with open('/tmp/train.conllu', 'w') as file_:
|
|
||||||
print_conllu(list(nlp.pipe([d.text for d in batch_docs])), file_)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue
Block a user