mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	Fix bug when too many entity types. Fixes #2800
This commit is contained in:
		
							parent
							
								
									8809dc4514
								
							
						
					
					
						commit
						96fe314d8d
					
				|  | @ -747,7 +747,8 @@ cdef class Parser: | |||
| 
 | ||||
|     def transition_batch(self, states, float[:, ::1] scores): | ||||
|         cdef StateClass state | ||||
|         cdef int[500] is_valid # TODO: Unhack | ||||
|         cdef Pool mem = Pool() | ||||
|         is_valid = <int*>mem.alloc(self.moves.n_moves, sizeof(int)) | ||||
|         cdef float* c_scores = &scores[0, 0] | ||||
|         for state in states: | ||||
|             self.moves.set_valid(is_valid, state.c) | ||||
|  |  | |||
							
								
								
									
										34
									
								
								spacy/tests/regression/test_issue2800.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								spacy/tests/regression/test_issue2800.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,34 @@ | |||
| '''Test issue that arises when too many labels are added to NER model.''' | ||||
| import random | ||||
| from ...lang.en import English | ||||
| 
 | ||||
| def train_model(train_data, entity_types): | ||||
|     nlp = English(pipeline=[]) | ||||
| 
 | ||||
|     ner = nlp.create_pipe("ner") | ||||
|     nlp.add_pipe(ner) | ||||
| 
 | ||||
|     for entity_type in list(entity_types): | ||||
|         ner.add_label(entity_type) | ||||
| 
 | ||||
|     optimizer = nlp.begin_training() | ||||
| 
 | ||||
|     # Start training | ||||
|     for i in range(20): | ||||
|         losses = {} | ||||
|         index = 0 | ||||
|         random.shuffle(train_data) | ||||
| 
 | ||||
|         for statement, entities in train_data: | ||||
|             nlp.update([statement], [entities], sgd=optimizer, losses=losses, drop=0.5) | ||||
|     return nlp | ||||
| 
 | ||||
| 
 | ||||
| def test_train_with_many_entity_types(): | ||||
|     train_data = [] | ||||
|     train_data.extend([("One sentence", {"entities": []})]) | ||||
|     entity_types = [str(i) for i in range(1000)] | ||||
| 
 | ||||
|     model = train_model(train_data, entity_types) | ||||
| 
 | ||||
|      | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user