mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-26 05:31:15 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			76 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			76 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from __future__ import unicode_literals, print_function
 | |
| import json
 | |
| import pathlib
 | |
| import random
 | |
| 
 | |
| import spacy
 | |
| from spacy.pipeline import DependencyParser
 | |
| from spacy.gold import GoldParse
 | |
| from spacy.tokens import Doc
 | |
| 
 | |
| 
 | |
| def train_parser(nlp, train_data, left_labels, right_labels):
 | |
|     parser = DependencyParser(
 | |
|                 nlp.vocab,
 | |
|                 left_labels=left_labels,
 | |
|                 right_labels=right_labels)
 | |
|     for itn in range(1000):
 | |
|         random.shuffle(train_data)
 | |
|         loss = 0
 | |
|         for words, heads, deps in train_data:
 | |
|             doc = Doc(nlp.vocab, words=words)
 | |
|             gold = GoldParse(doc, heads=heads, deps=deps)
 | |
|             loss += parser.update(doc, gold)
 | |
|     parser.model.end_training()
 | |
|     return parser
 | |
| 
 | |
| 
 | |
| def main(model_dir=None):
 | |
|     if model_dir is not None:
 | |
|         model_dir = pathlib.Path(model_dir)
 | |
|         if not model_dir.exists():
 | |
|             model_dir.mkdir()
 | |
|         assert model_dir.is_dir()
 | |
| 
 | |
|     nlp = spacy.load('en', tagger=False, parser=False, entity=False, add_vectors=False)
 | |
| 
 | |
|     train_data = [
 | |
|         (
 | |
|             ['They', 'trade',  'mortgage', '-', 'backed', 'securities', '.'],
 | |
|             [1, 1, 4, 4, 5, 1, 1],
 | |
|             ['nsubj', 'ROOT', 'compound', 'punct', 'nmod', 'dobj', 'punct']
 | |
|         ),
 | |
|         (
 | |
|             ['I', 'like', 'London', 'and', 'Berlin', '.'],
 | |
|             [1, 1, 1, 2, 2, 1],
 | |
|             ['nsubj', 'ROOT', 'dobj', 'cc', 'conj', 'punct']
 | |
|         )
 | |
|     ]
 | |
|     left_labels = set()
 | |
|     right_labels = set()
 | |
|     for _, heads, deps in train_data:
 | |
|         for i, (head, dep) in enumerate(zip(heads, deps)):
 | |
|             if i < head:
 | |
|                 left_labels.add(dep)
 | |
|             elif i > head:
 | |
|                 right_labels.add(dep)
 | |
|     parser = train_parser(nlp, train_data, sorted(left_labels), sorted(right_labels))
 | |
| 
 | |
|     doc = Doc(nlp.vocab, words=['I', 'like', 'securities', '.'])
 | |
|     parser(doc)
 | |
|     for word in doc:
 | |
|         print(word.text, word.dep_, word.head.text)
 | |
| 
 | |
|     if model_dir is not None:
 | |
|         with (model_dir / 'config.json').open('w') as file_:
 | |
|             json.dump(parser.cfg, file_)
 | |
|         parser.model.dump(str(model_dir / 'model'))
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     main()
 | |
|     # I nsubj like
 | |
|     # like ROOT like
 | |
|     # securities dobj like
 | |
|     # . cc securities
 |