mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-26 13:41:21 +03:00 
			
		
		
		
	Improve the API for the GoldParse class.
This commit is contained in:
		
							parent
							
								
									e07fe92b27
								
							
						
					
					
						commit
						a48aa15384
					
				|  | @ -216,7 +216,12 @@ def _consume_ent(tags): | |||
| 
 | ||||
| cdef class GoldParse: | ||||
|     @classmethod | ||||
|     def new_init(cls, doc, annot_tuples=None, words=None, tags=None, heads=None, | ||||
|     def from_annot_tuples(cls, doc, annot_tuples, make_projective=False): | ||||
|         _, words, tags, heads, deps, entities = annot_tuples | ||||
|         return cls(doc, words=words, tags=tags, heads=heads, deps=deps, entities=entities, | ||||
|                    make_projective=make_projective) | ||||
| 
 | ||||
|     def __init__(cls, doc, annot_tuples=None, words=None, tags=None, heads=None, | ||||
|                  deps=None, entities=None): | ||||
|         if words is None: | ||||
|             words = [token.text for token in doc] | ||||
|  | @ -233,50 +238,48 @@ cdef class GoldParse: | |||
|         elif not isinstance(entities[0], basestring): | ||||
|             # Assume we have entities specified by character offset. | ||||
|             entities = biluo_tags_from_offsets(doc, entities) | ||||
|         return cls(doc, [(range(len(doc)), words, tags, heads, deps, entities)]) | ||||
| 
 | ||||
|     def __init__(self, tokens, annot_tuples, make_projective=False): | ||||
|         self.mem = Pool() | ||||
|         self.loss = 0 | ||||
|         self.length = len(tokens) | ||||
| 
 | ||||
|         # These are filled by the tagger/parser/entity recogniser | ||||
|         self.c.tags = <int*>self.mem.alloc(len(tokens), sizeof(int)) | ||||
|         self.c.heads = <int*>self.mem.alloc(len(tokens), sizeof(int)) | ||||
|         self.c.labels = <int*>self.mem.alloc(len(tokens), sizeof(int)) | ||||
|         self.c.ner = <Transition*>self.mem.alloc(len(tokens), sizeof(Transition)) | ||||
|         self.c.tags = <int*>self.mem.alloc(len(doc), sizeof(int)) | ||||
|         self.c.heads = <int*>self.mem.alloc(len(doc), sizeof(int)) | ||||
|         self.c.labels = <int*>self.mem.alloc(len(doc), sizeof(int)) | ||||
|         self.c.ner = <Transition*>self.mem.alloc(len(doc), sizeof(Transition)) | ||||
| 
 | ||||
|         self.tags = [None] * len(tokens) | ||||
|         self.heads = [None] * len(tokens) | ||||
|         self.labels = [''] * len(tokens) | ||||
|         self.ner = ['-'] * len(tokens) | ||||
|         self.tags = [None] * len(doc) | ||||
|         self.heads = [None] * len(doc) | ||||
|         self.labels = [''] * len(doc) | ||||
|         self.ner = ['-'] * len(doc) | ||||
| 
 | ||||
|         self.cand_to_gold = align([t.orth_ for t in tokens], annot_tuples[1]) | ||||
|         self.gold_to_cand = align(annot_tuples[1], [t.orth_ for t in tokens]) | ||||
|         self.cand_to_gold = align([t.orth_ for t in doc], words) | ||||
|         self.gold_to_cand = align(words, [t.orth_ for t in doc]) | ||||
| 
 | ||||
|         annot_tuples = (range(len(words)), words, tags, heads, deps, entities) | ||||
|         self.orig_annot = list(zip(*annot_tuples)) | ||||
| 
 | ||||
|         words = [w.orth_ for w in tokens] | ||||
|         for i, gold_i in enumerate(self.cand_to_gold): | ||||
|             if words[i].isspace(): | ||||
|             if doc[i].isspace(): | ||||
|                 self.tags[i] = 'SP' | ||||
|                 self.heads[i] = None | ||||
|                 self.labels[i] = None | ||||
|                 self.ner[i] = 'O' | ||||
|             if gold_i is None: | ||||
|             elif gold_i is None: | ||||
|                 pass | ||||
|             else: | ||||
|                 self.tags[i] = annot_tuples[2][gold_i] | ||||
|                 self.heads[i] = self.gold_to_cand[annot_tuples[3][gold_i]] | ||||
|                 self.labels[i] = annot_tuples[4][gold_i] | ||||
|                 self.ner[i] = annot_tuples[5][gold_i] | ||||
|                 self.tags[i] = tags[gold_i] | ||||
|                 self.heads[i] = self.gold_to_cand[heads[gold_i]] | ||||
|                 self.labels[i] = deps[gold_i] | ||||
|                 self.ner[i] = entities[gold_i] | ||||
| 
 | ||||
|         cycle = nonproj.contains_cycle(self.heads) | ||||
|         if cycle != None: | ||||
|             raise Exception("Cycle found: %s" % cycle) | ||||
| 
 | ||||
|         if make_projective: | ||||
|             proj_heads,_ = nonproj.PseudoProjectivity.projectivize(self.heads,self.labels) | ||||
|             proj_heads,_ = nonproj.PseudoProjectivity.projectivize(self.heads, self.labels) | ||||
|             self.heads = proj_heads | ||||
| 
 | ||||
|     def __len__(self): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user