Merge pull request #6109 from svlandeg/feature/2rename

This commit is contained in:
Ines Montani 2020-09-23 09:47:12 +02:00 committed by GitHub
commit 60a317520a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 115 additions and 98 deletions

View File

@ -378,7 +378,7 @@ def git_sparse_checkout(repo, subpath, dest, branch):
# Looking for this 'rev-list' command in the git --help? Hah.
cmd = f"git -C {tmp_dir} rev-list --objects --all --missing=print -- {subpath}"
ret = run_command(cmd, capture=True)
git_repo = _from_http_to_git(repo)
git_repo = _http_to_git(repo)
# Now pass those missings into another bit of git internals
missings = " ".join([x[1:] for x in ret.stdout.split() if x.startswith("?")])
if not missings:
@ -414,7 +414,7 @@ def get_git_version(
return (int(version[0]), int(version[1]))
def _from_http_to_git(repo: str) -> str:
def _http_to_git(repo: str) -> str:
if repo.startswith("http://"):
repo = repo.replace(r"http://", r"https://")
if repo.startswith(r"https://"):

View File

@ -9,7 +9,7 @@ import sys
from ._util import app, Arg, Opt
from ..training import docs_to_json
from ..tokens import DocBin
from ..training.converters import iob2docs, conll_ner2docs, json2docs, conllu2docs
from ..training.converters import iob_to_docs, conll_ner_to_docs, json_to_docs, conllu_to_docs
# Converters are matched by file extension except for ner/iob, which are
@ -18,12 +18,12 @@ from ..training.converters import iob2docs, conll_ner2docs, json2docs, conllu2do
# imported from /converters.
CONVERTERS = {
"conllubio": conllu2docs,
"conllu": conllu2docs,
"conll": conllu2docs,
"ner": conll_ner2docs,
"iob": iob2docs,
"json": json2docs,
"conllubio": conllu_to_docs,
"conllu": conllu_to_docs,
"conll": conllu_to_docs,
"ner": conll_ner_to_docs,
"iob": iob_to_docs,
"json": json_to_docs,
}

View File

@ -69,7 +69,7 @@ class Warnings:
"in problems with the vocab further on in the pipeline.")
W030 = ("Some entities could not be aligned in the text \"{text}\" with "
"entities \"{entities}\". Use "
"`spacy.training.biluo_tags_from_offsets(nlp.make_doc(text), entities)`"
"`spacy.training.offsets_to_biluo_tags(nlp.make_doc(text), entities)`"
" to check the alignment. Misaligned entities ('-') will be "
"ignored during training.")
W033 = ("Training a new {model} using a model with no lexeme normalization "

View File

@ -3,7 +3,7 @@ from spacy.pipeline import Pipe
from spacy.matcher import PhraseMatcher, Matcher
from spacy.tokens import Doc, Span, DocBin
from spacy.training import Example, Corpus
from spacy.training.converters import json2docs
from spacy.training.converters import json_to_docs
from spacy.vocab import Vocab
from spacy.lang.en import English
from spacy.util import minibatch, ensure_path, load_model
@ -425,7 +425,7 @@ def test_issue4402():
attrs = ["ORTH", "SENT_START", "ENT_IOB", "ENT_TYPE"]
with make_tempdir() as tmpdir:
output_file = tmpdir / "test4402.spacy"
docs = json2docs([json_data])
docs = json_to_docs([json_data])
data = DocBin(docs=docs, attrs=attrs).to_bytes()
with output_file.open("wb") as file_:
file_.write(data)

View File

@ -1,7 +1,7 @@
import pytest
from spacy.tokens import Doc, Span, DocBin
from spacy.training import Example
from spacy.training.converters.conllu2docs import conllu2docs
from spacy.training.converters.conllu_to_docs import conllu_to_docs
from spacy.lang.en import English
from spacy.kb import KnowledgeBase
from spacy.vocab import Vocab
@ -82,7 +82,7 @@ def test_issue4651_without_phrase_matcher_attr():
def test_issue4665():
"""
conllu2json should not raise an exception if the HEAD column contains an
conllu_to_docs should not raise an exception if the HEAD column contains an
underscore
"""
input_data = """
@ -105,7 +105,7 @@ def test_issue4665():
17 . _ PUNCT . _ _ punct _ _
18 ] _ PUNCT -RRB- _ _ punct _ _
"""
conllu2docs(input_data)
conllu_to_docs(input_data)
def test_issue4674():

View File

@ -1,7 +1,7 @@
import pytest
from click import NoSuchOption
from spacy.training import docs_to_json, biluo_tags_from_offsets
from spacy.training.converters import iob2docs, conll_ner2docs, conllu2docs
from spacy.training import docs_to_json, offsets_to_biluo_tags
from spacy.training.converters import iob_to_docs, conll_ner_to_docs, conllu_to_docs
from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate
from spacy.cli.init_config import init_config, RECOMMENDATIONS
from spacy.cli._util import validate_project_commands, parse_config_overrides
@ -15,7 +15,7 @@ import os
from .util import make_tempdir
def test_cli_converters_conllu2json():
def test_cli_converters_conllu_to_docs():
# from NorNE: https://github.com/ltgoslo/norne/blob/3d23274965f513f23aa48455b28b1878dad23c05/ud/nob/no_bokmaal-ud-dev.conllu
lines = [
"1\tDommer\tdommer\tNOUN\t_\tDefinite=Ind|Gender=Masc|Number=Sing\t2\tappos\t_\tO",
@ -24,7 +24,7 @@ def test_cli_converters_conllu2json():
"4\tavstår\tavstå\tVERB\t_\tMood=Ind|Tense=Pres|VerbForm=Fin\t0\troot\t_\tO",
]
input_data = "\n".join(lines)
converted_docs = conllu2docs(input_data, n_sents=1)
converted_docs = conllu_to_docs(input_data, n_sents=1)
assert len(converted_docs) == 1
converted = [docs_to_json(converted_docs)]
assert converted[0]["id"] == 0
@ -40,7 +40,7 @@ def test_cli_converters_conllu2json():
ent_offsets = [
(e[0], e[1], e[2]) for e in converted[0]["paragraphs"][0]["entities"]
]
biluo_tags = biluo_tags_from_offsets(converted_docs[0], ent_offsets, missing="O")
biluo_tags = offsets_to_biluo_tags(converted_docs[0], ent_offsets, missing="O")
assert biluo_tags == ["O", "B-PER", "L-PER", "O"]
@ -63,9 +63,9 @@ def test_cli_converters_conllu2json():
),
],
)
def test_cli_converters_conllu2json_name_ner_map(lines):
def test_cli_converters_conllu_to_docs_name_ner_map(lines):
input_data = "\n".join(lines)
converted_docs = conllu2docs(
converted_docs = conllu_to_docs(
input_data, n_sents=1, ner_map={"PER": "PERSON", "BAD": ""}
)
assert len(converted_docs) == 1
@ -84,11 +84,11 @@ def test_cli_converters_conllu2json_name_ner_map(lines):
ent_offsets = [
(e[0], e[1], e[2]) for e in converted[0]["paragraphs"][0]["entities"]
]
biluo_tags = biluo_tags_from_offsets(converted_docs[0], ent_offsets, missing="O")
biluo_tags = offsets_to_biluo_tags(converted_docs[0], ent_offsets, missing="O")
assert biluo_tags == ["O", "B-PERSON", "L-PERSON", "O", "O"]
def test_cli_converters_conllu2json_subtokens():
def test_cli_converters_conllu_to_docs_subtokens():
# https://raw.githubusercontent.com/ohenrik/nb_news_ud_sm/master/original_data/no-ud-dev-ner.conllu
lines = [
"1\tDommer\tdommer\tNOUN\t_\tDefinite=Ind|Gender=Masc|Number=Sing\t2\tappos\t_\tname=O",
@ -99,7 +99,7 @@ def test_cli_converters_conllu2json_subtokens():
"5\t.\t$.\tPUNCT\t_\t_\t4\tpunct\t_\tname=O",
]
input_data = "\n".join(lines)
converted_docs = conllu2docs(
converted_docs = conllu_to_docs(
input_data, n_sents=1, merge_subtokens=True, append_morphology=True
)
assert len(converted_docs) == 1
@ -133,11 +133,11 @@ def test_cli_converters_conllu2json_subtokens():
ent_offsets = [
(e[0], e[1], e[2]) for e in converted[0]["paragraphs"][0]["entities"]
]
biluo_tags = biluo_tags_from_offsets(converted_docs[0], ent_offsets, missing="O")
biluo_tags = offsets_to_biluo_tags(converted_docs[0], ent_offsets, missing="O")
assert biluo_tags == ["O", "U-PER", "O", "O"]
def test_cli_converters_iob2json():
def test_cli_converters_iob_to_docs():
lines = [
"I|O like|O London|I-GPE and|O New|B-GPE York|I-GPE City|I-GPE .|O",
"I|O like|O London|B-GPE and|O New|B-GPE York|I-GPE City|I-GPE .|O",
@ -145,7 +145,7 @@ def test_cli_converters_iob2json():
"I|PRP|O like|VBP|O London|NNP|B-GPE and|CC|O New|NNP|B-GPE York|NNP|I-GPE City|NNP|I-GPE .|.|O",
]
input_data = "\n".join(lines)
converted_docs = iob2docs(input_data, n_sents=10)
converted_docs = iob_to_docs(input_data, n_sents=10)
assert len(converted_docs) == 1
converted = docs_to_json(converted_docs)
assert converted["id"] == 0
@ -162,7 +162,7 @@ def test_cli_converters_iob2json():
assert ent.text in ["New York City", "London"]
def test_cli_converters_conll_ner2json():
def test_cli_converters_conll_ner_to_docs():
lines = [
"-DOCSTART- -X- O O",
"",
@ -212,7 +212,7 @@ def test_cli_converters_conll_ner2json():
".\t.\t_\tO",
]
input_data = "\n".join(lines)
converted_docs = conll_ner2docs(input_data, n_sents=10)
converted_docs = conll_ner_to_docs(input_data, n_sents=10)
assert len(converted_docs) == 1
converted = docs_to_json(converted_docs)
assert converted["id"] == 0

View File

@ -2,7 +2,7 @@ from numpy.testing import assert_almost_equal, assert_array_almost_equal
import pytest
from pytest import approx
from spacy.training import Example
from spacy.training.iob_utils import biluo_tags_from_offsets
from spacy.training.iob_utils import offsets_to_biluo_tags
from spacy.scorer import Scorer, ROCAUCScore
from spacy.scorer import _roc_auc_score, _roc_curve
from spacy.lang.en import English
@ -186,7 +186,7 @@ def test_ner_per_type(en_vocab):
words=input_.split(" "),
ents=[("CARDINAL", 0, 1), ("CARDINAL", 2, 3)],
)
entities = biluo_tags_from_offsets(doc, annot["entities"])
entities = offsets_to_biluo_tags(doc, annot["entities"])
example = Example.from_dict(doc, {"entities": entities})
# a hack for sentence boundaries
example.predicted[1].is_sent_start = False
@ -211,7 +211,7 @@ def test_ner_per_type(en_vocab):
words=input_.split(" "),
ents=[("ORG", 0, 1), ("GPE", 5, 6), ("ORG", 6, 7)],
)
entities = biluo_tags_from_offsets(doc, annot["entities"])
entities = offsets_to_biluo_tags(doc, annot["entities"])
example = Example.from_dict(doc, {"entities": entities})
# a hack for sentence boundaries
example.predicted[1].is_sent_start = False

