mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	* Add toggle for OrigArcEager system
This commit is contained in:
		
							parent
							
								
									ea8a103007
								
							
						
					
					
						commit
						bcfdf126a4
					
				| 
						 | 
					@ -17,6 +17,7 @@ import spacy.util
 | 
				
			||||||
from spacy.en import English
 | 
					from spacy.en import English
 | 
				
			||||||
from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir
 | 
					from spacy.en.pos import POS_TEMPLATES, POS_TAGS, setup_model_dir
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from spacy.syntax.orig_arc_eager import OrigArcEager
 | 
				
			||||||
from spacy.syntax.util import Config
 | 
					from spacy.syntax.util import Config
 | 
				
			||||||
from spacy.gold import read_json_file
 | 
					from spacy.gold import read_json_file
 | 
				
			||||||
from spacy.gold import GoldParse
 | 
					from spacy.gold import GoldParse
 | 
				
			||||||
| 
						 | 
					@ -78,7 +79,8 @@ def _merge_sents(sents):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
 | 
					def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
 | 
				
			||||||
          seed=0, gold_preproc=False, n_sents=0, corruption_level=0,
 | 
					          seed=0, gold_preproc=False, n_sents=0, corruption_level=0,
 | 
				
			||||||
          beam_width=1, verbose=False):
 | 
					          beam_width=1, verbose=False,
 | 
				
			||||||
 | 
					          use_orig_arc_eager=False):
 | 
				
			||||||
    dep_model_dir = path.join(model_dir, 'deps')
 | 
					    dep_model_dir = path.join(model_dir, 'deps')
 | 
				
			||||||
    pos_model_dir = path.join(model_dir, 'pos')
 | 
					    pos_model_dir = path.join(model_dir, 'pos')
 | 
				
			||||||
    ner_model_dir = path.join(model_dir, 'ner')
 | 
					    ner_model_dir = path.join(model_dir, 'ner')
 | 
				
			||||||
| 
						 | 
					@ -92,6 +94,9 @@ def train(Language, gold_tuples, model_dir, n_iter=15, feat_set=u'basic',
 | 
				
			||||||
    os.mkdir(pos_model_dir)
 | 
					    os.mkdir(pos_model_dir)
 | 
				
			||||||
    os.mkdir(ner_model_dir)
 | 
					    os.mkdir(ner_model_dir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if use_orig_arc_eager:
 | 
				
			||||||
 | 
					        Language.ParserTransitionSystem = OrigArcEager
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES, pos_model_dir)
 | 
					    setup_model_dir(sorted(POS_TAGS.keys()), POS_TAGS, POS_TEMPLATES, pos_model_dir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
 | 
					    Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
 | 
				
			||||||
| 
						 | 
					@ -204,18 +209,20 @@ def write_parses(Language, dev_loc, model_dir, out_loc, beam_width=None):
 | 
				
			||||||
    n_iter=("Number of training iterations", "option", "i", int),
 | 
					    n_iter=("Number of training iterations", "option", "i", int),
 | 
				
			||||||
    beam_width=("Number of candidates to maintain in the beam", "option", "k", int),
 | 
					    beam_width=("Number of candidates to maintain in the beam", "option", "k", int),
 | 
				
			||||||
    verbose=("Verbose error reporting", "flag", "v", bool),
 | 
					    verbose=("Verbose error reporting", "flag", "v", bool),
 | 
				
			||||||
    debug=("Debug mode", "flag", "d", bool)
 | 
					    debug=("Debug mode", "flag", "d", bool),
 | 
				
			||||||
 | 
					    use_orig_arc_eager=("Use the original, monotonic arc-eager system", "flag", "m", bool)
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
def main(train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbose=False,
 | 
					def main(train_loc, dev_loc, model_dir, n_sents=0, n_iter=15, out_loc="", verbose=False,
 | 
				
			||||||
         debug=False, corruption_level=0.0, gold_preproc=False, beam_width=1,
 | 
					         debug=False, corruption_level=0.0, gold_preproc=False, beam_width=1,
 | 
				
			||||||
         eval_only=False):
 | 
					         eval_only=False, use_orig_arc_eager=False):
 | 
				
			||||||
    if not eval_only:
 | 
					    if not eval_only:
 | 
				
			||||||
        gold_train = list(read_json_file(train_loc))
 | 
					        gold_train = list(read_json_file(train_loc))
 | 
				
			||||||
        train(English, gold_train, model_dir,
 | 
					        train(English, gold_train, model_dir,
 | 
				
			||||||
              feat_set='basic' if not debug else 'debug',
 | 
					              feat_set='basic' if not debug else 'debug',
 | 
				
			||||||
              gold_preproc=gold_preproc, n_sents=n_sents,
 | 
					              gold_preproc=gold_preproc, n_sents=n_sents,
 | 
				
			||||||
              corruption_level=corruption_level, n_iter=n_iter,
 | 
					              corruption_level=corruption_level, n_iter=n_iter,
 | 
				
			||||||
              beam_width=beam_width, verbose=verbose)
 | 
					              beam_width=beam_width, verbose=verbose,
 | 
				
			||||||
 | 
					              use_orig_arc_eager=use_orig_arc_eager)
 | 
				
			||||||
    #if out_loc:
 | 
					    #if out_loc:
 | 
				
			||||||
    #    write_parses(English, dev_loc, model_dir, out_loc, beam_width=beam_width)
 | 
					    #    write_parses(English, dev_loc, model_dir, out_loc, beam_width=beam_width)
 | 
				
			||||||
    scorer = evaluate(English, list(read_json_file(dev_loc)),
 | 
					    scorer = evaluate(English, list(read_json_file(dev_loc)),
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										4
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								setup.py
									
									
									
									
									
								
							| 
						 | 
					@ -154,7 +154,9 @@ MOD_NAMES = ['spacy.parts_of_speech', 'spacy.strings',
 | 
				
			||||||
             'spacy._ml', 'spacy.tokenizer', 'spacy.en.attrs',
 | 
					             'spacy._ml', 'spacy.tokenizer', 'spacy.en.attrs',
 | 
				
			||||||
             'spacy.en.pos', 'spacy.syntax.parser', 
 | 
					             'spacy.en.pos', 'spacy.syntax.parser', 
 | 
				
			||||||
             'spacy.syntax.transition_system',
 | 
					             'spacy.syntax.transition_system',
 | 
				
			||||||
             'spacy.syntax.arc_eager', 'spacy.syntax._parse_features',
 | 
					             'spacy.syntax.arc_eager',
 | 
				
			||||||
 | 
					             'spacy.syntax.orig_arc_eager', 
 | 
				
			||||||
 | 
					             'spacy.syntax._parse_features',
 | 
				
			||||||
             'spacy.gold', 'spacy.orth', 
 | 
					             'spacy.gold', 'spacy.orth', 
 | 
				
			||||||
             'spacy.syntax.ner']
 | 
					             'spacy.syntax.ner']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user