mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-04 05:03:06 +03:00
Minibatch by number of tokens, support other vectors, refactor CoNLL printing
This commit is contained in:
parent
dd78ef066a
commit
c388833ca6
|
@ -8,21 +8,38 @@ import re
|
||||||
import sys
|
import sys
|
||||||
import spacy
|
import spacy
|
||||||
import spacy.util
|
import spacy.util
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Token, Doc
|
||||||
from spacy.gold import GoldParse, minibatch
|
from spacy.gold import GoldParse, minibatch
|
||||||
from spacy.syntax.nonproj import projectivize
|
from spacy.syntax.nonproj import projectivize
|
||||||
from collections import defaultdict, Counter
|
from collections import defaultdict, Counter
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
from spacy.matcher import Matcher
|
from spacy.matcher import Matcher
|
||||||
|
|
||||||
|
import itertools
|
||||||
import random
|
import random
|
||||||
import numpy.random
|
import numpy.random
|
||||||
|
import cytoolz
|
||||||
|
|
||||||
from spacy._align import align
|
from spacy._align import align
|
||||||
|
|
||||||
random.seed(0)
|
random.seed(0)
|
||||||
numpy.random.seed(0)
|
numpy.random.seed(0)
|
||||||
|
|
||||||
|
def minibatch_by_words(items, size=5000):
|
||||||
|
if isinstance(size, int):
|
||||||
|
size_ = itertools.repeat(size)
|
||||||
|
else:
|
||||||
|
size_ = size
|
||||||
|
items = iter(items)
|
||||||
|
while True:
|
||||||
|
batch_size = next(size_)
|
||||||
|
batch = []
|
||||||
|
while batch_size >= 0:
|
||||||
|
doc, gold = next(items)
|
||||||
|
batch_size -= len(doc)
|
||||||
|
batch.append((doc, gold))
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
def get_token_acc(docs, golds):
|
def get_token_acc(docs, golds):
|
||||||
'''Quick function to evaluate tokenization accuracy.'''
|
'''Quick function to evaluate tokenization accuracy.'''
|
||||||
|
@ -214,31 +231,51 @@ def print_conllu(docs, file_):
|
||||||
offsets = [(span.start_char, span.end_char) for span in spans]
|
offsets = [(span.start_char, span.end_char) for span in spans]
|
||||||
for start_char, end_char in offsets:
|
for start_char, end_char in offsets:
|
||||||
doc.merge(start_char, end_char)
|
doc.merge(start_char, end_char)
|
||||||
#print([t.text for t in doc])
|
|
||||||
file_.write("# newdoc id = {i}\n".format(i=i))
|
file_.write("# newdoc id = {i}\n".format(i=i))
|
||||||
for j, sent in enumerate(doc.sents):
|
for j, sent in enumerate(doc.sents):
|
||||||
file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j))
|
file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j))
|
||||||
file_.write("# text = {text}\n".format(text=sent.text))
|
file_.write("# text = {text}\n".format(text=sent.text))
|
||||||
for k, t in enumerate(sent):
|
for k, token in enumerate(sent):
|
||||||
if t.head.i == t.i:
|
file_.write(token._.get_conllu_lines(k) + '\n')
|
||||||
head = 0
|
|
||||||
else:
|
|
||||||
head = k + (t.head.i - t.i) + 1
|
|
||||||
fields = [str(k+1), t.text, t.lemma_, t.pos_, t.tag_, '_',
|
|
||||||
str(head), t.dep_.lower(), '_', '_']
|
|
||||||
file_.write('\t'.join(fields) + '\n')
|
|
||||||
file_.write('\n')
|
file_.write('\n')
|
||||||
|
|
||||||
|
#def get_sent_conllu(sent, sent_id):
|
||||||
|
# lines = ["# sent_id = {sent_id}".format(sent_id=sent_id)]
|
||||||
|
|
||||||
|
def get_token_conllu(token, i):
|
||||||
|
if token._.begins_fused:
|
||||||
|
n = 1
|
||||||
|
while token.nbor(n)._.inside_fused:
|
||||||
|
n += 1
|
||||||
|
id_ = '%d-%d' % (k, k+n)
|
||||||
|
lines = [id_, token.text, '_', '_', '_', '_', '_', '_', '_', '_']
|
||||||
|
else:
|
||||||
|
lines = []
|
||||||
|
if token.head.i == token.i:
|
||||||
|
head = 0
|
||||||
|
else:
|
||||||
|
head = i + (token.head.i - token.i) + 1
|
||||||
|
fields = [str(i+1), token.text, token.lemma_, token.pos_, token.tag_, '_',
|
||||||
|
str(head), token.dep_.lower(), '_', '_']
|
||||||
|
lines.append('\t'.join(fields))
|
||||||
|
return '\n'.join(lines)
|
||||||
|
|
||||||
|
Token.set_extension('get_conllu_lines', method=get_token_conllu)
|
||||||
|
Token.set_extension('begins_fused', default=False)
|
||||||
|
Token.set_extension('inside_fused', default=False)
|
||||||
|
|
||||||
|
|
||||||
def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
|
def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
|
||||||
output_loc):
|
output_loc):
|
||||||
nlp = spacy.blank(lang)
|
|
||||||
if lang == 'en':
|
if lang == 'en':
|
||||||
|
nlp = spacy.blank(lang)
|
||||||
vec_nlp = spacy.util.load_model('spacy/data/en_core_web_lg/en_core_web_lg-2.0.0')
|
vec_nlp = spacy.util.load_model('spacy/data/en_core_web_lg/en_core_web_lg-2.0.0')
|
||||||
nlp.vocab.vectors = vec_nlp.vocab.vectors
|
nlp.vocab.vectors = vec_nlp.vocab.vectors
|
||||||
for lex in vec_nlp.vocab:
|
for lex in vec_nlp.vocab:
|
||||||
_ = nlp.vocab[lex.orth_]
|
_ = nlp.vocab[lex.orth_]
|
||||||
vec_nlp = None
|
vec_nlp = None
|
||||||
|
else:
|
||||||
|
nlp = spacy.load(lang)
|
||||||
with open(conllu_train_loc) as conllu_file:
|
with open(conllu_train_loc) as conllu_file:
|
||||||
with open(text_train_loc) as text_file:
|
with open(text_train_loc) as text_file:
|
||||||
docs, golds = read_data(nlp, conllu_file, text_file,
|
docs, golds = read_data(nlp, conllu_file, text_file,
|
||||||
|
@ -272,7 +309,7 @@ def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
|
||||||
spacy.util.env_opt('batch_compound', 1.001))
|
spacy.util.env_opt('batch_compound', 1.001))
|
||||||
for i in range(30):
|
for i in range(30):
|
||||||
docs = refresh_docs(docs)
|
docs = refresh_docs(docs)
|
||||||
batches = minibatch(list(zip(docs, golds)), size=batch_sizes)
|
batches = minibatch_by_words(list(zip(docs, golds)), size=1000)
|
||||||
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
||||||
losses = {}
|
losses = {}
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user