View File

@ -1,9 +1,9 @@
import numpy
from spacy.training import biluo_tags_from_offsets, offsets_from_biluo_tags, Alignment
from spacy.training import spans_from_biluo_tags, iob_to_biluo
from spacy.training import offsets_to_biluo_tags, biluo_tags_to_offsets, Alignment
from spacy.training import biluo_tags_to_spans, iob_to_biluo
from spacy.training import Corpus, docs_to_json
from spacy.training.example import Example
from spacy.training.converters import json2docs
from spacy.training.converters import json_to_docs
from spacy.training.augment import make_orth_variants_example
from spacy.lang.en import English
from spacy.tokens import Doc, DocBin
@ -69,7 +69,7 @@ def test_gold_biluo_U(en_vocab):
spaces = [True, True, True, False, True]
doc = Doc(en_vocab, words=words, spaces=spaces)
entities = [(len("I flew to "), len("I flew to London"), "LOC")]
tags = biluo_tags_from_offsets(doc, entities)
tags = offsets_to_biluo_tags(doc, entities)
assert tags == ["O", "O", "O", "U-LOC", "O"]
@ -78,7 +78,7 @@ def test_gold_biluo_BL(en_vocab):
spaces = [True, True, True, True, False, True]
doc = Doc(en_vocab, words=words, spaces=spaces)
entities = [(len("I flew to "), len("I flew to San Francisco"), "LOC")]
tags = biluo_tags_from_offsets(doc, entities)
tags = offsets_to_biluo_tags(doc, entities)
assert tags == ["O", "O", "O", "B-LOC", "L-LOC", "O"]
@ -87,7 +87,7 @@ def test_gold_biluo_BIL(en_vocab):
spaces = [True, True, True, True, True, False, True]
doc = Doc(en_vocab, words=words, spaces=spaces)
entities = [(len("I flew to "), len("I flew to San Francisco Valley"), "LOC")]
tags = biluo_tags_from_offsets(doc, entities)
tags = offsets_to_biluo_tags(doc, entities)
assert tags == ["O", "O", "O", "B-LOC", "I-LOC", "L-LOC", "O"]
@ -100,7 +100,7 @@ def test_gold_biluo_overlap(en_vocab):
(len("I flew to "), len("I flew to San Francisco"), "LOC"),
]
with pytest.raises(ValueError):
biluo_tags_from_offsets(doc, entities)
offsets_to_biluo_tags(doc, entities)
def test_gold_biluo_misalign(en_vocab):
@ -109,7 +109,7 @@ def test_gold_biluo_misalign(en_vocab):
doc = Doc(en_vocab, words=words, spaces=spaces)
entities = [(len("I flew to "), len("I flew to San Francisco Valley"), "LOC")]
with pytest.warns(UserWarning):
tags = biluo_tags_from_offsets(doc, entities)
tags = offsets_to_biluo_tags(doc, entities)
assert tags == ["O", "O", "O", "-", "-", "-"]
@ -155,7 +155,7 @@ def test_example_from_dict_some_ner(en_vocab):
@pytest.mark.filterwarnings("ignore::UserWarning")
def test_json2docs_no_ner(en_vocab):
def test_json_to_docs_no_ner(en_vocab):
data = [
{
"id": 1,
@ -191,7 +191,7 @@ def test_json2docs_no_ner(en_vocab):
],
}
]
docs = json2docs(data)
docs = json_to_docs(data)
assert len(docs) == 1
for doc in docs:
assert not doc.has_annotation("ENT_IOB")
@ -358,9 +358,9 @@ def test_roundtrip_offsets_biluo_conversion(en_tokenizer):
biluo_tags = ["O", "O", "O", "B-LOC", "L-LOC", "O", "U-GPE", "O"]
offsets = [(10, 24, "LOC"), (29, 35, "GPE")]
doc = en_tokenizer(text)
biluo_tags_converted = biluo_tags_from_offsets(doc, offsets)
biluo_tags_converted = offsets_to_biluo_tags(doc, offsets)
assert biluo_tags_converted == biluo_tags
offsets_converted = offsets_from_biluo_tags(doc, biluo_tags)
offsets_converted = biluo_tags_to_offsets(doc, biluo_tags)
offsets_converted = [ent for ent in offsets if ent[2]]
assert offsets_converted == offsets
@ -368,7 +368,7 @@ def test_roundtrip_offsets_biluo_conversion(en_tokenizer):
def test_biluo_spans(en_tokenizer):
doc = en_tokenizer("I flew to Silicon Valley via London.")
biluo_tags = ["O", "O", "O", "B-LOC", "L-LOC", "O", "U-GPE", "O"]
spans = spans_from_biluo_tags(doc, biluo_tags)
spans = biluo_tags_to_spans(doc, biluo_tags)
spans = [span for span in spans if span.label_]
assert len(spans) == 2
assert spans[0].text == "Silicon Valley"

