mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	Refactor conllu script
This commit is contained in:
		
							parent
							
								
									c388833ca6
								
							
						
					
					
						commit
						44e496a82e
					
				|  | @ -4,8 +4,12 @@ | |||
| from __future__ import unicode_literals | ||||
| import plac | ||||
| import tqdm | ||||
| import attr | ||||
| from pathlib import Path | ||||
| import re | ||||
| import sys | ||||
| import json | ||||
| 
 | ||||
| import spacy | ||||
| import spacy.util | ||||
| from spacy.tokens import Token, Doc | ||||
|  | @ -40,32 +44,9 @@ def minibatch_by_words(items, size=5000): | |||
|             batch.append((doc, gold)) | ||||
|         yield batch | ||||
| 
 | ||||
| 
 | ||||
| def get_token_acc(docs, golds): | ||||
|     '''Quick function to evaluate tokenization accuracy.''' | ||||
|     miss = 0 | ||||
|     hit = 0 | ||||
|     for doc, gold in zip(docs, golds): | ||||
|         for i in range(len(doc)): | ||||
|             token = doc[i] | ||||
|             align = gold.words[i] | ||||
|             if align == None: | ||||
|                 miss += 1 | ||||
|             else: | ||||
|                 hit += 1 | ||||
|     return miss, hit | ||||
| 
 | ||||
| 
 | ||||
| def golds_to_gold_tuples(docs, golds): | ||||
|     '''Get out the annoying 'tuples' format used by begin_training, given the | ||||
|     GoldParse objects.''' | ||||
|     tuples = [] | ||||
|     for doc, gold in zip(docs, golds): | ||||
|         text = doc.text | ||||
|         ids, words, tags, heads, labels, iob = zip(*gold.orig_annot) | ||||
|         sents = [((ids, words, tags, heads, labels, iob), [])] | ||||
|         tuples.append((text, sents)) | ||||
|     return tuples | ||||
| ################ | ||||
| # Data reading # | ||||
| ################ | ||||
| 
 | ||||
| def split_text(text): | ||||
|     return [par.strip().replace('\n', ' ') | ||||
|  | @ -127,34 +108,6 @@ def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False, | |||
|     return docs, golds | ||||
| 
 | ||||
| 
 | ||||
| def _make_gold(nlp, text, sent_annots): | ||||
|     # Flatten the conll annotations, and adjust the head indices | ||||
|     flat = defaultdict(list) | ||||
|     for sent in sent_annots: | ||||
|         flat['heads'].extend(len(flat['words'])+head for head in sent['heads']) | ||||
|         for field in ['words', 'tags', 'deps', 'entities', 'spaces']: | ||||
|             flat[field].extend(sent[field]) | ||||
|     # Construct text if necessary | ||||
|     assert len(flat['words']) == len(flat['spaces']) | ||||
|     if text is None: | ||||
|         text = ''.join(word+' '*space for word, space in zip(flat['words'], flat['spaces']))  | ||||
|     doc = nlp.make_doc(text) | ||||
|     flat.pop('spaces') | ||||
|     gold = GoldParse(doc, **flat) | ||||
|     #for annot in gold.orig_annot: | ||||
|     #    print(annot) | ||||
|     #for i in range(len(doc)): | ||||
|     #    print(doc[i].text, gold.words[i], gold.labels[i], gold.heads[i]) | ||||
|     return doc, gold | ||||
| 
 | ||||
| 
 | ||||
| def refresh_docs(docs): | ||||
|     vocab = docs[0].vocab | ||||
|     return [Doc(vocab, words=[t.text for t in doc], | ||||
|                        spaces=[t.whitespace_ for t in doc]) | ||||
|             for doc in docs] | ||||
| 
 | ||||
| 
 | ||||
| def read_conllu(file_): | ||||
|     docs = [] | ||||
|     sent = [] | ||||
|  | @ -179,6 +132,52 @@ def read_conllu(file_): | |||
|     return docs | ||||
| 
 | ||||
| 
 | ||||
| def _make_gold(nlp, text, sent_annots): | ||||
|     # Flatten the conll annotations, and adjust the head indices | ||||
|     flat = defaultdict(list) | ||||
|     for sent in sent_annots: | ||||
|         flat['heads'].extend(len(flat['words'])+head for head in sent['heads']) | ||||
|         for field in ['words', 'tags', 'deps', 'entities', 'spaces']: | ||||
|             flat[field].extend(sent[field]) | ||||
|     # Construct text if necessary | ||||
|     assert len(flat['words']) == len(flat['spaces']) | ||||
|     if text is None: | ||||
|         text = ''.join(word+' '*space for word, space in zip(flat['words'], flat['spaces']))  | ||||
|     doc = nlp.make_doc(text) | ||||
|     flat.pop('spaces') | ||||
|     gold = GoldParse(doc, **flat) | ||||
|     #for annot in gold.orig_annot: | ||||
|     #    print(annot) | ||||
|     #for i in range(len(doc)): | ||||
|     #    print(doc[i].text, gold.words[i], gold.labels[i], gold.heads[i]) | ||||
|     return doc, gold | ||||
| 
 | ||||
| ############################# | ||||
| # Data transforms for spaCy # | ||||
| ############################# | ||||
| 
 | ||||
| def golds_to_gold_tuples(docs, golds): | ||||
|     '''Get out the annoying 'tuples' format used by begin_training, given the | ||||
|     GoldParse objects.''' | ||||
|     tuples = [] | ||||
|     for doc, gold in zip(docs, golds): | ||||
|         text = doc.text | ||||
|         ids, words, tags, heads, labels, iob = zip(*gold.orig_annot) | ||||
|         sents = [((ids, words, tags, heads, labels, iob), [])] | ||||
|         tuples.append((text, sents)) | ||||
|     return tuples | ||||
| 
 | ||||
| 
 | ||||
| def refresh_docs(docs): | ||||
|     vocab = docs[0].vocab | ||||
|     return [Doc(vocab, words=[t.text for t in doc], | ||||
|                        spaces=[t.whitespace_ for t in doc]) | ||||
|             for doc in docs] | ||||
| 
 | ||||
| ############## | ||||
| # Evaluation # | ||||
| ############## | ||||
| 
 | ||||
| def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False, | ||||
|                    joint_sbd=True, limit=None): | ||||
|     with open(text_loc) as text_file: | ||||
|  | @ -265,33 +264,31 @@ Token.set_extension('begins_fused', default=False) | |||
| Token.set_extension('inside_fused', default=False) | ||||
| 
 | ||||
