mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	* Fix get labels for textcat * Fix char_embed for gpu * Revert "Fix char_embed for gpu" This reverts commit055b9a9e85. * Fix passing of cats in gold.pyx * Revert "Match pop with append for training format (#4516)" This reverts commit8e7414dace. * Fix popping gold parses * Fix handling of cats in gold tuples * Fix name * Fix ner_multitask_objective script * Add test for 4402
		
			
				
	
	
		
			90 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			90 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""This example shows how to add a multi-task objective that is trained
 | 
						|
alongside the entity recognizer. This is an alternative to adding features
 | 
						|
to the model.
 | 
						|
 | 
						|
The multi-task idea is to train an auxiliary model to predict some attribute,
 | 
						|
with weights shared between the auxiliary model and the main model. In this
 | 
						|
example, we're predicting the position of the word in the document.
 | 
						|
 | 
						|
The model that predicts the position of the word encourages the convolutional
 | 
						|
layers to include the position information in their representation. The
 | 
						|
information is then available to the main model, as a feature.
 | 
						|
 | 
						|
The overall idea is that we might know something about what sort of features
 | 
						|
we'd like the CNN to extract. The multi-task objectives can encourage the
 | 
						|
extraction of this type of feature. The multi-task objective is only used
 | 
						|
during training. We discard the auxiliary model before run-time.
 | 
						|
 | 
						|
The specific example here is not necessarily a good idea --- but it shows
 | 
						|
how an arbitrary objective function for some word can be used.
 | 
						|
 | 
						|
Developed and tested for spaCy 2.0.6. Updated for v2.2.2
 | 
						|
"""
 | 
						|
import random
 | 
						|
import plac
 | 
						|
import spacy
 | 
						|
import os.path
 | 
						|
from spacy.tokens import Doc
 | 
						|
from spacy.gold import read_json_file, GoldParse
 | 
						|
 | 
						|
random.seed(0)
 | 
						|
 | 
						|
PWD = os.path.dirname(__file__)
 | 
						|
 | 
						|
TRAIN_DATA = list(read_json_file(
 | 
						|
    os.path.join(PWD, "ner_example_data", "ner-sent-per-line.json")))
 | 
						|
 | 
						|
 | 
						|
def get_position_label(i, words, tags, heads, labels, ents):
 | 
						|
    """Return labels indicating the position of the word in the document.
 | 
						|
    """
 | 
						|
    if len(words) < 20:
 | 
						|
        return "short-doc"
 | 
						|
    elif i == 0:
 | 
						|
        return "first-word"
 | 
						|
    elif i < 10:
 | 
						|
        return "early-word"
 | 
						|
    elif i < 20:
 | 
						|
        return "mid-word"
 | 
						|
    elif i == len(words) - 1:
 | 
						|
        return "last-word"
 | 
						|
    else:
 | 
						|
        return "late-word"
 | 
						|
 | 
						|
 | 
						|
def main(n_iter=10):
 | 
						|
    nlp = spacy.blank("en")
 | 
						|
    ner = nlp.create_pipe("ner")
 | 
						|
    ner.add_multitask_objective(get_position_label)
 | 
						|
    nlp.add_pipe(ner)
 | 
						|
    print(nlp.pipeline)
 | 
						|
 | 
						|
    print("Create data", len(TRAIN_DATA))
 | 
						|
    optimizer = nlp.begin_training(get_gold_tuples=lambda: TRAIN_DATA)
 | 
						|
    for itn in range(n_iter):
 | 
						|
        random.shuffle(TRAIN_DATA)
 | 
						|
        losses = {}
 | 
						|
        for text, annot_brackets in TRAIN_DATA:
 | 
						|
            for annotations, _ in annot_brackets:
 | 
						|
                doc = Doc(nlp.vocab, words=annotations[1])
 | 
						|
                gold = GoldParse.from_annot_tuples(doc, annotations)
 | 
						|
                nlp.update(
 | 
						|
                    [doc],  # batch of texts
 | 
						|
                    [gold],  # batch of annotations
 | 
						|
                    drop=0.2,  # dropout - make it harder to memorise data
 | 
						|
                    sgd=optimizer,  # callable to update weights
 | 
						|
                    losses=losses,
 | 
						|
                )
 | 
						|
        print(losses.get("nn_labeller", 0.0), losses["ner"])
 | 
						|
 | 
						|
    # test the trained model
 | 
						|
    for text, _ in TRAIN_DATA:
 | 
						|
        if text is not None:
 | 
						|
            doc = nlp(text)
 | 
						|
            print("Entities", [(ent.text, ent.label_) for ent in doc.ents])
 | 
						|
            print("Tokens", [(t.text, t.ent_type_, t.ent_iob) for t in doc])
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    plac.call(main)
 |