mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Fix non-projective label filtering
This commit is contained in:
commit
f57bfbccdc
371
spacy/cli/ud_train.py
Normal file
371
spacy/cli/ud_train.py
Normal file
|
@ -0,0 +1,371 @@
|
||||||
|
'''Train for CONLL 2017 UD treebank evaluation. Takes .conllu files, writes
|
||||||
|
.conllu format for development data, allowing the official scorer to be used.
|
||||||
|
'''
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
import plac
|
||||||
|
import tqdm
|
||||||
|
from pathlib import Path
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
|
||||||
|
import spacy
|
||||||
|
import spacy.util
|
||||||
|
from ..tokens import Token, Doc
|
||||||
|
from ..gold import GoldParse
|
||||||
|
from ..util import compounding, minibatch_by_words
|
||||||
|
from ..syntax.nonproj import projectivize
|
||||||
|
from ..matcher import Matcher
|
||||||
|
from .. import displacy
|
||||||
|
from collections import defaultdict, Counter
|
||||||
|
from timeit import default_timer as timer
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
import random
|
||||||
|
import numpy.random
|
||||||
|
import cytoolz
|
||||||
|
|
||||||
|
from . import conll17_ud_eval
|
||||||
|
|
||||||
|
from .. import lang
|
||||||
|
from .. import lang
|
||||||
|
from ..lang import zh
|
||||||
|
from ..lang import ja
|
||||||
|
|
||||||
|
|
||||||
|
################
|
||||||
|
# Data reading #
|
||||||
|
################
|
||||||
|
|
||||||
|
space_re = re.compile('\s+')
|
||||||
|
def split_text(text):
|
||||||
|
return [space_re.sub(' ', par.strip()) for par in text.split('\n\n')]
|
||||||
|
|
||||||
|
|
||||||
|
def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False,
|
||||||
|
max_doc_length=None, limit=None):
|
||||||
|
'''Read the CONLLU format into (Doc, GoldParse) tuples. If raw_text=True,
|
||||||
|
include Doc objects created using nlp.make_doc and then aligned against
|
||||||
|
the gold-standard sequences. If oracle_segments=True, include Doc objects
|
||||||
|
created from the gold-standard segments. At least one must be True.'''
|
||||||
|
if not raw_text and not oracle_segments:
|
||||||
|
raise ValueError("At least one of raw_text or oracle_segments must be True")
|
||||||
|
paragraphs = split_text(text_file.read())
|
||||||
|
conllu = read_conllu(conllu_file)
|
||||||
|
# sd is spacy doc; cd is conllu doc
|
||||||
|
# cs is conllu sent, ct is conllu token
|
||||||
|
docs = []
|
||||||
|
golds = []
|
||||||
|
for doc_id, (text, cd) in enumerate(zip(paragraphs, conllu)):
|
||||||
|
sent_annots = []
|
||||||
|
for cs in cd:
|
||||||
|
sent = defaultdict(list)
|
||||||
|
for id_, word, lemma, pos, tag, morph, head, dep, _, space_after in cs:
|
||||||
|
if '.' in id_:
|
||||||
|
continue
|
||||||
|
if '-' in id_:
|
||||||
|
continue
|
||||||
|
id_ = int(id_)-1
|
||||||
|
head = int(head)-1 if head != '0' else id_
|
||||||
|
sent['words'].append(word)
|
||||||
|
sent['tags'].append(tag)
|
||||||
|
sent['heads'].append(head)
|
||||||
|
sent['deps'].append('ROOT' if dep == 'root' else dep)
|
||||||
|
sent['spaces'].append(space_after == '_')
|
||||||
|
sent['entities'] = ['-'] * len(sent['words'])
|
||||||
|
sent['heads'], sent['deps'] = projectivize(sent['heads'],
|
||||||
|
sent['deps'])
|
||||||
|
if oracle_segments:
|
||||||
|
docs.append(Doc(nlp.vocab, words=sent['words'], spaces=sent['spaces']))
|
||||||
|
golds.append(GoldParse(docs[-1], **sent))
|
||||||
|
|
||||||
|
sent_annots.append(sent)
|
||||||
|
if raw_text and max_doc_length and len(sent_annots) >= max_doc_length:
|
||||||
|
doc, gold = _make_gold(nlp, None, sent_annots)
|
||||||
|
sent_annots = []
|
||||||
|
docs.append(doc)
|
||||||
|
golds.append(gold)
|
||||||
|
if limit and len(docs) >= limit:
|
||||||
|
return docs, golds
|
||||||
|
|
||||||
|
if raw_text and sent_annots:
|
||||||
|
doc, gold = _make_gold(nlp, None, sent_annots)
|
||||||
|
docs.append(doc)
|
||||||
|
golds.append(gold)
|
||||||
|
if limit and len(docs) >= limit:
|
||||||
|
return docs, golds
|
||||||
|
return docs, golds
|
||||||
|
|
||||||
|
|
||||||
|
def read_conllu(file_):
|
||||||
|
docs = []
|
||||||
|
sent = []
|
||||||
|
doc = []
|
||||||
|
for line in file_:
|
||||||
|
if line.startswith('# newdoc'):
|
||||||
|
if doc:
|
||||||
|
docs.append(doc)
|
||||||
|
doc = []
|
||||||
|
elif line.startswith('#'):
|
||||||
|
continue
|
||||||
|
elif not line.strip():
|
||||||
|
if sent:
|
||||||
|
doc.append(sent)
|
||||||
|
sent = []
|
||||||
|
else:
|
||||||
|
sent.append(list(line.strip().split('\t')))
|
||||||
|
if len(sent[-1]) != 10:
|
||||||
|
print(repr(line))
|
||||||
|
raise ValueError
|
||||||
|
if sent:
|
||||||
|
doc.append(sent)
|
||||||
|
if doc:
|
||||||
|
docs.append(doc)
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
def _make_gold(nlp, text, sent_annots):
|
||||||
|
# Flatten the conll annotations, and adjust the head indices
|
||||||
|
flat = defaultdict(list)
|
||||||
|
for sent in sent_annots:
|
||||||
|
flat['heads'].extend(len(flat['words'])+head for head in sent['heads'])
|
||||||
|
for field in ['words', 'tags', 'deps', 'entities', 'spaces']:
|
||||||
|
flat[field].extend(sent[field])
|
||||||
|
# Construct text if necessary
|
||||||
|
assert len(flat['words']) == len(flat['spaces'])
|
||||||
|
if text is None:
|
||||||
|
text = ''.join(word+' '*space for word, space in zip(flat['words'], flat['spaces']))
|
||||||
|
doc = nlp.make_doc(text)
|
||||||
|
flat.pop('spaces')
|
||||||
|
gold = GoldParse(doc, **flat)
|
||||||
|
return doc, gold
|
||||||
|
|
||||||
|
#############################
|
||||||
|
# Data transforms for spaCy #
|
||||||
|
#############################
|
||||||
|
|
||||||
|
def golds_to_gold_tuples(docs, golds):
|
||||||
|
'''Get out the annoying 'tuples' format used by begin_training, given the
|
||||||
|
GoldParse objects.'''
|
||||||
|
tuples = []
|
||||||
|
for doc, gold in zip(docs, golds):
|
||||||
|
text = doc.text
|
||||||
|
ids, words, tags, heads, labels, iob = zip(*gold.orig_annot)
|
||||||
|
sents = [((ids, words, tags, heads, labels, iob), [])]
|
||||||
|
tuples.append((text, sents))
|
||||||
|
return tuples
|
||||||
|
|
||||||
|
|
||||||
|
##############
|
||||||
|
# Evaluation #
|
||||||
|
##############
|
||||||
|
|
||||||
|
def evaluate(nlp, text_loc, gold_loc, sys_loc, limit=None):
|
||||||
|
with text_loc.open('r', encoding='utf8') as text_file:
|
||||||
|
texts = split_text(text_file.read())
|
||||||
|
docs = list(nlp.pipe(texts))
|
||||||
|
with sys_loc.open('w', encoding='utf8') as out_file:
|
||||||
|
write_conllu(docs, out_file)
|
||||||
|
with gold_loc.open('r', encoding='utf8') as gold_file:
|
||||||
|
gold_ud = conll17_ud_eval.load_conllu(gold_file)
|
||||||
|
with sys_loc.open('r', encoding='utf8') as sys_file:
|
||||||
|
sys_ud = conll17_ud_eval.load_conllu(sys_file)
|
||||||
|
scores = conll17_ud_eval.evaluate(gold_ud, sys_ud)
|
||||||
|
return docs, scores
|
||||||
|
|
||||||
|
|
||||||
|
def write_conllu(docs, file_):
|
||||||
|
merger = Matcher(docs[0].vocab)
|
||||||
|
merger.add('SUBTOK', None, [{'DEP': 'subtok', 'op': '+'}])
|
||||||
|
for i, doc in enumerate(docs):
|
||||||
|
matches = merger(doc)
|
||||||
|
spans = [doc[start:end+1] for _, start, end in matches]
|
||||||
|
offsets = [(span.start_char, span.end_char) for span in spans]
|
||||||
|
for start_char, end_char in offsets:
|
||||||
|
doc.merge(start_char, end_char)
|
||||||
|
file_.write("# newdoc id = {i}\n".format(i=i))
|
||||||
|
for j, sent in enumerate(doc.sents):
|
||||||
|
file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j))
|
||||||
|
file_.write("# text = {text}\n".format(text=sent.text))
|
||||||
|
for k, token in enumerate(sent):
|
||||||
|
file_.write(token._.get_conllu_lines(k) + '\n')
|
||||||
|
file_.write('\n')
|
||||||
|
|
||||||
|
|
||||||
|
def print_progress(itn, losses, ud_scores):
|
||||||
|
fields = {
|
||||||
|
'dep_loss': losses.get('parser', 0.0),
|
||||||
|
'tag_loss': losses.get('tagger', 0.0),
|
||||||
|
'words': ud_scores['Words'].f1 * 100,
|
||||||
|
'sents': ud_scores['Sentences'].f1 * 100,
|
||||||
|
'tags': ud_scores['XPOS'].f1 * 100,
|
||||||
|
'uas': ud_scores['UAS'].f1 * 100,
|
||||||
|
'las': ud_scores['LAS'].f1 * 100,
|
||||||
|
}
|
||||||
|
header = ['Epoch', 'Loss', 'LAS', 'UAS', 'TAG', 'SENT', 'WORD']
|
||||||
|
if itn == 0:
|
||||||
|
print('\t'.join(header))
|
||||||
|
tpl = '\t'.join((
|
||||||
|
'{:d}',
|
||||||
|
'{dep_loss:.1f}',
|
||||||
|
'{las:.1f}',
|
||||||
|
'{uas:.1f}',
|
||||||
|
'{tags:.1f}',
|
||||||
|
'{sents:.1f}',
|
||||||
|
'{words:.1f}',
|
||||||
|
))
|
||||||
|
print(tpl.format(itn, **fields))
|
||||||
|
|
||||||
|
#def get_sent_conllu(sent, sent_id):
|
||||||
|
# lines = ["# sent_id = {sent_id}".format(sent_id=sent_id)]
|
||||||
|
|
||||||
|
def get_token_conllu(token, i):
|
||||||
|
if token._.begins_fused:
|
||||||
|
n = 1
|
||||||
|
while token.nbor(n)._.inside_fused:
|
||||||
|
n += 1
|
||||||
|
id_ = '%d-%d' % (i, i+n)
|
||||||
|
lines = [id_, token.text, '_', '_', '_', '_', '_', '_', '_', '_']
|
||||||
|
else:
|
||||||
|
lines = []
|
||||||
|
if token.head.i == token.i:
|
||||||
|
head = 0
|
||||||
|
else:
|
||||||
|
head = i + (token.head.i - token.i) + 1
|
||||||
|
fields = [str(i+1), token.text, token.lemma_, token.pos_, token.tag_, '_',
|
||||||
|
str(head), token.dep_.lower(), '_', '_']
|
||||||
|
lines.append('\t'.join(fields))
|
||||||
|
return '\n'.join(lines)
|
||||||
|
|
||||||
|
Token.set_extension('get_conllu_lines', method=get_token_conllu)
|
||||||
|
Token.set_extension('begins_fused', default=False)
|
||||||
|
Token.set_extension('inside_fused', default=False)
|
||||||
|
|
||||||
|
|
||||||
|
##################
|
||||||
|
# Initialization #
|
||||||
|
##################
|
||||||
|
|
||||||
|
|
||||||
|
def load_nlp(corpus, config):
|
||||||
|
lang = corpus.split('_')[0]
|
||||||
|
nlp = spacy.blank(lang)
|
||||||
|
if config.vectors:
|
||||||
|
nlp.vocab.from_disk(Path(config.vectors) / 'vocab')
|
||||||
|
return nlp
|
||||||
|
|
||||||
|
def initialize_pipeline(nlp, docs, golds, config, device):
|
||||||
|
nlp.add_pipe(nlp.create_pipe('parser'))
|
||||||
|
if config.multitask_tag:
|
||||||
|
nlp.parser.add_multitask_objective('tag')
|
||||||
|
if config.multitask_sent:
|
||||||
|
nlp.parser.add_multitask_objective('sent_start')
|
||||||
|
nlp.add_pipe(nlp.create_pipe('tagger'))
|
||||||
|
for gold in golds:
|
||||||
|
for tag in gold.tags:
|
||||||
|
if tag is not None:
|
||||||
|
nlp.tagger.add_label(tag)
|
||||||
|
return nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds), device=device)
|
||||||
|
|
||||||
|
|
||||||
|
########################
|
||||||
|
# Command line helpers #
|
||||||
|
########################
|
||||||
|
|
||||||
|
class Config(object):
|
||||||
|
def __init__(self, vectors=None, max_doc_length=10, multitask_tag=True,
|
||||||
|
multitask_sent=True, nr_epoch=30, batch_size=1000, dropout=0.2):
|
||||||
|
for key, value in locals().items():
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, loc):
|
||||||
|
with Path(loc).open('r', encoding='utf8') as file_:
|
||||||
|
cfg = json.load(file_)
|
||||||
|
return cls(**cfg)
|
||||||
|
|
||||||
|
|
||||||
|
class Dataset(object):
|
||||||
|
def __init__(self, path, section):
|
||||||
|
self.path = path
|
||||||
|
self.section = section
|
||||||
|
self.conllu = None
|
||||||
|
self.text = None
|
||||||
|
for file_path in self.path.iterdir():
|
||||||
|
name = file_path.parts[-1]
|
||||||
|
if section in name and name.endswith('conllu'):
|
||||||
|
self.conllu = file_path
|
||||||
|
elif section in name and name.endswith('txt'):
|
||||||
|
self.text = file_path
|
||||||
|
if self.conllu is None:
|
||||||
|
msg = "Could not find .txt file in {path} for {section}"
|
||||||
|
raise IOError(msg.format(section=section, path=path))
|
||||||
|
if self.text is None:
|
||||||
|
msg = "Could not find .txt file in {path} for {section}"
|
||||||
|
self.lang = self.conllu.parts[-1].split('-')[0].split('_')[0]
|
||||||
|
|
||||||
|
|
||||||
|
class TreebankPaths(object):
|
||||||
|
def __init__(self, ud_path, treebank, **cfg):
|
||||||
|
self.train = Dataset(ud_path / treebank, 'train')
|
||||||
|
self.dev = Dataset(ud_path / treebank, 'dev')
|
||||||
|
self.lang = self.train.lang
|
||||||
|
|
||||||
|
|
||||||
|
@plac.annotations(
|
||||||
|
ud_dir=("Path to Universal Dependencies corpus", "positional", None, Path),
|
||||||
|
corpus=("UD corpus to train and evaluate on, e.g. en, es_ancora, etc",
|
||||||
|
"positional", None, str),
|
||||||
|
parses_dir=("Directory to write the development parses", "positional", None, Path),
|
||||||
|
config=("Path to json formatted config file", "positional"),
|
||||||
|
limit=("Size limit", "option", "n", int),
|
||||||
|
use_gpu=("Use GPU", "option", "g", int)
|
||||||
|
)
|
||||||
|
def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1):
|
||||||
|
spacy.util.fix_random_seed()
|
||||||
|
lang.zh.Chinese.Defaults.use_jieba = False
|
||||||
|
lang.ja.Japanese.Defaults.use_janome = False
|
||||||
|
|
||||||
|
config = Config.load(config)
|
||||||
|
paths = TreebankPaths(ud_dir, corpus)
|
||||||
|
if not (parses_dir / corpus).exists():
|
||||||
|
(parses_dir / corpus).mkdir()
|
||||||
|
print("Train and evaluate", corpus, "using lang", paths.lang)
|
||||||
|
nlp = load_nlp(paths.lang, config)
|
||||||
|
|
||||||
|
docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(),
|
||||||
|
max_doc_length=config.max_doc_length, limit=limit)
|
||||||
|
|
||||||
|
optimizer = initialize_pipeline(nlp, docs, golds, config, use_gpu)
|
||||||
|
|
||||||
|
batch_sizes = compounding(config.batch_size//10, config.batch_size, 1.001)
|
||||||
|
for i in range(config.nr_epoch):
|
||||||
|
docs = [nlp.make_doc(doc.text) for doc in docs]
|
||||||
|
Xs = list(zip(docs, golds))
|
||||||
|
random.shuffle(Xs)
|
||||||
|
batches = minibatch_by_words(Xs, size=batch_sizes)
|
||||||
|
losses = {}
|
||||||
|
n_train_words = sum(len(doc) for doc in docs)
|
||||||
|
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
||||||
|
for batch in batches:
|
||||||
|
batch_docs, batch_gold = zip(*batch)
|
||||||
|
pbar.update(sum(len(doc) for doc in batch_docs))
|
||||||
|
nlp.update(batch_docs, batch_gold, sgd=optimizer,
|
||||||
|
drop=config.dropout, losses=losses)
|
||||||
|
|
||||||
|
out_path = parses_dir / corpus / 'epoch-{i}.conllu'.format(i=i)
|
||||||
|
with nlp.use_params(optimizer.averages):
|
||||||
|
parsed_docs, scores = evaluate(nlp, paths.dev.text, paths.dev.conllu, out_path)
|
||||||
|
print_progress(i, losses, scores)
|
||||||
|
_render_parses(i, parsed_docs[:50])
|
||||||
|
|
||||||
|
|
||||||
|
def _render_parses(i, to_render):
|
||||||
|
to_render[0].user_data['title'] = "Batch %d" % i
|
||||||
|
with Path('/tmp/parses.html').open('w') as file_:
|
||||||
|
html = displacy.render(to_render[:5], style='dep', page=True)
|
||||||
|
file_.write(html)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
plac.call(main)
|
|
@ -462,7 +462,7 @@ class Language(object):
|
||||||
self._optimizer = sgd
|
self._optimizer = sgd
|
||||||
for name, proc in self.pipeline:
|
for name, proc in self.pipeline:
|
||||||
if hasattr(proc, 'begin_training'):
|
if hasattr(proc, 'begin_training'):
|
||||||
proc.begin_training(get_gold_tuples(),
|
proc.begin_training(get_gold_tuples,
|
||||||
pipeline=self.pipeline,
|
pipeline=self.pipeline,
|
||||||
sgd=self._optimizer,
|
sgd=self._optimizer,
|
||||||
**cfg)
|
**cfg)
|
||||||
|
|
|
@ -172,7 +172,7 @@ class Pipe(object):
|
||||||
return create_default_optimizer(self.model.ops,
|
return create_default_optimizer(self.model.ops,
|
||||||
**self.cfg.get('optimizer', {}))
|
**self.cfg.get('optimizer', {}))
|
||||||
|
|
||||||
def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None,
|
def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Initialize the pipe for training, using data exampes if available.
|
"""Initialize the pipe for training, using data exampes if available.
|
||||||
If no model has been initialized yet, the model is added."""
|
If no model has been initialized yet, the model is added."""
|
||||||
|
@ -374,7 +374,7 @@ class Tensorizer(Pipe):
|
||||||
loss = (d_scores**2).sum()
|
loss = (d_scores**2).sum()
|
||||||
return loss, d_scores
|
return loss, d_scores
|
||||||
|
|
||||||
def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None,
|
def begin_training(self, gold_tuples=lambda: [], pipeline=None, sgd=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Allocate models, pre-process training data and acquire an
|
"""Allocate models, pre-process training data and acquire an
|
||||||
optimizer.
|
optimizer.
|
||||||
|
@ -498,11 +498,11 @@ class Tagger(Pipe):
|
||||||
d_scores = self.model.ops.unflatten(d_scores, [len(d) for d in docs])
|
d_scores = self.model.ops.unflatten(d_scores, [len(d) for d in docs])
|
||||||
return float(loss), d_scores
|
return float(loss), d_scores
|
||||||
|
|
||||||
def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None,
|
def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
orig_tag_map = dict(self.vocab.morphology.tag_map)
|
orig_tag_map = dict(self.vocab.morphology.tag_map)
|
||||||
new_tag_map = OrderedDict()
|
new_tag_map = OrderedDict()
|
||||||
for raw_text, annots_brackets in gold_tuples:
|
for raw_text, annots_brackets in get_gold_tuples():
|
||||||
for annots, brackets in annots_brackets:
|
for annots, brackets in annots_brackets:
|
||||||
ids, words, tags, heads, deps, ents = annots
|
ids, words, tags, heads, deps, ents = annots
|
||||||
for tag in tags:
|
for tag in tags:
|
||||||
|
@ -673,9 +673,9 @@ class MultitaskObjective(Tagger):
|
||||||
def set_annotations(self, docs, dep_ids, tensors=None):
|
def set_annotations(self, docs, dep_ids, tensors=None):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def begin_training(self, gold_tuples=tuple(), pipeline=None, tok2vec=None,
|
def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, tok2vec=None,
|
||||||
sgd=None, **kwargs):
|
sgd=None, **kwargs):
|
||||||
gold_tuples = nonproj.preprocess_training_data(gold_tuples)
|
gold_tuples = nonproj.preprocess_training_data(get_gold_tuples())
|
||||||
for raw_text, annots_brackets in gold_tuples:
|
for raw_text, annots_brackets in gold_tuples:
|
||||||
for annots, brackets in annots_brackets:
|
for annots, brackets in annots_brackets:
|
||||||
ids, words, tags, heads, deps, ents = annots
|
ids, words, tags, heads, deps, ents = annots
|
||||||
|
@ -898,7 +898,7 @@ class TextCategorizer(Pipe):
|
||||||
self.labels.append(label)
|
self.labels.append(label)
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None):
|
def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None):
|
||||||
if pipeline and getattr(pipeline[0], 'name', None) == 'tensorizer':
|
if pipeline and getattr(pipeline[0], 'name', None) == 'tensorizer':
|
||||||
token_vector_width = pipeline[0].model.nO
|
token_vector_width = pipeline[0].model.nO
|
||||||
else:
|
else:
|
||||||
|
@ -925,10 +925,10 @@ cdef class DependencyParser(Parser):
|
||||||
labeller = MultitaskObjective(self.vocab, target=target)
|
labeller = MultitaskObjective(self.vocab, target=target)
|
||||||
self._multitasks.append(labeller)
|
self._multitasks.append(labeller)
|
||||||
|
|
||||||
def init_multitask_objectives(self, gold_tuples, pipeline, sgd=None, **cfg):
|
def init_multitask_objectives(self, get_gold_tuples, pipeline, sgd=None, **cfg):
|
||||||
for labeller in self._multitasks:
|
for labeller in self._multitasks:
|
||||||
tok2vec = self.model[0]
|
tok2vec = self.model[0]
|
||||||
labeller.begin_training(gold_tuples, pipeline=pipeline,
|
labeller.begin_training(get_gold_tuples, pipeline=pipeline,
|
||||||
tok2vec=tok2vec, sgd=sgd)
|
tok2vec=tok2vec, sgd=sgd)
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
|
@ -946,10 +946,10 @@ cdef class EntityRecognizer(Parser):
|
||||||
labeller = MultitaskObjective(self.vocab, target=target)
|
labeller = MultitaskObjective(self.vocab, target=target)
|
||||||
self._multitasks.append(labeller)
|
self._multitasks.append(labeller)
|
||||||
|
|
||||||
def init_multitask_objectives(self, gold_tuples, pipeline, sgd=None, **cfg):
|
def init_multitask_objectives(self, get_gold_tuples, pipeline, sgd=None, **cfg):
|
||||||
for labeller in self._multitasks:
|
for labeller in self._multitasks:
|
||||||
tok2vec = self.model[0]
|
tok2vec = self.model[0]
|
||||||
labeller.begin_training(gold_tuples, pipeline=pipeline,
|
labeller.begin_training(get_gold_tuples, pipeline=pipeline,
|
||||||
tok2vec=tok2vec)
|
tok2vec=tok2vec)
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
|
|
|
@ -164,16 +164,17 @@ cdef void sum_state_features(float* output,
|
||||||
cdef const float* feature
|
cdef const float* feature
|
||||||
padding = cached
|
padding = cached
|
||||||
cached += F * O
|
cached += F * O
|
||||||
|
cdef int id_stride = F*O
|
||||||
|
cdef float one = 1.
|
||||||
for b in range(B):
|
for b in range(B):
|
||||||
for f in range(F):
|
for f in range(F):
|
||||||
if token_ids[f] < 0:
|
if token_ids[f] < 0:
|
||||||
feature = &padding[f*O]
|
feature = &padding[f*O]
|
||||||
else:
|
else:
|
||||||
idx = token_ids[f] * F * O + f*O
|
idx = token_ids[f] * id_stride + f*O
|
||||||
feature = &cached[idx]
|
feature = &cached[idx]
|
||||||
for i in range(O):
|
openblas.simple_axpy(&output[b*O], O,
|
||||||
output[i] += feature[i]
|
feature, one)
|
||||||
output += O
|
|
||||||
token_ids += F
|
token_ids += F
|
||||||
|
|
||||||
|
|
||||||
|
@ -726,7 +727,7 @@ cdef class Parser:
|
||||||
lower, stream, drop=0.0)
|
lower, stream, drop=0.0)
|
||||||
return (tokvecs, bp_tokvecs), state2vec, upper
|
return (tokvecs, bp_tokvecs), state2vec, upper
|
||||||
|
|
||||||
nr_feature = 13
|
nr_feature = 8
|
||||||
|
|
||||||
def get_token_ids(self, states):
|
def get_token_ids(self, states):
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
|
@ -821,15 +822,13 @@ cdef class Parser:
|
||||||
copy_array(larger.b[:smaller.nO], smaller.b)
|
copy_array(larger.b[:smaller.nO], smaller.b)
|
||||||
self.model[-1]._layers[-1] = larger
|
self.model[-1]._layers[-1] = larger
|
||||||
|
|
||||||
def begin_training(self, gold_tuples, pipeline=None, sgd=None, **cfg):
|
def begin_training(self, get_gold_tuples, pipeline=None, sgd=None, **cfg):
|
||||||
if 'model' in cfg:
|
if 'model' in cfg:
|
||||||
self.model = cfg['model']
|
self.model = cfg['model']
|
||||||
gold_tuples = nonproj.preprocess_training_data(gold_tuples,
|
cfg.setdefault('min_action_freq', 30)
|
||||||
label_freq_cutoff=100)
|
actions = self.moves.get_actions(gold_parses=get_gold_tuples(),
|
||||||
actions = self.moves.get_actions(gold_parses=gold_tuples)
|
min_freq=cfg.get('min_action_freq', 30))
|
||||||
for action, labels in actions.items():
|
self.moves.initialize_actions(actions)
|
||||||
for label in labels:
|
|
||||||
self.moves.add_action(action, label)
|
|
||||||
cfg.setdefault('token_vector_width', 128)
|
cfg.setdefault('token_vector_width', 128)
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
cfg['pretrained_dims'] = self.vocab.vectors_length
|
cfg['pretrained_dims'] = self.vocab.vectors_length
|
||||||
|
@ -839,7 +838,7 @@ cdef class Parser:
|
||||||
self.model[1].begin_training(
|
self.model[1].begin_training(
|
||||||
self.model[1].ops.allocate((5, cfg['token_vector_width'])))
|
self.model[1].ops.allocate((5, cfg['token_vector_width'])))
|
||||||
if pipeline is not None:
|
if pipeline is not None:
|
||||||
self.init_multitask_objectives(gold_tuples, pipeline, sgd=sgd, **cfg)
|
self.init_multitask_objectives(get_gold_tuples, pipeline, sgd=sgd, **cfg)
|
||||||
link_vectors_to_models(self.vocab)
|
link_vectors_to_models(self.vocab)
|
||||||
else:
|
else:
|
||||||
if sgd is None:
|
if sgd is None:
|
||||||
|
@ -853,7 +852,7 @@ cdef class Parser:
|
||||||
# Defined in subclasses, to avoid circular import
|
# Defined in subclasses, to avoid circular import
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def init_multitask_objectives(self, gold_tuples, pipeline, **cfg):
|
def init_multitask_objectives(self, get_gold_tuples, pipeline, **cfg):
|
||||||
'''Setup models for secondary objectives, to benefit from multi-task
|
'''Setup models for secondary objectives, to benefit from multi-task
|
||||||
learning. This method is intended to be overridden by subclasses.
|
learning. This method is intended to be overridden by subclasses.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user