mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Minibatch by number of tokens, support other vectors, refactor CoNLL printing
This commit is contained in:
		
							parent
							
								
									dd78ef066a
								
							
						
					
					
						commit
						c388833ca6
					
				| 
						 | 
					@ -8,21 +8,38 @@ import re
 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
import spacy
 | 
					import spacy
 | 
				
			||||||
import spacy.util
 | 
					import spacy.util
 | 
				
			||||||
from spacy.tokens import Doc
 | 
					from spacy.tokens import Token, Doc
 | 
				
			||||||
from spacy.gold import GoldParse, minibatch
 | 
					from spacy.gold import GoldParse, minibatch
 | 
				
			||||||
from spacy.syntax.nonproj import projectivize
 | 
					from spacy.syntax.nonproj import projectivize
 | 
				
			||||||
from collections import defaultdict, Counter
 | 
					from collections import defaultdict, Counter
 | 
				
			||||||
from timeit import default_timer as timer
 | 
					from timeit import default_timer as timer
 | 
				
			||||||
from spacy.matcher import Matcher
 | 
					from spacy.matcher import Matcher
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import itertools
 | 
				
			||||||
import random
 | 
					import random
 | 
				
			||||||
import numpy.random
 | 
					import numpy.random
 | 
				
			||||||
 | 
					import cytoolz
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from spacy._align import align
 | 
					from spacy._align import align
 | 
				
			||||||
 | 
					
 | 
				
			||||||
random.seed(0)
 | 
					random.seed(0)
 | 
				
			||||||
