mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	WIP on add_label bug during NER training
Currently when a new label is introduced to NER during training, it causes the labels to be read in in an unexpected order. This invalidates the model.
This commit is contained in:
		
							parent
							
								
									33ba5066eb
								
							
						
					
					
						commit
						354458484c
					
				|  | @ -17,17 +17,13 @@ cdef class EntityRecognizer(Parser): | ||||||
|     feature_templates = get_feature_templates('ner') |     feature_templates = get_feature_templates('ner') | ||||||
| 
 | 
 | ||||||
|     def add_label(self, label): |     def add_label(self, label): | ||||||
|         for action in self.moves.action_types: |         Parser.add_label(self, label) | ||||||
|             self.moves.add_action(action, label) |  | ||||||
|             if 'actions' in self.cfg: |  | ||||||
|                 self.cfg['actions'].setdefault(action, |  | ||||||
|                                         {}).setdefault(label, True) |  | ||||||
|         if isinstance(label, basestring): |         if isinstance(label, basestring): | ||||||
|             label = self.vocab.strings[label] |             label = self.vocab.strings[label] | ||||||
|  |         # Set label into serializer. Super hacky :( | ||||||
|         for attr, freqs in self.vocab.serializer_freqs: |         for attr, freqs in self.vocab.serializer_freqs: | ||||||
|             if attr == ENT_TYPE and label not in freqs: |             if attr == ENT_TYPE and label not in freqs: | ||||||
|                 freqs.append([label, 1]) |                 freqs.append([label, 1]) | ||||||
|         # Super hacky :( |  | ||||||
|         self.vocab._serializer = None |         self.vocab._serializer = None | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -38,17 +34,13 @@ cdef class BeamEntityRecognizer(BeamParser): | ||||||
|     feature_templates = get_feature_templates('ner') |     feature_templates = get_feature_templates('ner') | ||||||
|      |      | ||||||
|     def add_label(self, label): |     def add_label(self, label): | ||||||
|         for action in self.moves.action_types: |         Parser.add_label(self, label) | ||||||
|             self.moves.add_action(action, label) |  | ||||||
|             if 'actions' in self.cfg: |  | ||||||
|                 self.cfg['actions'].setdefault(action, |  | ||||||
|                                         {}).setdefault(label, True) |  | ||||||
|         if isinstance(label, basestring): |         if isinstance(label, basestring): | ||||||
|             label = self.vocab.strings[label] |             label = self.vocab.strings[label] | ||||||
|  |         # Set label into serializer. Super hacky :( | ||||||
|         for attr, freqs in self.vocab.serializer_freqs: |         for attr, freqs in self.vocab.serializer_freqs: | ||||||
|             if attr == ENT_TYPE and label not in freqs: |             if attr == ENT_TYPE and label not in freqs: | ||||||
|                 freqs.append([label, 1]) |                 freqs.append([label, 1]) | ||||||
|         # Super hacky :( |  | ||||||
|         self.vocab._serializer = None |         self.vocab._serializer = None | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -58,11 +50,7 @@ cdef class DependencyParser(Parser): | ||||||
|     feature_templates = get_feature_templates('basic') |     feature_templates = get_feature_templates('basic') | ||||||
| 
 | 
 | ||||||
|     def add_label(self, label): |     def add_label(self, label): | ||||||
|         for action in self.moves.action_types: |         Parser.add_label(self, label) | ||||||
|             self.moves.add_action(action, label) |  | ||||||
|             if 'actions' in self.cfg: |  | ||||||
|                 self.cfg['actions'].setdefault(action, |  | ||||||
|                                         {}).setdefault(label, True) |  | ||||||
|         if isinstance(label, basestring): |         if isinstance(label, basestring): | ||||||
|             label = self.vocab.strings[label] |             label = self.vocab.strings[label] | ||||||
|         for attr, freqs in self.vocab.serializer_freqs: |         for attr, freqs in self.vocab.serializer_freqs: | ||||||
|  | @ -78,11 +66,7 @@ cdef class BeamDependencyParser(BeamParser): | ||||||
|     feature_templates = get_feature_templates('basic') |     feature_templates = get_feature_templates('basic') | ||||||
| 
 | 
 | ||||||
