mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-25 05:01:02 +03:00 
			
		
		
		
	Merge pull request #2019 from explosion/feature/better-gold
Make Levenshtein alignment faster, bug fixes to parser, add UD parsing script
This commit is contained in:
		
						commit
						dd3ebe4931
					
				
							
								
								
									
										303
									
								
								examples/training/conllu.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										303
									
								
								examples/training/conllu.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,303 @@ | ||||||
|  | '''Train for CONLL 2017 UD treebank evaluation. Takes .conllu files, writes | ||||||
|  | .conllu format for development data, allowing the official scorer to be used. | ||||||
|  | ''' | ||||||
|  | from __future__ import unicode_literals | ||||||
|  | import plac | ||||||
|  | import tqdm | ||||||
|  | import re | ||||||
|  | import sys | ||||||
|  | import spacy | ||||||
|  | import spacy.util | ||||||
|  | from spacy.tokens import Doc | ||||||
|  | from spacy.gold import GoldParse, minibatch | ||||||
|  | from spacy.syntax.nonproj import projectivize | ||||||
|  | from collections import Counter | ||||||
|  | from timeit import default_timer as timer | ||||||
|  | 
 | ||||||
|  | from spacy._align import align | ||||||
|  | 
 | ||||||
|  | def prevent_bad_sentences(doc): | ||||||
|  |     '''This is an example pipeline component for fixing sentence segmentation | ||||||
|  |     mistakes. The component sets is_sent_start to False, which means the | ||||||
|  |     parser will be prevented from making a sentence boundary there. The | ||||||
|  |     rules here aren't necessarily a good idea.''' | ||||||
|  |     for token in doc[1:]: | ||||||
|  |         if token.nbor(-1).text == ',': | ||||||
|  |             token.is_sent_start = False | ||||||
|  |         elif not token.nbor(-1).whitespace_: | ||||||
|  |             token.is_sent_start = False | ||||||
|  |         elif not token.nbor(-1).is_punct: | ||||||
|  |             token.is_sent_start = False | ||||||
|  |         elif token.nbor(-1).is_left_punct: | ||||||
|  |             token.is_sent_start = False | ||||||
|  |     return doc | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def load_model(lang): | ||||||
|  |     '''This shows how to adjust the tokenization rules, to special-case | ||||||
|  |     for ways the CoNLLU tokenization differs. We need to get the tokenizer | ||||||
|  |     accuracy high on the various treebanks in order to do well. If we don't | ||||||
|  |     align on a content word, all dependencies to and from that word will | ||||||
|  |     be marked as incorrect. | ||||||
|  |     ''' | ||||||
|  |     English = spacy.util.get_lang_class(lang) | ||||||
|  |     English.Defaults.infixes += ('(?<=[^-\d])[+\-\*^](?=[^-\d])',) | ||||||
|  |     English.Defaults.infixes += ('(?<=[^-])[+\-\*^](?=[^-\d])',) | ||||||
|  |     English.Defaults.infixes += ('(?<=[^-\d])[+\-\*^](?=[^-])',) | ||||||
|  |     English.Defaults.token_match = re.compile(r'=+').match | ||||||
|  |     nlp = English() | ||||||
|  |     nlp.tokenizer.add_special_case('***', [{'ORTH': '***'}]) | ||||||
|  |     nlp.tokenizer.add_special_case("):", [{'ORTH': ")"}, {"ORTH": ":"}]) | ||||||
|  |     nlp.tokenizer.add_special_case("and/or", [{'ORTH': "and"}, {"ORTH": "/"}, {"ORTH": "or"}]) | ||||||
|  |     nlp.tokenizer.add_special_case("non-Microsoft", [{'ORTH': "non-Microsoft"}]) | ||||||
|  |     nlp.tokenizer.add_special_case("mis-matches", [{'ORTH': "mis-matches"}]) | ||||||
|  |     nlp.tokenizer.add_special_case("X.", [{'ORTH': "X"}, {"ORTH": "."}]) | ||||||
|  |     nlp.tokenizer.add_special_case("b/c", [{'ORTH': "b/c"}]) | ||||||
|  |     return nlp | ||||||
|  |      | ||||||
|  | 
 | ||||||
|  | def get_token_acc(docs, golds): | ||||||
|  |     '''Quick function to evaluate tokenization accuracy.''' | ||||||
|  |     miss = 0 | ||||||
|  |     hit = 0 | ||||||
|  |     for doc, gold in zip(docs, golds): | ||||||
|  |         for i in range(len(doc)): | ||||||
|  |             token = doc[i] | ||||||
|  |             align = gold.words[i] | ||||||
|  |             if align == None: | ||||||
|  |                 miss += 1 | ||||||
|  |             else: | ||||||
|  |                 hit += 1 | ||||||
|  |     return miss, hit | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def golds_to_gold_tuples(docs, golds): | ||||||
|  |     '''Get out the annoying 'tuples' format used by begin_training, given the | ||||||
|  |     GoldParse objects.''' | ||||||
|  |     tuples = [] | ||||||
|  |     for doc, gold in zip(docs, golds): | ||||||
|  |         text = doc.text | ||||||
|  |         ids, words, tags, heads, labels, iob = zip(*gold.orig_annot) | ||||||
|  |         sents = [((ids, words, tags, heads, labels, iob), [])] | ||||||
|  |         tuples.append((text, sents)) | ||||||
|  |     return tuples | ||||||
|  | 
 | ||||||
|  | def split_text(text): | ||||||
|  |     return [par.strip().replace('\n', ' ') | ||||||
|  |             for par in text.split('\n\n')] | ||||||
|  |   | ||||||
|  | 
 | ||||||