numpy.random.seed(0)
 | 
					numpy.random.seed(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def minibatch_by_words(items, size=5000):
 | 
				
			||||||
 | 
					    if isinstance(size, int):
 | 
				
			||||||
 | 
					        size_ = itertools.repeat(size)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        size_ = size
 | 
				
			||||||
 | 
					    items = iter(items)
 | 
				
			||||||
 | 
					    while True:
 | 
				
			||||||
 | 
					        batch_size = next(size_)
 | 
				
			||||||
 | 
					        batch = []
 | 
				
			||||||
 | 
					        while batch_size >= 0:
 | 
				
			||||||
 | 
					            doc, gold = next(items)
 | 
				
			||||||
 | 
					            batch_size -= len(doc)
 | 
				
			||||||
 | 
					            batch.append((doc, gold))
 | 
				
			||||||
 | 
					        yield batch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_token_acc(docs, golds):
 | 
					def get_token_acc(docs, golds):
 | 
				
			||||||
    '''Quick function to evaluate tokenization accuracy.'''
 | 
					    '''Quick function to evaluate tokenization accuracy.'''
 | 
				
			||||||
| 
						 | 
					@ -214,31 +231,51 @@ def print_conllu(docs, file_):
 | 
				
			||||||
        offsets = [(span.start_char, span.end_char) for span in spans]
 | 
					        offsets = [(span.start_char, span.end_char) for span in spans]
 | 
				
			||||||
        for start_char, end_char in offsets:
 | 
					        for start_char, end_char in offsets:
 | 
				
			||||||
            doc.merge(start_char, end_char)
 | 
					            doc.merge(start_char, end_char)
 | 
				
			||||||
        #print([t.text for t in doc])
 | 
					 | 
				
			||||||
        file_.write("# newdoc id = {i}\n".format(i=i))
 | 
					        file_.write("# newdoc id = {i}\n".format(i=i))
 | 
				
			||||||
        for j, sent in enumerate(doc.sents):
 | 
					        for j, sent in enumerate(doc.sents):
 | 
				
			||||||
            file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j))
 | 
					            file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j))
 | 
				
			||||||
            file_.write("# text = {text}\n".format(text=sent.text))
 | 
					            file_.write("# text = {text}\n".format(text=sent.text))
 | 
				
			||||||
            for k, t in enumerate(sent):
 | 
					            for k, token in enumerate(sent):
 | 
				
			||||||
                if t.head.i == t.i:
 | 
					                file_.write(token._.get_conllu_lines(k) + '\n')
 | 
				
			||||||
 | 
					            file_.write('\n')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#def get_sent_conllu(sent, sent_id):
 | 
				
			||||||
 | 
					#    lines = ["# sent_id = {sent_id}".format(sent_id=sent_id)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_token_conllu(token, i):
 | 
				
			||||||
 | 
					    if token._.begins_fused:
 | 
				
			||||||
 | 
					        n = 1
 | 
				
			||||||
 | 
					        while token.nbor(n)._.inside_fused:
 | 
				
			||||||
 | 
					            n += 1
 | 
				
			||||||
 | 
					        id_ = '%d-%d' % (k, k+n)
 | 
				
			||||||
 | 
					        lines = [id_, token.text, '_', '_', '_', '_', '_', '_', '_', '_']
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        lines = []
 | 
				
			||||||
 | 
					    if token.head.i == token.i:
 | 
				
			||||||
        head = 0
 | 
					        head = 0
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
                    head = k + (t.head.i - t.i) + 1
 | 
					        head = i + (token.head.i - token.i) + 1
 | 
				
			||||||
                fields = [str(k+1), t.text, t.lemma_, t.pos_, t.tag_, '_',
 | 
					    fields = [str(i+1), token.text, token.lemma_, token.pos_, token.tag_, '_',
 | 
				
			||||||
                          str(head), t.dep_.lower(), '_', '_']
 | 
					              str(head), token.dep_.lower(), '_', '_']
 | 
				
			||||||
                file_.write('\t'.join(fields) + '\n')
 | 
					    lines.append('\t'.join(fields))
 | 
				
			||||||
            file_.write('\n')
 | 
					    return '\n'.join(lines)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Token.set_extension('get_conllu_lines', method=get_token_conllu)
 | 
				
			||||||
 | 
					Token.set_extension('begins_fused', default=False)
 | 
				
			||||||
 | 
					Token.set_extension('inside_fused', default=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
 | 
					def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
 | 
				
			||||||
         output_loc):
 | 
					         output_loc):
 | 
				
			||||||
    nlp = spacy.blank(lang)
 | 
					 | 
				
			||||||
    if lang == 'en':
 | 
					    if lang == 'en':
 | 
				
			||||||
 | 
					        nlp = spacy.blank(lang)
 | 
				
			||||||
        vec_nlp = spacy.util.load_model('spacy/data/en_core_web_lg/en_core_web_lg-2.0.0')
 | 
					        vec_nlp = spacy.util.load_model('spacy/data/en_core_web_lg/en_core_web_lg-2.0.0')
 | 
				
			||||||
        nlp.vocab.vectors = vec_nlp.vocab.vectors
 | 
					        nlp.vocab.vectors = vec_nlp.vocab.vectors
 | 
				
			||||||
        for lex in vec_nlp.vocab:
 | 
					        for lex in vec_nlp.vocab:
 | 
				
			||||||
            _ = nlp.vocab[lex.orth_]
 | 
					            _ = nlp.vocab[lex.orth_]
 | 
				
			||||||
        vec_nlp = None
 | 
					        vec_nlp = None
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        nlp = spacy.load(lang)
 | 
				
			||||||
    with open(conllu_train_loc) as conllu_file:
 | 
					    with open(conllu_train_loc) as conllu_file:
 | 
				
			||||||
        with open(text_train_loc) as text_file:
 | 
					        with open(text_train_loc) as text_file:
 | 
				
			||||||
            docs, golds = read_data(nlp, conllu_file, text_file,
 | 
					            docs, golds = read_data(nlp, conllu_file, text_file,
 | 
				
			||||||
| 
						 | 
					@ -272,7 +309,7 @@ def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc,
 | 
				
			||||||
                                   spacy.util.env_opt('batch_compound', 1.001))
 | 
					                                   spacy.util.env_opt('batch_compound', 1.001))
 | 
				
			||||||
    for i in range(30):
 | 
					    for i in range(30):
 | 
				
			||||||
        docs = refresh_docs(docs)
 | 
					        docs = refresh_docs(docs)
 | 
				
			||||||
        batches = minibatch(list(zip(docs, golds)), size=batch_sizes)
 | 
					        batches = minibatch_by_words(list(zip(docs, golds)), size=1000)
 | 
				
			||||||
        with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
 | 
					        with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
 | 
				
			||||||
            losses = {}
 | 
					            losses = {}
 | 
				
			||||||
            for batch in batches:
 | 
					            for batch in batches:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user