|     def add_label(self, label): |     def add_label(self, label): | ||||||
|         for action in self.moves.action_types: |         Parser.add_label(self, label) | ||||||
|             self.moves.add_action(action, label) |  | ||||||
|             if 'actions' in self.cfg: |  | ||||||
|                 self.cfg['actions'].setdefault(action, |  | ||||||
|                                         {}).setdefault(label, True) |  | ||||||
|         if isinstance(label, basestring): |         if isinstance(label, basestring): | ||||||
|             label = self.vocab.strings[label] |             label = self.vocab.strings[label] | ||||||
|         for attr, freqs in self.vocab.serializer_freqs: |         for attr, freqs in self.vocab.serializer_freqs: | ||||||
|  |  | ||||||
|  | @ -317,17 +317,20 @@ cdef class ArcEager(TransitionSystem): | ||||||
|     def get_actions(cls, **kwargs): |     def get_actions(cls, **kwargs): | ||||||
|         actions = kwargs.get('actions', |         actions = kwargs.get('actions', | ||||||
|                     { |                     { | ||||||
|                         SHIFT: {'': True}, |                         SHIFT: [''], | ||||||
|                         REDUCE: {'': True}, |                         REDUCE: [''], | ||||||
|                         RIGHT: {}, |                         RIGHT: [], | ||||||
|                         LEFT: {}, |                         LEFT: [], | ||||||
|                         BREAK: {'ROOT': True}}) |                         BREAK: ['ROOT']}) | ||||||
|  |         seen_actions = set() | ||||||
|         for label in kwargs.get('left_labels', []): |         for label in kwargs.get('left_labels', []): | ||||||
|             if label.upper() != 'ROOT': |             if label.upper() != 'ROOT': | ||||||
|                 actions[LEFT][label] = True |                 if (LEFT, label) not in seen_actions: | ||||||
|  |                     actions[LEFT].append(label) | ||||||
|         for label in kwargs.get('right_labels', []): |         for label in kwargs.get('right_labels', []): | ||||||
|             if label.upper() != 'ROOT': |             if label.upper() != 'ROOT': | ||||||
|                 actions[RIGHT][label] = True |                 if (RIGHT, label) not in seen_actions: | ||||||
|  |                     actions[RIGHT].append(label) | ||||||
| 
 | 
 | ||||||
|         for raw_text, sents in kwargs.get('gold_parses', []): |         for raw_text, sents in kwargs.get('gold_parses', []): | ||||||
|             for (ids, words, tags, heads, labels, iob), ctnts in sents: |             for (ids, words, tags, heads, labels, iob), ctnts in sents: | ||||||
|  | @ -336,9 +339,11 @@ cdef class ArcEager(TransitionSystem): | ||||||
|                         label = 'ROOT' |                         label = 'ROOT' | ||||||
|                     if label != 'ROOT': |                     if label != 'ROOT': | ||||||
|                         if head < child: |                         if head < child: | ||||||
|                             actions[RIGHT][label] = True |                             if (RIGHT, label) not in seen_actions: | ||||||
|  |                                 actions[RIGHT].append(label) | ||||||
|                         elif head > child: |                         elif head > child: | ||||||
|                             actions[LEFT][label] = True |                             if (LEFT, label) not in seen_actions: | ||||||
|  |                                 actions[LEFT].append(label) | ||||||
|         return actions |         return actions | ||||||
| 
 | 
 | ||||||
|     property action_types: |     property action_types: | ||||||
|  |  | ||||||
|  | @ -21,6 +21,7 @@ cdef enum: | ||||||
|     LAST |     LAST | ||||||
|     UNIT |     UNIT | ||||||
|     OUT |     OUT | ||||||
|  |     ISNT | ||||||
|     N_MOVES |     N_MOVES | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -31,6 +32,7 @@ MOVE_NAMES[IN] = 'I' | ||||||
| MOVE_NAMES[LAST] = 'L' | MOVE_NAMES[LAST] = 'L' | ||||||
| MOVE_NAMES[UNIT] = 'U' | MOVE_NAMES[UNIT] = 'U' | ||||||
| MOVE_NAMES[OUT] = 'O' | MOVE_NAMES[OUT] = 'O' | ||||||
|  | MOVE_NAMES[ISNT] = 'x' | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| cdef do_func_t[N_MOVES] do_funcs | cdef do_func_t[N_MOVES] do_funcs | ||||||
|  | @ -54,16 +56,20 @@ cdef class BiluoPushDown(TransitionSystem): | ||||||
|     def get_actions(cls, **kwargs): |     def get_actions(cls, **kwargs): | ||||||
|         actions = kwargs.get('actions', |         actions = kwargs.get('actions', | ||||||
|                     { |                     { | ||||||
|                         MISSING: {'': True}, |                         MISSING: [''], | ||||||
|                         BEGIN: {}, |                         BEGIN: [], | ||||||
|                         IN: {}, |                         IN: [], | ||||||
|                         LAST: {}, |                         LAST: [], | ||||||
|                         UNIT: {}, |                         UNIT: [], | ||||||
|                         OUT: {'': True} |                         OUT: [''] | ||||||
|                     }) |                     }) | ||||||
|  |         seen_entities = set() | ||||||
|         for entity_type in kwargs.get('entity_types', []): |         for entity_type in kwargs.get('entity_types', []): | ||||||
|  |             if entity_type in seen_entities: | ||||||
|  |                 continue | ||||||
|  |             seen_entities.add(entity_type) | ||||||
|             for action in (BEGIN, IN, LAST, UNIT): |             for action in (BEGIN, IN, LAST, UNIT): | ||||||
|                 actions[action][entity_type] = True |                 actions[action].append(entity_type) | ||||||
|         moves = ('M', 'B', 'I', 'L', 'U') |         moves = ('M', 'B', 'I', 'L', 'U') | ||||||
|         for raw_text, sents in kwargs.get('gold_parses', []): |         for raw_text, sents in kwargs.get('gold_parses', []): | ||||||
|             for (ids, words, tags, heads, labels, biluo), _ in sents: |             for (ids, words, tags, heads, labels, biluo), _ in sents: | ||||||
|  | @ -72,8 +78,10 @@ cdef class BiluoPushDown(TransitionSystem): | ||||||
|                         if ner_tag.count('-') != 1: |                         if ner_tag.count('-') != 1: | ||||||
|                             raise ValueError(ner_tag) |                             raise ValueError(ner_tag) | ||||||
|                         _, label = ner_tag.split('-') |                         _, label = ner_tag.split('-') | ||||||
|  |                         if label not in seen_entities: | ||||||
|  |                             seen_entities.add(label) | ||||||
|                             for move_str in ('B', 'I', 'L', 'U'): |                             for move_str in ('B', 'I', 'L', 'U'): | ||||||
|                             actions[moves.index(move_str)][label] = True |                                 actions[moves.index(move_str)].append(label) | ||||||
|         return actions |         return actions | ||||||
| 
 | 
 | ||||||