|  | def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False, | ||||||
|  |               limit=None): | ||||||
|  |     '''Read the CONLLU format into (Doc, GoldParse) tuples. If raw_text=True, | ||||||
|  |     include Doc objects created using nlp.make_doc and then aligned against | ||||||
|  |     the gold-standard sequences. If oracle_segments=True, include Doc objects | ||||||
|  |     created from the gold-standard segments. At least one must be True.''' | ||||||
|  |     if not raw_text and not oracle_segments: | ||||||
|  |         raise ValueError("At least one of raw_text or oracle_segments must be True") | ||||||
|  |     paragraphs = split_text(text_file.read()) | ||||||
|  |     conllu = read_conllu(conllu_file) | ||||||
|  |     # sd is spacy doc; cd is conllu doc | ||||||
|  |     # cs is conllu sent, ct is conllu token | ||||||
|  |     docs = [] | ||||||
|  |     golds = [] | ||||||
|  |     for doc_id, (text, cd) in enumerate(zip(paragraphs, conllu)): | ||||||
|  |         doc_words = [] | ||||||
|  |         doc_tags = [] | ||||||
|  |         doc_heads = [] | ||||||
|  |         doc_deps = [] | ||||||
|  |         doc_ents = [] | ||||||
|  |         for cs in cd: | ||||||
|  |             sent_words = [] | ||||||
|  |             sent_tags = [] | ||||||
|  |             sent_heads = [] | ||||||
|  |             sent_deps = [] | ||||||
|  |             for id_, word, lemma, pos, tag, morph, head, dep, _1, _2 in cs: | ||||||
|  |                 if '.' in id_: | ||||||
|  |                     continue | ||||||
|  |                 if '-' in id_: | ||||||
|  |                     continue | ||||||
|  |                 id_ = int(id_)-1 | ||||||
|  |                 head = int(head)-1 if head != '0' else id_ | ||||||
|  |                 sent_words.append(word) | ||||||
|  |                 sent_tags.append(tag) | ||||||
|  |                 sent_heads.append(head) | ||||||
|  |                 sent_deps.append('ROOT' if dep == 'root' else dep) | ||||||
|  |             if oracle_segments: | ||||||
|  |                 sent_heads, sent_deps = projectivize(sent_heads, sent_deps) | ||||||
|  |                 docs.append(Doc(nlp.vocab, words=sent_words)) | ||||||
|  |                 golds.append(GoldParse(docs[-1], words=sent_words, heads=sent_heads, | ||||||
|  |                                        tags=sent_tags, deps=sent_deps, | ||||||
|  |                                        entities=['-']*len(sent_words))) | ||||||
|  |             for head in sent_heads: | ||||||
|  |                 doc_heads.append(len(doc_words)+head) | ||||||
|  |             doc_words.extend(sent_words) | ||||||
|  |             doc_tags.extend(sent_tags) | ||||||
|  |             doc_deps.extend(sent_deps) | ||||||
|  |             doc_ents.extend(['-']*len(sent_words)) | ||||||
|  |             # Create a GoldParse object for the sentence | ||||||
|  |         doc_heads, doc_deps = projectivize(doc_heads, doc_deps) | ||||||
|  |         if raw_text: | ||||||
|  |             docs.append(nlp.make_doc(text)) | ||||||
|  |             golds.append(GoldParse(docs[-1], words=doc_words, tags=doc_tags, | ||||||
|  |                                    heads=doc_heads, deps=doc_deps, | ||||||
|  |                                    entities=doc_ents)) | ||||||
|  |         if limit and doc_id >= limit: | ||||||
|  |             break | ||||||
|  |     return docs, golds | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def refresh_docs(docs): | ||||||
|  |     vocab = docs[0].vocab | ||||||
|  |     return [Doc(vocab, words=[t.text for t in doc], | ||||||
|  |                        spaces=[t.whitespace_ for t in doc]) | ||||||
|  |             for doc in docs] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def read_conllu(file_): | ||||||
|  |     docs = [] | ||||||
|  |     doc = None | ||||||
|  |     sent = [] | ||||||
|  |     for line in file_: | ||||||
|  |         if line.startswith('# newdoc'): | ||||||
|  |             if doc: | ||||||
|  |                 docs.append(doc) | ||||||
|  |             doc = [] | ||||||
|  |         elif line.startswith('#'): | ||||||
|  |             continue | ||||||
|  |         elif not line.strip(): | ||||||
|  |             if sent: | ||||||
|  |                 if doc is None: | ||||||
|  |                     docs.append([sent]) | ||||||
|  |                 else: | ||||||
|  |                     doc.append(sent) | ||||||
|  |             sent = [] | ||||||
|  |         else: | ||||||
|  |             sent.append(line.strip().split()) | ||||||
|  |     if sent: | ||||||
|  |         if doc is None: | ||||||
|  |             docs.append([sent]) | ||||||
|  |         else: | ||||||
|  |             doc.append(sent) | ||||||
|  |     if doc: | ||||||
|  |         docs.append(doc) | ||||||
|  |     return docs | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def parse_dev_data(nlp, text_loc, conllu_loc, oracle_segments=False, | ||||||
|  |                    joint_sbd=True): | ||||||
|  |     with open(text_loc) as text_file: | ||||||
|  |         with open(conllu_loc) as conllu_file: | ||||||
|  |             docs, golds = read_data(nlp, conllu_file, text_file, | ||||||
|  |                                     oracle_segments=oracle_segments) | ||||||
|  |     if joint_sbd: | ||||||
|  |         pass | ||||||
|  |     else: | ||||||
|  |         sbd = nlp.create_pipe('sentencizer') | ||||||
|  |         for doc in docs: | ||||||
|  |             doc = sbd(doc) | ||||||
|  |             for sent in doc.sents: | ||||||
|  |                 sent[0].is_sent_start = True | ||||||
|  |                 for word in sent[1:]: | ||||||
|  |                     word.is_sent_start = False | ||||||
|  |     scorer = nlp.evaluate(zip(docs, golds)) | ||||||
|  |     return docs, scorer | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def print_progress(itn, losses, scorer): | ||||||
|  |     scores = {} | ||||||
|  |     for col in ['dep_loss', 'tag_loss', 'uas', 'tags_acc', 'token_acc', | ||||||
|  |                 'ents_p', 'ents_r', 'ents_f', 'cpu_wps', 'gpu_wps']: | ||||||
|  |         scores[col] = 0.0 | ||||||
|  |     scores['dep_loss'] = losses.get('parser', 0.0) | ||||||
|  |     scores['ner_loss'] = losses.get('ner', 0.0) | ||||||
|  |     scores['tag_loss'] = losses.get('tagger', 0.0) | ||||||
|  |     scores.update(scorer.scores) | ||||||
|  |     tpl = '\t'.join(( | ||||||
|  |         '{:d}', | ||||||
|  |         '{dep_loss:.3f}', | ||||||
|  |         '{ner_loss:.3f}', | ||||||
|  |         '{uas:.3f}', | ||||||
|  |         '{ents_p:.3f}', | ||||||
|  |         '{ents_r:.3f}', | ||||||
|  |         '{ents_f:.3f}', | ||||||
|  |         '{tags_acc:.3f}', | ||||||
|  |         '{token_acc:.3f}', | ||||||
|  |     )) | ||||||
|  |     print(tpl.format(itn, **scores)) | ||||||
|  | 
 | ||||||
