mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Merge branch 'develop' of https://github.com/explosion/spaCy into develop
This commit is contained in:
commit
74ee456374
|
@ -1,6 +1,6 @@
|
||||||
# fmt: off
|
# fmt: off
|
||||||
__title__ = "spacy-nightly"
|
__title__ = "spacy-nightly"
|
||||||
__version__ = "3.0.0a23"
|
__version__ = "3.0.0a24"
|
||||||
__release__ = True
|
__release__ = True
|
||||||
__download_url__ = "https://github.com/explosion/spacy-models/releases/download"
|
__download_url__ = "https://github.com/explosion/spacy-models/releases/download"
|
||||||
__compatibility__ = "https://raw.githubusercontent.com/explosion/spacy-models/master/compatibility.json"
|
__compatibility__ = "https://raw.githubusercontent.com/explosion/spacy-models/master/compatibility.json"
|
||||||
|
|
|
@ -88,7 +88,6 @@ def get_compatibility() -> dict:
|
||||||
|
|
||||||
|
|
||||||
def get_version(model: str, comp: dict) -> str:
|
def get_version(model: str, comp: dict) -> str:
|
||||||
model = get_base_version(model)
|
|
||||||
if model not in comp:
|
if model not in comp:
|
||||||
msg.fail(
|
msg.fail(
|
||||||
f"No compatible package found for '{model}' (spaCy v{about.__version__})",
|
f"No compatible package found for '{model}' (spaCy v{about.__version__})",
|
||||||
|
|
|
@ -91,7 +91,9 @@ def info_model(model: str, *, silent: bool = True) -> Dict[str, Any]:
|
||||||
meta["source"] = str(model_path.resolve())
|
meta["source"] = str(model_path.resolve())
|
||||||
else:
|
else:
|
||||||
meta["source"] = str(model_path)
|
meta["source"] = str(model_path)
|
||||||
return {k: v for k, v in meta.items() if k not in ("accuracy", "speed")}
|
return {
|
||||||
|
k: v for k, v in meta.items() if k not in ("accuracy", "performance", "speed")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_markdown(data: Dict[str, Any], title: Optional[str] = None) -> str:
|
def get_markdown(data: Dict[str, Any], title: Optional[str] = None) -> str:
|
||||||
|
|
|
@ -97,6 +97,7 @@ def train(
|
||||||
dev_corpus = dot_to_object(config, T_cfg["dev_corpus"])
|
dev_corpus = dot_to_object(config, T_cfg["dev_corpus"])
|
||||||
batcher = T_cfg["batcher"]
|
batcher = T_cfg["batcher"]
|
||||||
train_logger = T_cfg["logger"]
|
train_logger = T_cfg["logger"]
|
||||||
|
before_to_disk = create_before_to_disk_callback(T_cfg["before_to_disk"])
|
||||||
# Components that shouldn't be updated during training
|
# Components that shouldn't be updated during training
|
||||||
frozen_components = T_cfg["frozen_components"]
|
frozen_components = T_cfg["frozen_components"]
|
||||||
# Sourced components that require resume_training
|
# Sourced components that require resume_training
|
||||||
|
@ -167,6 +168,7 @@ def train(
|
||||||
with nlp.select_pipes(disable=frozen_components):
|
with nlp.select_pipes(disable=frozen_components):
|
||||||
update_meta(T_cfg, nlp, info)
|
update_meta(T_cfg, nlp, info)
|
||||||
with nlp.use_params(optimizer.averages):
|
with nlp.use_params(optimizer.averages):
|
||||||
|
nlp = before_to_disk(nlp)
|
||||||
nlp.to_disk(output_path / "model-best")
|
nlp.to_disk(output_path / "model-best")
|
||||||
progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False)
|
progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False)
|
||||||
progress.set_description(f"Epoch {info['epoch']}")
|
progress.set_description(f"Epoch {info['epoch']}")
|
||||||
|
@ -179,6 +181,7 @@ def train(
|
||||||
f"Aborting and saving the final best model. "
|
f"Aborting and saving the final best model. "
|
||||||
f"Encountered exception: {str(e)}"
|
f"Encountered exception: {str(e)}"
|
||||||
)
|
)
|
||||||
|
nlp = before_to_disk(nlp)
|
||||||
nlp.to_disk(output_path / "model-final")
|
nlp.to_disk(output_path / "model-final")
|
||||||
raise e
|
raise e
|
||||||
finally:
|
finally:
|
||||||
|
@ -233,6 +236,21 @@ def create_evaluation_callback(
|
||||||
return evaluate
|
return evaluate
|
||||||
|
|
||||||
|
|
||||||
|
def create_before_to_disk_callback(
|
||||||
|
callback: Optional[Callable[[Language], Language]]
|
||||||
|
) -> Callable[[Language], Language]:
|
||||||
|
def before_to_disk(nlp: Language) -> Language:
|
||||||
|
if not callback:
|
||||||
|
return nlp
|
||||||
|
modified_nlp = callback(nlp)
|
||||||
|
if not isinstance(modified_nlp, Language):
|
||||||
|
err = Errors.E914.format(name="before_to_disk", value=type(modified_nlp))
|
||||||
|
raise ValueError(err)
|
||||||
|
return modified_nlp
|
||||||
|
|
||||||
|
return before_to_disk
|
||||||
|
|
||||||
|
|
||||||
def train_while_improving(
|
def train_while_improving(
|
||||||
nlp: Language,
|
nlp: Language,
|
||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
|
|
|
@ -72,6 +72,8 @@ frozen_components = []
|
||||||
dev_corpus = "corpora.dev"
|
dev_corpus = "corpora.dev"
|
||||||
# Location in the config where the train corpus is defined
|
# Location in the config where the train corpus is defined
|
||||||
train_corpus = "corpora.train"
|
train_corpus = "corpora.train"
|
||||||
|
# Optional callback before nlp object is saved to disk after training
|
||||||
|
before_to_disk = null
|
||||||
|
|
||||||
[training.logger]
|
[training.logger]
|
||||||
@loggers = "spacy.ConsoleLogger.v1"
|
@loggers = "spacy.ConsoleLogger.v1"
|
||||||
|
|
|
@ -480,6 +480,9 @@ class Errors:
|
||||||
E201 = ("Span index out of range.")
|
E201 = ("Span index out of range.")
|
||||||
|
|
||||||
# TODO: fix numbering after merging develop into master
|
# TODO: fix numbering after merging develop into master
|
||||||
|
E914 = ("Executing {name} callback failed. Expected the function to "
|
||||||
|
"return the nlp object but got: {value}. Maybe you forgot to return "
|
||||||
|
"the modified object in your function?")
|
||||||
E915 = ("Can't use score '{name}' to calculate final weighted score. Expected "
|
E915 = ("Can't use score '{name}' to calculate final weighted score. Expected "
|
||||||
"float or int but got: {score_type}. To exclude the score from the "
|
"float or int but got: {score_type}. To exclude the score from the "
|
||||||
"final score, set its weight to null in the [training.score_weights] "
|
"final score, set its weight to null in the [training.score_weights] "
|
||||||
|
@ -693,6 +696,12 @@ class Errors:
|
||||||
E1009 = ("String for hash '{val}' not found in StringStore. Set the value "
|
E1009 = ("String for hash '{val}' not found in StringStore. Set the value "
|
||||||
"through token.morph_ instead or add the string to the "
|
"through token.morph_ instead or add the string to the "
|
||||||
"StringStore with `nlp.vocab.strings.add(string)`.")
|
"StringStore with `nlp.vocab.strings.add(string)`.")
|
||||||
|
E1010 = ("Unable to set entity information for token {i} which is included "
|
||||||
|
"in more than one span in entities, blocked, missing or outside.")
|
||||||
|
E1011 = ("Unsupported default '{default}' in doc.set_ents. Available "
|
||||||
|
"options: {modes}")
|
||||||
|
E1012 = ("Entity spans and blocked/missing/outside spans should be "
|
||||||
|
"provided to doc.set_ents as lists of `Span` objects.")
|
||||||
|
|
||||||
|
|
||||||
@add_codes
|
@add_codes
|
||||||
|
|
|
@ -182,8 +182,7 @@ class ModelMetaSchema(BaseModel):
|
||||||
sources: Optional[Union[List[StrictStr], List[Dict[str, str]]]] = Field(None, title="Training data sources")
|
sources: Optional[Union[List[StrictStr], List[Dict[str, str]]]] = Field(None, title="Training data sources")
|
||||||
vectors: Dict[str, Any] = Field({}, title="Included word vectors")
|
vectors: Dict[str, Any] = Field({}, title="Included word vectors")
|
||||||
labels: Dict[str, List[str]] = Field({}, title="Component labels, keyed by component name")
|
labels: Dict[str, List[str]] = Field({}, title="Component labels, keyed by component name")
|
||||||
accuracy: Dict[str, Union[float, Dict[str, float]]] = Field({}, title="Accuracy numbers")
|
performance: Dict[str, Union[float, Dict[str, float]]] = Field({}, title="Accuracy and speed numbers")
|
||||||
speed: Dict[str, Union[float, int]] = Field({}, title="Speed evaluation numbers")
|
|
||||||
spacy_git_version: StrictStr = Field("", title="Commit of spaCy version used")
|
spacy_git_version: StrictStr = Field("", title="Commit of spaCy version used")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
@ -217,6 +216,7 @@ class ConfigSchemaTraining(BaseModel):
|
||||||
optimizer: Optimizer = Field(..., title="The optimizer to use")
|
optimizer: Optimizer = Field(..., title="The optimizer to use")
|
||||||
logger: Logger = Field(..., title="The logger to track training progress")
|
logger: Logger = Field(..., title="The logger to track training progress")
|
||||||
frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training")
|
frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training")
|
||||||
|
before_to_disk: Optional[Callable[["Language"], "Language"]] = Field(..., title="Optional callback to modify nlp object after training, before it's saved to disk")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
|
|
@ -29,10 +29,10 @@ def test_doc_add_entities_set_ents_iob(en_vocab):
|
||||||
ner.begin_training(lambda: [_ner_example(ner)])
|
ner.begin_training(lambda: [_ner_example(ner)])
|
||||||
ner(doc)
|
ner(doc)
|
||||||
|
|
||||||
doc.ents = [(doc.vocab.strings["ANIMAL"], 3, 4)]
|
doc.ents = [("ANIMAL", 3, 4)]
|
||||||
assert [w.ent_iob_ for w in doc] == ["O", "O", "O", "B"]
|
assert [w.ent_iob_ for w in doc] == ["O", "O", "O", "B"]
|
||||||
|
|
||||||
doc.ents = [(doc.vocab.strings["WORD"], 0, 2)]
|
doc.ents = [("WORD", 0, 2)]
|
||||||
assert [w.ent_iob_ for w in doc] == ["B", "I", "O", "O"]
|
assert [w.ent_iob_ for w in doc] == ["B", "I", "O", "O"]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -152,7 +152,7 @@ def test_doc_api_set_ents(en_tokenizer):
|
||||||
assert len(tokens.ents) == 0
|
assert len(tokens.ents) == 0
|
||||||
tokens.ents = [(tokens.vocab.strings["PRODUCT"], 2, 4)]
|
tokens.ents = [(tokens.vocab.strings["PRODUCT"], 2, 4)]
|
||||||
assert len(list(tokens.ents)) == 1
|
assert len(list(tokens.ents)) == 1
|
||||||
assert [t.ent_iob for t in tokens] == [0, 0, 3, 1, 0, 0, 0, 0]
|
assert [t.ent_iob for t in tokens] == [2, 2, 3, 1, 2, 2, 2, 2]
|
||||||
assert tokens.ents[0].label_ == "PRODUCT"
|
assert tokens.ents[0].label_ == "PRODUCT"
|
||||||
assert tokens.ents[0].start == 2
|
assert tokens.ents[0].start == 2
|
||||||
assert tokens.ents[0].end == 4
|
assert tokens.ents[0].end == 4
|
||||||
|
@ -427,7 +427,7 @@ def test_has_annotation(en_vocab):
|
||||||
doc[0].lemma_ = "a"
|
doc[0].lemma_ = "a"
|
||||||
doc[0].dep_ = "dep"
|
doc[0].dep_ = "dep"
|
||||||
doc[0].head = doc[1]
|
doc[0].head = doc[1]
|
||||||
doc.ents = [Span(doc, 0, 1, label="HELLO")]
|
doc.set_ents([Span(doc, 0, 1, label="HELLO")], default="missing")
|
||||||
|
|
||||||
for attr in attrs:
|
for attr in attrs:
|
||||||
assert doc.has_annotation(attr)
|
assert doc.has_annotation(attr)
|
||||||
|
@ -457,7 +457,74 @@ def test_is_flags_deprecated(en_tokenizer):
|
||||||
doc.is_sentenced
|
doc.is_sentenced
|
||||||
|
|
||||||
|
|
||||||
def test_doc_set_ents():
|
def test_doc_set_ents(en_tokenizer):
|
||||||
|
# set ents
|
||||||
|
doc = en_tokenizer("a b c d e")
|
||||||
|
doc.set_ents([Span(doc, 0, 1, 10), Span(doc, 1, 3, 11)])
|
||||||
|
assert [t.ent_iob for t in doc] == [3, 3, 1, 2, 2]
|
||||||
|
assert [t.ent_type for t in doc] == [10, 11, 11, 0, 0]
|
||||||
|
|
||||||
|
# add ents, invalid IOB repaired
|
||||||
|
doc = en_tokenizer("a b c d e")
|
||||||
|
doc.set_ents([Span(doc, 0, 1, 10), Span(doc, 1, 3, 11)])
|
||||||
|
doc.set_ents([Span(doc, 0, 2, 12)], default="unmodified")
|
||||||
|
assert [t.ent_iob for t in doc] == [3, 1, 3, 2, 2]
|
||||||
|
assert [t.ent_type for t in doc] == [12, 12, 11, 0, 0]
|
||||||
|
|
||||||
|
# missing ents
|
||||||
|
doc = en_tokenizer("a b c d e")
|
||||||
|
doc.set_ents([Span(doc, 0, 1, 10), Span(doc, 1, 3, 11)], missing=[doc[4:5]])
|
||||||
|
assert [t.ent_iob for t in doc] == [3, 3, 1, 2, 0]
|
||||||
|
assert [t.ent_type for t in doc] == [10, 11, 11, 0, 0]
|
||||||
|
|
||||||
|
# outside ents
|
||||||
|
doc = en_tokenizer("a b c d e")
|
||||||
|
doc.set_ents(
|
||||||
|
[Span(doc, 0, 1, 10), Span(doc, 1, 3, 11)],
|
||||||
|
outside=[doc[4:5]],
|
||||||
|
default="missing",
|
||||||
|
)
|
||||||
|
assert [t.ent_iob for t in doc] == [3, 3, 1, 0, 2]
|
||||||
|
assert [t.ent_type for t in doc] == [10, 11, 11, 0, 0]
|
||||||
|
|
||||||
|
# blocked ents
|
||||||
|
doc = en_tokenizer("a b c d e")
|
||||||
|
doc.set_ents([], blocked=[doc[1:2], doc[3:5]], default="unmodified")
|
||||||
|
assert [t.ent_iob for t in doc] == [0, 3, 0, 3, 3]
|
||||||
|
assert [t.ent_type for t in doc] == [0, 0, 0, 0, 0]
|
||||||
|
assert doc.ents == tuple()
|
||||||
|
|
||||||
|
# invalid IOB repaired after blocked
|
||||||
|
doc.ents = [Span(doc, 3, 5, "ENT")]
|
||||||
|
assert [t.ent_iob for t in doc] == [2, 2, 2, 3, 1]
|
||||||
|
doc.set_ents([], blocked=[doc[3:4]], default="unmodified")
|
||||||
|
assert [t.ent_iob for t in doc] == [2, 2, 2, 3, 3]
|
||||||
|
|
||||||
|
# all types
|
||||||
|
doc = en_tokenizer("a b c d e")
|
||||||
|
doc.set_ents(
|
||||||
|
[Span(doc, 0, 1, 10)],
|
||||||
|
blocked=[doc[1:2]],
|
||||||
|
missing=[doc[2:3]],
|
||||||
|
outside=[doc[3:4]],
|
||||||
|
default="unmodified",
|
||||||
|
)
|
||||||
|
assert [t.ent_iob for t in doc] == [3, 3, 0, 2, 0]
|
||||||
|
assert [t.ent_type for t in doc] == [10, 0, 0, 0, 0]
|
||||||
|
|
||||||
|
doc = en_tokenizer("a b c d e")
|
||||||
|
# single span instead of a list
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
doc.set_ents([], missing=doc[1:2])
|
||||||
|
# invalid default mode
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
doc.set_ents([], missing=[doc[1:2]], default="none")
|
||||||
|
# conflicting/overlapping specifications
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
doc.set_ents([], missing=[doc[1:2]], outside=[doc[1:2]])
|
||||||
|
|
||||||
|
|
||||||
|
def test_doc_ents_setter():
|
||||||
"""Test that both strings and integers can be used to set entities in
|
"""Test that both strings and integers can be used to set entities in
|
||||||
tuple format via doc.ents."""
|
tuple format via doc.ents."""
|
||||||
words = ["a", "b", "c", "d", "e"]
|
words = ["a", "b", "c", "d", "e"]
|
||||||
|
|
|
@ -168,7 +168,7 @@ def test_accept_blocked_token():
|
||||||
ner2 = nlp2.create_pipe("ner", config=config)
|
ner2 = nlp2.create_pipe("ner", config=config)
|
||||||
|
|
||||||
# set "New York" to a blocked entity
|
# set "New York" to a blocked entity
|
||||||
doc2.ents = [(0, 3, 5)]
|
doc2.set_ents([], blocked=[doc2[3:5]], default="unmodified")
|
||||||
assert [token.ent_iob_ for token in doc2] == ["", "", "", "B", "B"]
|
assert [token.ent_iob_ for token in doc2] == ["", "", "", "B", "B"]
|
||||||
assert [token.ent_type_ for token in doc2] == ["", "", "", "", ""]
|
assert [token.ent_type_ for token in doc2] == ["", "", "", "", ""]
|
||||||
|
|
||||||
|
@ -358,5 +358,5 @@ class BlockerComponent1:
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
def __call__(self, doc):
|
def __call__(self, doc):
|
||||||
doc.ents = [(0, self.start, self.end)]
|
doc.set_ents([], blocked=[doc[self.start:self.end]], default="unmodified")
|
||||||
return doc
|
return doc
|
||||||
|
|
|
@ -7,6 +7,8 @@ from libc.stdint cimport int32_t, uint64_t
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
from enum import Enum
|
||||||
|
import itertools
|
||||||
import numpy
|
import numpy
|
||||||
import srsly
|
import srsly
|
||||||
from thinc.api import get_array_module
|
from thinc.api import get_array_module
|
||||||
|
@ -86,6 +88,17 @@ cdef attr_t get_token_attr_for_matcher(const TokenC* token, attr_id_t feat_name)
|
||||||
return get_token_attr(token, feat_name)
|
return get_token_attr(token, feat_name)
|
||||||
|
|
||||||
|
|
||||||
|
class SetEntsDefault(str, Enum):
|
||||||
|
blocked = "blocked"
|
||||||
|
missing = "missing"
|
||||||
|
outside = "outside"
|
||||||
|
unmodified = "unmodified"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def values(cls):
|
||||||
|
return list(cls.__members__.keys())
|
||||||
|
|
||||||
|
|
||||||
cdef class Doc:
|
cdef class Doc:
|
||||||
"""A sequence of Token objects. Access sentences and named entities, export
|
"""A sequence of Token objects. Access sentences and named entities, export
|
||||||
annotations to numpy arrays, losslessly serialize to compressed binary
|
annotations to numpy arrays, losslessly serialize to compressed binary
|
||||||
|
@ -660,50 +673,100 @@ cdef class Doc:
|
||||||
# TODO:
|
# TODO:
|
||||||
# 1. Test basic data-driven ORTH gazetteer
|
# 1. Test basic data-driven ORTH gazetteer
|
||||||
# 2. Test more nuanced date and currency regex
|
# 2. Test more nuanced date and currency regex
|
||||||
tokens_in_ents = {}
|
cdef attr_t entity_type, kb_id
|
||||||
cdef attr_t entity_type
|
cdef int ent_start, ent_end
|
||||||
cdef attr_t kb_id
|
ent_spans = []
|
||||||
cdef int ent_start, ent_end, token_index
|
|
||||||
for ent_info in ents:
|
for ent_info in ents:
|
||||||
entity_type_, kb_id, ent_start, ent_end = get_entity_info(ent_info)
|
entity_type_, kb_id, ent_start, ent_end = get_entity_info(ent_info)
|
||||||
if isinstance(entity_type_, str):
|
if isinstance(entity_type_, str):
|
||||||
self.vocab.strings.add(entity_type_)
|
self.vocab.strings.add(entity_type_)
|
||||||
entity_type = self.vocab.strings.as_int(entity_type_)
|
span = Span(self, ent_start, ent_end, label=entity_type_, kb_id=kb_id)
|
||||||
for token_index in range(ent_start, ent_end):
|
ent_spans.append(span)
|
||||||
if token_index in tokens_in_ents:
|
self.set_ents(ent_spans, default=SetEntsDefault.outside)
|
||||||
raise ValueError(Errors.E103.format(
|
|
||||||
span1=(tokens_in_ents[token_index][0],
|
def set_ents(self, entities, *, blocked=None, missing=None, outside=None, default=SetEntsDefault.outside):
|
||||||
tokens_in_ents[token_index][1],
|
"""Set entity annotation.
|
||||||
self.vocab.strings[tokens_in_ents[token_index][2]]),
|
|
||||||
span2=(ent_start, ent_end, self.vocab.strings[entity_type])))
|
entities (List[Span]): Spans with labels to set as entities.
|
||||||
tokens_in_ents[token_index] = (ent_start, ent_end, entity_type, kb_id)
|
blocked (Optional[List[Span]]): Spans to set as 'blocked' (never an
|
||||||
cdef int i
|
entity) for spacy's built-in NER component. Other components may
|
||||||
|
ignore this setting.
|
||||||
|
missing (Optional[List[Span]]): Spans with missing/unknown entity
|
||||||
|
information.
|
||||||
|
outside (Optional[List[Span]]): Spans outside of entities (O in IOB).
|
||||||
|
default (str): How to set entity annotation for tokens outside of any
|
||||||
|
provided spans. Options: "blocked", "missing", "outside" and
|
||||||
|
"unmodified" (preserve current state). Defaults to "outside".
|
||||||
|
"""
|
||||||
|
if default not in SetEntsDefault.values():
|
||||||
|
raise ValueError(Errors.E1011.format(default=default, modes=", ".join(SetEntsDefault)))
|
||||||
|
|
||||||
|
# Ignore spans with missing labels
|
||||||
|
entities = [ent for ent in entities if ent.label > 0]
|
||||||
|
|
||||||
|
if blocked is None:
|
||||||
|
blocked = tuple()
|
||||||
|
if missing is None:
|
||||||
|
missing = tuple()
|
||||||
|
if outside is None:
|
||||||
|
outside = tuple()
|
||||||
|
|
||||||
|
# Find all tokens covered by spans and check that none are overlapping
|
||||||
|
cdef int i
|
||||||
|
seen_tokens = set()
|
||||||
|
for span in itertools.chain.from_iterable([entities, blocked, missing, outside]):
|
||||||
|
if not isinstance(span, Span):
|
||||||
|
raise ValueError(Errors.E1012.format(span=span))
|
||||||
|
for i in range(span.start, span.end):
|
||||||
|
if i in seen_tokens:
|
||||||
|
raise ValueError(Errors.E1010.format(i=i))
|
||||||
|
seen_tokens.add(i)
|
||||||
|
|
||||||
|
# Set all specified entity information
|
||||||
|
for span in entities:
|
||||||
|
for i in range(span.start, span.end):
|
||||||
|
if i == span.start:
|
||||||
|
self.c[i].ent_iob = 3
|
||||||
|
else:
|
||||||
|
self.c[i].ent_iob = 1
|
||||||
|
self.c[i].ent_type = span.label
|
||||||
|
self.c[i].ent_kb_id = span.kb_id
|
||||||
|
for span in blocked:
|
||||||
|
for i in range(span.start, span.end):
|
||||||
|
self.c[i].ent_iob = 3
|
||||||
|
self.c[i].ent_type = 0
|
||||||
|
for span in missing:
|
||||||
|
for i in range(span.start, span.end):
|
||||||
|
self.c[i].ent_iob = 0
|
||||||
|
self.c[i].ent_type = 0
|
||||||
|
for span in outside:
|
||||||
|
for i in range(span.start, span.end):
|
||||||
|
self.c[i].ent_iob = 2
|
||||||
|
self.c[i].ent_type = 0
|
||||||
|
|
||||||
|
# Set tokens outside of all provided spans
|
||||||
|
if default != SetEntsDefault.unmodified:
|
||||||
for i in range(self.length):
|
for i in range(self.length):
|
||||||
# default values
|
if i not in seen_tokens:
|
||||||
entity_type = 0
|
self.c[i].ent_type = 0
|
||||||
kb_id = 0
|
if default == SetEntsDefault.outside:
|
||||||
|
self.c[i].ent_iob = 2
|
||||||
|
elif default == SetEntsDefault.missing:
|
||||||
|
self.c[i].ent_iob = 0
|
||||||
|
elif default == SetEntsDefault.blocked:
|
||||||
|
self.c[i].ent_iob = 3
|
||||||
|
|
||||||
# Set ent_iob to Missing (0) by default unless this token was nered before
|
# Fix any resulting inconsistent annotation
|
||||||
ent_iob = 0
|
for i in range(self.length - 1):
|
||||||
if self.c[i].ent_iob != 0:
|
# I must follow B or I: convert I to B
|
||||||
ent_iob = 2
|
if (self.c[i].ent_iob == 0 or self.c[i].ent_iob == 2) and \
|
||||||
|
self.c[i+1].ent_iob == 1:
|
||||||
# overwrite if the token was part of a specified entity
|
self.c[i+1].ent_iob = 3
|
||||||
if i in tokens_in_ents.keys():
|
# Change of type with BI or II: convert second I to B
|
||||||
ent_start, ent_end, entity_type, kb_id = tokens_in_ents[i]
|
if self.c[i].ent_type != self.c[i+1].ent_type and \
|
||||||
if entity_type is None or entity_type <= 0:
|
(self.c[i].ent_iob == 3 or self.c[i].ent_iob == 1) and \
|
||||||
# Blocking this token from being overwritten by downstream NER
|
self.c[i+1].ent_iob == 1:
|
||||||
ent_iob = 3
|
self.c[i+1].ent_iob = 3
|
||||||
elif ent_start == i:
|
|
||||||
# Marking the start of an entity
|
|
||||||
ent_iob = 3
|
|
||||||
else:
|
|
||||||
# Marking the inside of an entity
|
|
||||||
ent_iob = 1
|
|
||||||
|
|
||||||
self.c[i].ent_type = entity_type
|
|
||||||
self.c[i].ent_kb_id = kb_id
|
|
||||||
self.c[i].ent_iob = ent_iob
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def noun_chunks(self):
|
def noun_chunks(self):
|
||||||
|
|
|
@ -288,6 +288,7 @@ def _annot2array(vocab, tok_annot, doc_annot):
|
||||||
|
|
||||||
|
|
||||||
def _add_entities_to_doc(doc, ner_data):
|
def _add_entities_to_doc(doc, ner_data):
|
||||||
|
print(ner_data)
|
||||||
if ner_data is None:
|
if ner_data is None:
|
||||||
return
|
return
|
||||||
elif ner_data == []:
|
elif ner_data == []:
|
||||||
|
@ -303,9 +304,14 @@ def _add_entities_to_doc(doc, ner_data):
|
||||||
biluo_tags_to_spans(doc, ner_data)
|
biluo_tags_to_spans(doc, ner_data)
|
||||||
)
|
)
|
||||||
elif isinstance(ner_data[0], Span):
|
elif isinstance(ner_data[0], Span):
|
||||||
# Ugh, this is super messy. Really hard to set O entities
|
entities = []
|
||||||
doc.ents = ner_data
|
missing = []
|
||||||
doc.ents = [span for span in ner_data if span.label_]
|
for span in ner_data:
|
||||||
|
if span.label:
|
||||||
|
entities.append(span)
|
||||||
|
else:
|
||||||
|
missing.append(span)
|
||||||
|
doc.set_ents(entities, missing=missing)
|
||||||
else:
|
else:
|
||||||
raise ValueError(Errors.E973)
|
raise ValueError(Errors.E973)
|
||||||
|
|
||||||
|
|
|
@ -151,9 +151,10 @@ def biluo_tags_to_spans(doc: Doc, tags: Iterable[str]) -> List[Span]:
|
||||||
|
|
||||||
doc (Doc): The document that the BILUO tags refer to.
|
doc (Doc): The document that the BILUO tags refer to.
|
||||||
entities (iterable): A sequence of BILUO tags with each tag describing one
|
entities (iterable): A sequence of BILUO tags with each tag describing one
|
||||||
token. Each tags string will be of the form of either "", "O" or
|
token. Each tag string will be of the form of either "", "O" or
|
||||||
"{action}-{label}", where action is one of "B", "I", "L", "U".
|
"{action}-{label}", where action is one of "B", "I", "L", "U".
|
||||||
RETURNS (list): A sequence of Span objects.
|
RETURNS (list): A sequence of Span objects. Each token with a missing IOB
|
||||||
|
tag is returned as a Span with an empty label.
|
||||||
"""
|
"""
|
||||||
token_offsets = tags_to_entities(tags)
|
token_offsets = tags_to_entities(tags)
|
||||||
spans = []
|
spans = []
|
||||||
|
@ -186,22 +187,18 @@ def tags_to_entities(tags: Iterable[str]) -> List[Tuple[str, int, int]]:
|
||||||
entities = []
|
entities = []
|
||||||
start = None
|
start = None
|
||||||
for i, tag in enumerate(tags):
|
for i, tag in enumerate(tags):
|
||||||
if tag is None:
|
if tag is None or tag.startswith("-"):
|
||||||
continue
|
|
||||||
if tag.startswith("O"):
|
|
||||||
# TODO: We shouldn't be getting these malformed inputs. Fix this.
|
# TODO: We shouldn't be getting these malformed inputs. Fix this.
|
||||||
if start is not None:
|
if start is not None:
|
||||||
start = None
|
start = None
|
||||||
else:
|
else:
|
||||||
entities.append(("", i, i))
|
entities.append(("", i, i))
|
||||||
continue
|
elif tag.startswith("O"):
|
||||||
elif tag == "-":
|
pass
|
||||||
continue
|
|
||||||
elif tag.startswith("I"):
|
elif tag.startswith("I"):
|
||||||
if start is None:
|
if start is None:
|
||||||
raise ValueError(Errors.E067.format(start="I", tags=tags[: i + 1]))
|
raise ValueError(Errors.E067.format(start="I", tags=tags[: i + 1]))
|
||||||
continue
|
elif tag.startswith("U"):
|
||||||
if tag.startswith("U"):
|
|
||||||
entities.append((tag[2:], i, i))
|
entities.append((tag[2:], i, i))
|
||||||
elif tag.startswith("B"):
|
elif tag.startswith("B"):
|
||||||
start = i
|
start = i
|
||||||
|
|
|
@ -180,26 +180,27 @@ single corpus once and then divide it up into `train` and `dev` partitions.
|
||||||
This section defines settings and controls for the training and evaluation
|
This section defines settings and controls for the training and evaluation
|
||||||
process that are used when you run [`spacy train`](/api/cli#train).
|
process that are used when you run [`spacy train`](/api/cli#train).
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
| --------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `accumulate_gradient` | Whether to divide the batch up into substeps. Defaults to `1`. ~~int~~ |
|
| `accumulate_gradient` | Whether to divide the batch up into substeps. Defaults to `1`. ~~int~~ |
|
||||||
| `batcher` | Callable that takes an iterator of [`Doc`](/api/doc) objects and yields batches of `Doc`s. Defaults to [`batch_by_words`](/api/top-level#batch_by_words). ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ |
|
| `batcher` | Callable that takes an iterator of [`Doc`](/api/doc) objects and yields batches of `Doc`s. Defaults to [`batch_by_words`](/api/top-level#batch_by_words). ~~Callable[[Iterator[Doc], Iterator[List[Doc]]]]~~ |
|
||||||
| `dev_corpus` | Dot notation of the config location defining the dev corpus. Defaults to `corpora.dev`. ~~str~~ |
|
| `before_to_disk` | Optional callback to modify `nlp` object right before it is saved to disk during and after training. Can be used to remove or reset config values or disable components. Defaults to `null`. ~~Optional[Callable[[Language], Language]]~~ |
|
||||||
| `dropout` | The dropout rate. Defaults to `0.1`. ~~float~~ |
|
| `dev_corpus` | Dot notation of the config location defining the dev corpus. Defaults to `corpora.dev`. ~~str~~ |
|
||||||
| `eval_frequency` | How often to evaluate during training (steps). Defaults to `200`. ~~int~~ |
|
| `dropout` | The dropout rate. Defaults to `0.1`. ~~float~~ |
|
||||||
| `frozen_components` | Pipeline component names that are "frozen" and shouldn't be updated during training. See [here](/usage/training#config-components) for details. Defaults to `[]`. ~~List[str]~~ |
|
| `eval_frequency` | How often to evaluate during training (steps). Defaults to `200`. ~~int~~ |
|
||||||
| `gpu_allocator` | Library for cupy to route GPU memory allocation to. Can be `"pytorch"` or `"tensorflow"`. Defaults to variable `${system.gpu_allocator}`. ~~str~~ |
|
| `frozen_components` | Pipeline component names that are "frozen" and shouldn't be updated during training. See [here](/usage/training#config-components) for details. Defaults to `[]`. ~~List[str]~~ |
|
||||||
| `init_tok2vec` | Optional path to pretrained tok2vec weights created with [`spacy pretrain`](/api/cli#pretrain). Defaults to variable `${paths.init_tok2vec}`. ~~Optional[str]~~ |
|
| `gpu_allocator` | Library for cupy to route GPU memory allocation to. Can be `"pytorch"` or `"tensorflow"`. Defaults to variable `${system.gpu_allocator}`. ~~str~~ |
|
||||||
| `lookups` | Additional lexeme and vocab data from [`spacy-lookups-data`](https://github.com/explosion/spacy-lookups-data). Defaults to `null`. ~~Optional[Lookups]~~ |
|
| `init_tok2vec` | Optional path to pretrained tok2vec weights created with [`spacy pretrain`](/api/cli#pretrain). Defaults to variable `${paths.init_tok2vec}`. ~~Optional[str]~~ |
|
||||||
| `max_epochs` | Maximum number of epochs to train for. Defaults to `0`. ~~int~~ |
|
| `lookups` | Additional lexeme and vocab data from [`spacy-lookups-data`](https://github.com/explosion/spacy-lookups-data). Defaults to `null`. ~~Optional[Lookups]~~ |
|
||||||
| `max_steps` | Maximum number of update steps to train for. Defaults to `20000`. ~~int~~ |
|
| `max_epochs` | Maximum number of epochs to train for. Defaults to `0`. ~~int~~ |
|
||||||
| `optimizer` | The optimizer. The learning rate schedule and other settings can be configured as part of the optimizer. Defaults to [`Adam`](https://thinc.ai/docs/api-optimizers#adam). ~~Optimizer~~ |
|
| `max_steps` | Maximum number of update steps to train for. Defaults to `20000`. ~~int~~ |
|
||||||
| `patience` | How many steps to continue without improvement in evaluation score. Defaults to `1600`. ~~int~~ |
|
| `optimizer` | The optimizer. The learning rate schedule and other settings can be configured as part of the optimizer. Defaults to [`Adam`](https://thinc.ai/docs/api-optimizers#adam). ~~Optimizer~~ |
|
||||||
| `raw_text` | Optional path to a jsonl file with unlabelled text documents for a [rehearsal](/api/language#rehearse) step. Defaults to variable `${paths.raw}`. ~~Optional[str]~~ |
|
| `patience` | How many steps to continue without improvement in evaluation score. Defaults to `1600`. ~~int~~ |
|
||||||
| `score_weights` | Score names shown in metrics mapped to their weight towards the final weighted score. See [here](/usage/training#metrics) for details. Defaults to `{}`. ~~Dict[str, float]~~ |
|
| `raw_text` | Optional path to a jsonl file with unlabelled text documents for a [rehearsal](/api/language#rehearse) step. Defaults to variable `${paths.raw}`. ~~Optional[str]~~ |
|
||||||
| `seed` | The random seed. Defaults to variable `${system.seed}`. ~~int~~ |
|
| `score_weights` | Score names shown in metrics mapped to their weight towards the final weighted score. See [here](/usage/training#metrics) for details. Defaults to `{}`. ~~Dict[str, float]~~ |
|
||||||
| `train_corpus` | Dot notation of the config location defining the train corpus. Defaults to `corpora.train`. ~~str~~ |
|
| `seed` | The random seed. Defaults to variable `${system.seed}`. ~~int~~ |
|
||||||
| `vectors` | Name or path of pipeline containing pretrained word vectors to use, e.g. created with [`init vocab`](/api/cli#init-vocab). Defaults to `null`. ~~Optional[str]~~ |
|
| `train_corpus` | Dot notation of the config location defining the train corpus. Defaults to `corpora.train`. ~~str~~ |
|
||||||
|
| `vectors` | Name or path of pipeline containing pretrained word vectors to use, e.g. created with [`init vocab`](/api/cli#init-vocab). Defaults to `null`. ~~Optional[str]~~ |
|
||||||
|
|
||||||
### pretraining {#config-pretraining tag="section,optional"}
|
### pretraining {#config-pretraining tag="section,optional"}
|
||||||
|
|
||||||
|
@ -275,8 +276,8 @@ $ python -m spacy convert ./data.json ./output.spacy
|
||||||
> entity label, prefixed by the BILUO marker. For example `"B-ORG"` describes
|
> entity label, prefixed by the BILUO marker. For example `"B-ORG"` describes
|
||||||
> the first token of a multi-token `ORG` entity and `"U-PERSON"` a single token
|
> the first token of a multi-token `ORG` entity and `"U-PERSON"` a single token
|
||||||
> representing a `PERSON` entity. The
|
> representing a `PERSON` entity. The
|
||||||
> [`offsets_to_biluo_tags`](/api/top-level#offsets_to_biluo_tags) function
|
> [`offsets_to_biluo_tags`](/api/top-level#offsets_to_biluo_tags) function can
|
||||||
> can help you convert entity offsets to the right format.
|
> help you convert entity offsets to the right format.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
### Example structure
|
### Example structure
|
||||||
|
@ -518,7 +519,7 @@ source of truth** used for loading a pipeline.
|
||||||
> "ner": ["PERSON", "ORG", "PRODUCT"],
|
> "ner": ["PERSON", "ORG", "PRODUCT"],
|
||||||
> "textcat": ["POSITIVE", "NEGATIVE"]
|
> "textcat": ["POSITIVE", "NEGATIVE"]
|
||||||
> },
|
> },
|
||||||
> "accuracy": {
|
> "performance": {
|
||||||
> "ents_f": 82.7300930714,
|
> "ents_f": 82.7300930714,
|
||||||
> "ents_p": 82.135523614,
|
> "ents_p": 82.135523614,
|
||||||
> "ents_r": 83.3333333333,
|
> "ents_r": 83.3333333333,
|
||||||
|
|
|
@ -219,6 +219,30 @@ alignment mode `"strict".
|
||||||
| `alignment_mode` | How character indices snap to token boundaries. Options: `"strict"` (no snapping), `"contract"` (span of all tokens completely within the character span), `"expand"` (span of all tokens at least partially covered by the character span). Defaults to `"strict"`. ~~str~~ |
|
| `alignment_mode` | How character indices snap to token boundaries. Options: `"strict"` (no snapping), `"contract"` (span of all tokens completely within the character span), `"expand"` (span of all tokens at least partially covered by the character span). Defaults to `"strict"`. ~~str~~ |
|
||||||
| **RETURNS** | The newly constructed object or `None`. ~~Optional[Span]~~ |
|
| **RETURNS** | The newly constructed object or `None`. ~~Optional[Span]~~ |
|
||||||
|
|
||||||
|
## Doc.set_ents {#ents tag="method" new="3"}
|
||||||
|
|
||||||
|
Set the named entities in the document.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> from spacy.tokens import Span
|
||||||
|
> doc = nlp("Mr. Best flew to New York on Saturday morning.")
|
||||||
|
> doc.set_ents([Span(doc, 0, 2, "PERSON")])
|
||||||
|
> ents = list(doc.ents)
|
||||||
|
> assert ents[0].label_ == "PERSON"
|
||||||
|
> assert ents[0].text == "Mr. Best"
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| -------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| entities | Spans with labels to set as entities. ~~List[Span]~~ |
|
||||||
|
| _keyword-only_ | |
|
||||||
|
| blocked | Spans to set as "blocked" (never an entity) for spacy's built-in NER component. Other components may ignore this setting. ~~Optional[List[Span]]~~ |
|
||||||
|
| missing | Spans with missing/unknown entity information. ~~Optional[List[Span]]~~ |
|
||||||
|
| outside | Spans outside of entities (O in IOB). ~~Optional[List[Span]]~~ |
|
||||||
|
| default | How to set entity annotation for tokens outside of any provided spans. Options: "blocked", "missing", "outside" and "unmodified" (preserve current state). Defaults to "outside". ~~str~~ |
|
||||||
|
|
||||||
## Doc.similarity {#similarity tag="method" model="vectors"}
|
## Doc.similarity {#similarity tag="method" model="vectors"}
|
||||||
|
|
||||||
Make a semantic similarity estimate. The default estimate is cosine similarity
|
Make a semantic similarity estimate. The default estimate is cosine similarity
|
||||||
|
@ -542,7 +566,6 @@ objects, if the entity recognizer has been applied.
|
||||||
> ```python
|
> ```python
|
||||||
> doc = nlp("Mr. Best flew to New York on Saturday morning.")
|
> doc = nlp("Mr. Best flew to New York on Saturday morning.")
|
||||||
> ents = list(doc.ents)
|
> ents = list(doc.ents)
|
||||||
> assert ents[0].label == 346
|
|
||||||
> assert ents[0].label_ == "PERSON"
|
> assert ents[0].label_ == "PERSON"
|
||||||
> assert ents[0].text == "Mr. Best"
|
> assert ents[0].text == "Mr. Best"
|
||||||
> ```
|
> ```
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
import { Help } from 'components/typography'; import Link from 'components/link'
|
import { Help } from 'components/typography'; import Link from 'components/link'
|
||||||
|
|
||||||
<!-- TODO: update, add project template -->
|
<!-- TODO: update numbers -->
|
||||||
|
|
||||||
<figure>
|
<figure>
|
||||||
|
|
||||||
| System | Parser | Tagger | NER | WPS<br />CPU <Help>words per second on CPU, higher is better</Help> | WPS<br/>GPU <Help>words per second on GPU, higher is better</Help> |
|
| Pipeline | Parser | Tagger | NER | WPS<br />CPU <Help>words per second on CPU, higher is better</Help> | WPS<br/>GPU <Help>words per second on GPU, higher is better</Help> |
|
||||||
| ---------------------------------------------------------- | -----: | -----: | ---: | ------------------------------------------------------------------: | -----------------------------------------------------------------: |
|
| ---------------------------------------------------------- | -----: | -----: | ---: | ------------------------------------------------------------------: | -----------------------------------------------------------------: |
|
||||||
| [`en_core_web_trf`](/models/en#en_core_web_trf) (spaCy v3) | | | | | 6k |
|
| [`en_core_web_trf`](/models/en#en_core_web_trf) (spaCy v3) | | | | | 6k |
|
||||||
| [`en_core_web_lg`](/models/en#en_core_web_lg) (spaCy v3) | | | | | |
|
| [`en_core_web_lg`](/models/en#en_core_web_lg) (spaCy v3) | 92.1 | 97.4 | 87.0 | 7k | |
|
||||||
| `en_core_web_lg` (spaCy v2) | 91.9 | 97.2 | 85.9 | 10k | |
|
| `en_core_web_lg` (spaCy v2) | 91.9 | 97.2 | 85.9 | 10k | |
|
||||||
|
|
||||||
<figcaption class="caption">
|
<figcaption class="caption">
|
||||||
|
@ -21,10 +21,10 @@ import { Help } from 'components/typography'; import Link from 'components/link'
|
||||||
|
|
||||||
<figure>
|
<figure>
|
||||||
|
|
||||||
| Named Entity Recognition Model | OntoNotes | CoNLL '03 |
|
| Named Entity Recognition System | OntoNotes | CoNLL '03 |
|
||||||
| ------------------------------------------------------------------------------ | --------: | --------: |
|
| ------------------------------------------------------------------------------ | --------: | --------: |
|
||||||
| spaCy RoBERTa (2020) | | 92.2 |
|
| spaCy RoBERTa (2020) | | 92.2 |
|
||||||
| spaCy CNN (2020) | | 88.4 |
|
| spaCy CNN (2020) | 85.3 | 88.4 |
|
||||||
| spaCy CNN (2017) | 86.4 | |
|
| spaCy CNN (2017) | 86.4 | |
|
||||||
| [Stanza](https://stanfordnlp.github.io/stanza/) (StanfordNLP)<sup>1</sup> | 88.8 | 92.1 |
|
| [Stanza](https://stanfordnlp.github.io/stanza/) (StanfordNLP)<sup>1</sup> | 88.8 | 92.1 |
|
||||||
| <Link to="https://github.com/flairNLP/flair" hideIcon>Flair</Link><sup>2</sup> | 89.7 | 93.1 |
|
| <Link to="https://github.com/flairNLP/flair" hideIcon>Flair</Link><sup>2</sup> | 89.7 | 93.1 |
|
||||||
|
|
|
@ -235,8 +235,6 @@ The `Transformer` component sets the
|
||||||
[`Doc._.trf_data`](/api/transformer#custom_attributes) extension attribute,
|
[`Doc._.trf_data`](/api/transformer#custom_attributes) extension attribute,
|
||||||
which lets you access the transformers outputs at runtime.
|
which lets you access the transformers outputs at runtime.
|
||||||
|
|
||||||
<!-- TODO: update/confirm once we have final models trained -->
|
|
||||||
|
|
||||||
```cli
|
```cli
|
||||||
$ python -m spacy download en_core_trf_lg
|
$ python -m spacy download en_core_trf_lg
|
||||||
```
|
```
|
||||||
|
|
|
@ -63,7 +63,7 @@ import Benchmarks from 'usage/\_benchmarks-models.md'
|
||||||
|
|
||||||
<figure>
|
<figure>
|
||||||
|
|
||||||
| System | UAS | LAS |
|
| Dependency Parsing System | UAS | LAS |
|
||||||
| ------------------------------------------------------------------------------ | ---: | ---: |
|
| ------------------------------------------------------------------------------ | ---: | ---: |
|
||||||
| spaCy RoBERTa (2020)<sup>1</sup> | 96.8 | 95.0 |
|
| spaCy RoBERTa (2020)<sup>1</sup> | 96.8 | 95.0 |
|
||||||
| spaCy CNN (2020)<sup>1</sup> | 93.7 | 91.8 |
|
| spaCy CNN (2020)<sup>1</sup> | 93.7 | 91.8 |
|
||||||
|
|
|
@ -1654,9 +1654,12 @@ The [`SentenceRecognizer`](/api/sentencerecognizer) is a simple statistical
|
||||||
component that only provides sentence boundaries. Along with being faster and
|
component that only provides sentence boundaries. Along with being faster and
|
||||||
smaller than the parser, its primary advantage is that it's easier to train
|
smaller than the parser, its primary advantage is that it's easier to train
|
||||||
because it only requires annotated sentence boundaries rather than full
|
because it only requires annotated sentence boundaries rather than full
|
||||||
dependency parses.
|
dependency parses. spaCy's [trained pipelines](/models) include both a parser
|
||||||
|
and a trained sentence segmenter, which is
|
||||||
<!-- TODO: update/confirm usage once we have final models trained -->
|
[disabled](/usage/processing-pipelines#disabling) by default. If you only need
|
||||||
|
sentence boundaries and no parser, you can use the `enable` and `disable`
|
||||||
|
arguments on [`spacy.load`](/api/top-level#spacy.load) to enable the senter and
|
||||||
|
disable the parser.
|
||||||
|
|
||||||
> #### senter vs. parser
|
> #### senter vs. parser
|
||||||
>
|
>
|
||||||
|
|
|
@ -253,8 +253,6 @@ different mechanisms you can use:
|
||||||
Disabled and excluded component names can be provided to
|
Disabled and excluded component names can be provided to
|
||||||
[`spacy.load`](/api/top-level#spacy.load) as a list.
|
[`spacy.load`](/api/top-level#spacy.load) as a list.
|
||||||
|
|
||||||
<!-- TODO: update with info on our models shipped with optional components -->
|
|
||||||
|
|
||||||
> #### 💡 Optional pipeline components
|
> #### 💡 Optional pipeline components
|
||||||
>
|
>
|
||||||
> The `disable` mechanism makes it easy to distribute pipeline packages with
|
> The `disable` mechanism makes it easy to distribute pipeline packages with
|
||||||
|
@ -262,6 +260,11 @@ Disabled and excluded component names can be provided to
|
||||||
> your pipeline may include a statistical _and_ a rule-based component for
|
> your pipeline may include a statistical _and_ a rule-based component for
|
||||||
> sentence segmentation, and you can choose which one to run depending on your
|
> sentence segmentation, and you can choose which one to run depending on your
|
||||||
> use case.
|
> use case.
|
||||||
|
>
|
||||||
|
> For example, spaCy's [trained pipelines](/models) like
|
||||||
|
> [`en_core_web_sm`](/models/en#en_core_web_sm) contain both a `parser` and
|
||||||
|
> `senter` that perform sentence segmentation, but the `senter` is disabled by
|
||||||
|
> default.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Load the pipeline without the entity recognizer
|
# Load the pipeline without the entity recognizer
|
||||||
|
|
|
@ -733,7 +733,10 @@ workflows, but only one can be tracked by DVC.
|
||||||
<Infobox title="This section is still under construction" emoji="🚧" variant="warning">
|
<Infobox title="This section is still under construction" emoji="🚧" variant="warning">
|
||||||
|
|
||||||
The Prodigy integration will require a nightly version of Prodigy that supports
|
The Prodigy integration will require a nightly version of Prodigy that supports
|
||||||
spaCy v3+.
|
spaCy v3+. You can already use annotations created with Prodigy in spaCy v3 by
|
||||||
|
exporting your data with
|
||||||
|
[`data-to-spacy`](https://prodi.gy/docs/recipes#data-to-spacy) and running
|
||||||
|
[`spacy convert`](/api/cli#convert) to convert it to the binary format.
|
||||||
|
|
||||||
</Infobox>
|
</Infobox>
|
||||||
|
|
||||||
|
|
|
@ -32,11 +32,17 @@ const MODEL_META = {
|
||||||
las: 'Labelled dependencies',
|
las: 'Labelled dependencies',
|
||||||
token_acc: 'Tokenization',
|
token_acc: 'Tokenization',
|
||||||
tok: 'Tokenization',
|
tok: 'Tokenization',
|
||||||
|
lemma: 'Statistical lemmatization',
|
||||||
|
morph: 'Morphological analysis',
|
||||||
tags_acc: 'Part-of-speech tags (fine grained tags, Token.tag)',
|
tags_acc: 'Part-of-speech tags (fine grained tags, Token.tag)',
|
||||||
tag: 'Part-of-speech tags (fine grained tags, Token.tag)',
|
tag: 'Part-of-speech tags (fine grained tags, Token.tag)',
|
||||||
|
pos: 'Part-of-speech tags (coarse grained tags, Token.pos)',
|
||||||
ents_f: 'Named entities (F-score)',
|
ents_f: 'Named entities (F-score)',
|
||||||
ents_p: 'Named entities (precision)',
|
ents_p: 'Named entities (precision)',
|
||||||
ents_r: 'Named entities (recall)',
|
ents_r: 'Named entities (recall)',
|
||||||
|
ner_f: 'Named entities (F-score)',
|
||||||
|
ner_p: 'Named entities (precision)',
|
||||||
|
ner_r: 'Named entities (recall)',
|
||||||
sent_f: 'Sentence segmentation (F-score)',
|
sent_f: 'Sentence segmentation (F-score)',
|
||||||
sent_p: 'Sentence segmentation (precision)',
|
sent_p: 'Sentence segmentation (precision)',
|
||||||
sent_r: 'Sentence segmentation (recall)',
|
sent_r: 'Sentence segmentation (recall)',
|
||||||
|
@ -88,11 +94,12 @@ function formatVectors(data) {
|
||||||
}
|
}
|
||||||
|
|
||||||
function formatAccuracy(data) {
|
function formatAccuracy(data) {
|
||||||
|
const exclude = ['speed']
|
||||||
if (!data) return []
|
if (!data) return []
|
||||||
return Object.keys(data)
|
return Object.keys(data)
|
||||||
.map(label => {
|
.map(label => {
|
||||||
const value = data[label]
|
const value = data[label]
|
||||||
return isNaN(value)
|
return isNaN(value) || exclude.includes(label)
|
||||||
? null
|
? null
|
||||||
: {
|
: {
|
||||||
label,
|
label,
|
||||||
|
@ -109,6 +116,7 @@ function formatModelMeta(data) {
|
||||||
version: data.version,
|
version: data.version,
|
||||||
sizeFull: data.size,
|
sizeFull: data.size,
|
||||||
pipeline: data.pipeline,
|
pipeline: data.pipeline,
|
||||||
|
components: data.components,
|
||||||
notes: data.notes,
|
notes: data.notes,
|
||||||
description: data.description,
|
description: data.description,
|
||||||
sources: data.sources,
|
sources: data.sources,
|
||||||
|
@ -117,7 +125,8 @@ function formatModelMeta(data) {
|
||||||
license: data.license,
|
license: data.license,
|
||||||
labels: isEmptyObj(data.labels) ? null : data.labels,
|
labels: isEmptyObj(data.labels) ? null : data.labels,
|
||||||
vectors: formatVectors(data.vectors),
|
vectors: formatVectors(data.vectors),
|
||||||
accuracy: formatAccuracy(data.accuracy),
|
// TODO: remove accuracy fallback
|
||||||
|
accuracy: formatAccuracy(data.accuracy || data.performance),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user