| 
 | ||||
| def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc, | ||||
|          output_loc): | ||||
|     if lang == 'en': | ||||
|         nlp = spacy.blank(lang) | ||||
|         vec_nlp = spacy.util.load_model('spacy/data/en_core_web_lg/en_core_web_lg-2.0.0') | ||||
|         nlp.vocab.vectors = vec_nlp.vocab.vectors | ||||
|         for lex in vec_nlp.vocab: | ||||
|             _ = nlp.vocab[lex.orth_] | ||||
|         vec_nlp = None | ||||
|     else: | ||||
|         nlp = spacy.load(lang) | ||||
|     with open(conllu_train_loc) as conllu_file: | ||||
|         with open(text_train_loc) as text_file: | ||||
|             docs, golds = read_data(nlp, conllu_file, text_file, | ||||
|                                     oracle_segments=False, raw_text=True, | ||||
|                                     max_doc_length=10, limit=None) | ||||
| ################## | ||||
| # Initialization # | ||||
| ################## | ||||
| 
 | ||||
| 
 | ||||
| def load_nlp(corpus, config): | ||||
|     lang = corpus.split('_')[0] | ||||
|     nlp = spacy.blank(lang) | ||||
|     if config.vectors: | ||||
|         nlp.vocab.from_disk(config.vectors / 'vocab') | ||||
|     return nlp | ||||
| 
 | ||||
| def initialize_pipeline(nlp, docs, golds, config): | ||||
|     print("Create parser") | ||||
|     nlp.add_pipe(nlp.create_pipe('parser')) | ||||
|     nlp.parser.add_multitask_objective('tag') | ||||
|     nlp.parser.add_multitask_objective('sent_start') | ||||
|     if config.multitask_tag: | ||||
|         nlp.parser.add_multitask_objective('tag') | ||||
|     if config.multitask_sent: | ||||
|         nlp.parser.add_multitask_objective('sent_start') | ||||
|     nlp.parser.moves.add_action(2, 'subtok') | ||||
|     nlp.add_pipe(nlp.create_pipe('tagger')) | ||||
|     for gold in golds: | ||||
|         for tag in gold.tags: | ||||
|             if tag is not None: | ||||
|                 nlp.tagger.add_label(tag) | ||||
|     optimizer = nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds)) | ||||
|     # Replace labels that didn't make the frequency cutoff | ||||
|     actions = set(nlp.parser.labels) | ||||
|     label_set = set([act.split('-')[1] for act in actions if '-' in act]) | ||||
|  | @ -299,38 +296,92 @@ def main(lang, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc, | |||
|         for i, label in enumerate(gold.labels): | ||||
|             if label is not None and label not in label_set: | ||||
|                 gold.labels[i] = label.split('||')[0] | ||||
|     n_train_words = sum(len(doc) for doc in docs) | ||||
|     print(n_train_words) | ||||
|     print("Begin training") | ||||
|     # Batch size starts at 1 and grows, so that we make updates quickly | ||||
|     # at the beginning of training. | ||||
|     batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 1), | ||||
|                                    spacy.util.env_opt('batch_to', 8), | ||||
|                                    spacy.util.env_opt('batch_compound', 1.001)) | ||||
|     for i in range(30): | ||||
|         docs = refresh_docs(docs) | ||||
|         batches = minibatch_by_words(list(zip(docs, golds)), size=1000) | ||||
|         with tqdm.tqdm(total=n_train_words, leave=False) as pbar: | ||||
|             losses = {} | ||||
|             for batch in batches: | ||||
|                 if not batch: | ||||
|                     continue | ||||
|                 batch_docs, batch_gold = zip(*batch) | ||||
|     return nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds)) | ||||
| 
 | ||||
