mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Fix --gold-preproc train cli command (#4392)
* Fix get labels for textcat * Fix char_embed for gpu * Revert "Fix char_embed for gpu" This reverts commit055b9a9e85. * Fix passing of cats in gold.pyx * Revert "Match pop with append for training format (#4516)" This reverts commit8e7414dace. * Fix popping gold parses * Fix handling of cats in gold tuples * Fix name * Fix ner_multitask_objective script * Add test for 4402
This commit is contained in:
		
							parent
							
								
									8e7414dace
								
							
						
					
					
						commit
						f8d740bfb1
					
				| 
						 | 
					@ -18,21 +18,21 @@ during training. We discard the auxiliary model before run-time.
 | 
				
			||||||
The specific example here is not necessarily a good idea --- but it shows
 | 
					The specific example here is not necessarily a good idea --- but it shows
 | 
				
			||||||
how an arbitrary objective function for some word can be used.
 | 
					how an arbitrary objective function for some word can be used.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Developed for spaCy 2.0.6 and last tested for 2.2.2
 | 
					Developed and tested for spaCy 2.0.6. Updated for v2.2.2
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
import random
 | 
					import random
 | 
				
			||||||
import plac
 | 
					import plac
 | 
				
			||||||
import spacy
 | 
					import spacy
 | 
				
			||||||
import os.path
 | 
					import os.path
 | 
				
			||||||
from spacy.gold import read_json_file, GoldParse
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from spacy.tokens import Doc
 | 
					from spacy.tokens import Doc
 | 
				
			||||||
 | 
					from spacy.gold import read_json_file, GoldParse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
random.seed(0)
 | 
					random.seed(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
PWD = os.path.dirname(__file__)
 | 
					PWD = os.path.dirname(__file__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TRAIN_DATA = list(read_json_file(os.path.join(PWD, "training-data.json")))
 | 
					TRAIN_DATA = list(read_json_file(
 | 
				
			||||||
 | 
					    os.path.join(PWD, "ner_example_data", "ner-sent-per-line.json")))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_position_label(i, words, tags, heads, labels, ents):
 | 
					def get_position_label(i, words, tags, heads, labels, ents):
 | 
				
			||||||
| 
						 | 
					@ -57,22 +57,17 @@ def main(n_iter=10):
 | 
				
			||||||
    ner = nlp.create_pipe("ner")
 | 
					    ner = nlp.create_pipe("ner")
 | 
				
			||||||
    ner.add_multitask_objective(get_position_label)
 | 
					    ner.add_multitask_objective(get_position_label)
 | 
				
			||||||
    nlp.add_pipe(ner)
 | 
					    nlp.add_pipe(ner)
 | 
				
			||||||
 | 
					    print(nlp.pipeline)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    _, sents = TRAIN_DATA[0]
 | 
					    print("Create data", len(TRAIN_DATA))
 | 
				
			||||||
    print("Create data, # of sentences =", len(sents) - 1) # not counting the cats attribute
 | 
					 | 
				
			||||||
    optimizer = nlp.begin_training(get_gold_tuples=lambda: TRAIN_DATA)
 | 
					    optimizer = nlp.begin_training(get_gold_tuples=lambda: TRAIN_DATA)
 | 
				
			||||||
    for itn in range(n_iter):
 | 
					    for itn in range(n_iter):
 | 
				
			||||||
        random.shuffle(TRAIN_DATA)
 | 
					        random.shuffle(TRAIN_DATA)
 | 
				
			||||||
        losses = {}
 | 
					        losses = {}
 | 
				
			||||||
 | 
					        for text, annot_brackets in TRAIN_DATA:
 | 
				
			||||||
        for raw_text, annots_brackets in TRAIN_DATA:
 | 
					            for annotations, _ in annot_brackets:
 | 
				
			||||||
            cats = annots_brackets.pop()
 | 
					 | 
				
			||||||
            for annotations, _ in annots_brackets:
 | 
					 | 
				
			||||||
                annotations.append(cats)  # temporarily add it here for from_annot_tuples to work
 | 
					 | 
				
			||||||
                doc = Doc(nlp.vocab, words=annotations[1])
 | 
					                doc = Doc(nlp.vocab, words=annotations[1])
 | 
				
			||||||
                gold = GoldParse.from_annot_tuples(doc, annotations)
 | 
					                gold = GoldParse.from_annot_tuples(doc, annotations)
 | 
				
			||||||
                annotations.pop()  # restore data
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                nlp.update(
 | 
					                nlp.update(
 | 
				
			||||||
                    [doc],  # batch of texts
 | 
					                    [doc],  # batch of texts
 | 
				
			||||||
                    [gold],  # batch of annotations
 | 
					                    [gold],  # batch of annotations
 | 
				
			||||||
| 
						 | 
					@ -80,11 +75,11 @@ def main(n_iter=10):
 | 
				
			||||||
                    sgd=optimizer,  # callable to update weights
 | 
					                    sgd=optimizer,  # callable to update weights
 | 
				
			||||||
                    losses=losses,
 | 
					                    losses=losses,
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
            annots_brackets.append(cats)  # restore data
 | 
					 | 
				
			||||||
        print(losses.get("nn_labeller", 0.0), losses["ner"])
 | 
					        print(losses.get("nn_labeller", 0.0), losses["ner"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # test the trained model
 | 
					    # test the trained model
 | 
				
			||||||
    for text, _ in TRAIN_DATA:
 | 
					    for text, _ in TRAIN_DATA:
 | 
				
			||||||
 | 
					        if text is not None:
 | 
				
			||||||
            doc = nlp(text)
 | 
					            doc = nlp(text)
 | 
				
			||||||
            print("Entities", [(ent.text, ent.label_) for ent in doc.ents])
 | 
					            print("Entities", [(ent.text, ent.label_) for ent in doc.ents])
 | 
				
			||||||
            print("Tokens", [(t.text, t.ent_type_, t.ent_iob) for t in doc])
 | 
					            print("Tokens", [(t.text, t.ent_type_, t.ent_iob) for t in doc])
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -55,22 +55,22 @@ def tags_to_entities(tags):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def merge_sents(sents):
 | 
					def merge_sents(sents):
 | 
				
			||||||
    m_sents = [[], [], [], [], [], []]
 | 
					    m_deps = [[], [], [], [], [], []]
 | 
				
			||||||
 | 
					    m_cats = {}
 | 
				
			||||||
    m_brackets = []
 | 
					    m_brackets = []
 | 
				
			||||||
    m_cats = sents.pop()
 | 
					 | 
				
			||||||
    i = 0
 | 
					    i = 0
 | 
				
			||||||
    for (ids, words, tags, heads, labels, ner), brackets in sents:
 | 
					    for (ids, words, tags, heads, labels, ner), (cats, brackets) in sents:
 | 
				
			||||||
        m_sents[0].extend(id_ + i for id_ in ids)
 | 
					        m_deps[0].extend(id_ + i for id_ in ids)
 | 
				
			||||||
        m_sents[1].extend(words)
 | 
					        m_deps[1].extend(words)
 | 
				
			||||||
        m_sents[2].extend(tags)
 | 
					        m_deps[2].extend(tags)
 | 
				
			||||||
        m_sents[3].extend(head + i for head in heads)
 | 
					        m_deps[3].extend(head + i for head in heads)
 | 
				
			||||||
        m_sents[4].extend(labels)
 | 
					        m_deps[4].extend(labels)
 | 
				
			||||||
        m_sents[5].extend(ner)
 | 
					        m_deps[5].extend(ner)
 | 
				
			||||||
        m_brackets.extend((b["first"] + i, b["last"] + i, b["label"])
 | 
					        m_brackets.extend((b["first"] + i, b["last"] + i, b["label"])
 | 
				
			||||||
                          for b in brackets)
 | 
					                          for b in brackets)
 | 
				
			||||||
 | 
					        m_cats.update(cats)
 | 
				
			||||||
        i += len(ids)
 | 
					        i += len(ids)
 | 
				
			||||||
    sents.append(m_cats)  # restore original data
 | 
					    return [(m_deps, (m_cats, m_brackets))]
 | 
				
			||||||
    return [[(m_sents, m_brackets)], m_cats]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
_NORM_MAP = {"``": '"', "''": '"'}
 | 
					_NORM_MAP = {"``": '"', "''": '"'}
 | 
				
			||||||
| 
						 | 
					@ -242,13 +242,11 @@ class GoldCorpus(object):
 | 
				
			||||||
        n = 0
 | 
					        n = 0
 | 
				
			||||||
        i = 0
 | 
					        i = 0
 | 
				
			||||||
        for raw_text, paragraph_tuples in self.train_tuples:
 | 
					        for raw_text, paragraph_tuples in self.train_tuples:
 | 
				
			||||||
            cats = paragraph_tuples.pop()
 | 
					 | 
				
			||||||
            for sent_tuples, brackets in paragraph_tuples:
 | 
					            for sent_tuples, brackets in paragraph_tuples:
 | 
				
			||||||
                n += len(sent_tuples[1])
 | 
					                n += len(sent_tuples[1])
 | 
				
			||||||
                if self.limit and i >= self.limit:
 | 
					                if self.limit and i >= self.limit:
 | 
				
			||||||
                    break
 | 
					                    break
 | 
				
			||||||
                i += 1
 | 
					                i += 1
 | 
				
			||||||
            paragraph_tuples.append(cats) # restore original data
 | 
					 | 
				
			||||||
        return n
 | 
					        return n
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def train_docs(self, nlp, gold_preproc=False, max_length=None,
 | 
					    def train_docs(self, nlp, gold_preproc=False, max_length=None,
 | 
				
			||||||
| 
						 | 
					@ -289,36 +287,26 @@ class GoldCorpus(object):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def _make_docs(cls, nlp, raw_text, paragraph_tuples, gold_preproc, noise_level=0.0, orth_variant_level=0.0):
 | 
					    def _make_docs(cls, nlp, raw_text, paragraph_tuples, gold_preproc, noise_level=0.0, orth_variant_level=0.0):
 | 
				
			||||||
        cats = paragraph_tuples.pop()
 | 
					 | 
				
			||||||
        if raw_text is not None:
 | 
					        if raw_text is not None:
 | 
				
			||||||
            raw_text, paragraph_tuples = make_orth_variants(nlp, raw_text, paragraph_tuples, orth_variant_level=orth_variant_level)
 | 
					            raw_text, paragraph_tuples = make_orth_variants(nlp, raw_text, paragraph_tuples, orth_variant_level=orth_variant_level)
 | 
				
			||||||
            raw_text = add_noise(raw_text, noise_level)
 | 
					            raw_text = add_noise(raw_text, noise_level)
 | 
				
			||||||
            result = [nlp.make_doc(raw_text)], paragraph_tuples
 | 
					            return [nlp.make_doc(raw_text)], paragraph_tuples
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            docs = []
 | 
					            docs = []
 | 
				
			||||||
            raw_text, paragraph_tuples = make_orth_variants(nlp, None, paragraph_tuples, orth_variant_level=orth_variant_level)
 | 
					            raw_text, paragraph_tuples = make_orth_variants(nlp, None, paragraph_tuples, orth_variant_level=orth_variant_level)
 | 
				
			||||||
            result = [Doc(nlp.vocab, words=add_noise(sent_tuples[1], noise_level))
 | 
					            return [Doc(nlp.vocab, words=add_noise(sent_tuples[1], noise_level))
 | 
				
			||||||
                    for (sent_tuples, brackets) in paragraph_tuples], paragraph_tuples
 | 
					                    for (sent_tuples, brackets) in paragraph_tuples], paragraph_tuples
 | 
				
			||||||
        paragraph_tuples.append(cats)
 | 
					 | 
				
			||||||
        return result
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def _make_golds(cls, docs, paragraph_tuples, make_projective):
 | 
					    def _make_golds(cls, docs, paragraph_tuples, make_projective):
 | 
				
			||||||
        cats = paragraph_tuples.pop()
 | 
					 | 
				
			||||||
        if len(docs) != len(paragraph_tuples):
 | 
					        if len(docs) != len(paragraph_tuples):
 | 
				
			||||||
            n_annots = len(paragraph_tuples)
 | 
					            n_annots = len(paragraph_tuples)
 | 
				
			||||||
            raise ValueError(Errors.E070.format(n_docs=len(docs), n_annots=n_annots))
 | 
					            raise ValueError(Errors.E070.format(n_docs=len(docs), n_annots=n_annots))
 | 
				
			||||||
        result = []
 | 
					        return [GoldParse.from_annot_tuples(doc, sent_tuples, cats=cats,
 | 
				
			||||||
        for doc, brack_annot in zip(docs, paragraph_tuples):
 | 
					                                                make_projective=make_projective)
 | 
				
			||||||
            if len(brack_annot) == 1:
 | 
					                    for doc, (sent_tuples, (cats, brackets))
 | 
				
			||||||
                brack_annot = brack_annot[0]
 | 
					                    in zip(docs, paragraph_tuples)]
 | 
				
			||||||
            sent_tuples, brackets = brack_annot
 | 
					 | 
				
			||||||
            sent_tuples.append(cats)
 | 
					 | 
				
			||||||
            result.append(GoldParse.from_annot_tuples(doc, sent_tuples, make_projective=make_projective))
 | 
					 | 
				
			||||||
            sent_tuples.pop()
 | 
					 | 
				
			||||||
        paragraph_tuples.append(cats)
 | 
					 | 
				
			||||||
        return result
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def make_orth_variants(nlp, raw, paragraph_tuples, orth_variant_level=0.0):
 | 
					def make_orth_variants(nlp, raw, paragraph_tuples, orth_variant_level=0.0):
 | 
				
			||||||
| 
						 | 
					@ -333,7 +321,7 @@ def make_orth_variants(nlp, raw, paragraph_tuples, orth_variant_level=0.0):
 | 
				
			||||||
    # modify words in paragraph_tuples
 | 
					    # modify words in paragraph_tuples
 | 
				
			||||||
    variant_paragraph_tuples = []
 | 
					    variant_paragraph_tuples = []
 | 
				
			||||||
    for sent_tuples, brackets in paragraph_tuples:
 | 
					    for sent_tuples, brackets in paragraph_tuples:
 | 
				
			||||||
        ids, words, tags, heads, labels, ner, cats = sent_tuples
 | 
					        ids, words, tags, heads, labels, ner = sent_tuples
 | 
				
			||||||
        if lower:
 | 
					        if lower:
 | 
				
			||||||
            words = [w.lower() for w in words]
 | 
					            words = [w.lower() for w in words]
 | 
				
			||||||
        # single variants
 | 
					        # single variants
 | 
				
			||||||
| 
						 | 
					@ -362,7 +350,7 @@ def make_orth_variants(nlp, raw, paragraph_tuples, orth_variant_level=0.0):
 | 
				
			||||||
                                pair_idx = pair.index(words[word_idx])
 | 
					                                pair_idx = pair.index(words[word_idx])
 | 
				
			||||||
                    words[word_idx] = punct_choices[punct_idx][pair_idx]
 | 
					                    words[word_idx] = punct_choices[punct_idx][pair_idx]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        variant_paragraph_tuples.append(((ids, words, tags, heads, labels, ner, cats), brackets))
 | 
					        variant_paragraph_tuples.append(((ids, words, tags, heads, labels, ner), brackets))
 | 
				
			||||||
    # modify raw to match variant_paragraph_tuples
 | 
					    # modify raw to match variant_paragraph_tuples
 | 
				
			||||||
    if raw is not None:
 | 
					    if raw is not None:
 | 
				
			||||||
        variants = []
 | 
					        variants = []
 | 
				
			||||||
| 
						 | 
					@ -381,7 +369,7 @@ def make_orth_variants(nlp, raw, paragraph_tuples, orth_variant_level=0.0):
 | 
				
			||||||
            variant_raw += raw[raw_idx]
 | 
					            variant_raw += raw[raw_idx]
 | 
				
			||||||
            raw_idx += 1
 | 
					            raw_idx += 1
 | 
				
			||||||
        for sent_tuples, brackets in variant_paragraph_tuples:
 | 
					        for sent_tuples, brackets in variant_paragraph_tuples:
 | 
				
			||||||
            ids, words, tags, heads, labels, ner, cats = sent_tuples
 | 
					            ids, words, tags, heads, labels, ner = sent_tuples
 | 
				
			||||||
            for word in words:
 | 
					            for word in words:
 | 
				
			||||||
                match_found = False
 | 
					                match_found = False
 | 
				
			||||||
                # add identical word
 | 
					                # add identical word
 | 
				
			||||||
| 
						 | 
					@ -452,6 +440,9 @@ def json_to_tuple(doc):
 | 
				
			||||||
    paragraphs = []
 | 
					    paragraphs = []
 | 
				
			||||||
    for paragraph in doc["paragraphs"]:
 | 
					    for paragraph in doc["paragraphs"]:
 | 
				
			||||||
        sents = []
 | 
					        sents = []
 | 
				
			||||||
 | 
					        cats = {}
 | 
				
			||||||
 | 
					        for cat in paragraph.get("cats", {}):
 | 
				
			||||||
 | 
					            cats[cat["label"]] = cat["value"]
 | 
				
			||||||
        for sent in paragraph["sentences"]:
 | 
					        for sent in paragraph["sentences"]:
 | 
				
			||||||
            words = []
 | 
					            words = []
 | 
				
			||||||
            ids = []
 | 
					            ids = []
 | 
				
			||||||
| 
						 | 
					@ -471,11 +462,7 @@ def json_to_tuple(doc):
 | 
				
			||||||
                ner.append(token.get("ner", "-"))
 | 
					                ner.append(token.get("ner", "-"))
 | 
				
			||||||
            sents.append([
 | 
					            sents.append([
 | 
				
			||||||
                [ids, words, tags, heads, labels, ner],
 | 
					                [ids, words, tags, heads, labels, ner],
 | 
				
			||||||
                sent.get("brackets", [])])
 | 
					                [cats, sent.get("brackets", [])]])
 | 
				
			||||||
        cats = {}
 | 
					 | 
				
			||||||
        for cat in paragraph.get("cats", {}):
 | 
					 | 
				
			||||||
            cats[cat["label"]] = cat["value"]
 | 
					 | 
				
			||||||
        sents.append(cats)
 | 
					 | 
				
			||||||
        if sents:
 | 
					        if sents:
 | 
				
			||||||
            yield [paragraph.get("raw", None), sents]
 | 
					            yield [paragraph.get("raw", None), sents]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -589,8 +576,8 @@ cdef class GoldParse:
 | 
				
			||||||
    DOCS: https://spacy.io/api/goldparse
 | 
					    DOCS: https://spacy.io/api/goldparse
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def from_annot_tuples(cls, doc, annot_tuples, make_projective=False):
 | 
					    def from_annot_tuples(cls, doc, annot_tuples, cats=None, make_projective=False):
 | 
				
			||||||
        _, words, tags, heads, deps, entities, cats = annot_tuples
 | 
					        _, words, tags, heads, deps, entities = annot_tuples
 | 
				
			||||||
        return cls(doc, words=words, tags=tags, heads=heads, deps=deps,
 | 
					        return cls(doc, words=words, tags=tags, heads=heads, deps=deps,
 | 
				
			||||||
                   entities=entities, cats=cats,
 | 
					                   entities=entities, cats=cats,
 | 
				
			||||||
                   make_projective=make_projective)
 | 
					                   make_projective=make_projective)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -598,11 +598,10 @@ class Language(object):
 | 
				
			||||||
        # Populate vocab
 | 
					        # Populate vocab
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            for _, annots_brackets in get_gold_tuples():
 | 
					            for _, annots_brackets in get_gold_tuples():
 | 
				
			||||||
                cats = annots_brackets.pop()
 | 
					                _ = annots_brackets.pop()
 | 
				
			||||||
                for annots, _ in annots_brackets:
 | 
					                for annots, _ in annots_brackets:
 | 
				
			||||||
                    for word in annots[1]:
 | 
					                    for word in annots[1]:
 | 
				
			||||||
                        _ = self.vocab[word]  # noqa: F841
 | 
					                        _ = self.vocab[word]  # noqa: F841
 | 
				
			||||||
                annots_brackets.append(cats)  # restore original data
 | 
					 | 
				
			||||||
        if cfg.get("device", -1) >= 0:
 | 
					        if cfg.get("device", -1) >= 0:
 | 
				
			||||||
            util.use_gpu(cfg["device"])
 | 
					            util.use_gpu(cfg["device"])
 | 
				
			||||||
            if self.vocab.vectors.data.shape[1] >= 1:
 | 
					            if self.vocab.vectors.data.shape[1] >= 1:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -517,7 +517,6 @@ class Tagger(Pipe):
 | 
				
			||||||
        orig_tag_map = dict(self.vocab.morphology.tag_map)
 | 
					        orig_tag_map = dict(self.vocab.morphology.tag_map)
 | 
				
			||||||
        new_tag_map = OrderedDict()
 | 
					        new_tag_map = OrderedDict()
 | 
				
			||||||
        for raw_text, annots_brackets in get_gold_tuples():
 | 
					        for raw_text, annots_brackets in get_gold_tuples():
 | 
				
			||||||
            cats = annots_brackets.pop()
 | 
					 | 
				
			||||||
            for annots, brackets in annots_brackets:
 | 
					            for annots, brackets in annots_brackets:
 | 
				
			||||||
                ids, words, tags, heads, deps, ents = annots
 | 
					                ids, words, tags, heads, deps, ents = annots
 | 
				
			||||||
                for tag in tags:
 | 
					                for tag in tags:
 | 
				
			||||||
| 
						 | 
					@ -525,7 +524,6 @@ class Tagger(Pipe):
 | 
				
			||||||
                        new_tag_map[tag] = orig_tag_map[tag]
 | 
					                        new_tag_map[tag] = orig_tag_map[tag]
 | 
				
			||||||
                    else:
 | 
					                    else:
 | 
				
			||||||
                        new_tag_map[tag] = {POS: X}
 | 
					                        new_tag_map[tag] = {POS: X}
 | 
				
			||||||
            annots_brackets.append(cats)  # restore original data
 | 
					 | 
				
			||||||
        cdef Vocab vocab = self.vocab
 | 
					        cdef Vocab vocab = self.vocab
 | 
				
			||||||
        if new_tag_map:
 | 
					        if new_tag_map:
 | 
				
			||||||
            vocab.morphology = Morphology(vocab.strings, new_tag_map,
 | 
					            vocab.morphology = Morphology(vocab.strings, new_tag_map,
 | 
				
			||||||
| 
						 | 
					@ -704,14 +702,12 @@ class MultitaskObjective(Tagger):
 | 
				
			||||||
                       sgd=None, **kwargs):
 | 
					                       sgd=None, **kwargs):
 | 
				
			||||||
        gold_tuples = nonproj.preprocess_training_data(get_gold_tuples())
 | 
					        gold_tuples = nonproj.preprocess_training_data(get_gold_tuples())
 | 
				
			||||||
        for raw_text, annots_brackets in gold_tuples:
 | 
					        for raw_text, annots_brackets in gold_tuples:
 | 
				
			||||||
            cats = annots_brackets.pop()
 | 
					 | 
				
			||||||
            for annots, brackets in annots_brackets:
 | 
					            for annots, brackets in annots_brackets:
 | 
				
			||||||
                ids, words, tags, heads, deps, ents = annots
 | 
					                ids, words, tags, heads, deps, ents = annots
 | 
				
			||||||
                for i in range(len(ids)):
 | 
					                for i in range(len(ids)):
 | 
				
			||||||
                    label = self.make_label(i, words, tags, heads, deps, ents)
 | 
					                    label = self.make_label(i, words, tags, heads, deps, ents)
 | 
				
			||||||
                    if label is not None and label not in self.labels:
 | 
					                    if label is not None and label not in self.labels:
 | 
				
			||||||
                        self.labels[label] = len(self.labels)
 | 
					                        self.labels[label] = len(self.labels)
 | 
				
			||||||
            annots_brackets.append(cats)
 | 
					 | 
				
			||||||
        if self.model is True:
 | 
					        if self.model is True:
 | 
				
			||||||
            token_vector_width = util.env_opt("token_vector_width")
 | 
					            token_vector_width = util.env_opt("token_vector_width")
 | 
				
			||||||
            self.model = self.Model(len(self.labels), tok2vec=tok2vec)
 | 
					            self.model = self.Model(len(self.labels), tok2vec=tok2vec)
 | 
				
			||||||
| 
						 | 
					@ -1037,8 +1033,8 @@ class TextCategorizer(Pipe):
 | 
				
			||||||
        return 1
 | 
					        return 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None, **kwargs):
 | 
					    def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None, **kwargs):
 | 
				
			||||||
        for raw_text, annots_brackets in get_gold_tuples():
 | 
					        for raw_text, annot_brackets in get_gold_tuples():
 | 
				
			||||||
            cats = annots_brackets[-1]
 | 
					            for _, (cats, _2) in annot_brackets: 
 | 
				
			||||||
                for cat in cats:
 | 
					                for cat in cats:
 | 
				
			||||||
                    self.add_label(cat)
 | 
					                    self.add_label(cat)
 | 
				
			||||||
        if self.model is True:
 | 
					        if self.model is True:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -342,7 +342,6 @@ cdef class ArcEager(TransitionSystem):
 | 
				
			||||||
            actions[RIGHT][label] = 1
 | 
					            actions[RIGHT][label] = 1
 | 
				
			||||||
            actions[REDUCE][label] = 1
 | 
					            actions[REDUCE][label] = 1
 | 
				
			||||||
        for raw_text, sents in kwargs.get('gold_parses', []):
 | 
					        for raw_text, sents in kwargs.get('gold_parses', []):
 | 
				
			||||||
            cats = sents.pop()
 | 
					 | 
				
			||||||
            for (ids, words, tags, heads, labels, iob), ctnts in sents:
 | 
					            for (ids, words, tags, heads, labels, iob), ctnts in sents:
 | 
				
			||||||
                heads, labels = nonproj.projectivize(heads, labels)
 | 
					                heads, labels = nonproj.projectivize(heads, labels)
 | 
				
			||||||
                for child, head, label in zip(ids, heads, labels):
 | 
					                for child, head, label in zip(ids, heads, labels):
 | 
				
			||||||
| 
						 | 
					@ -356,7 +355,6 @@ cdef class ArcEager(TransitionSystem):
 | 
				
			||||||
                    elif head > child:
 | 
					                    elif head > child:
 | 
				
			||||||
                        actions[LEFT][label] += 1
 | 
					                        actions[LEFT][label] += 1
 | 
				
			||||||
                        actions[SHIFT][''] += 1
 | 
					                        actions[SHIFT][''] += 1
 | 
				
			||||||
            sents.append(cats)  # restore original data
 | 
					 | 
				
			||||||
        if min_freq is not None:
 | 
					        if min_freq is not None:
 | 
				
			||||||
            for action, label_freqs in actions.items():
 | 
					            for action, label_freqs in actions.items():
 | 
				
			||||||
                for label, freq in list(label_freqs.items()):
 | 
					                for label, freq in list(label_freqs.items()):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -73,14 +73,12 @@ cdef class BiluoPushDown(TransitionSystem):
 | 
				
			||||||
                actions[action][entity_type] = 1
 | 
					                actions[action][entity_type] = 1
 | 
				
			||||||
        moves = ('M', 'B', 'I', 'L', 'U')
 | 
					        moves = ('M', 'B', 'I', 'L', 'U')
 | 
				
			||||||
        for raw_text, sents in kwargs.get('gold_parses', []):
 | 
					        for raw_text, sents in kwargs.get('gold_parses', []):
 | 
				
			||||||
            cats = sents.pop()
 | 
					 | 
				
			||||||
            for (ids, words, tags, heads, labels, biluo), _ in sents:
 | 
					            for (ids, words, tags, heads, labels, biluo), _ in sents:
 | 
				
			||||||
                for i, ner_tag in enumerate(biluo):
 | 
					                for i, ner_tag in enumerate(biluo):
 | 
				
			||||||
                    if ner_tag != 'O' and ner_tag != '-':
 | 
					                    if ner_tag != 'O' and ner_tag != '-':
 | 
				
			||||||
                        _, label = ner_tag.split('-', 1)
 | 
					                        _, label = ner_tag.split('-', 1)
 | 
				
			||||||
                        for action in (BEGIN, IN, LAST, UNIT):
 | 
					                        for action in (BEGIN, IN, LAST, UNIT):
 | 
				
			||||||
                            actions[action][label] += 1
 | 
					                            actions[action][label] += 1
 | 
				
			||||||
            sents.append(cats)  # restore original data
 | 
					 | 
				
			||||||
        return actions
 | 
					        return actions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -606,13 +606,11 @@ cdef class Parser:
 | 
				
			||||||
            doc_sample = []
 | 
					            doc_sample = []
 | 
				
			||||||
            gold_sample = []
 | 
					            gold_sample = []
 | 
				
			||||||
            for raw_text, annots_brackets in islice(get_gold_tuples(), 1000):
 | 
					            for raw_text, annots_brackets in islice(get_gold_tuples(), 1000):
 | 
				
			||||||
                cats = annots_brackets.pop()
 | 
					 | 
				
			||||||
                for annots, brackets in annots_brackets:
 | 
					                for annots, brackets in annots_brackets:
 | 
				
			||||||
                    ids, words, tags, heads, deps, ents = annots
 | 
					                    ids, words, tags, heads, deps, ents = annots
 | 
				
			||||||
                    doc_sample.append(Doc(self.vocab, words=words))
 | 
					                    doc_sample.append(Doc(self.vocab, words=words))
 | 
				
			||||||
                    gold_sample.append(GoldParse(doc_sample[-1], words=words, tags=tags,
 | 
					                    gold_sample.append(GoldParse(doc_sample[-1], words=words, tags=tags,
 | 
				
			||||||
                                                 heads=heads, deps=deps, entities=ents))
 | 
					                                                 heads=heads, deps=deps, entities=ents))
 | 
				
			||||||
                annots_brackets.append(cats)  # restore original data
 | 
					 | 
				
			||||||
            self.model.begin_training(doc_sample, gold_sample)
 | 
					            self.model.begin_training(doc_sample, gold_sample)
 | 
				
			||||||
            if pipeline is not None:
 | 
					            if pipeline is not None:
 | 
				
			||||||
                self.init_multitask_objectives(get_gold_tuples, pipeline, sgd=sgd, **cfg)
 | 
					                self.init_multitask_objectives(get_gold_tuples, pipeline, sgd=sgd, **cfg)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -97,7 +97,6 @@ def preprocess_training_data(gold_tuples, label_freq_cutoff=30):
 | 
				
			||||||
    freqs = {}
 | 
					    freqs = {}
 | 
				
			||||||
    for raw_text, sents in gold_tuples:
 | 
					    for raw_text, sents in gold_tuples:
 | 
				
			||||||
        prepro_sents = []
 | 
					        prepro_sents = []
 | 
				
			||||||
        cats = sents.pop()
 | 
					 | 
				
			||||||
        for (ids, words, tags, heads, labels, iob), ctnts in sents:
 | 
					        for (ids, words, tags, heads, labels, iob), ctnts in sents:
 | 
				
			||||||
            proj_heads, deco_labels = projectivize(heads, labels)
 | 
					            proj_heads, deco_labels = projectivize(heads, labels)
 | 
				
			||||||
            # set the label to ROOT for each root dependent
 | 
					            # set the label to ROOT for each root dependent
 | 
				
			||||||
| 
						 | 
					@ -110,8 +109,6 @@ def preprocess_training_data(gold_tuples, label_freq_cutoff=30):
 | 
				
			||||||
                        freqs[label] = freqs.get(label, 0) + 1
 | 
					                        freqs[label] = freqs.get(label, 0) + 1
 | 
				
			||||||
            prepro_sents.append(
 | 
					            prepro_sents.append(
 | 
				
			||||||
                ((ids, words, tags, proj_heads, deco_labels, iob), ctnts))
 | 
					                ((ids, words, tags, proj_heads, deco_labels, iob), ctnts))
 | 
				
			||||||
        sents.append(cats)
 | 
					 | 
				
			||||||
        prepro_sents.append(cats)
 | 
					 | 
				
			||||||
        preprocessed.append((raw_text, prepro_sents))
 | 
					        preprocessed.append((raw_text, prepro_sents))
 | 
				
			||||||
    if label_freq_cutoff > 0:
 | 
					    if label_freq_cutoff > 0:
 | 
				
			||||||
        return _filter_labels(preprocessed, label_freq_cutoff, freqs)
 | 
					        return _filter_labels(preprocessed, label_freq_cutoff, freqs)
 | 
				
			||||||
| 
						 | 
					@ -212,7 +209,6 @@ def _filter_labels(gold_tuples, cutoff, freqs):
 | 
				
			||||||
    filtered = []
 | 
					    filtered = []
 | 
				
			||||||
    for raw_text, sents in gold_tuples:
 | 
					    for raw_text, sents in gold_tuples:
 | 
				
			||||||
        filtered_sents = []
 | 
					        filtered_sents = []
 | 
				
			||||||
        cats = sents.pop()
 | 
					 | 
				
			||||||
        for (ids, words, tags, heads, labels, iob), ctnts in sents:
 | 
					        for (ids, words, tags, heads, labels, iob), ctnts in sents:
 | 
				
			||||||
            filtered_labels = []
 | 
					            filtered_labels = []
 | 
				
			||||||
            for label in labels:
 | 
					            for label in labels:
 | 
				
			||||||
| 
						 | 
					@ -222,7 +218,5 @@ def _filter_labels(gold_tuples, cutoff, freqs):
 | 
				
			||||||
                    filtered_labels.append(label)
 | 
					                    filtered_labels.append(label)
 | 
				
			||||||
            filtered_sents.append(
 | 
					            filtered_sents.append(
 | 
				
			||||||
                ((ids, words, tags, heads, filtered_labels, iob), ctnts))
 | 
					                ((ids, words, tags, heads, filtered_labels, iob), ctnts))
 | 
				
			||||||
        sents.append(cats)
 | 
					 | 
				
			||||||
        filtered_sents.append(cats)
 | 
					 | 
				
			||||||
        filtered.append((raw_text, filtered_sents))
 | 
					        filtered.append((raw_text, filtered_sents))
 | 
				
			||||||
    return filtered
 | 
					    return filtered
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user