mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
Refactor conllu script
This commit is contained in:
parent
c388833ca6
commit
44e496a82e
|
@ -4,8 +4,12 @@
|
|||
from __future__ import unicode_literals
|
||||
import plac
|
||||
import tqdm
|
||||
import attr
|
||||
from pathlib import Path
|
||||
import re
|
||||
import sys
|
||||
import json
|
||||
|
||||
import spacy
|
||||
import spacy.util
|
||||
from spacy.tokens import Token, Doc
|
||||
|
@ -40,32 +44,9 @@ def minibatch_by_words(items, size=5000):
|
|||
batch.append((doc, gold))
|
||||
yield batch
|
||||
|
||||
|
||||
def get_token_acc(docs, golds):
|
||||
'''Quick function to evaluate tokenization accuracy.'''
|
||||
miss = 0
|
||||
hit = 0
|
||||
for doc, gold in zip(docs, golds):
|
||||
for i in range(len(doc)):
|
||||
token = doc[i]
|
||||
align = gold.words[i]
|
||||
if align == None:
|
||||
miss += 1
|
||||
else:
|
||||
hit += 1
|
||||
return miss, hit
|
||||
|
||||
|
||||
def golds_to_gold_tuples(docs, golds):
|
||||
'''Get out the annoying 'tuples' format used by begin_training, given the
|
||||
GoldParse objects.'''
|
||||
tuples = []
|
||||
for doc, gold in zip(docs, golds):
|
||||
text = doc.text
|
||||
ids, words, tags, heads, labels, iob = zip(*gold.orig_annot)
|
||||
sents = [((ids, words, tags, heads, labels, iob), [])]
|
||||
tuples.append((text, sents))
|
||||
return tuples
|
||||
################
|
||||
# Data reading #
|
||||
################
|
||||
|
||||
def split_text(text):
|
||||
return [par.strip().replace('\n', ' ')
|
||||
|
@ -127,34 +108,6 @@ def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False,
|
|||
return docs, golds
|
||||
|
||||
|
||||
def _make_gold(nlp, text, sent_annots):
|
||||
# Flatten the conll annotations, and adjust the head indices
|
||||
flat = defaultdict(list)
|
||||
for sent in sent_annots:
|
||||
flat['heads'].extend(len(flat['words'])+head for head in sent['heads'])
|
||||
for field in ['words', 'tags', 'deps', 'entities', 'spaces']:
|
||||
flat[field].extend(sent[field])
|
||||
# Construct text if necessary
|
||||
assert len(flat['words']) == len(flat['spaces'])
|
||||
if text is None:
|
||||
text = ''.join(word+' '*space for word, space in zip(flat['words'], flat['spaces']))
|
||||
doc = nlp.make_doc(text)
|
||||
flat.pop('spaces')
|
||||
gold = GoldParse(doc, **flat)
|
||||
#for annot in gold.orig_annot:
|
||||
# print(annot)
|
||||
#for i in range(len(doc)):
|
||||
# print(doc[i].text, gold.words[i], gold.labels[i], gold.heads[i])
|
||||
return doc, gold
|
||||
|
||||
|
||||
def refresh_docs(docs):
|
||||
vocab = docs[0].vocab
|
||||
return [Doc(vocab, words=[t.text for t in doc],
|
||||
spaces=[t.whitespace_ for t in doc])
|
||||
for doc in docs]
|
||||
|
||||
|
||||
def read_conllu(file_):
|
||||
docs = []
|
||||
sent = []
|
||||
|
@ -179,6 +132,52 @@ def read_conllu(file_):
|
|||
return docs
|
||||
|
||||
|
||||
def _make_gold(nlp, text, sent_annots):
|
||||
# Flatten the conll annotations, and adjust the head indices
|
||||
flat = defaultdict(list)
|
||||
for sent in sent_annots:
|
||||
flat['heads'].extend(len(flat['words'])+head for head in sent['heads'])
|
||||
for field in ['words', 'tags', 'deps', 'entities', 'spaces']:
|
||||
flat[field].extend(sent[field])
|
||||
# Construct text if necessary
|
||||
assert len(flat['words']) == len(flat['spaces'])
|
||||
if text is None:
|
||||
text = ''.join(word+' '*space for word, space in zip(flat['words'], flat['spaces']))
|
||||
doc = nlp.make_doc(text)
|
||||
flat.pop('spaces')
|
||||
gold = GoldParse(doc, **flat)
|
||||
#for annot in gold.orig_annot:
|
||||
# print(annot)
|
||||
#for i in range(len(doc)):
|
||||
# print(doc[i].text, gold.words[i], gold.labels[i], gold.heads[i])
|
||||
return doc, gold
|
||||
|
||||
#############################
|
||||
# Data transforms for spaCy #
|
||||
#############################
|
||||
|
||||
def golds_to_gold_tuples(docs, golds):
|
||||
'''Get out the annoying 'tuples' format used by begin_training, given the
|
||||
GoldParse objects.'''
|
||||
tuples = []
|
||||
for doc, gold in zip(docs, golds):
|
||||
text = doc.text
|
||||
ids, words, tags, heads, labels, iob = zip(*gold.orig_annot)
|
||||
sents = [((ids, words, tags, heads, labels, iob), [])]
|
||||
tuples.append((text, sents))
|
||||
return tuples
|
||||
|
||||
|
||||
def refresh_docs(docs):
|
||||
vocab = docs[0].vocab
|
||||
return [Doc(vocab, words=[t.text for t in doc],
|
||||
spaces=[t.whitespace_ for t in doc])
|
||||
for doc in docs]
|
||||
|
||||
##############
|
||||
# Evaluation #
|
||||
##############
|
||||
|
||||
def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False,
|
||||
joint_sbd=True, limit=None):
|
||||
with open(text_loc) as text_file:
|
||||
|
@ -265,25 +264,24 @@ Token.set_extension('begins_fused', default=False)
|
|||
Token.set_extension('inside_fused', default=False)
|
||||
|
||||
|
||||
def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
|
||||
output_loc):
|
||||
if lang == 'en':
|
||||
##################
|
||||
# Initialization #
|
||||
##################
|
||||
|
||||
|
||||
def load_nlp(corpus, config):
|
||||
lang = corpus.split('_')[0]
|
||||
nlp = spacy.blank(lang)
|
||||
vec_nlp = spacy.util.load_model('spacy/data/en_core_web_lg/en_core_web_lg-2.0.0')
|
||||
nlp.vocab.vectors = vec_nlp.vocab.vectors
|
||||
for lex in vec_nlp.vocab:
|
||||
_ = nlp.vocab[lex.orth_]
|
||||
vec_nlp = None
|
||||
else:
|
||||
nlp = spacy.load(lang)
|
||||
with open(conllu_train_loc) as conllu_file:
|
||||
with open(text_train_loc) as text_file:
|
||||
docs, golds = read_data(nlp, conllu_file, text_file,
|
||||
oracle_segments=False, raw_text=True,
|
||||
max_doc_length=10, limit=None)
|
||||
if config.vectors:
|
||||
nlp.vocab.from_disk(config.vectors / 'vocab')
|
||||
return nlp
|
||||
|
||||
def initialize_pipeline(nlp, docs, golds, config):
|
||||
print("Create parser")
|
||||
nlp.add_pipe(nlp.create_pipe('parser'))
|
||||
if config.multitask_tag:
|
||||
nlp.parser.add_multitask_objective('tag')
|
||||
if config.multitask_sent:
|
||||
nlp.parser.add_multitask_objective('sent_start')
|
||||
nlp.parser.moves.add_action(2, 'subtok')
|
||||
nlp.add_pipe(nlp.create_pipe('tagger'))
|
||||
|
@ -291,7 +289,6 @@ def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
|
|||
for tag in gold.tags:
|
||||
if tag is not None:
|
||||
nlp.tagger.add_label(tag)
|
||||
optimizer = nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds))
|
||||
# Replace labels that didn't make the frequency cutoff
|
||||
actions = set(nlp.parser.labels)
|
||||
label_set = set([act.split('-')[1] for act in actions if '-' in act])
|
||||
|
@ -299,38 +296,92 @@ def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
|
|||
for i, label in enumerate(gold.labels):
|
||||
if label is not None and label not in label_set:
|
||||
gold.labels[i] = label.split('||')[0]
|
||||
return nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds))
|
||||
|
||||
|
||||
########################
|
||||
# Command line helpers #
|
||||
########################
|
||||
|
||||
@attr.s
|
||||
class Config(object):
|
||||
vectors = attr.ib(default=None)
|
||||
max_doc_length = attr.ib(default=10)
|
||||
multitask_tag = attr.ib(default=True)
|
||||
multitask_sent = attr.ib(default=True)
|
||||
nr_epoch = attr.ib(default=30)
|
||||
batch_size = attr.ib(default=1000)
|
||||
dropout = attr.ib(default=0.2)
|
||||
|
||||
@classmethod
|
||||
def load(cls, loc):
|
||||
with Path(loc).open('r', encoding='utf8') as file_:
|
||||
cfg = json.load(file_)
|
||||
return cls(**cfg)
|
||||
|
||||
|
||||
class Dataset(object):
|
||||
def __init__(self, path, section):
|
||||
self.path = path
|
||||
self.section = section
|
||||
self.conllu = None
|
||||
self.text = None
|
||||
for file_path in self.path.iterdir():
|
||||
name = file_path.parts[-1]
|
||||
if section in name and name.endswith('conllu'):
|
||||
self.conllu = file_path
|
||||
elif section in name and name.endswith('txt'):
|
||||
self.text = file_path
|
||||
if self.conllu is None:
|
||||
msg = "Could not find .txt file in {path} for {section}"
|
||||
raise IOError(msg.format(section=section, path=path))
|
||||
if self.text is None:
|
||||
msg = "Could not find .txt file in {path} for {section}"
|
||||
self.lang = self.conllu.parts[-1].split('-')[0].split('_')[0]
|
||||
|
||||
|
||||
class TreebankPaths(object):
|
||||
def __init__(self, ud_path, treebank, **cfg):
|
||||
self.train = Dataset(ud_path / treebank, 'train')
|
||||
self.dev = Dataset(ud_path / treebank, 'dev')
|
||||
self.lang = self.train.lang
|
||||
|
||||
|
||||
@plac.annotations(
|
||||
ud_dir=("Path to Universal Dependencies corpus", "positional", None, Path),
|
||||
config=("Path to json formatted config file", "positional", None, Config.load),
|
||||
corpus=("UD corpus to train and evaluate on, e.g. en, es_ancora, etc",
|
||||
"positional", None, str),
|
||||
parses=("Path to write the development parses", "positional", None, Path)
|
||||
)
|
||||
def main(ud_dir, corpus, config, parses='/tmp/dev.conllu'):
|
||||
paths = TreebankPaths(ud_dir, corpus)
|
||||
nlp = load_nlp(paths.lang, config)
|
||||
|
||||
docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(),
|
||||
config)
|
||||
|
||||
optimizer = initialize_pipeline(nlp, docs, golds, config)
|
||||
n_train_words = sum(len(doc) for doc in docs)
|
||||
print(n_train_words)
|
||||
print("Begin training")
|
||||
# Batch size starts at 1 and grows, so that we make updates quickly
|
||||
# at the beginning of training.
|
||||
batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 1),
|
||||
spacy.util.env_opt('batch_to', 8),
|
||||
spacy.util.env_opt('batch_compound', 1.001))
|
||||
for i in range(30):
|
||||
print("Begin training (%d words)" % n_train_words)
|
||||
for i in range(config.nr_epoch):
|
||||
docs = refresh_docs(docs)
|
||||
batches = minibatch_by_words(list(zip(docs, golds)), size=1000)
|
||||
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
||||
batches = minibatch_by_words(list(zip(docs, golds)), size=config.batch_size)
|
||||
losses = {}
|
||||
for batch in batches:
|
||||
for batch in tqdm.tqdm(batches, total=n_train_words//config.batch_size):
|
||||
if not batch:
|
||||
continue
|
||||
batch_docs, batch_gold = zip(*batch)
|
||||
|
||||
nlp.update(batch_docs, batch_gold, sgd=optimizer,
|
||||
drop=0.2, losses=losses)
|
||||
pbar.update(sum(len(doc) for doc in batch_docs))
|
||||
drop=config.dropout, losses=losses)
|
||||
|
||||
with nlp.use_params(optimizer.averages):
|
||||
dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc,
|
||||
oracle_segments=False, joint_sbd=True)
|
||||
dev_docs, scorer = parse_dev_data(nlp, paths.dev.text, paths.dev.conllu,
|
||||
**attr.asdict(config))
|
||||
print_progress(i, losses, scorer)
|
||||
with open(output_loc, 'w') as file_:
|
||||
print_conllu(dev_docs, file_)
|
||||
with open('/tmp/train.conllu', 'w') as file_:
|
||||
print_conllu(list(nlp.pipe([d.text for d in batch_docs])), file_)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue
Block a user