View File

@ -2,8 +2,8 @@ from .corpus import Corpus # noqa: F401
from .example import Example, validate_examples # noqa: F401
from .align import Alignment # noqa: F401
from .iob_utils import iob_to_biluo, biluo_to_iob # noqa: F401
from .iob_utils import biluo_tags_from_offsets, offsets_from_biluo_tags # noqa: F401
from .iob_utils import spans_from_biluo_tags, tags_to_entities # noqa: F401
from .iob_utils import offsets_to_biluo_tags, biluo_tags_to_offsets # noqa: F401
from .iob_utils import biluo_tags_to_spans, tags_to_entities # noqa: F401
from .gold_io import docs_to_json, read_json_file # noqa: F401
from .batchers import minibatch_by_padded_size, minibatch_by_words # noqa: F401
from .loggers import console_logger, wandb_logger # noqa: F401

View File

@ -1,4 +1,4 @@
from .iob2docs import iob2docs # noqa: F401
from .conll_ner2docs import conll_ner2docs # noqa: F401
from .json2docs import json2docs # noqa: F401
from .conllu2docs import conllu2docs # noqa: F401
from .iob_to_docs import iob_to_docs # noqa: F401
from .conll_ner_to_docs import conll_ner_to_docs # noqa: F401
from .json_to_docs import json_to_docs # noqa: F401
from .conllu_to_docs import conllu_to_docs # noqa: F401