|                 nlp.update(batch_docs, batch_gold, sgd=optimizer, | ||||
|                            drop=0.2, losses=losses) | ||||
|                 pbar.update(sum(len(doc) for doc in batch_docs)) | ||||
| 
 | ||||
| ######################## | ||||
| # Command line helpers # | ||||
| ######################## | ||||
| 
 | ||||
| @attr.s | ||||
| class Config(object): | ||||
|     vectors = attr.ib(default=None) | ||||
|     max_doc_length = attr.ib(default=10) | ||||
|     multitask_tag = attr.ib(default=True) | ||||
|     multitask_sent = attr.ib(default=True) | ||||
|     nr_epoch = attr.ib(default=30) | ||||
|     batch_size = attr.ib(default=1000) | ||||
|     dropout = attr.ib(default=0.2) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def load(cls, loc): | ||||
|         with Path(loc).open('r', encoding='utf8') as file_: | ||||
|             cfg = json.load(file_) | ||||
|         return cls(**cfg) | ||||
| 
 | ||||
| 
 | ||||
| class Dataset(object): | ||||
|     def __init__(self, path, section): | ||||
|         self.path = path | ||||
|         self.section = section | ||||
|         self.conllu = None | ||||
|         self.text = None | ||||
|         for file_path in self.path.iterdir(): | ||||
|             name = file_path.parts[-1] | ||||
|             if section in name and name.endswith('conllu'): | ||||
|                 self.conllu = file_path | ||||
|             elif section in name and name.endswith('txt'): | ||||
|                 self.text = file_path | ||||
|         if self.conllu is None: | ||||
|             msg = "Could not find .txt file in {path} for {section}" | ||||
|             raise IOError(msg.format(section=section, path=path)) | ||||
|         if self.text is None: | ||||
|             msg = "Could not find .txt file in {path} for {section}" | ||||
|         self.lang = self.conllu.parts[-1].split('-')[0].split('_')[0] | ||||
| 
 | ||||
| 
 | ||||
| class TreebankPaths(object): | ||||
|     def __init__(self, ud_path, treebank, **cfg): | ||||
|         self.train = Dataset(ud_path / treebank, 'train') | ||||
|         self.dev = Dataset(ud_path / treebank, 'dev') | ||||
|         self.lang = self.train.lang | ||||
| 
 | ||||
| 
 | ||||
| @plac.annotations( | ||||
|     ud_dir=("Path to Universal Dependencies corpus", "positional", None, Path), | ||||
|     config=("Path to json formatted config file", "positional", None, Config.load), | ||||
|     corpus=("UD corpus to train and evaluate on, e.g. en, es_ancora, etc", | ||||
|             "positional", None, str), | ||||
|     parses=("Path to write the development parses", "positional", None, Path) | ||||
| ) | ||||
| def main(ud_dir, corpus, config, parses='/tmp/dev.conllu'): | ||||
|     paths = TreebankPaths(ud_dir, corpus) | ||||
|     nlp = load_nlp(paths.lang, config) | ||||
| 
 | ||||
|     docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(), | ||||
|                             config) | ||||
| 
 | ||||
|     optimizer = initialize_pipeline(nlp, docs, golds, config) | ||||
|     n_train_words = sum(len(doc) for doc in docs) | ||||
|     print("Begin training (%d words)" % n_train_words) | ||||
|     for i in range(config.nr_epoch): | ||||
|         docs = refresh_docs(docs) | ||||
|         batches = minibatch_by_words(list(zip(docs, golds)), size=config.batch_size) | ||||
|         losses = {} | ||||
|         for batch in tqdm.tqdm(batches, total=n_train_words//config.batch_size): | ||||
|             if not batch: | ||||
|                 continue | ||||
|             batch_docs, batch_gold = zip(*batch) | ||||
| 
 | ||||
|             nlp.update(batch_docs, batch_gold, sgd=optimizer, | ||||
|                         drop=config.dropout, losses=losses) | ||||
|          | ||||
|         with nlp.use_params(optimizer.averages): | ||||
|             dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc, | ||||
|                                               oracle_segments=False, joint_sbd=True) | ||||
|             dev_docs, scorer = parse_dev_data(nlp, paths.dev.text, paths.dev.conllu, | ||||
|                                               **attr.asdict(config)) | ||||
|             print_progress(i, losses, scorer) | ||||
|             with open(output_loc, 'w') as file_: | ||||
|                 print_conllu(dev_docs, file_) | ||||
|             with open('/tmp/train.conllu', 'w') as file_: | ||||
|                 print_conllu(list(nlp.pipe([d.text for d in batch_docs])), file_) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user