diff --git a/examples/training/conllu.py b/examples/training/conllu.py index 605905361..5063cca21 100644 --- a/examples/training/conllu.py +++ b/examples/training/conllu.py @@ -18,6 +18,7 @@ from spacy.syntax.nonproj import projectivize from collections import defaultdict, Counter from timeit import default_timer as timer from spacy.matcher import Matcher +from spacy.morphology import Fused_begin, Fused_inside import itertools import random @@ -84,18 +85,28 @@ def read_data(nlp, conllu_file, text_file, raw_text=True, oracle_segments=False, sent_annots = [] for cs in cd: sent = defaultdict(list) + fused_ids = set() for id_, word, lemma, pos, tag, morph, head, dep, _, space_after in cs: if '.' in id_: continue if '-' in id_: + fuse_start, fuse_end = id_.split('-') + for sub_id in range(int(fuse_start), int(fuse_end)+1): + fused_ids.add(str(sub_id)) + sent['tokens'].append(word) continue + if id_ not in fused_ids: + sent['tokens'].append(word) + if space_after == '_': + sent['tokens'][-1] += ' ' + elif id_ == fuse_end and space_after == '_': + sent['tokens'][-1] += ' ' id_ = int(id_)-1 head = int(head)-1 if head != '0' else id_ sent['words'].append(word) sent['tags'].append(tag) sent['heads'].append(head) sent['deps'].append('ROOT' if dep == 'root' else dep) - sent['spaces'].append(space_after == '_') sent['entities'] = ['-'] * len(sent['words']) sent['heads'], sent['deps'] = projectivize(sent['heads'], sent['deps']) @@ -153,14 +164,13 @@ def _make_gold(nlp, text, sent_annots): flat = defaultdict(list) for sent in sent_annots: flat['heads'].extend(len(flat['words'])+head for head in sent['heads']) - for field in ['words', 'tags', 'deps', 'entities', 'spaces']: + for field in ['words', 'tags', 'deps', 'entities', 'tokens']: flat[field].extend(sent[field]) # Construct text if necessary - assert len(flat['words']) == len(flat['spaces']) if text is None: - text = ''.join(word+' '*space for word, space in zip(flat['words'], flat['spaces'])) + text = ''.join(flat['tokens']) doc = nlp.make_doc(text) - flat.pop('spaces') + flat.pop('tokens') gold = GoldParse(doc, **flat) return doc, gold @@ -210,12 +220,39 @@ def write_conllu(docs, file_): file_.write("# newdoc id = {i}\n".format(i=i)) for j, sent in enumerate(doc.sents): file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j)) - file_.write("# text = {text}\n".format(text=sent.text)) + file_.write('# text = {text}\n'.format(text=sent.text)) for k, token in enumerate(sent): - file_.write(token._.get_conllu_lines(k) + '\n') + file_.write(_get_token_conllu(token, k, len(sent)) + '\n') file_.write('\n') +def _get_token_conllu(token, k, sent_len): + if token.check_morph(Fused_begin) and (k+1 < sent_len): + n = 1 + text = [token.text] + while token.nbor(n).check_morph(Fused_inside): + text.append(token.nbor(n).text) + n += 1 + id_ = '%d-%d' % (k+1, (k+n)) + fields = [id_, ''.join(text)] + ['_'] * 8 + lines = ['\t'.join(fields)] + else: + lines = [] + if token.head.i == token.i: + head = 0 + else: + head = k + (token.head.i - token.i) + 1 + fields = [str(k+1), token.text, token.lemma_, token.pos_, token.tag_, '_', + str(head), token.dep_.lower(), '_', '_'] + if token.check_morph(Fused_begin) and (k+1 < sent_len): + if k == 0: + fields[1] = token.norm_[0].upper() + token.norm_[1:] + else: + fields[1] = token.norm_ + lines.append('\t'.join(fields)) + return '\n'.join(lines) + + def print_progress(itn, losses, ud_scores): fields = { 'dep_loss': losses.get('parser', 0.0), @@ -240,31 +277,6 @@ def print_progress(itn, losses, ud_scores): )) print(tpl.format(itn, **fields)) -#def get_sent_conllu(sent, sent_id): -# lines = ["# sent_id = {sent_id}".format(sent_id=sent_id)] - -def get_token_conllu(token, i): - if token._.begins_fused: - n = 1 - while token.nbor(n)._.inside_fused: - n += 1 - id_ = '%d-%d' % (k, k+n) - lines = [id_, token.text, '_', '_', '_', '_', '_', '_', '_', '_'] - else: - lines = [] - if token.head.i == token.i: - head = 0 - else: - head = i + (token.head.i - token.i) + 1 - fields = [str(i+1), token.text, token.lemma_, token.pos_, token.tag_, '_', - str(head), token.dep_.lower(), '_', '_'] - lines.append('\t'.join(fields)) - return '\n'.join(lines) - -Token.set_extension('get_conllu_lines', method=get_token_conllu) -Token.set_extension('begins_fused', default=False) -Token.set_extension('inside_fused', default=False) - ################## # Initialization # @@ -278,14 +290,63 @@ def load_nlp(corpus, config): nlp.vocab.from_disk(config.vectors / 'vocab') return nlp +def extract_tokenizer_exceptions(paths): + with paths.train.conllu.open() as file_: + conllu = read_conllu(file_) + fused = defaultdict(lambda: defaultdict(list)) + for doc in conllu: + for sent in doc: + for i, token in enumerate(sent): + if '-' in token[0]: + start, end = token[0].split('-') + length = int(end) - int(start) + subtokens = sent[i+1 : i+1+length+1] + forms = [t[1].lower() for t in subtokens] + fused[token[1]][tuple(forms)].append(subtokens) + exc = {} + for word, expansions in fused.items(): + by_freq = [(len(occurs), key, occurs) for key, occurs in expansions.items()] + freq, key, occurs = max(by_freq) + if word == ''.join(key): + # Happy case: we get a perfect split, with each letter accounted for. + analysis = [{'ORTH': subtoken} for subtoken in key] + elif len(word) == sum(len(subtoken) for subtoken in key): + # Unideal, but at least lengths match. + analysis = [] + remain = word + for subtoken in key: + analysis.append({'ORTH': remain[:len(subtoken)]}) + remain = remain[len(subtoken):] + assert len(remain) == 0, (word, key, remain) + else: + # Let's say word is 6 long, and there are three subtokens. The orths + # *must* equal the original string. Arbitrarily, split [4, 1, 1] + first = word[:len(word)-(len(key)-1)] + subtokens = [first] + remain = word[len(first):] + for i in range(1, len(key)): + subtokens.append(remain[:1]) + remain = remain[1:] + assert len(remain) == 0, (word, subtokens, remain) + analysis = [{'ORTH': subtoken} for subtoken in subtokens] + for i, token in enumerate(occurs[0]): + analysis[i]['NORM'] = token[1] + analysis[0]['morphology'] = [Fused_begin] + for subtoken in analysis[1:]: + subtoken['morphology'] = [Fused_inside] + exc[word] = analysis + return exc + def initialize_pipeline(nlp, docs, golds, config): + nlp.add_pipe(nlp.create_pipe('tagger')) nlp.add_pipe(nlp.create_pipe('parser')) + nlp.parser.moves.add_action(2, 'subtok') if config.multitask_tag: nlp.parser.add_multitask_objective('tag') if config.multitask_sent: nlp.parser.add_multitask_objective('sent_start') - nlp.parser.moves.add_action(2, 'subtok') - nlp.add_pipe(nlp.create_pipe('tagger')) + if config.multitask_dep: + nlp.parser.add_multitask_objective('dep') for gold in golds: for tag in gold.tags: if tag is not None: @@ -308,6 +369,7 @@ def initialize_pipeline(nlp, docs, golds, config): class Config(object): vectors = attr.ib(default=None) max_doc_length = attr.ib(default=10) + multitask_dep = attr.ib(default=True) multitask_tag = attr.ib(default=True) multitask_sent = attr.ib(default=True) nr_epoch = attr.ib(default=30) @@ -362,7 +424,9 @@ def main(ud_dir, parses_dir, config, corpus, limit=0): (parses_dir / corpus).mkdir() print("Train and evaluate", corpus, "using lang", paths.lang) nlp = load_nlp(paths.lang, config) - + tokenizer_exceptions = extract_tokenizer_exceptions(paths) + for orth, subtokens in tokenizer_exceptions.items(): + nlp.tokenizer.add_special_case(orth, subtokens) docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(), max_doc_length=config.max_doc_length, limit=limit)