|  | def print_conllu(docs, file_): | ||||||
|  |     for i, doc in enumerate(docs): | ||||||
|  |         file_.write("# newdoc id = {i}\n".format(i=i)) | ||||||
|  |         for j, sent in enumerate(doc.sents): | ||||||
|  |             file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j)) | ||||||
|  |             file_.write("# text = {text}\n".format(text=sent.text)) | ||||||
|  |             for k, t in enumerate(sent): | ||||||
|  |                 if t.head.i == t.i: | ||||||
|  |                     head = 0 | ||||||
|  |                 else: | ||||||
|  |                     head = k + (t.head.i - t.i) + 1 | ||||||
|  |                 fields = [str(k+1), t.text, t.lemma_, t.pos_, t.tag_, '_', | ||||||
|  |                           str(head), t.dep_.lower(), '_', '_'] | ||||||
|  |                 file_.write('\t'.join(fields) + '\n') | ||||||
|  |             file_.write('\n') | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def main(spacy_model, conllu_train_loc, text_train_loc, conllu_dev_loc, text_dev_loc, | ||||||
|  |          output_loc): | ||||||
|  |     nlp = load_model(spacy_model) | ||||||
|  |     with open(conllu_train_loc) as conllu_file: | ||||||
|  |         with open(text_train_loc) as text_file: | ||||||
|  |             docs, golds = read_data(nlp, conllu_file, text_file, | ||||||
|  |                                     oracle_segments=True, raw_text=True, | ||||||
|  |                                     limit=None) | ||||||
|  |     print("Create parser") | ||||||
|  |     nlp.add_pipe(nlp.create_pipe('parser')) | ||||||
|  |     nlp.add_pipe(nlp.create_pipe('tagger')) | ||||||
|  |     for gold in golds: | ||||||
|  |         for tag in gold.tags: | ||||||
|  |             if tag is not None: | ||||||
|  |                 nlp.tagger.add_label(tag) | ||||||
|  |     optimizer = nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds)) | ||||||
|  |     # Replace labels that didn't make the frequency cutoff | ||||||
|  |     actions = set(nlp.parser.labels) | ||||||
|  |     label_set = set([act.split('-')[1] for act in actions if '-' in act]) | ||||||
|  |     for gold in golds: | ||||||
|  |         for i, label in enumerate(gold.labels): | ||||||
|  |             if label is not None and label not in label_set: | ||||||
|  |                 gold.labels[i] = label.split('||')[0] | ||||||
|  |     n_train_words = sum(len(doc) for doc in docs) | ||||||
|  |     print(n_train_words) | ||||||
|  |     print("Begin training") | ||||||
|  |     # Batch size starts at 1 and grows, so that we make updates quickly | ||||||
|  |     # at the beginning of training. | ||||||
|  |     batch_sizes = spacy.util.compounding(spacy.util.env_opt('batch_from', 8), | ||||||
|  |                                    spacy.util.env_opt('batch_to', 8), | ||||||
|  |                                    spacy.util.env_opt('batch_compound', 1.001)) | ||||||
|  |     for i in range(30): | ||||||
|  |         docs = refresh_docs(docs) | ||||||
|  |         batches = minibatch(list(zip(docs, golds)), size=batch_sizes) | ||||||
|  |         with tqdm.tqdm(total=n_train_words, leave=False) as pbar: | ||||||
|  |             losses = {} | ||||||
|  |             for batch in batches: | ||||||
|  |                 if not batch: | ||||||
|  |                     continue | ||||||
|  |                 batch_docs, batch_gold = zip(*batch) | ||||||
|  | 
 | ||||||