View File

@ -7,7 +7,7 @@ from ...tokens import Doc, Span
from ...util import load_model
def conll_ner2docs(
def conll_ner_to_docs(
input_data, n_sents=10, seg_sents=False, model=None, no_print=False, **kwargs
):
"""

View File

@ -1,13 +1,13 @@
import re
from .conll_ner2docs import n_sents_info
from ...training import iob_to_biluo, spans_from_biluo_tags
from .conll_ner_to_docs import n_sents_info
from ...training import iob_to_biluo, biluo_tags_to_spans
from ...tokens import Doc, Token, Span
from ...vocab import Vocab
from wasabi import Printer
def conllu2docs(
def conllu_to_docs(
input_data,
n_sents=10,
append_morphology=False,
@ -78,7 +78,7 @@ def read_conllx(
if lines:
while lines[0].startswith("#"):
lines.pop(0)
doc = doc_from_conllu_sentence(
doc = conllu_sentence_to_doc(
vocab,
lines,
ner_tag_pattern,
@ -128,7 +128,7 @@ def get_entities(lines, tag_pattern, ner_map=None):
return iob_to_biluo(iob)
def doc_from_conllu_sentence(
def conllu_sentence_to_doc(
vocab,
lines,
ner_tag_pattern,
@ -215,7 +215,7 @@ def doc_from_conllu_sentence(
doc[i]._.merged_lemma = lemmas[i]
doc[i]._.merged_spaceafter = spaces[i]
ents = get_entities(lines, ner_tag_pattern, ner_map)
doc.ents = spans_from_biluo_tags(doc, ents)
doc.ents = biluo_tags_to_spans(doc, ents)
if merge_subtokens:
doc = merge_conllu_subtokens(lines, doc)

View File

@ -1,13 +1,13 @@
from wasabi import Printer
from .conll_ner2docs import n_sents_info
from .conll_ner_to_docs import n_sents_info
from ...vocab import Vocab
from ...training import iob_to_biluo, tags_to_entities
from ...tokens import Doc, Span
from ...util import minibatch
def iob2docs(input_data, n_sents=10, no_print=False, *args, **kwargs):
def iob_to_docs(input_data, n_sents=10, no_print=False, *args, **kwargs):
"""
Convert IOB files with one sentence per line and tags separated with '|'
into Doc objects so they can be saved. IOB and IOB2 are accepted.

View File

@ -1,12 +1,12 @@
import srsly
from ..gold_io import json_iterate, json_to_annotations
from ..example import annotations2doc
from ..example import annotations_to_doc
from ..example import _fix_legacy_dict_data, _parse_example_dict_data
from ...util import load_model
from ...lang.xx import MultiLanguage
def json2docs(input_data, model=None, **kwargs):
def json_to_docs(input_data, model=None, **kwargs):
nlp = load_model(model) if model is not None else MultiLanguage()
if not isinstance(input_data, bytes):
if not isinstance(input_data, str):
@ -17,6 +17,6 @@ def json2docs(input_data, model=None, **kwargs):
for json_para in json_to_annotations(json_doc):
example_dict = _fix_legacy_dict_data(json_para)
tok_dict, doc_dict = _parse_example_dict_data(example_dict)
doc = annotations2doc(nlp.vocab, tok_dict, doc_dict)
doc = annotations_to_doc(nlp.vocab, tok_dict, doc_dict)
docs.append(doc)
return docs

View File

@ -7,13 +7,13 @@ from ..tokens.span cimport Span
from ..tokens.span import Span
from ..attrs import IDS
from .align import Alignment
from .iob_utils import biluo_to_iob, biluo_tags_from_offsets, biluo_tags_from_doc
from .iob_utils import spans_from_biluo_tags
from .iob_utils import biluo_to_iob, offsets_to_biluo_tags, doc_to_biluo_tags
from .iob_utils import biluo_tags_to_spans
from ..errors import Errors, Warnings
from ..pipeline._parser_internals import nonproj
cpdef Doc annotations2doc(vocab, tok_annot, doc_annot):
cpdef Doc annotations_to_doc(vocab, tok_annot, doc_annot):
""" Create a Doc from dictionaries with token and doc annotations. """
attrs, array = _annot2array(vocab, tok_annot, doc_annot)
output = Doc(vocab, words=tok_annot["ORTH"], spaces=tok_annot["SPACY"])
@ -92,7 +92,7 @@ cdef class Example:
tok_dict["SPACY"] = [tok.whitespace_ for tok in predicted]
return Example(
predicted,
annotations2doc(predicted.vocab, tok_dict, doc_dict)
annotations_to_doc(predicted.vocab, tok_dict, doc_dict)
)
@property
@ -176,7 +176,7 @@ cdef class Example:
return [None] * len(self.x) # should this be 'missing' instead of 'None' ?
x_ents = self.get_aligned_spans_y2x(self.y.ents)
# Default to 'None' for missing values
x_tags = biluo_tags_from_offsets(
x_tags = offsets_to_biluo_tags(
self.x,
[(e.start_char, e.end_char, e.label_) for e in x_ents],
missing=None
@ -195,7 +195,7 @@ cdef class Example:
return {
"doc_annotation": {
"cats": dict(self.reference.cats),
"entities": biluo_tags_from_doc(self.reference),
"entities": doc_to_biluo_tags(self.reference),
"links": self._links_to_dict()
},
"token_annotation": {
@ -295,12 +295,12 @@ def _add_entities_to_doc(doc, ner_data):
elif isinstance(ner_data[0], tuple):
return _add_entities_to_doc(
doc,
biluo_tags_from_offsets(doc, ner_data)
offsets_to_biluo_tags(doc, ner_data)
)
elif isinstance(ner_data[0], str) or ner_data[0] is None:
return _add_entities_to_doc(
doc,
spans_from_biluo_tags(doc, ner_data)
biluo_tags_to_spans(doc, ner_data)
)
elif isinstance(ner_data[0], Span):
# Ugh, this is super messy. Really hard to set O entities
@ -388,7 +388,7 @@ def _parse_ner_tags(biluo_or_offsets, vocab, words, spaces):
# This is annoying but to convert the offsets we need a Doc
# that has the target tokenization.
reference = Doc(vocab, words=words, spaces=spaces)
biluo = biluo_tags_from_offsets(reference, biluo_or_offsets)
biluo = offsets_to_biluo_tags(reference, biluo_or_offsets)
else:
biluo = biluo_or_offsets
ent_iobs = []

View File

@ -3,7 +3,7 @@ import srsly
from .. import util
from ..errors import Warnings
from ..tokens import Doc
from .iob_utils import biluo_tags_from_offsets, tags_to_entities
from .iob_utils import offsets_to_biluo_tags, tags_to_entities
import json
@ -32,7 +32,7 @@ def docs_to_json(docs, doc_id=0, ner_missing_tag="O"):
if ent.kb_id_:
link_dict = {(ent.start_char, ent.end_char): {ent.kb_id_: 1.0}}
json_para["links"].append(link_dict)
biluo_tags = biluo_tags_from_offsets(doc, json_para["entities"], missing=ner_missing_tag)
biluo_tags = offsets_to_biluo_tags(doc, json_para["entities"], missing=ner_missing_tag)
attrs = ("TAG", "POS", "MORPH", "LEMMA", "DEP", "ENT_IOB")
include_annotation = {attr: doc.has_annotation(attr) for attr in attrs}
for j, sent in enumerate(doc.sents):

View File

@ -51,7 +51,11 @@ def _consume_ent(tags):
def biluo_tags_from_doc(doc, missing="O"):
return biluo_tags_from_offsets(
return doc_to_biluo_tags(doc, missing)
def doc_to_biluo_tags(doc, missing="O"):
return offsets_to_biluo_tags(
doc,
[(ent.start_char, ent.end_char, ent.label_) for ent in doc.ents],
missing=missing,
@ -59,6 +63,10 @@ def biluo_tags_from_doc(doc, missing="O"):
def biluo_tags_from_offsets(doc, entities, missing="O"):
return offsets_to_biluo_tags(doc, entities, missing)
def offsets_to_biluo_tags(doc, entities, missing="O"):
"""Encode labelled spans into per-token tags, using the
Begin/In/Last/Unit/Out scheme (BILUO).
@ -80,7 +88,7 @@ def biluo_tags_from_offsets(doc, entities, missing="O"):
>>> text = 'I like London.'
>>> entities = [(len('I like '), len('I like London'), 'LOC')]
>>> doc = nlp.tokenizer(text)
>>> tags = biluo_tags_from_offsets(doc, entities)
>>> tags = offsets_to_biluo_tags(doc, entities)
>>> assert tags == ["O", "O", 'U-LOC', "O"]
"""
# Ensure no overlapping entity labels exist
@ -144,6 +152,10 @@ def biluo_tags_from_offsets(doc, entities, missing="O"):
def spans_from_biluo_tags(doc, tags):
return biluo_tags_to_spans(doc, tags)
def biluo_tags_to_spans(doc, tags):
"""Encode per-token tags following the BILUO scheme into Span object, e.g.
to overwrite the doc.ents.
@ -162,6 +174,10 @@ def spans_from_biluo_tags(doc, tags):
def offsets_from_biluo_tags(doc, tags):
return biluo_tags_to_offsets(doc, tags)
def biluo_tags_to_offsets(doc, tags):
"""Encode per-token tags following the BILUO scheme into entity offsets.
doc (Doc): The document that the BILUO tags refer to.
@ -172,7 +188,7 @@ def offsets_from_biluo_tags(doc, tags):
`end` will be character-offset integers denoting the slice into the
original string.
"""
spans = spans_from_biluo_tags(doc, tags)
spans = biluo_tags_to_spans(doc, tags)
return [(span.start_char, span.end_char, span.label_) for span in spans]

View File

@ -275,7 +275,7 @@ $ python -m spacy convert ./data.json ./output.spacy
> entity label, prefixed by the BILUO marker. For example `"B-ORG"` describes
> the first token of a multi-token `ORG` entity and `"U-PERSON"` a single token
> representing a `PERSON` entity. The
> [`biluo_tags_from_offsets`](/api/top-level#biluo_tags_from_offsets) function
> [`offsets_to_biluo_tags`](/api/top-level#offsets_to_biluo_tags) function
> can help you convert entity offsets to the right format.
```python

View File

@ -619,7 +619,7 @@ sequences in the batch.
## Training data and alignment {#gold source="spacy/training"}
### training.biluo_tags_from_offsets {#biluo_tags_from_offsets tag="function"}
### training.offsets_to_biluo_tags {#offsets_to_biluo_tags tag="function"}
Encode labelled spans into per-token tags, using the
[BILUO scheme](/usage/linguistic-features#accessing-ner) (Begin, In, Last, Unit,
@ -635,11 +635,11 @@ single-token entity.
> #### Example
>
> ```python
> from spacy.training import biluo_tags_from_offsets
> from spacy.training import offsets_to_biluo_tags
>
> doc = nlp("I like London.")
> entities = [(7, 13, "LOC")]
> tags = biluo_tags_from_offsets(doc, entities)
> tags = offsets_to_biluo_tags(doc, entities)
> assert tags == ["O", "O", "U-LOC", "O"]
> ```
@ -649,7 +649,7 @@ single-token entity.
| `entities` | A sequence of `(start, end, label)` triples. `start` and `end` should be character-offset integers denoting the slice into the original string. ~~List[Tuple[int, int, Union[str, int]]]~~ |
| **RETURNS** | A list of strings, describing the [BILUO](/usage/linguistic-features#accessing-ner) tags. ~~List[str]~~ |
### training.offsets_from_biluo_tags {#offsets_from_biluo_tags tag="function"}
### training.biluo_tags_to_offsets {#biluo_tags_to_offsets tag="function"}
Encode per-token tags following the
[BILUO scheme](/usage/linguistic-features#accessing-ner) into entity offsets.
@ -657,11 +657,11 @@ Encode per-token tags following the
> #### Example
>
> ```python
> from spacy.training import offsets_from_biluo_tags
> from spacy.training import biluo_tags_to_offsets
>
> doc = nlp("I like London.")
> tags = ["O", "O", "U-LOC", "O"]
> entities = offsets_from_biluo_tags(doc, tags)
> entities = biluo_tags_to_offsets(doc, tags)
> assert entities == [(7, 13, "LOC")]
> ```
@ -671,7 +671,7 @@ Encode per-token tags following the
| `entities` | A sequence of [BILUO](/usage/linguistic-features#accessing-ner) tags with each tag describing one token. Each tag string will be of the form of either `""`, `"O"` or `"{action}-{label}"`, where action is one of `"B"`, `"I"`, `"L"`, `"U"`. ~~List[str]~~ |
| **RETURNS** | A sequence of `(start, end, label)` triples. `start` and `end` will be character-offset integers denoting the slice into the original string. ~~List[Tuple[int, int, str]]~~ |
### training.spans_from_biluo_tags {#spans_from_biluo_tags tag="function" new="2.1"}
### training.biluo_tags_to_spans {#biluo_tags_to_spans tag="function" new="2.1"}
Encode per-token tags following the
[BILUO scheme](/usage/linguistic-features#accessing-ner) into
@ -681,11 +681,11 @@ token-based tags, e.g. to overwrite the `doc.ents`.
> #### Example
>
> ```python
> from spacy.training import spans_from_biluo_tags
> from spacy.training import biluo_tags_to_spans
>
> doc = nlp("I like London.")
> tags = ["O", "O", "U-LOC", "O"]
> doc.ents = spans_from_biluo_tags(doc, tags)
> doc.ents = biluo_tags_to_spans(doc, tags)
> ```
| Name | Description |

View File

@ -1501,7 +1501,7 @@ add those entities to the `doc.ents`, you can wrap it in a custom pipeline
component function and pass it the token texts from the `Doc` object received by
the component.
The [`training.spans_from_biluo_tags`](/api/top-level#spans_from_biluo_tags) is very
The [`training.biluo_tags_to_spans`](/api/top-level#biluo_tags_to_spans) is very
helpful here, because it takes a `Doc` object and token-based BILUO tags and
returns a sequence of `Span` objects in the `Doc` with added labels. So all your
wrapper has to do is compute the entity spans and overwrite the `doc.ents`.
@ -1516,14 +1516,14 @@ wrapper has to do is compute the entity spans and overwrite the `doc.ents`.
```python
### {highlight="1,8-9"}
import your_custom_entity_recognizer
from spacy.training import offsets_from_biluo_tags
from spacy.training import biluo_tags_to_spans
from spacy.language import Language
@Language.component("custom_ner_wrapper")
def custom_ner_wrapper(doc):
words = [token.text for token in doc]
custom_entities = your_custom_entity_recognizer(words)
doc.ents = spans_from_biluo_tags(doc, custom_entities)
doc.ents = biluo_tags_to_spans(doc, custom_entities)
return doc
```

View File

@ -971,16 +971,17 @@ python -m spacy package ./output ./packages
#### Data utilities and gold module {#migrating-gold}
The `spacy.gold` module has been renamed to `spacy.training`. This mostly
The `spacy.gold` module has been renamed to `spacy.training` and the conversion
utilities now follow the naming format of `x_to_y`. This mostly
affects internals, but if you've been using the span offset conversion utilities
[`biluo_tags_from_offsets`](/api/top-level#biluo_tags_from_offsets),
[`offsets_from_biluo_tags`](/api/top-level#offsets_from_biluo_tags) or
[`spans_from_biluo_tags`](/api/top-level#spans_from_biluo_tags), you'll have to
change your imports:
[`offsets_to_biluo_tags`](/api/top-level#offsets_to_biluo_tags),
[`biluo_tags_to_offsets`](/api/top-level#biluo_tags_to_offsets) or
[`biluo_tags_to_spans`](/api/top-level#biluo_tags_to_spans), you'll have to
change your names and imports:
```diff
- from spacy.gold import biluo_tags_from_offsets, spans_from_biluo_tags
+ from spacy.training import biluo_tags_from_offsets, spans_from_biluo_tags
- from spacy.gold import biluo_tags_from_offsets, offsets_from_biluo_tags, spans_from_biluo_tags
+ from spacy.training import offsets_to_biluo_tags, biluo_tags_to_offsets, biluo_tags_to_spans
```
#### Migration notes for plugin maintainers {#migrating-plugins}