diff --git a/spacy/gold.pyx b/spacy/gold.pyx index 489b8b124..9a2e51d84 100644 --- a/spacy/gold.pyx +++ b/spacy/gold.pyx @@ -148,6 +148,9 @@ def read_json_file(loc, docs_filter=None): tags.append(token['tag']) heads.append(token['head'] + i) labels.append(token['dep']) + # Ensure ROOT label is case-insensitive + if labels[-1].lower() == 'root': + labels[-1] = 'ROOT' ner.append(token.get('ner', '-')) sents.append(( (ids, words, tags, heads, labels, ner), diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 4d89ad386..663ffd2cb 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -284,12 +284,14 @@ cdef int _get_root(int word, const GoldParseC* gold) nogil: cdef class ArcEager(TransitionSystem): @classmethod def get_labels(cls, gold_parses): - move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {'root': True}, - LEFT: {'root': True}, BREAK: {'root': True}} + move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {'ROOT': True}, + LEFT: {'ROOT': True}, BREAK: {'ROOT': True}} for raw_text, sents in gold_parses: for (ids, words, tags, heads, labels, iob), ctnts in sents: for child, head, label in zip(ids, heads, labels): - if label != 'root': + if label.upper() == 'ROOT': + label = 'ROOT' + if label != 'ROOT': if head < child: move_labels[RIGHT][label] = True elif head > child: @@ -302,8 +304,11 @@ cdef class ArcEager(TransitionSystem): gold.c.heads[i] = i gold.c.labels[i] = -1 else: + label = gold.labels[i] + if label.upper() == 'ROOT': + label = 'ROOT' gold.c.heads[i] = gold.heads[i] - gold.c.labels[i] = self.strings[gold.labels[i]] + gold.c.labels[i] = self.strings[label] for end, brackets in gold.brackets.items(): for start, label_strs in brackets.items(): gold.c.brackets[start][end] = 1