diff --git a/spacy/pipeline.pyx b/spacy/pipeline.pyx index ea8221cff..32cb3a7d7 100644 --- a/spacy/pipeline.pyx +++ b/spacy/pipeline.pyx @@ -17,17 +17,13 @@ cdef class EntityRecognizer(Parser): feature_templates = get_feature_templates('ner') def add_label(self, label): - for action in self.moves.action_types: - self.moves.add_action(action, label) - if 'actions' in self.cfg: - self.cfg['actions'].setdefault(action, - {}).setdefault(label, True) + Parser.add_label(self, label) if isinstance(label, basestring): label = self.vocab.strings[label] + # Set label into serializer. Super hacky :( for attr, freqs in self.vocab.serializer_freqs: if attr == ENT_TYPE and label not in freqs: freqs.append([label, 1]) - # Super hacky :( self.vocab._serializer = None @@ -36,19 +32,15 @@ cdef class BeamEntityRecognizer(BeamParser): TransitionSystem = BiluoPushDown feature_templates = get_feature_templates('ner') - + def add_label(self, label): - for action in self.moves.action_types: - self.moves.add_action(action, label) - if 'actions' in self.cfg: - self.cfg['actions'].setdefault(action, - {}).setdefault(label, True) + Parser.add_label(self, label) if isinstance(label, basestring): label = self.vocab.strings[label] + # Set label into serializer. Super hacky :( for attr, freqs in self.vocab.serializer_freqs: if attr == ENT_TYPE and label not in freqs: freqs.append([label, 1]) - # Super hacky :( self.vocab._serializer = None @@ -58,11 +50,7 @@ cdef class DependencyParser(Parser): feature_templates = get_feature_templates('basic') def add_label(self, label): - for action in self.moves.action_types: - self.moves.add_action(action, label) - if 'actions' in self.cfg: - self.cfg['actions'].setdefault(action, - {}).setdefault(label, True) + Parser.add_label(self, label) if isinstance(label, basestring): label = self.vocab.strings[label] for attr, freqs in self.vocab.serializer_freqs: @@ -78,11 +66,7 @@ cdef class BeamDependencyParser(BeamParser): feature_templates = get_feature_templates('basic') def add_label(self, label): - for action in self.moves.action_types: - self.moves.add_action(action, label) - if 'actions' in self.cfg: - self.cfg['actions'].setdefault(action, - {}).setdefault(label, True) + Parser.add_label(self, label) if isinstance(label, basestring): label = self.vocab.strings[label] for attr, freqs in self.vocab.serializer_freqs: diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 93bc21e22..eac71eaa8 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -317,17 +317,20 @@ cdef class ArcEager(TransitionSystem): def get_actions(cls, **kwargs): actions = kwargs.get('actions', { - SHIFT: {'': True}, - REDUCE: {'': True}, - RIGHT: {}, - LEFT: {}, - BREAK: {'ROOT': True}}) + SHIFT: [''], + REDUCE: [''], + RIGHT: [], + LEFT: [], + BREAK: ['ROOT']}) + seen_actions = set() for label in kwargs.get('left_labels', []): if label.upper() != 'ROOT': - actions[LEFT][label] = True + if (LEFT, label) not in seen_actions: + actions[LEFT].append(label) for label in kwargs.get('right_labels', []): if label.upper() != 'ROOT': - actions[RIGHT][label] = True + if (RIGHT, label) not in seen_actions: + actions[RIGHT].append(label) for raw_text, sents in kwargs.get('gold_parses', []): for (ids, words, tags, heads, labels, iob), ctnts in sents: @@ -336,9 +339,11 @@ cdef class ArcEager(TransitionSystem): label = 'ROOT' if label != 'ROOT': if head < child: - actions[RIGHT][label] = True + if (RIGHT, label) not in seen_actions: + actions[RIGHT].append(label) elif head > child: - actions[LEFT][label] = True + if (LEFT, label) not in seen_actions: + actions[LEFT].append(label) return actions property action_types: diff --git a/spacy/syntax/ner.pyx b/spacy/syntax/ner.pyx index 736cc0039..1090f546f 100644 --- a/spacy/syntax/ner.pyx +++ b/spacy/syntax/ner.pyx @@ -21,6 +21,7 @@ cdef enum: LAST UNIT OUT + ISNT N_MOVES @@ -31,6 +32,7 @@ MOVE_NAMES[IN] = 'I' MOVE_NAMES[LAST] = 'L' MOVE_NAMES[UNIT] = 'U' MOVE_NAMES[OUT] = 'O' +MOVE_NAMES[ISNT] = 'x' cdef do_func_t[N_MOVES] do_funcs @@ -54,16 +56,20 @@ cdef class BiluoPushDown(TransitionSystem): def get_actions(cls, **kwargs): actions = kwargs.get('actions', { - MISSING: {'': True}, - BEGIN: {}, - IN: {}, - LAST: {}, - UNIT: {}, - OUT: {'': True} + MISSING: [''], + BEGIN: [], + IN: [], + LAST: [], + UNIT: [], + OUT: [''] }) + seen_entities = set() for entity_type in kwargs.get('entity_types', []): + if entity_type in seen_entities: + continue + seen_entities.add(entity_type) for action in (BEGIN, IN, LAST, UNIT): - actions[action][entity_type] = True + actions[action].append(entity_type) moves = ('M', 'B', 'I', 'L', 'U') for raw_text, sents in kwargs.get('gold_parses', []): for (ids, words, tags, heads, labels, biluo), _ in sents: @@ -72,8 +78,10 @@ cdef class BiluoPushDown(TransitionSystem): if ner_tag.count('-') != 1: raise ValueError(ner_tag) _, label = ner_tag.split('-') - for move_str in ('B', 'I', 'L', 'U'): - actions[moves.index(move_str)][label] = True + if label not in seen_entities: + seen_entities.add(label) + for move_str in ('B', 'I', 'L', 'U'): + actions[moves.index(move_str)].append(label) return actions property action_types: @@ -111,11 +119,17 @@ cdef class BiluoPushDown(TransitionSystem): label = 0 elif '-' in name: move_str, label_str = name.split('-', 1) + # Hacky way to denote 'not this entity' + if label_str.startswith('!'): + label_str = label_str[1:] + move_str = 'x' label = self.strings[label_str] else: move_str = name label = 0 move = MOVE_NAMES.index(move_str) + if move == ISNT: + return Transition(clas=0, move=ISNT, label=label, score=0) for i in range(self.n_moves): if self.c[i].move == move and self.c[i].label == label: return self.c[i] @@ -225,6 +239,9 @@ cdef class Begin: elif g_act == BEGIN: # B, Gold B --> Label match return label != g_tag + # Support partial supervision in the form of "not this label" + elif g_act == ISNT: + return label == g_tag else: # B, Gold I --> False (P) # B, Gold L --> False (P) @@ -359,6 +376,9 @@ cdef class Unit: elif g_act == UNIT: # U, Gold U --> True iff tag match return label != g_tag + # Support partial supervision in the form of "not this label" + elif g_act == ISNT: + return label == g_tag else: # U, Gold B --> False # U, Gold I --> False @@ -388,7 +408,7 @@ cdef class Out: cdef int g_act = gold.ner[s.B(0)].move cdef int g_tag = gold.ner[s.B(0)].label - if g_act == MISSING: + if g_act == MISSING or g_act == ISNT: return 0 elif g_act == BEGIN: # O, Gold B --> False diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index 344ac5568..969c4ef06 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -52,7 +52,7 @@ from ._parse_features cimport fill_context from .stateclass cimport StateClass from ._state cimport StateC -USE_FTRL = True +USE_FTRL = False DEBUG = False def set_debug(val): global DEBUG @@ -152,6 +152,13 @@ cdef class Parser: # TODO: remove this shim when we don't have to support older data if 'labels' in cfg and 'actions' not in cfg: cfg['actions'] = cfg.pop('labels') + # TODO: remove this shim when we don't have to support older data + for action_name, labels in dict(cfg['actions']).items(): + # We need this to be sorted + if isinstance(labels, dict): + labels = list(sorted(labels.keys())) + cfg['actions'][action_name] = labels + print(cfg['actions']) self = cls(vocab, TransitionSystem=TransitionSystem, model=None, **cfg) if (path / 'model').exists(): self.model.load(str(path / 'model')) @@ -362,6 +369,10 @@ cdef class Parser: # Doesn't set label into serializer -- subclasses override it to do that. for action in self.moves.action_types: self.moves.add_action(action, label) + if 'actions' in self.cfg: + # Important that the labels be stored as a list! We need the + # order, or the model goes out of synch + self.cfg['actions'].setdefault(str(action), []).append(label) cdef class StepwiseState: diff --git a/spacy/syntax/transition_system.pyx b/spacy/syntax/transition_system.pyx index 7e5577885..e6a96062b 100644 --- a/spacy/syntax/transition_system.pyx +++ b/spacy/syntax/transition_system.pyx @@ -32,7 +32,7 @@ cdef class TransitionSystem: self.c = self.mem.alloc(self._size, sizeof(Transition)) for action, label_strs in sorted(labels_by_action.items()): - for label_str in sorted(label_strs): + for label_str in label_strs: self.add_action(int(action), label_str) self.root_label = self.strings['ROOT'] self.freqs = {} if _freqs is None else _freqs @@ -105,5 +105,6 @@ cdef class TransitionSystem: self.c = self.mem.realloc(self.c, self._size * sizeof(self.c[0])) self.c[self.n_moves] = self.init_transition(self.n_moves, action, label) + print("Add action", action, self.strings[label], self.n_moves) self.n_moves += 1 return 1