|  |                 nlp.update(batch_docs, batch_gold, sgd=optimizer, | ||||||
|  |                            drop=0.2, losses=losses) | ||||||
|  |                 pbar.update(sum(len(doc) for doc in batch_docs)) | ||||||
|  |          | ||||||
|  |         with nlp.use_params(optimizer.averages): | ||||||
|  |             dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc, | ||||||
|  |                                               oracle_segments=False, joint_sbd=True) | ||||||
|  |             print_progress(i, losses, scorer) | ||||||
|  |             with open(output_loc, 'w') as file_: | ||||||
|  |                 print_conllu(dev_docs, file_) | ||||||
|  |             dev_docs, scorer = parse_dev_data(nlp, text_dev_loc, conllu_dev_loc, | ||||||
|  |                                               oracle_segments=False, joint_sbd=False) | ||||||
|  |             print_progress(i, losses, scorer) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     plac.call(main) | ||||||
							
								
								
									
										1
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								setup.py
									
									
									
									
									
								
							|  | @ -18,6 +18,7 @@ PACKAGES = find_packages() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| MOD_NAMES = [ | MOD_NAMES = [ | ||||||
|  |     'spacy._align', | ||||||
|     'spacy.parts_of_speech', |     'spacy.parts_of_speech', | ||||||
|     'spacy.strings', |     'spacy.strings', | ||||||
|     'spacy.lexeme', |     'spacy.lexeme', | ||||||
|  |  | ||||||
							
								
								
									
										175
									
								
								spacy/_align.pyx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										175
									
								
								spacy/_align.pyx
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,175 @@ | ||||||
|  | # cython: infer_types=True | ||||||
|  | '''Do Levenshtein alignment, for evaluation of tokenized input. | ||||||
|  | 
 | ||||||
|  | Random notes: | ||||||
|  | 
 | ||||||
|  |   r i n g | ||||||
|  |   0 1 2 3 4 | ||||||
|  | r 1 0 1 2 3 | ||||||
|  | a 2 1 1 2 3 | ||||||
|  | n 3 2 2 1 2 | ||||||
|  | g 4 3 3 2 1 | ||||||
|  | 
 | ||||||
|  | 0,0: (1,1)=min(0+0,1+1,1+1)=0 S | ||||||
|  | 1,0: (2,1)=min(1+1,0+1,2+1)=1 D | ||||||
|  | 2,0: (3,1)=min(2+1,3+1,1+1)=2 D | ||||||
|  | 3,0: (4,1)=min(3+1,4+1,2+1)=3 D | ||||||
|  | 0,1: (1,2)=min(1+1,2+1,0+1)=1 D | ||||||
|  | 1,1: (2,2)=min(0+1,1+1,1+1)=1 S | ||||||
|  | 2,1: (3,2)=min(1+1,1+1,2+1)=2 S or I | ||||||
|  | 3,1: (4,2)=min(2+1,2+1,3+1)=3 S or I | ||||||
|  | 0,2: (1,3)=min(2+1,3+1,1+1)=2 I | ||||||
|  | 1,2: (2,3)=min(1+1,2+1,1+1)=2 S or I | ||||||
|  | 2,2: (3,3) | ||||||
|  | 3,2: (4,3) | ||||||
|  | At state (i, j) we're asking "How do I transform S[:i+1] to T[:j+1]?" | ||||||
|  | 
 | ||||||
|  | We know the costs to transition: | ||||||
|  | 
 | ||||||
|  | S[:i]   -> T[:j]   (at D[i,j]) | ||||||
|  | S[:i+1] -> T[:j]   (at D[i+1,j]) | ||||||
|  | S[:i]   -> T[:j+1] (at D[i,j+1]) | ||||||
|  |      | ||||||
|  | Further, we now we can tranform: | ||||||
|  | S[:i+1] -> S[:i] (DEL) for 1, | ||||||
|  | T[:j+1] -> T[:j] (INS) for 1. | ||||||
|  | S[i+1]  -> T[j+1] (SUB) for 0 or 1 | ||||||
|  | 
 | ||||||
|  | Therefore we have the costs: | ||||||
|  | SUB: Cost(S[:i]->T[:j])   + Cost(S[i]->S[j]) | ||||||
|  | i.e. D[i, j] + S[i+1] != T[j+1] | ||||||
|  | INS: Cost(S[:i+1]->T[:j]) + Cost(T[:j+1]->T[:j]) | ||||||
|  | i.e. D[i+1,j] + 1 | ||||||
|  | DEL: Cost(S[:i]->T[:j+1]) + Cost(S[:i+1]->S[:i])  | ||||||
|  | i.e. D[i,j+1] + 1 | ||||||
|  | 
 | ||||||
|  |     Source string S has length m, with index i | ||||||
|  |     Target string T has length n, with index j | ||||||
|  | 
 | ||||||
|  |     Output two alignment vectors: i2j (length m) and j2i (length n) | ||||||
|  |     # function LevenshteinDistance(char s[1..m], char t[1..n]): | ||||||
|  |     # for all i and j, d[i,j] will hold the Levenshtein distance between | ||||||
|  |     # the first i characters of s and the first j characters of t | ||||||
|  |     # note that d has (m+1)*(n+1) values | ||||||
|  |     # set each element in d to zero | ||||||
|  |     ring rang | ||||||
|  |       - r i n g | ||||||
|  |     - 0 0 0 0 0 | ||||||
|  |     r 0 0 0 0 0 | ||||||
|  |     a 0 0 0 0 0 | ||||||
|  |     n 0 0 0 0 0 | ||||||
|  |     g 0 0 0 0 0 | ||||||
|  | 
 | ||||||
|  |     # source prefixes can be transformed into empty string by | ||||||
|  |     # dropping all characters | ||||||
|  |     # d[i, 0] := i | ||||||
|  |     ring rang | ||||||
|  |       - r i n g | ||||||
|  |     - 0 0 0 0 0 | ||||||
|  |     r 1 0 0 0 0 | ||||||
|  |     a 2 0 0 0 0 | ||||||
|  |     n 3 0 0 0 0 | ||||||
|  |     g 4 0 0 0 0 | ||||||
|  | 
 | ||||||
|  |     # target prefixes can be reached from empty source prefix | ||||||
|  |     # by inserting every character | ||||||
|  |     # d[0, j] := j | ||||||
|  |       - r i n g | ||||||
|  |     - 0 1 2 3 4 | ||||||
|  |     r 1 0 0 0 0 | ||||||
|  |     a 2 0 0 0 0 | ||||||
|  |     n 3 0 0 0 0 | ||||||
|  |     g 4 0 0 0 0 | ||||||
|  | 
 | ||||||
|  | ''' | ||||||
|  | import numpy | ||||||
|  | cimport numpy as np | ||||||
|  | from .compat import unicode_ | ||||||
|  | from murmurhash.mrmr cimport hash32 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def align(S, T): | ||||||
|  |     cdef int m = len(S) | ||||||
|  |     cdef int n = len(T) | ||||||
|  |     cdef np.ndarray matrix = numpy.zeros((m+1, n+1), dtype='int32') | ||||||
|  |     cdef np.ndarray i2j = numpy.zeros((m,), dtype='i') | ||||||
|  |     cdef np.ndarray j2i = numpy.zeros((n,), dtype='i') | ||||||
|  | 
 | ||||||
|  |     cdef np.ndarray S_arr = _convert_sequence(S) | ||||||
|  |     cdef np.ndarray T_arr = _convert_sequence(T) | ||||||
|  | 
 | ||||||
|  |     fill_matrix(<int*>matrix.data, | ||||||
|  |         <const int*>S_arr.data, m, <const int*>T_arr.data, n) | ||||||
|  |     fill_i2j(i2j, matrix) | ||||||
|  |     fill_j2i(j2i, matrix) | ||||||
|  |     return matrix[-1,-1], i2j, j2i, matrix | ||||||
|  | 
 | ||||||
|  | def _convert_sequence(seq): | ||||||
|  |     if isinstance(seq, numpy.ndarray): | ||||||
|  |         return numpy.ascontiguousarray(seq, dtype='i') | ||||||
|  |     cdef np.ndarray output = numpy.zeros((len(seq),), dtype='i') | ||||||
|  |     cdef bytes item_bytes | ||||||
|  |     for i, item in enumerate(seq): | ||||||
|  |         if isinstance(item, unicode): | ||||||
|  |             item_bytes = item.encode('utf8') | ||||||
|  |         else: | ||||||
|  |             item_bytes = item | ||||||
|  |         output[i] = hash32(<void*><char*>item_bytes, len(item_bytes), 0) | ||||||
|  |     return output | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | cdef void fill_matrix(int* D,  | ||||||
|  |         const int* S, int m, const int* T, int n) nogil: | ||||||
|  |     m1 = m+1 | ||||||
|  |     n1 = n+1 | ||||||
|  |     for i in range(m1*n1): | ||||||
|  |         D[i] = 0 | ||||||
|  |   | ||||||
|  |     for i in range(m1): | ||||||
|  |         D[i*n1] = i | ||||||
|  |   | ||||||
|  |     for j in range(n1): | ||||||
|  |         D[j] = j | ||||||
|  |   | ||||||
|  |     cdef int sub_cost, ins_cost, del_cost | ||||||
|  |     for j in range(n): | ||||||
|  |         for i in range(m): | ||||||
|  |             i_j = i*n1 + j | ||||||
|  |             i1_j1 = (i+1)*n1 + j+1 | ||||||
|  |             i1_j = (i+1)*n1 + j | ||||||
|  |             i_j1 = i*n1 + j+1 | ||||||
|  |             if S[i] != T[j]: | ||||||
|  |                 sub_cost = D[i_j] + 1 | ||||||
|  |             else: | ||||||
|  |                 sub_cost = D[i_j] | ||||||
|  |             del_cost = D[i_j1] + 1 | ||||||
|  |             ins_cost = D[i1_j] + 1 | ||||||
|  |             best = min(min(sub_cost, ins_cost), del_cost) | ||||||
|  |             D[i1_j1] = best | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | cdef void fill_i2j(np.ndarray i2j, np.ndarray D) except *: | ||||||
|  |     j = D.shape[1]-2 | ||||||
|  |     cdef int i = D.shape[0]-2 | ||||||
|  |     while i >= 0: | ||||||
|  |         while D[i+1, j] < D[i+1, j+1]: | ||||||
|  |             j -= 1 | ||||||
|  |         if D[i, j+1] < D[i+1, j+1]: | ||||||
|  |             i2j[i] = -1 | ||||||
|  |         else: | ||||||
|  |             i2j[i] = j | ||||||
|  |             j -= 1 | ||||||
|  |         i -= 1 | ||||||
|  | 
 | ||||||
|  | cdef void fill_j2i(np.ndarray j2i, np.ndarray D) except *: | ||||||
|  |     i = D.shape[0]-2 | ||||||
|  |     cdef int j = D.shape[1]-2 | ||||||
|  |     while j >= 0: | ||||||
|  |         while D[i, j+1] < D[i+1, j+1]: | ||||||
|  |             i -= 1 | ||||||
|  |         if D[i+1, j] < D[i+1, j+1]: | ||||||
|  |             j2i[j] = -1 | ||||||
|  |         else: | ||||||
|  |             j2i[j] = i | ||||||
|  |             i -= 1 | ||||||
|  |         j -= 1 | ||||||
|  | @ -7,7 +7,9 @@ import ujson | ||||||
| import random | import random | ||||||
| import cytoolz | import cytoolz | ||||||
| import itertools | import itertools | ||||||
|  | import numpy | ||||||
| 
 | 
 | ||||||
|  | from . import _align  | ||||||
| from .syntax import nonproj | from .syntax import nonproj | ||||||
| from .tokens import Doc | from .tokens import Doc | ||||||
| from . import util | from . import util | ||||||
|  | @ -59,90 +61,15 @@ def merge_sents(sents): | ||||||
|     return [(m_deps, m_brackets)] |     return [(m_deps, m_brackets)] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def align(cand_words, gold_words): |  | ||||||
|     cost, edit_path = _min_edit_path(cand_words, gold_words) |  | ||||||
|     alignment = [] |  | ||||||
|     i_of_gold = 0 |  | ||||||
|     for move in edit_path: |  | ||||||
|         if move == 'M': |  | ||||||
|             alignment.append(i_of_gold) |  | ||||||
|             i_of_gold += 1 |  | ||||||
|         elif move == 'S': |  | ||||||
|             alignment.append(None) |  | ||||||
|             i_of_gold += 1 |  | ||||||
|         elif move == 'D': |  | ||||||
|             alignment.append(None) |  | ||||||
|         elif move == 'I': |  | ||||||
|             i_of_gold += 1 |  | ||||||
|         else: |  | ||||||
|             raise Exception(move) |  | ||||||
|     return alignment |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| punct_re = re.compile(r'\W') | punct_re = re.compile(r'\W') | ||||||
| 
 | def align(cand_words, gold_words): | ||||||
| 
 |  | ||||||
| def _min_edit_path(cand_words, gold_words): |  | ||||||
|     cdef: |  | ||||||
|         Pool mem |  | ||||||
|         int i, j, n_cand, n_gold |  | ||||||
|         int* curr_costs |  | ||||||
|         int* prev_costs |  | ||||||
| 
 |  | ||||||
|     # TODO: Fix this --- just do it properly, make the full edit matrix and |  | ||||||
|     # then walk back over it... |  | ||||||
|     # Preprocess inputs |  | ||||||
|     cand_words = [punct_re.sub('', w).lower() for w in cand_words] |     cand_words = [punct_re.sub('', w).lower() for w in cand_words] | ||||||
|     gold_words = [punct_re.sub('', w).lower() for w in gold_words] |     gold_words = [punct_re.sub('', w).lower() for w in gold_words] | ||||||
| 
 |  | ||||||
|     if cand_words == gold_words: |     if cand_words == gold_words: | ||||||
|         return 0, ''.join(['M' for _ in gold_words]) |         alignment = numpy.arange(len(cand_words)) | ||||||
|     mem = Pool() |         return 0, alignment, alignment | ||||||
|     n_cand = len(cand_words) |     cost, i2j, j2i, matrix = _align.align(cand_words, gold_words) | ||||||
|     n_gold = len(gold_words) |     return cost, i2j, j2i | ||||||
|     # Levenshtein distance, except we need the history, and we may want |  | ||||||
|     # different costs. Mark operations with a string, and score the history |  | ||||||
|     # using _edit_cost. |  | ||||||
|     previous_row = [] |  | ||||||
|     prev_costs = <int*>mem.alloc(n_gold + 1, sizeof(int)) |  | ||||||
|     curr_costs = <int*>mem.alloc(n_gold + 1, sizeof(int)) |  | ||||||
|     for i in range(n_gold + 1): |  | ||||||
|         cell = '' |  | ||||||
|         for j in range(i): |  | ||||||
|             cell += 'I' |  | ||||||
|         previous_row.append('I' * i) |  | ||||||
|         prev_costs[i] = i |  | ||||||
|     for i, cand in enumerate(cand_words): |  | ||||||
|         current_row = ['D' * (i + 1)] |  | ||||||
|         curr_costs[0] = i+1 |  | ||||||
|         for j, gold in enumerate(gold_words): |  | ||||||
|             if gold.lower() == cand.lower(): |  | ||||||
|                 s_cost = prev_costs[j] |  | ||||||
|                 i_cost = curr_costs[j] + 1 |  | ||||||
|                 d_cost = prev_costs[j + 1] + 1 |  | ||||||
|             else: |  | ||||||
|                 s_cost = prev_costs[j] + 1 |  | ||||||
|                 i_cost = curr_costs[j] + 1 |  | ||||||
|                 d_cost = prev_costs[j + 1] + (1 if cand else 0) |  | ||||||
| 
 |  | ||||||
|             if s_cost <= i_cost and s_cost <= d_cost: |  | ||||||
|                 best_cost = s_cost |  | ||||||
|                 best_hist = previous_row[j] + ('M' if gold == cand else 'S') |  | ||||||
|             elif i_cost <= s_cost and i_cost <= d_cost: |  | ||||||
|                 best_cost = i_cost |  | ||||||
|                 best_hist = current_row[j] + 'I' |  | ||||||
|             else: |  | ||||||
|                 best_cost = d_cost |  | ||||||
|                 best_hist = previous_row[j + 1] + 'D' |  | ||||||
| 
 |  | ||||||
|             current_row.append(best_hist) |  | ||||||
|             curr_costs[j+1] = best_cost |  | ||||||
|         previous_row = current_row |  | ||||||
|         for j in range(len(gold_words) + 1): |  | ||||||
|             prev_costs[j] = curr_costs[j] |  | ||||||
|             curr_costs[j] = 0 |  | ||||||
| 
 |  | ||||||
|     return prev_costs[n_gold], previous_row[-1] |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class GoldCorpus(object): | class GoldCorpus(object): | ||||||
|  | @ -434,8 +361,9 @@ cdef class GoldParse: | ||||||
|         self.labels = [None] * len(doc) |         self.labels = [None] * len(doc) | ||||||
|         self.ner = [None] * len(doc) |         self.ner = [None] * len(doc) | ||||||
| 
 | 
 | ||||||
|         self.cand_to_gold = align([t.orth_ for t in doc], words) |         cost, i2j, j2i = align([t.orth_ for t in doc], words) | ||||||
|         self.gold_to_cand = align(words, [t.orth_ for t in doc]) |         self.cand_to_gold = [(j if j != -1 else None) for j in i2j] | ||||||
|  |         self.gold_to_cand = [(i if i != -1 else None) for i in j2i] | ||||||
| 
 | 
 | ||||||
|         annot_tuples = (range(len(words)), words, tags, heads, deps, entities) |         annot_tuples = (range(len(words)), words, tags, heads, deps, entities) | ||||||
|         self.orig_annot = list(zip(*annot_tuples)) |         self.orig_annot = list(zip(*annot_tuples)) | ||||||
|  |  | ||||||
|  | @ -1,7 +1,7 @@ | ||||||
| # coding: utf8 | # coding: utf8 | ||||||
| from __future__ import unicode_literals | from __future__ import unicode_literals | ||||||
| 
 | 
 | ||||||
| from .symbols import POS, NOUN, VERB, ADJ, PUNCT | from .symbols import POS, NOUN, VERB, ADJ, PUNCT, PROPN | ||||||
| from .symbols import VerbForm_inf, VerbForm_none, Number_sing, Degree_pos | from .symbols import VerbForm_inf, VerbForm_none, Number_sing, Degree_pos | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -27,11 +27,13 @@ class Lemmatizer(object): | ||||||
|             univ_pos = 'adj' |             univ_pos = 'adj' | ||||||
|         elif univ_pos in (PUNCT, 'PUNCT', 'punct'): |         elif univ_pos in (PUNCT, 'PUNCT', 'punct'): | ||||||
|             univ_pos = 'punct' |             univ_pos = 'punct' | ||||||
|  |         elif univ_pos in (PROPN, 'PROPN'): | ||||||
|  |             return [string] | ||||||
|         else: |         else: | ||||||
|             return list(set([string.lower()])) |             return [string.lower()] | ||||||
|         # See Issue #435 for example of where this logic is requied. |         # See Issue #435 for example of where this logic is requied. | ||||||
|         if self.is_base_form(univ_pos, morphology): |         if self.is_base_form(univ_pos, morphology): | ||||||
|             return list(set([string.lower()])) |             return [string.lower()] | ||||||
|         lemmas = lemmatize(string, self.index.get(univ_pos, {}), |         lemmas = lemmatize(string, self.index.get(univ_pos, {}), | ||||||
|                            self.exc.get(univ_pos, {}), |                            self.exc.get(univ_pos, {}), | ||||||
|                            self.rules.get(univ_pos, [])) |                            self.rules.get(univ_pos, [])) | ||||||
|  | @ -88,6 +90,7 @@ class Lemmatizer(object): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def lemmatize(string, index, exceptions, rules): | def lemmatize(string, index, exceptions, rules): | ||||||
|  |     orig = string | ||||||
|     string = string.lower() |     string = string.lower() | ||||||
|     forms = [] |     forms = [] | ||||||
|     forms.extend(exceptions.get(string, [])) |     forms.extend(exceptions.get(string, [])) | ||||||
|  | @ -105,5 +108,5 @@ def lemmatize(string, index, exceptions, rules): | ||||||
|     if not forms: |     if not forms: | ||||||
|         forms.extend(oov_forms) |         forms.extend(oov_forms) | ||||||
|     if not forms: |     if not forms: | ||||||
|         forms.append(string) |         forms.append(orig) | ||||||
|     return list(set(forms)) |     return list(set(forms)) | ||||||
|  |  | ||||||
|  | @ -110,7 +110,8 @@ cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil: | ||||||
| cdef class Shift: | cdef class Shift: | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     cdef bint is_valid(const StateC* st, attr_t label) nogil: |     cdef bint is_valid(const StateC* st, attr_t label) nogil: | ||||||
|         return st.buffer_length() >= 2 and not st.shifted[st.B(0)] and st.B_(0).sent_start != 1 |         sent_start = st._sent[st.B_(0).l_edge].sent_start | ||||||
|  |         return st.buffer_length() >= 2 and not st.shifted[st.B(0)] and sent_start != 1 | ||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     cdef int transition(StateC* st, attr_t label) nogil: |     cdef int transition(StateC* st, attr_t label) nogil: | ||||||
|  | @ -170,7 +171,8 @@ cdef class Reduce: | ||||||
| cdef class LeftArc: | cdef class LeftArc: | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     cdef bint is_valid(const StateC* st, attr_t label) nogil: |     cdef bint is_valid(const StateC* st, attr_t label) nogil: | ||||||
|         return st.B_(0).sent_start != 1 |         sent_start = st._sent[st.B_(0).l_edge].sent_start | ||||||
|  |         return sent_start != 1 | ||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     cdef int transition(StateC* st, attr_t label) nogil: |     cdef int transition(StateC* st, attr_t label) nogil: | ||||||
|  | @ -205,7 +207,8 @@ cdef class RightArc: | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     cdef bint is_valid(const StateC* st, attr_t label) nogil: |     cdef bint is_valid(const StateC* st, attr_t label) nogil: | ||||||
|         # If there's (perhaps partial) parse pre-set, don't allow cycle. |         # If there's (perhaps partial) parse pre-set, don't allow cycle. | ||||||
|         return st.B_(0).sent_start != 1 and st.H(st.S(0)) != st.B(0) |         sent_start = st._sent[st.B_(0).l_edge].sent_start | ||||||
|  |         return sent_start != 1 and st.H(st.S(0)) != st.B(0) | ||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     cdef int transition(StateC* st, attr_t label) nogil: |     cdef int transition(StateC* st, attr_t label) nogil: | ||||||
|  | @ -527,7 +530,12 @@ cdef class ArcEager(TransitionSystem): | ||||||
|                 is_valid[i] = False |                 is_valid[i] = False | ||||||
|                 costs[i] = 9000 |                 costs[i] = 9000 | ||||||
|         if n_gold < 1: |         if n_gold < 1: | ||||||
|             # Check projectivity --- leading cause |             # Check label set --- leading cause | ||||||
|  |             label_set = set([self.strings[self.c[i].label] for i in range(self.n_moves)]) | ||||||
|  |             for label_str in gold.labels: | ||||||
|  |                 if label_str is not None and label_str not in label_set: | ||||||
|  |                     raise ValueError("Cannot get gold parser action: unknown label: %s" % label_str) | ||||||
|  |             # Check projectivity --- other leading cause | ||||||
|             if is_nonproj_tree(gold.heads): |             if is_nonproj_tree(gold.heads): | ||||||
|                 raise ValueError( |                 raise ValueError( | ||||||
|                     "Could not find a gold-standard action to supervise the " |                     "Could not find a gold-standard action to supervise the " | ||||||
|  |  | ||||||
|  | @ -555,7 +555,10 @@ cdef class Parser: | ||||||
|         for multitask in self._multitasks: |         for multitask in self._multitasks: | ||||||
|             multitask.update(docs, golds, drop=drop, sgd=sgd) |             multitask.update(docs, golds, drop=drop, sgd=sgd) | ||||||
|         cuda_stream = util.get_cuda_stream() |         cuda_stream = util.get_cuda_stream() | ||||||
|         states, golds, max_steps = self._init_gold_batch(docs, golds) |         # Chop sequences into lengths of this many transitions, to make the | ||||||
|  |         # batch uniform length. | ||||||
|  |         cut_gold = numpy.random.choice(range(20, 100)) | ||||||
|  |         states, golds, max_steps = self._init_gold_batch(docs, golds, max_length=cut_gold) | ||||||
|         (tokvecs, bp_tokvecs), state2vec, vec2scores = self.get_batch_model(docs, cuda_stream, |         (tokvecs, bp_tokvecs), state2vec, vec2scores = self.get_batch_model(docs, cuda_stream, | ||||||
|                                                                             drop) |                                                                             drop) | ||||||
|         todo = [(s, g) for (s, g) in zip(states, golds) |         todo = [(s, g) for (s, g) in zip(states, golds) | ||||||
|  | @ -659,7 +662,7 @@ cdef class Parser: | ||||||
|             _cleanup(beam) |             _cleanup(beam) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     def _init_gold_batch(self, whole_docs, whole_golds): |     def _init_gold_batch(self, whole_docs, whole_golds, min_length=5, max_length=500): | ||||||
|         """Make a square batch, of length equal to the shortest doc. A long |         """Make a square batch, of length equal to the shortest doc. A long | ||||||
|         doc will get multiple states. Let's say we have a doc of length 2*N, |         doc will get multiple states. Let's say we have a doc of length 2*N, | ||||||
|         where N is the shortest doc. We'll make two states, one representing |         where N is the shortest doc. We'll make two states, one representing | ||||||
|  | @ -668,7 +671,7 @@ cdef class Parser: | ||||||
|             StateClass state |             StateClass state | ||||||
|             Transition action |             Transition action | ||||||
|         whole_states = self.moves.init_batch(whole_docs) |         whole_states = self.moves.init_batch(whole_docs) | ||||||
|         max_length = max(5, min(50, min([len(doc) for doc in whole_docs]))) |         max_length = max(min_length, min(max_length, min([len(doc) for doc in whole_docs]))) | ||||||
|         max_moves = 0 |         max_moves = 0 | ||||||
|         states = [] |         states = [] | ||||||
|         golds = [] |         golds = [] | ||||||
|  | @ -790,6 +793,11 @@ cdef class Parser: | ||||||
|                 for doc in docs: |                 for doc in docs: | ||||||
|                     hook(doc) |                     hook(doc) | ||||||
| 
 | 
 | ||||||
|  |     @property | ||||||
|  |     def labels(self): | ||||||
|  |         class_names = [self.moves.get_class_name(i) for i in range(self.moves.n_moves)] | ||||||
|  |         return class_names | ||||||
|  | 
 | ||||||
|     @property |     @property | ||||||
|     def tok2vec(self): |     def tok2vec(self): | ||||||
|         '''Return the embedding and convolutional layer of the model.''' |         '''Return the embedding and convolutional layer of the model.''' | ||||||
|  | @ -825,7 +833,7 @@ cdef class Parser: | ||||||
|         if 'model' in cfg: |         if 'model' in cfg: | ||||||
|             self.model = cfg['model'] |             self.model = cfg['model'] | ||||||
|         gold_tuples = nonproj.preprocess_training_data(gold_tuples, |         gold_tuples = nonproj.preprocess_training_data(gold_tuples, | ||||||
|                                                        label_freq_cutoff=100) |                                                        label_freq_cutoff=30) | ||||||
|         actions = self.moves.get_actions(gold_parses=gold_tuples) |         actions = self.moves.get_actions(gold_parses=gold_tuples) | ||||||
|         for action, labels in actions.items(): |         for action, labels in actions.items(): | ||||||
|             for label in labels: |             for label in labels: | ||||||
|  |  | ||||||
|  | @ -1,36 +0,0 @@ | ||||||
| # coding: utf-8 |  | ||||||
| """Find the min-cost alignment between two tokenizations""" |  | ||||||
| 
 |  | ||||||
| from __future__ import unicode_literals |  | ||||||
| 
 |  | ||||||
| from ...gold import _min_edit_path as min_edit_path |  | ||||||
| from ...gold import align |  | ||||||
| 
 |  | ||||||
| import pytest |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| @pytest.mark.parametrize('cand,gold,path', [ |  | ||||||
|     (["U.S", ".", "policy"], ["U.S.", "policy"], (0, 'MDM')), |  | ||||||
|     (["U.N", ".", "policy"], ["U.S.", "policy"], (1, 'SDM')), |  | ||||||
|     (["The", "cat", "sat", "down"], ["The", "cat", "sat", "down"], (0, 'MMMM')), |  | ||||||
|     (["cat", "sat", "down"], ["The", "cat", "sat", "down"], (1, 'IMMM')), |  | ||||||
|     (["The", "cat", "down"], ["The", "cat", "sat", "down"], (1, 'MMIM')), |  | ||||||
|     (["The", "cat", "sag", "down"], ["The", "cat", "sat", "down"], (1, 'MMSM'))]) |  | ||||||
| def test_gold_lev_align_edit_path(cand, gold, path): |  | ||||||
|     assert min_edit_path(cand, gold) == path |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def test_gold_lev_align_edit_path2(): |  | ||||||
|     cand = ["your", "stuff"] |  | ||||||
|     gold = ["you", "r", "stuff"] |  | ||||||
|     assert min_edit_path(cand, gold) in [(2, 'ISM'), (2, 'SIM')] |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| @pytest.mark.parametrize('cand,gold,result', [ |  | ||||||
|     (["U.S", ".", "policy"], ["U.S.", "policy"], [0, None, 1]), |  | ||||||
|     (["your", "stuff"], ["you", "r", "stuff"], [None, 2]), |  | ||||||
|     (["i", "like", "2", "guys", "   ", "well", "id", "just", "come", "straight", "out"], |  | ||||||
|      ["i", "like", "2", "guys", "well", "i", "d", "just", "come", "straight", "out"], |  | ||||||
|      [0, 1, 2, 3, None, 4, None, 7, 8, 9, 10])]) |  | ||||||
| def test_gold_lev_align(cand, gold, result): |  | ||||||
|     assert align(cand, gold) == result |  | ||||||
							
								
								
									
										46
									
								
								spacy/tests/test_align.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								spacy/tests/test_align.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,46 @@ | ||||||
|  | import pytest | ||||||
|  | from .._align import align | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.mark.parametrize('string1,string2,cost', [ | ||||||
|  |     ('hello', 'hell', 1), | ||||||
|  |     ('rat', 'cat', 1), | ||||||
|  |     ('rat', 'rat', 0), | ||||||
|  |     ('rat', 'catsie', 4), | ||||||
|  |     ('t', 'catsie', 5), | ||||||
|  | ]) | ||||||
|  | def test_align_costs(string1, string2, cost): | ||||||
|  |     output_cost, i2j, j2i, matrix = align(string1, string2) | ||||||
|  |     assert output_cost == cost | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.mark.parametrize('string1,string2,i2j', [ | ||||||
|  |     ('hello', 'hell', [0,1,2,3,-1]), | ||||||
|  |     ('rat', 'cat', [0,1,2]), | ||||||
|  |     ('rat', 'rat', [0,1,2]), | ||||||
|  |     ('rat', 'catsie', [0,1,2]), | ||||||
|  |     ('t', 'catsie', [2]), | ||||||
|  | ]) | ||||||
|  | def test_align_i2j(string1, string2, i2j): | ||||||
|  |     output_cost, output_i2j, j2i, matrix = align(string1, string2) | ||||||
|  |     assert list(output_i2j) == i2j | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.mark.parametrize('string1,string2,j2i', [ | ||||||
|  |     ('hello', 'hell', [0,1,2,3]), | ||||||
|  |     ('rat', 'cat', [0,1,2]), | ||||||
|  |     ('rat', 'rat', [0,1,2]), | ||||||
|  |     ('rat', 'catsie', [0,1,2, -1, -1, -1]), | ||||||
|  |     ('t', 'catsie', [-1, -1, 0, -1, -1, -1]), | ||||||
|  | ]) | ||||||
|  | def test_align_i2j(string1, string2, j2i): | ||||||
|  |     output_cost, output_i2j, output_j2i, matrix = align(string1, string2) | ||||||
|  |     assert list(output_j2i) == j2i | ||||||
|  | 
 | ||||||
|  | def test_align_strings(): | ||||||
|  |     words1 = ['hello', 'this', 'is', 'test!'] | ||||||
|  |     words2 = ['hellothis', 'is', 'test', '!'] | ||||||
|  |     cost, i2j, j2i, matrix = align(words1, words2) | ||||||
|  |     assert cost == 4 | ||||||
|  |     assert list(i2j) == [0, -1, 1, 2] | ||||||
|  |     assert list(j2i) == [0, 2, 3, -1] | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user