Fix passing of cats in gold.pyx

This commit is contained in:
Matthw Honnibal 2019-10-07 16:49:00 +02:00
parent a132da1558
commit f4040a98f0

View File

@ -56,10 +56,10 @@ def tags_to_entities(tags):
def merge_sents(sents):
m_deps = [[], [], [], [], [], []]
m_cats = {}
m_brackets = []
m_cats = sents.pop()
i = 0
for (ids, words, tags, heads, labels, ner), brackets in sents:
for (ids, words, tags, heads, labels, ner), (cats, brackets) in sents:
m_deps[0].extend(id_ + i for id_ in ids)
m_deps[1].extend(words)
m_deps[2].extend(tags)
@ -68,9 +68,9 @@ def merge_sents(sents):
m_deps[5].extend(ner)
m_brackets.extend((b["first"] + i, b["last"] + i, b["label"])
for b in brackets)
m_cats.update(cats)
i += len(ids)
m_deps.append(m_cats)
return [(m_deps, m_brackets)]
return [(m_deps, (m_cats, m_brackets))]
def align(tokens_a, tokens_b):
@ -201,7 +201,6 @@ class GoldCorpus(object):
n = 0
i = 0
for raw_text, paragraph_tuples in self.train_tuples:
cats = paragraph_tuples.pop()
for sent_tuples, brackets in paragraph_tuples:
n += len(sent_tuples[1])
if self.limit and i >= self.limit:
@ -253,11 +252,6 @@ class GoldCorpus(object):
return [nlp.make_doc(raw_text)], paragraph_tuples
else:
raw_text, paragraph_tuples = make_orth_variants(nlp, None, paragraph_tuples, orth_variant_level=orth_variant_level)
cats = paragraph_tuples.pop()
for i in range(len(paragraph_tuples)):
sent_tuples, brackets = paragraph_tuples[i]
sent_tuples = sent_tuples + [cats]
paragraph_tuples[i] = [sent_tuples, brackets]
return [Doc(nlp.vocab, words=add_noise(sent_tuples[1], noise_level))
for (sent_tuples, brackets) in paragraph_tuples], paragraph_tuples
@ -267,9 +261,9 @@ class GoldCorpus(object):
if len(docs) != len(paragraph_tuples):
n_annots = len(paragraph_tuples)
raise ValueError(Errors.E070.format(n_docs=len(docs), n_annots=n_annots))
return [GoldParse.from_annot_tuples(doc, sent_tuples,
return [GoldParse.from_annot_tuples(doc, sent_tuples, cats=cats,
make_projective=make_projective)
for doc, (sent_tuples, brackets)
for doc, (sent_tuples, (cats, brackets))
in zip(docs, paragraph_tuples)]
@ -285,7 +279,7 @@ def make_orth_variants(nlp, raw, paragraph_tuples, orth_variant_level=0.0):
# modify words in paragraph_tuples
variant_paragraph_tuples = []
for sent_tuples, brackets in paragraph_tuples:
ids, words, tags, heads, labels, ner, cats = sent_tuples
ids, words, tags, heads, labels, ner = sent_tuples
if lower:
words = [w.lower() for w in words]
# single variants
@ -314,7 +308,7 @@ def make_orth_variants(nlp, raw, paragraph_tuples, orth_variant_level=0.0):
pair_idx = pair.index(words[word_idx])
words[word_idx] = punct_choices[punct_idx][pair_idx]
variant_paragraph_tuples.append(((ids, words, tags, heads, labels, ner, cats), brackets))
variant_paragraph_tuples.append(((ids, words, tags, heads, labels, ner), brackets))
# modify raw to match variant_paragraph_tuples
if raw is not None:
variants = []
@ -333,7 +327,7 @@ def make_orth_variants(nlp, raw, paragraph_tuples, orth_variant_level=0.0):
variant_raw += raw[raw_idx]
raw_idx += 1
for sent_tuples, brackets in variant_paragraph_tuples:
ids, words, tags, heads, labels, ner, cats = sent_tuples
ids, words, tags, heads, labels, ner = sent_tuples
for word in words:
match_found = False
# add identical word
@ -404,6 +398,9 @@ def json_to_tuple(doc):
paragraphs = []
for paragraph in doc["paragraphs"]:
sents = []
cats = {}
for cat in paragraph.get("cats", {}):
cats[cat["label"]] = cat["value"]
for sent in paragraph["sentences"]:
words = []
ids = []
@ -423,11 +420,7 @@ def json_to_tuple(doc):
ner.append(token.get("ner", "-"))
sents.append([
[ids, words, tags, heads, labels, ner],
sent.get("brackets", [])])
cats = {}
for cat in paragraph.get("cats", {}):
cats[cat["label"]] = cat["value"]
sents.append(cats)
[cats, sent.get("brackets", [])]])
if sents:
yield [paragraph.get("raw", None), sents]
@ -540,8 +533,8 @@ cdef class GoldParse:
DOCS: https://spacy.io/api/goldparse
"""
@classmethod
def from_annot_tuples(cls, doc, annot_tuples, make_projective=False):
_, words, tags, heads, deps, entities, cats = annot_tuples
def from_annot_tuples(cls, doc, annot_tuples, cats=None, make_projective=False):
_, words, tags, heads, deps, entities = annot_tuples
return cls(doc, words=words, tags=tags, heads=heads, deps=deps,
entities=entities, cats=cats,
make_projective=make_projective)