mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-20 21:40:35 +03:00
Fix passing of cats in gold.pyx
This commit is contained in:
parent
a132da1558
commit
f4040a98f0
|
@ -56,10 +56,10 @@ def tags_to_entities(tags):
|
|||
|
||||
def merge_sents(sents):
|
||||
m_deps = [[], [], [], [], [], []]
|
||||
m_cats = {}
|
||||
m_brackets = []
|
||||
m_cats = sents.pop()
|
||||
i = 0
|
||||
for (ids, words, tags, heads, labels, ner), brackets in sents:
|
||||
for (ids, words, tags, heads, labels, ner), (cats, brackets) in sents:
|
||||
m_deps[0].extend(id_ + i for id_ in ids)
|
||||
m_deps[1].extend(words)
|
||||
m_deps[2].extend(tags)
|
||||
|
@ -68,9 +68,9 @@ def merge_sents(sents):
|
|||
m_deps[5].extend(ner)
|
||||
m_brackets.extend((b["first"] + i, b["last"] + i, b["label"])
|
||||
for b in brackets)
|
||||
m_cats.update(cats)
|
||||
i += len(ids)
|
||||
m_deps.append(m_cats)
|
||||
return [(m_deps, m_brackets)]
|
||||
return [(m_deps, (m_cats, m_brackets))]
|
||||
|
||||
|
||||
def align(tokens_a, tokens_b):
|
||||
|
@ -201,7 +201,6 @@ class GoldCorpus(object):
|
|||
n = 0
|
||||
i = 0
|
||||
for raw_text, paragraph_tuples in self.train_tuples:
|
||||
cats = paragraph_tuples.pop()
|
||||
for sent_tuples, brackets in paragraph_tuples:
|
||||
n += len(sent_tuples[1])
|
||||
if self.limit and i >= self.limit:
|
||||
|
@ -253,11 +252,6 @@ class GoldCorpus(object):
|
|||
return [nlp.make_doc(raw_text)], paragraph_tuples
|
||||
else:
|
||||
raw_text, paragraph_tuples = make_orth_variants(nlp, None, paragraph_tuples, orth_variant_level=orth_variant_level)
|
||||
cats = paragraph_tuples.pop()
|
||||
for i in range(len(paragraph_tuples)):
|
||||
sent_tuples, brackets = paragraph_tuples[i]
|
||||
sent_tuples = sent_tuples + [cats]
|
||||
paragraph_tuples[i] = [sent_tuples, brackets]
|
||||
return [Doc(nlp.vocab, words=add_noise(sent_tuples[1], noise_level))
|
||||
for (sent_tuples, brackets) in paragraph_tuples], paragraph_tuples
|
||||
|
||||
|
@ -267,9 +261,9 @@ class GoldCorpus(object):
|
|||
if len(docs) != len(paragraph_tuples):
|
||||
n_annots = len(paragraph_tuples)
|
||||
raise ValueError(Errors.E070.format(n_docs=len(docs), n_annots=n_annots))
|
||||
return [GoldParse.from_annot_tuples(doc, sent_tuples,
|
||||
return [GoldParse.from_annot_tuples(doc, sent_tuples, cats=cats,
|
||||
make_projective=make_projective)
|
||||
for doc, (sent_tuples, brackets)
|
||||
for doc, (sent_tuples, (cats, brackets))
|
||||
in zip(docs, paragraph_tuples)]
|
||||
|
||||
|
||||
|
@ -285,7 +279,7 @@ def make_orth_variants(nlp, raw, paragraph_tuples, orth_variant_level=0.0):
|
|||
# modify words in paragraph_tuples
|
||||
variant_paragraph_tuples = []
|
||||
for sent_tuples, brackets in paragraph_tuples:
|
||||
ids, words, tags, heads, labels, ner, cats = sent_tuples
|
||||
ids, words, tags, heads, labels, ner = sent_tuples
|
||||
if lower:
|
||||
words = [w.lower() for w in words]
|
||||
# single variants
|
||||
|
@ -314,7 +308,7 @@ def make_orth_variants(nlp, raw, paragraph_tuples, orth_variant_level=0.0):
|
|||
pair_idx = pair.index(words[word_idx])
|
||||
words[word_idx] = punct_choices[punct_idx][pair_idx]
|
||||
|
||||
variant_paragraph_tuples.append(((ids, words, tags, heads, labels, ner, cats), brackets))
|
||||
variant_paragraph_tuples.append(((ids, words, tags, heads, labels, ner), brackets))
|
||||
# modify raw to match variant_paragraph_tuples
|
||||
if raw is not None:
|
||||
variants = []
|
||||
|
@ -333,7 +327,7 @@ def make_orth_variants(nlp, raw, paragraph_tuples, orth_variant_level=0.0):
|
|||
variant_raw += raw[raw_idx]
|
||||
raw_idx += 1
|
||||
for sent_tuples, brackets in variant_paragraph_tuples:
|
||||
ids, words, tags, heads, labels, ner, cats = sent_tuples
|
||||
ids, words, tags, heads, labels, ner = sent_tuples
|
||||
for word in words:
|
||||
match_found = False
|
||||
# add identical word
|
||||
|
@ -404,6 +398,9 @@ def json_to_tuple(doc):
|
|||
paragraphs = []
|
||||
for paragraph in doc["paragraphs"]:
|
||||
sents = []
|
||||
cats = {}
|
||||
for cat in paragraph.get("cats", {}):
|
||||
cats[cat["label"]] = cat["value"]
|
||||
for sent in paragraph["sentences"]:
|
||||
words = []
|
||||
ids = []
|
||||
|
@ -423,11 +420,7 @@ def json_to_tuple(doc):
|
|||
ner.append(token.get("ner", "-"))
|
||||
sents.append([
|
||||
[ids, words, tags, heads, labels, ner],
|
||||
sent.get("brackets", [])])
|
||||
cats = {}
|
||||
for cat in paragraph.get("cats", {}):
|
||||
cats[cat["label"]] = cat["value"]
|
||||
sents.append(cats)
|
||||
[cats, sent.get("brackets", [])]])
|
||||
if sents:
|
||||
yield [paragraph.get("raw", None), sents]
|
||||
|
||||
|
@ -540,8 +533,8 @@ cdef class GoldParse:
|
|||
DOCS: https://spacy.io/api/goldparse
|
||||
"""
|
||||
@classmethod
|
||||
def from_annot_tuples(cls, doc, annot_tuples, make_projective=False):
|
||||
_, words, tags, heads, deps, entities, cats = annot_tuples
|
||||
def from_annot_tuples(cls, doc, annot_tuples, cats=None, make_projective=False):
|
||||
_, words, tags, heads, deps, entities = annot_tuples
|
||||
return cls(doc, words=words, tags=tags, heads=heads, deps=deps,
|
||||
entities=entities, cats=cats,
|
||||
make_projective=make_projective)
|
||||
|
|
Loading…
Reference in New Issue
Block a user