Add support for text-classification labels to GoldParse

This commit is contained in:
Matthew Honnibal 2017-07-20 00:17:47 +02:00
parent 727481377e
commit 7ea50182a5
2 changed files with 20 additions and 1 deletions

View File

@ -29,6 +29,7 @@ cdef class GoldParse:
cdef public list ner
cdef public list ents
cdef public dict brackets
cdef public object cats
cdef readonly list cand_to_gold
cdef readonly list gold_to_cand

View File

@ -381,7 +381,8 @@ cdef class GoldParse:
make_projective=make_projective)
def __init__(self, doc, annot_tuples=None, words=None, tags=None, heads=None,
deps=None, entities=None, make_projective=False):
deps=None, entities=None, make_projective=False,
cats=tuple()):
"""Create a GoldParse.
doc (Doc): The document the annotations refer to.
@ -392,6 +393,12 @@ cdef class GoldParse:
entities (iterable): A sequence of named entity annotations, either as
BILUO tag strings, or as `(start_char, end_char, label)` tuples,
representing the entity positions.
cats (iterable): A sequence of labels for text classification. Each
label may be a string or an int, or a `(start_char, end_char, label)`
tuple, indicating that the label is applied to only part of the
document (usually a sentence). Unlike entity annotations, label
annotations can overlap, i.e. a single word can be covered by
multiple labelled spans.
RETURNS (GoldParse): The newly constructed object.
"""
if words is None:
@ -421,6 +428,17 @@ cdef class GoldParse:
self.c.has_dep = <int*>self.mem.alloc(len(doc), sizeof(int))
self.c.ner = <Transition*>self.mem.alloc(len(doc), sizeof(Transition))
self.cats = []
for item in cats:
if isinstance(item, int):
self.cats.append((0, len(doc.text), self.vocab.strings[item]))
elif isinstance(item, str):
self.cats.append((0, len(doc.text), item))
elif hasattr(item, '__len__') and len(item) == 3:
start_char, end_char, label = item
if isinstance(label, int):
label = self.vocab.strings[label]
self.cats.append((start_char, end_char, label))
self.words = [None] * len(doc)
self.tags = [None] * len(doc)
self.heads = [None] * len(doc)