|     property action_types: |     property action_types: | ||||||
|  | @ -111,11 +119,17 @@ cdef class BiluoPushDown(TransitionSystem): | ||||||
|             label = 0 |             label = 0 | ||||||
|         elif '-' in name: |         elif '-' in name: | ||||||
|             move_str, label_str = name.split('-', 1) |             move_str, label_str = name.split('-', 1) | ||||||
|  |             # Hacky way to denote 'not this entity' | ||||||
|  |             if label_str.startswith('!'): | ||||||
|  |                 label_str = label_str[1:] | ||||||
|  |                 move_str = 'x' | ||||||
|             label = self.strings[label_str] |             label = self.strings[label_str] | ||||||
|         else: |         else: | ||||||
|             move_str = name |             move_str = name | ||||||
|             label = 0 |             label = 0 | ||||||
|         move = MOVE_NAMES.index(move_str) |         move = MOVE_NAMES.index(move_str) | ||||||
|  |         if move == ISNT: | ||||||
|  |             return Transition(clas=0, move=ISNT, label=label, score=0) | ||||||
|         for i in range(self.n_moves): |         for i in range(self.n_moves): | ||||||
|             if self.c[i].move == move and self.c[i].label == label: |             if self.c[i].move == move and self.c[i].label == label: | ||||||
|                 return self.c[i] |                 return self.c[i] | ||||||
|  | @ -225,6 +239,9 @@ cdef class Begin: | ||||||
|         elif g_act == BEGIN: |         elif g_act == BEGIN: | ||||||
|             # B, Gold B --> Label match |             # B, Gold B --> Label match | ||||||
|             return label != g_tag |             return label != g_tag | ||||||
|  |         # Support partial supervision in the form of "not this label" | ||||||
|  |         elif g_act == ISNT: | ||||||
|  |             return label == g_tag | ||||||
|         else: |         else: | ||||||
|             # B, Gold I --> False (P) |             # B, Gold I --> False (P) | ||||||
|             # B, Gold L --> False (P) |             # B, Gold L --> False (P) | ||||||
|  | @ -359,6 +376,9 @@ cdef class Unit: | ||||||
|         elif g_act == UNIT: |         elif g_act == UNIT: | ||||||
|             # U, Gold U --> True iff tag match |             # U, Gold U --> True iff tag match | ||||||
|             return label != g_tag |             return label != g_tag | ||||||
|  |         # Support partial supervision in the form of "not this label" | ||||||
|  |         elif g_act == ISNT: | ||||||
|  |             return label == g_tag | ||||||
|         else: |         else: | ||||||
|             # U, Gold B --> False |             # U, Gold B --> False | ||||||
|             # U, Gold I --> False |             # U, Gold I --> False | ||||||
|  | @ -388,7 +408,7 @@ cdef class Out: | ||||||
|         cdef int g_act = gold.ner[s.B(0)].move |         cdef int g_act = gold.ner[s.B(0)].move | ||||||
|         cdef int g_tag = gold.ner[s.B(0)].label |         cdef int g_tag = gold.ner[s.B(0)].label | ||||||
| 
 | 
 | ||||||
