mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 17:54:39 +03:00
Minibatch by number of tokens, support other vectors, refactor CoNLL printing
This commit is contained in:
parent
dd78ef066a
commit
c388833ca6
|
@ -8,21 +8,38 @@ import re
|
|||
import sys
|
||||
import spacy
|
||||
import spacy.util
|
||||
from spacy.tokens import Doc
|
||||
from spacy.tokens import Token, Doc
|
||||
from spacy.gold import GoldParse, minibatch
|
||||
from spacy.syntax.nonproj import projectivize
|
||||
from collections import defaultdict, Counter
|
||||
from timeit import default_timer as timer
|
||||
from spacy.matcher import Matcher
|
||||
|
||||
import itertools
|
||||
import random
|
||||
import numpy.random
|
||||
import cytoolz
|
||||
|
||||
from spacy._align import align
|
||||
|
||||
random.seed(0)
|
||||
numpy.random.seed(0)
|
||||
|
||||
def minibatch_by_words(items, size=5000):
|
||||
if isinstance(size, int):
|
||||
size_ = itertools.repeat(size)
|
||||
else:
|
||||
size_ = size
|
||||
items = iter(items)
|
||||
while True:
|
||||
batch_size = next(size_)
|
||||
batch = []
|
||||
while batch_size >= 0:
|
||||
doc, gold = next(items)
|
||||
batch_size -= len(doc)
|
||||
batch.append((doc, gold))
|
||||
yield batch
|
||||
|
||||
|
||||
def get_token_acc(docs, golds):
|
||||
'''Quick function to evaluate tokenization accuracy.'''
|
||||
|
@ -214,31 +231,51 @@ def print_conllu(docs, file_):
|
|||
offsets = [(span.start_char, span.end_char) for span in spans]
|
||||
for start_char, end_char in offsets:
|
||||
doc.merge(start_char, end_char)
|
||||
#print([t.text for t in doc])
|
||||
file_.write("# newdoc id = {i}\n".format(i=i))
|
||||
for j, sent in enumerate(doc.sents):
|
||||
file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j))
|
||||
file_.write("# text = {text}\n".format(text=sent.text))
|
||||
for k, t in enumerate(sent):
|
||||
if t.head.i == t.i:
|
||||
head = 0
|
||||
else:
|
||||
head = k + (t.head.i - t.i) + 1
|
||||
fields = [str(k+1), t.text, t.lemma_, t.pos_, t.tag_, '_',
|
||||
str(head), t.dep_.lower(), '_', '_']
|
||||
file_.write('\t'.join(fields) + '\n')
|
||||
for k, token in enumerate(sent):
|
||||
file_.write(token._.get_conllu_lines(k) + '\n')
|
||||
file_.write('\n')
|
||||
|
||||
#def get_sent_conllu(sent, sent_id):
|
||||
# lines = ["# sent_id = {sent_id}".format(sent_id=sent_id)]
|
||||
|
||||
def get_token_conllu(token, i):
|
||||
if token._.begins_fused:
|
||||
n = 1
|
||||
while token.nbor(n)._.inside_fused:
|
||||
n += 1
|
||||
id_ = '%d-%d' % (k, k+n)
|
||||
lines = [id_, token.text, '_', '_', '_', '_', '_', '_', '_', '_']
|
||||
else:
|
||||
lines = []
|
||||
if token.head.i == token.i:
|
||||
head = 0
|
||||
else:
|
||||
head = i + (token.head.i - token.i) + 1
|
||||
fields = [str(i+1), token.text, token.lemma_, token.pos_, token.tag_, '_',
|
||||
str(head), token.dep_.lower(), '_', '_']
|
||||
lines.append('\t'.join(fields))
|
||||
return '\n'.join(lines)
|
||||
|
||||
Token.set_extension('get_conllu_lines', method=get_token_conllu)
|
||||
Token.set_extension('begins_fused', default=False)
|
||||
Token.set_extension('inside_fused', default=False)
|
||||
|
||||
|
||||
def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
|
||||
output_loc):
|
||||
nlp = spacy.blank(lang)
|
||||
if lang == 'en':
|
||||
nlp = spacy.blank(lang)
|
||||
vec_nlp = spacy.util.load_model('spacy/data/en_core_web_lg/en_core_web_lg-2.0.0')
|
||||
nlp.vocab.vectors = vec_nlp.vocab.vectors
|
||||
for lex in vec_nlp.vocab:
|
||||
_ = nlp.vocab[lex.orth_]
|
||||
vec_nlp = None
|
||||
else:
|
||||
nlp = spacy.load(lang)
|
||||
with open(conllu_train_loc) as conllu_file:
|
||||
with open(text_train_loc) as text_file:
|
||||
docs, golds = read_data(nlp, conllu_file, text_file,
|
||||
|
@ -272,7 +309,7 @@ def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
|
|||
spacy.util.env_opt('batch_compound', 1.001))
|
||||
for i in range(30):
|
||||
docs = refresh_docs(docs)
|
||||
batches = minibatch(list(zip(docs, golds)), size=batch_sizes)
|
||||
batches = minibatch_by_words(list(zip(docs, golds)), size=1000)
|
||||
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
||||
losses = {}
|
||||
for batch in batches:
|
||||
|
|
Loading…
Reference in New Issue
Block a user