From f4040a98f0292361049926523ea6f1ee3e12eb8e Mon Sep 17 00:00:00 2001 From: Matthw Honnibal Date: Mon, 7 Oct 2019 16:49:00 +0200 Subject: [PATCH] Fix passing of cats in gold.pyx --- spacy/gold.pyx | 37 +++++++++++++++---------------------- 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/spacy/gold.pyx b/spacy/gold.pyx index d10df2324..2fa789006 100644 --- a/spacy/gold.pyx +++ b/spacy/gold.pyx @@ -56,10 +56,10 @@ def tags_to_entities(tags): def merge_sents(sents): m_deps = [[], [], [], [], [], []] + m_cats = {} m_brackets = [] - m_cats = sents.pop() i = 0 - for (ids, words, tags, heads, labels, ner), brackets in sents: + for (ids, words, tags, heads, labels, ner), (cats, brackets) in sents: m_deps[0].extend(id_ + i for id_ in ids) m_deps[1].extend(words) m_deps[2].extend(tags) @@ -68,9 +68,9 @@ def merge_sents(sents): m_deps[5].extend(ner) m_brackets.extend((b["first"] + i, b["last"] + i, b["label"]) for b in brackets) + m_cats.update(cats) i += len(ids) - m_deps.append(m_cats) - return [(m_deps, m_brackets)] + return [(m_deps, (m_cats, m_brackets))] def align(tokens_a, tokens_b): @@ -201,7 +201,6 @@ class GoldCorpus(object): n = 0 i = 0 for raw_text, paragraph_tuples in self.train_tuples: - cats = paragraph_tuples.pop() for sent_tuples, brackets in paragraph_tuples: n += len(sent_tuples[1]) if self.limit and i >= self.limit: @@ -253,11 +252,6 @@ class GoldCorpus(object): return [nlp.make_doc(raw_text)], paragraph_tuples else: raw_text, paragraph_tuples = make_orth_variants(nlp, None, paragraph_tuples, orth_variant_level=orth_variant_level) - cats = paragraph_tuples.pop() - for i in range(len(paragraph_tuples)): - sent_tuples, brackets = paragraph_tuples[i] - sent_tuples = sent_tuples + [cats] - paragraph_tuples[i] = [sent_tuples, brackets] return [Doc(nlp.vocab, words=add_noise(sent_tuples[1], noise_level)) for (sent_tuples, brackets) in paragraph_tuples], paragraph_tuples @@ -267,9 +261,9 @@ class GoldCorpus(object): if len(docs) != len(paragraph_tuples): n_annots = len(paragraph_tuples) raise ValueError(Errors.E070.format(n_docs=len(docs), n_annots=n_annots)) - return [GoldParse.from_annot_tuples(doc, sent_tuples, + return [GoldParse.from_annot_tuples(doc, sent_tuples, cats=cats, make_projective=make_projective) - for doc, (sent_tuples, brackets) + for doc, (sent_tuples, (cats, brackets)) in zip(docs, paragraph_tuples)] @@ -285,7 +279,7 @@ def make_orth_variants(nlp, raw, paragraph_tuples, orth_variant_level=0.0): # modify words in paragraph_tuples variant_paragraph_tuples = [] for sent_tuples, brackets in paragraph_tuples: - ids, words, tags, heads, labels, ner, cats = sent_tuples + ids, words, tags, heads, labels, ner = sent_tuples if lower: words = [w.lower() for w in words] # single variants @@ -314,7 +308,7 @@ def make_orth_variants(nlp, raw, paragraph_tuples, orth_variant_level=0.0): pair_idx = pair.index(words[word_idx]) words[word_idx] = punct_choices[punct_idx][pair_idx] - variant_paragraph_tuples.append(((ids, words, tags, heads, labels, ner, cats), brackets)) + variant_paragraph_tuples.append(((ids, words, tags, heads, labels, ner), brackets)) # modify raw to match variant_paragraph_tuples if raw is not None: variants = [] @@ -333,7 +327,7 @@ def make_orth_variants(nlp, raw, paragraph_tuples, orth_variant_level=0.0): variant_raw += raw[raw_idx] raw_idx += 1 for sent_tuples, brackets in variant_paragraph_tuples: - ids, words, tags, heads, labels, ner, cats = sent_tuples + ids, words, tags, heads, labels, ner = sent_tuples for word in words: match_found = False # add identical word @@ -404,6 +398,9 @@ def json_to_tuple(doc): paragraphs = [] for paragraph in doc["paragraphs"]: sents = [] + cats = {} + for cat in paragraph.get("cats", {}): + cats[cat["label"]] = cat["value"] for sent in paragraph["sentences"]: words = [] ids = [] @@ -423,11 +420,7 @@ def json_to_tuple(doc): ner.append(token.get("ner", "-")) sents.append([ [ids, words, tags, heads, labels, ner], - sent.get("brackets", [])]) - cats = {} - for cat in paragraph.get("cats", {}): - cats[cat["label"]] = cat["value"] - sents.append(cats) + [cats, sent.get("brackets", [])]]) if sents: yield [paragraph.get("raw", None), sents] @@ -540,8 +533,8 @@ cdef class GoldParse: DOCS: https://spacy.io/api/goldparse """ @classmethod - def from_annot_tuples(cls, doc, annot_tuples, make_projective=False): - _, words, tags, heads, deps, entities, cats = annot_tuples + def from_annot_tuples(cls, doc, annot_tuples, cats=None, make_projective=False): + _, words, tags, heads, deps, entities = annot_tuples return cls(doc, words=words, tags=tags, heads=heads, deps=deps, entities=entities, cats=cats, make_projective=make_projective)