mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
WIP on add_label bug during NER training
Currently when a new label is introduced to NER during training, it causes the labels to be read in in an unexpected order. This invalidates the model.
This commit is contained in:
parent
33ba5066eb
commit
354458484c
|
@ -17,17 +17,13 @@ cdef class EntityRecognizer(Parser):
|
||||||
feature_templates = get_feature_templates('ner')
|
feature_templates = get_feature_templates('ner')
|
||||||
|
|
||||||
def add_label(self, label):
|
def add_label(self, label):
|
||||||
for action in self.moves.action_types:
|
Parser.add_label(self, label)
|
||||||
self.moves.add_action(action, label)
|
|
||||||
if 'actions' in self.cfg:
|
|
||||||
self.cfg['actions'].setdefault(action,
|
|
||||||
{}).setdefault(label, True)
|
|
||||||
if isinstance(label, basestring):
|
if isinstance(label, basestring):
|
||||||
label = self.vocab.strings[label]
|
label = self.vocab.strings[label]
|
||||||
|
# Set label into serializer. Super hacky :(
|
||||||
for attr, freqs in self.vocab.serializer_freqs:
|
for attr, freqs in self.vocab.serializer_freqs:
|
||||||
if attr == ENT_TYPE and label not in freqs:
|
if attr == ENT_TYPE and label not in freqs:
|
||||||
freqs.append([label, 1])
|
freqs.append([label, 1])
|
||||||
# Super hacky :(
|
|
||||||
self.vocab._serializer = None
|
self.vocab._serializer = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,19 +32,15 @@ cdef class BeamEntityRecognizer(BeamParser):
|
||||||
TransitionSystem = BiluoPushDown
|
TransitionSystem = BiluoPushDown
|
||||||
|
|
||||||
feature_templates = get_feature_templates('ner')
|
feature_templates = get_feature_templates('ner')
|
||||||
|
|
||||||
def add_label(self, label):
|
def add_label(self, label):
|
||||||
for action in self.moves.action_types:
|
Parser.add_label(self, label)
|
||||||
self.moves.add_action(action, label)
|
|
||||||
if 'actions' in self.cfg:
|
|
||||||
self.cfg['actions'].setdefault(action,
|
|
||||||
{}).setdefault(label, True)
|
|
||||||
if isinstance(label, basestring):
|
if isinstance(label, basestring):
|
||||||
label = self.vocab.strings[label]
|
label = self.vocab.strings[label]
|
||||||
|
# Set label into serializer. Super hacky :(
|
||||||
for attr, freqs in self.vocab.serializer_freqs:
|
for attr, freqs in self.vocab.serializer_freqs:
|
||||||
if attr == ENT_TYPE and label not in freqs:
|
if attr == ENT_TYPE and label not in freqs:
|
||||||
freqs.append([label, 1])
|
freqs.append([label, 1])
|
||||||
# Super hacky :(
|
|
||||||
self.vocab._serializer = None
|
self.vocab._serializer = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -58,11 +50,7 @@ cdef class DependencyParser(Parser):
|
||||||
feature_templates = get_feature_templates('basic')
|
feature_templates = get_feature_templates('basic')
|
||||||
|
|
||||||
def add_label(self, label):
|
def add_label(self, label):
|
||||||
for action in self.moves.action_types:
|
Parser.add_label(self, label)
|
||||||
self.moves.add_action(action, label)
|
|
||||||
if 'actions' in self.cfg:
|
|
||||||
self.cfg['actions'].setdefault(action,
|
|
||||||
{}).setdefault(label, True)
|
|
||||||
if isinstance(label, basestring):
|
if isinstance(label, basestring):
|
||||||
label = self.vocab.strings[label]
|
label = self.vocab.strings[label]
|
||||||
for attr, freqs in self.vocab.serializer_freqs:
|
for attr, freqs in self.vocab.serializer_freqs:
|
||||||
|
@ -78,11 +66,7 @@ cdef class BeamDependencyParser(BeamParser):
|
||||||
feature_templates = get_feature_templates('basic')
|
feature_templates = get_feature_templates('basic')
|
||||||
|
|
||||||
def add_label(self, label):
|
def add_label(self, label):
|
||||||
for action in self.moves.action_types:
|
Parser.add_label(self, label)
|
||||||
self.moves.add_action(action, label)
|
|
||||||
if 'actions' in self.cfg:
|
|
||||||
self.cfg['actions'].setdefault(action,
|
|
||||||
{}).setdefault(label, True)
|
|
||||||
if isinstance(label, basestring):
|
if isinstance(label, basestring):
|
||||||
label = self.vocab.strings[label]
|
label = self.vocab.strings[label]
|
||||||
for attr, freqs in self.vocab.serializer_freqs:
|
for attr, freqs in self.vocab.serializer_freqs:
|
||||||
|
|
|
@ -317,17 +317,20 @@ cdef class ArcEager(TransitionSystem):
|
||||||
def get_actions(cls, **kwargs):
|
def get_actions(cls, **kwargs):
|
||||||
actions = kwargs.get('actions',
|
actions = kwargs.get('actions',
|
||||||
{
|
{
|
||||||
SHIFT: {'': True},
|
SHIFT: [''],
|
||||||
REDUCE: {'': True},
|
REDUCE: [''],
|
||||||
RIGHT: {},
|
RIGHT: [],
|
||||||
LEFT: {},
|
LEFT: [],
|
||||||
BREAK: {'ROOT': True}})
|
BREAK: ['ROOT']})
|
||||||
|
seen_actions = set()
|
||||||
for label in kwargs.get('left_labels', []):
|
for label in kwargs.get('left_labels', []):
|
||||||
if label.upper() != 'ROOT':
|
if label.upper() != 'ROOT':
|
||||||
actions[LEFT][label] = True
|
if (LEFT, label) not in seen_actions:
|
||||||
|
actions[LEFT].append(label)
|
||||||
for label in kwargs.get('right_labels', []):
|
for label in kwargs.get('right_labels', []):
|
||||||
if label.upper() != 'ROOT':
|
if label.upper() != 'ROOT':
|
||||||
actions[RIGHT][label] = True
|
if (RIGHT, label) not in seen_actions:
|
||||||
|
actions[RIGHT].append(label)
|
||||||
|
|
||||||
for raw_text, sents in kwargs.get('gold_parses', []):
|
for raw_text, sents in kwargs.get('gold_parses', []):
|
||||||
for (ids, words, tags, heads, labels, iob), ctnts in sents:
|
for (ids, words, tags, heads, labels, iob), ctnts in sents:
|
||||||
|
@ -336,9 +339,11 @@ cdef class ArcEager(TransitionSystem):
|
||||||
label = 'ROOT'
|
label = 'ROOT'
|
||||||
if label != 'ROOT':
|
if label != 'ROOT':
|
||||||
if head < child:
|
if head < child:
|
||||||
actions[RIGHT][label] = True
|
if (RIGHT, label) not in seen_actions:
|
||||||
|
actions[RIGHT].append(label)
|
||||||
elif head > child:
|
elif head > child:
|
||||||
actions[LEFT][label] = True
|
if (LEFT, label) not in seen_actions:
|
||||||
|
actions[LEFT].append(label)
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
property action_types:
|
property action_types:
|
||||||
|
|
|
@ -21,6 +21,7 @@ cdef enum:
|
||||||
LAST
|
LAST
|
||||||
UNIT
|
UNIT
|
||||||
OUT
|
OUT
|
||||||
|
ISNT
|
||||||
N_MOVES
|
N_MOVES
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,6 +32,7 @@ MOVE_NAMES[IN] = 'I'
|
||||||
MOVE_NAMES[LAST] = 'L'
|
MOVE_NAMES[LAST] = 'L'
|
||||||
MOVE_NAMES[UNIT] = 'U'
|
MOVE_NAMES[UNIT] = 'U'
|
||||||
MOVE_NAMES[OUT] = 'O'
|
MOVE_NAMES[OUT] = 'O'
|
||||||
|
MOVE_NAMES[ISNT] = 'x'
|
||||||
|
|
||||||
|
|
||||||
cdef do_func_t[N_MOVES] do_funcs
|
cdef do_func_t[N_MOVES] do_funcs
|
||||||
|
@ -54,16 +56,20 @@ cdef class BiluoPushDown(TransitionSystem):
|
||||||
def get_actions(cls, **kwargs):
|
def get_actions(cls, **kwargs):
|
||||||
actions = kwargs.get('actions',
|
actions = kwargs.get('actions',
|
||||||
{
|
{
|
||||||
MISSING: {'': True},
|
MISSING: [''],
|
||||||
BEGIN: {},
|
BEGIN: [],
|
||||||
IN: {},
|
IN: [],
|
||||||
LAST: {},
|
LAST: [],
|
||||||
UNIT: {},
|
UNIT: [],
|
||||||
OUT: {'': True}
|
OUT: ['']
|
||||||
})
|
})
|
||||||
|
seen_entities = set()
|
||||||
for entity_type in kwargs.get('entity_types', []):
|
for entity_type in kwargs.get('entity_types', []):
|
||||||
|
if entity_type in seen_entities:
|
||||||
|
continue
|
||||||
|
seen_entities.add(entity_type)
|
||||||
for action in (BEGIN, IN, LAST, UNIT):
|
for action in (BEGIN, IN, LAST, UNIT):
|
||||||
actions[action][entity_type] = True
|
actions[action].append(entity_type)
|
||||||
moves = ('M', 'B', 'I', 'L', 'U')
|
moves = ('M', 'B', 'I', 'L', 'U')
|
||||||
for raw_text, sents in kwargs.get('gold_parses', []):
|
for raw_text, sents in kwargs.get('gold_parses', []):
|
||||||
for (ids, words, tags, heads, labels, biluo), _ in sents:
|
for (ids, words, tags, heads, labels, biluo), _ in sents:
|
||||||
|
@ -72,8 +78,10 @@ cdef class BiluoPushDown(TransitionSystem):
|
||||||
if ner_tag.count('-') != 1:
|
if ner_tag.count('-') != 1:
|
||||||
raise ValueError(ner_tag)
|
raise ValueError(ner_tag)
|
||||||
_, label = ner_tag.split('-')
|
_, label = ner_tag.split('-')
|
||||||
for move_str in ('B', 'I', 'L', 'U'):
|
if label not in seen_entities:
|
||||||
actions[moves.index(move_str)][label] = True
|
seen_entities.add(label)
|
||||||
|
for move_str in ('B', 'I', 'L', 'U'):
|
||||||
|
actions[moves.index(move_str)].append(label)
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
property action_types:
|
property action_types:
|
||||||
|
@ -111,11 +119,17 @@ cdef class BiluoPushDown(TransitionSystem):
|
||||||
label = 0
|
label = 0
|
||||||
elif '-' in name:
|
elif '-' in name:
|
||||||
move_str, label_str = name.split('-', 1)
|
move_str, label_str = name.split('-', 1)
|
||||||
|
# Hacky way to denote 'not this entity'
|
||||||
|
if label_str.startswith('!'):
|
||||||
|
label_str = label_str[1:]
|
||||||
|
move_str = 'x'
|
||||||
label = self.strings[label_str]
|
label = self.strings[label_str]
|
||||||
else:
|
else:
|
||||||
move_str = name
|
move_str = name
|
||||||
label = 0
|
label = 0
|
||||||
move = MOVE_NAMES.index(move_str)
|
move = MOVE_NAMES.index(move_str)
|
||||||
|
if move == ISNT:
|
||||||
|
return Transition(clas=0, move=ISNT, label=label, score=0)
|
||||||
for i in range(self.n_moves):
|
for i in range(self.n_moves):
|
||||||
if self.c[i].move == move and self.c[i].label == label:
|
if self.c[i].move == move and self.c[i].label == label:
|
||||||
return self.c[i]
|
return self.c[i]
|
||||||
|
@ -225,6 +239,9 @@ cdef class Begin:
|
||||||
elif g_act == BEGIN:
|
elif g_act == BEGIN:
|
||||||
# B, Gold B --> Label match
|
# B, Gold B --> Label match
|
||||||
return label != g_tag
|
return label != g_tag
|
||||||
|
# Support partial supervision in the form of "not this label"
|
||||||
|
elif g_act == ISNT:
|
||||||
|
return label == g_tag
|
||||||
else:
|
else:
|
||||||
# B, Gold I --> False (P)
|
# B, Gold I --> False (P)
|
||||||
# B, Gold L --> False (P)
|
# B, Gold L --> False (P)
|
||||||
|
@ -359,6 +376,9 @@ cdef class Unit:
|
||||||
elif g_act == UNIT:
|
elif g_act == UNIT:
|
||||||
# U, Gold U --> True iff tag match
|
# U, Gold U --> True iff tag match
|
||||||
return label != g_tag
|
return label != g_tag
|
||||||
|
# Support partial supervision in the form of "not this label"
|
||||||
|
elif g_act == ISNT:
|
||||||
|
return label == g_tag
|
||||||
else:
|
else:
|
||||||
# U, Gold B --> False
|
# U, Gold B --> False
|
||||||
# U, Gold I --> False
|
# U, Gold I --> False
|
||||||
|
@ -388,7 +408,7 @@ cdef class Out:
|
||||||
cdef int g_act = gold.ner[s.B(0)].move
|
cdef int g_act = gold.ner[s.B(0)].move
|
||||||
cdef int g_tag = gold.ner[s.B(0)].label
|
cdef int g_tag = gold.ner[s.B(0)].label
|
||||||
|
|
||||||
if g_act == MISSING:
|
if g_act == MISSING or g_act == ISNT:
|
||||||
return 0
|
return 0
|
||||||
elif g_act == BEGIN:
|
elif g_act == BEGIN:
|
||||||
# O, Gold B --> False
|
# O, Gold B --> False
|
||||||
|
|
|
@ -52,7 +52,7 @@ from ._parse_features cimport fill_context
|
||||||
from .stateclass cimport StateClass
|
from .stateclass cimport StateClass
|
||||||
from ._state cimport StateC
|
from ._state cimport StateC
|
||||||
|
|
||||||
USE_FTRL = True
|
USE_FTRL = False
|
||||||
DEBUG = False
|
DEBUG = False
|
||||||
def set_debug(val):
|
def set_debug(val):
|
||||||
global DEBUG
|
global DEBUG
|
||||||
|
@ -152,6 +152,13 @@ cdef class Parser:
|
||||||
# TODO: remove this shim when we don't have to support older data
|
# TODO: remove this shim when we don't have to support older data
|
||||||
if 'labels' in cfg and 'actions' not in cfg:
|
if 'labels' in cfg and 'actions' not in cfg:
|
||||||
cfg['actions'] = cfg.pop('labels')
|
cfg['actions'] = cfg.pop('labels')
|
||||||
|
# TODO: remove this shim when we don't have to support older data
|
||||||
|
for action_name, labels in dict(cfg['actions']).items():
|
||||||
|
# We need this to be sorted
|
||||||
|
if isinstance(labels, dict):
|
||||||
|
labels = list(sorted(labels.keys()))
|
||||||
|
cfg['actions'][action_name] = labels
|
||||||
|
print(cfg['actions'])
|
||||||
self = cls(vocab, TransitionSystem=TransitionSystem, model=None, **cfg)
|
self = cls(vocab, TransitionSystem=TransitionSystem, model=None, **cfg)
|
||||||
if (path / 'model').exists():
|
if (path / 'model').exists():
|
||||||
self.model.load(str(path / 'model'))
|
self.model.load(str(path / 'model'))
|
||||||
|
@ -362,6 +369,10 @@ cdef class Parser:
|
||||||
# Doesn't set label into serializer -- subclasses override it to do that.
|
# Doesn't set label into serializer -- subclasses override it to do that.
|
||||||
for action in self.moves.action_types:
|
for action in self.moves.action_types:
|
||||||
self.moves.add_action(action, label)
|
self.moves.add_action(action, label)
|
||||||
|
if 'actions' in self.cfg:
|
||||||
|
# Important that the labels be stored as a list! We need the
|
||||||
|
# order, or the model goes out of synch
|
||||||
|
self.cfg['actions'].setdefault(str(action), []).append(label)
|
||||||
|
|
||||||
|
|
||||||
cdef class StepwiseState:
|
cdef class StepwiseState:
|
||||||
|
|
|
@ -32,7 +32,7 @@ cdef class TransitionSystem:
|
||||||
self.c = <Transition*>self.mem.alloc(self._size, sizeof(Transition))
|
self.c = <Transition*>self.mem.alloc(self._size, sizeof(Transition))
|
||||||
|
|
||||||
for action, label_strs in sorted(labels_by_action.items()):
|
for action, label_strs in sorted(labels_by_action.items()):
|
||||||
for label_str in sorted(label_strs):
|
for label_str in label_strs:
|
||||||
self.add_action(int(action), label_str)
|
self.add_action(int(action), label_str)
|
||||||
self.root_label = self.strings['ROOT']
|
self.root_label = self.strings['ROOT']
|
||||||
self.freqs = {} if _freqs is None else _freqs
|
self.freqs = {} if _freqs is None else _freqs
|
||||||
|
@ -105,5 +105,6 @@ cdef class TransitionSystem:
|
||||||
self.c = <Transition*>self.mem.realloc(self.c, self._size * sizeof(self.c[0]))
|
self.c = <Transition*>self.mem.realloc(self.c, self._size * sizeof(self.c[0]))
|
||||||
|
|
||||||
self.c[self.n_moves] = self.init_transition(self.n_moves, action, label)
|
self.c[self.n_moves] = self.init_transition(self.n_moves, action, label)
|
||||||
|
print("Add action", action, self.strings[label], self.n_moves)
|
||||||
self.n_moves += 1
|
self.n_moves += 1
|
||||||
return 1
|
return 1
|
||||||
|
|
Loading…
Reference in New Issue
Block a user