|         if g_act == MISSING: |         if g_act == MISSING or g_act == ISNT: | ||||||
|             return 0 |             return 0 | ||||||
|         elif g_act == BEGIN: |         elif g_act == BEGIN: | ||||||
|             # O, Gold B --> False |             # O, Gold B --> False | ||||||
|  |  | ||||||
|  | @ -52,7 +52,7 @@ from ._parse_features cimport fill_context | ||||||
| from .stateclass cimport StateClass | from .stateclass cimport StateClass | ||||||
| from ._state cimport StateC | from ._state cimport StateC | ||||||
| 
 | 
 | ||||||
| USE_FTRL = True | USE_FTRL = False | ||||||
| DEBUG = False | DEBUG = False | ||||||
| def set_debug(val): | def set_debug(val): | ||||||
|     global DEBUG |     global DEBUG | ||||||
|  | @ -152,6 +152,13 @@ cdef class Parser: | ||||||
|         # TODO: remove this shim when we don't have to support older data |         # TODO: remove this shim when we don't have to support older data | ||||||
|         if 'labels' in cfg and 'actions' not in cfg: |         if 'labels' in cfg and 'actions' not in cfg: | ||||||
|             cfg['actions'] = cfg.pop('labels') |             cfg['actions'] = cfg.pop('labels') | ||||||
|  |         # TODO: remove this shim when we don't have to support older data | ||||||
|  |         for action_name, labels in dict(cfg['actions']).items(): | ||||||
|  |             # We need this to be sorted | ||||||
|  |             if isinstance(labels, dict): | ||||||
|  |                 labels = list(sorted(labels.keys())) | ||||||
|  |             cfg['actions'][action_name] = labels | ||||||
|  |         print(cfg['actions']) | ||||||
|         self = cls(vocab, TransitionSystem=TransitionSystem, model=None, **cfg) |         self = cls(vocab, TransitionSystem=TransitionSystem, model=None, **cfg) | ||||||
|         if (path / 'model').exists(): |         if (path / 'model').exists(): | ||||||
|             self.model.load(str(path / 'model')) |             self.model.load(str(path / 'model')) | ||||||
|  | @ -362,6 +369,10 @@ cdef class Parser: | ||||||
|         # Doesn't set label into serializer -- subclasses override it to do that. |         # Doesn't set label into serializer -- subclasses override it to do that. | ||||||
|         for action in self.moves.action_types: |         for action in self.moves.action_types: | ||||||
|             self.moves.add_action(action, label) |             self.moves.add_action(action, label) | ||||||
|  |             if 'actions' in self.cfg: | ||||||
|  |                 # Important that the labels be stored as a list! We need the | ||||||
|  |                 # order, or the model goes out of synch | ||||||
|  |                 self.cfg['actions'].setdefault(str(action), []).append(label) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| cdef class StepwiseState: | cdef class StepwiseState: | ||||||
|  |  | ||||||
|  | @ -32,7 +32,7 @@ cdef class TransitionSystem: | ||||||
|         self.c = <Transition*>self.mem.alloc(self._size, sizeof(Transition)) |         self.c = <Transition*>self.mem.alloc(self._size, sizeof(Transition)) | ||||||
| 
 | 
 | ||||||
|         for action, label_strs in sorted(labels_by_action.items()): |         for action, label_strs in sorted(labels_by_action.items()): | ||||||
|             for label_str in sorted(label_strs): |             for label_str in label_strs: | ||||||
|                 self.add_action(int(action), label_str) |                 self.add_action(int(action), label_str) | ||||||
|         self.root_label = self.strings['ROOT'] |         self.root_label = self.strings['ROOT'] | ||||||
|         self.freqs = {} if _freqs is None else _freqs |         self.freqs = {} if _freqs is None else _freqs | ||||||
|  | @ -105,5 +105,6 @@ cdef class TransitionSystem: | ||||||
|             self.c = <Transition*>self.mem.realloc(self.c, self._size * sizeof(self.c[0])) |             self.c = <Transition*>self.mem.realloc(self.c, self._size * sizeof(self.c[0])) | ||||||
| 
 | 
 | ||||||
|         self.c[self.n_moves] = self.init_transition(self.n_moves, action, label) |         self.c[self.n_moves] = self.init_transition(self.n_moves, action, label) | ||||||
|  |         print("Add action", action, self.strings[label], self.n_moves) | ||||||
|         self.n_moves += 1 |         self.n_moves += 1 | ||||||
|         return 1 |         return 1 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user