mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
API docs, docstrings and argument consistency
This commit is contained in:
parent
10b84e1e27
commit
d8b519c23c
|
@ -71,7 +71,6 @@ cdef class DependencyParser(Parser):
|
|||
|
||||
DOCS: https://spacy.io/api/dependencyparser
|
||||
"""
|
||||
# cdef classes can't have decorators, so we're defining this here
|
||||
TransitionSystem = ArcEager
|
||||
|
||||
@property
|
||||
|
@ -105,6 +104,14 @@ cdef class DependencyParser(Parser):
|
|||
return tuple(sorted(labels))
|
||||
|
||||
def score(self, examples, **kwargs):
|
||||
"""Score a batch of examples.
|
||||
|
||||
examples (Iterable[Example]): The examples to score.
|
||||
RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_spans
|
||||
and Scorer.score_deps.
|
||||
|
||||
DOCS: https://spacy.io/api/dependencyparser#score
|
||||
"""
|
||||
def dep_getter(token, attr):
|
||||
dep = getattr(token, attr)
|
||||
dep = token.vocab.strings.as_string(dep).lower()
|
||||
|
|
|
@ -86,6 +86,19 @@ class EntityLinker(Pipe):
|
|||
incl_prior: bool,
|
||||
incl_context: bool,
|
||||
) -> None:
|
||||
"""Initialize an entity linker.
|
||||
|
||||
vocab (Vocab): The shared vocabulary.
|
||||
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
||||
name (str): The component instance name, used to add entries to the
|
||||
losses during training.
|
||||
kb (KnowledgeBase): TODO:
|
||||
labels_discard (Iterable[str]): TODO:
|
||||
incl_prior (bool): TODO:
|
||||
incl_context (bool): TODO:
|
||||
|
||||
DOCS: https://spacy.io/api/entitylinker#init
|
||||
"""
|
||||
self.vocab = vocab
|
||||
self.model = model
|
||||
self.name = name
|
||||
|
@ -119,6 +132,19 @@ class EntityLinker(Pipe):
|
|||
pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None,
|
||||
sgd: Optional[Optimizer] = None,
|
||||
) -> Optimizer:
|
||||
"""Initialize the pipe for training, using data examples if available.
|
||||
|
||||
get_examples (Callable[[], Iterable[Example]]): Optional function that
|
||||
returns gold-standard Example objects.
|
||||
pipeline (List[Tuple[str, Callable]]): Optional list of pipeline
|
||||
components that this component is part of. Corresponds to
|
||||
nlp.pipeline.
|
||||
sgd (thinc.api.Optimizer): Optional optimizer. Will be created with
|
||||
create_optimizer if it doesn't exist.
|
||||
RETURNS (thinc.api.Optimizer): The optimizer.
|
||||
|
||||
DOCS: https://spacy.io/api/entitylinker#begin_training
|
||||
"""
|
||||
self.require_kb()
|
||||
nO = self.kb.entity_vector_length
|
||||
self.set_output(nO)
|
||||
|
@ -136,6 +162,20 @@ class EntityLinker(Pipe):
|
|||
sgd: Optional[Optimizer] = None,
|
||||
losses: Optional[Dict[str, float]] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""Learn from a batch of documents and gold-standard information,
|
||||
updating the pipe's model. Delegates to predict and get_loss.
|
||||
|
||||
examples (Iterable[Example]): A batch of Example objects.
|
||||
drop (float): The dropout rate.
|
||||
set_annotations (bool): Whether or not to update the Example objects
|
||||
with the predictions.
|
||||
sgd (thinc.api.Optimizer): The optimizer.
|
||||
losses (Dict[str, float]): Optional record of the loss during training.
|
||||
Updated using the component name as the key.
|
||||
RETURNS (Dict[str, float]): The updated losses dictionary.
|
||||
|
||||
DOCS: https://spacy.io/api/entitylinker#update
|
||||
"""
|
||||
self.require_kb()
|
||||
if losses is None:
|
||||
losses = {}
|
||||
|
@ -215,18 +255,43 @@ class EntityLinker(Pipe):
|
|||
return loss, gradients
|
||||
|
||||
def __call__(self, doc: Doc) -> Doc:
|
||||
"""Apply the pipe to a Doc.
|
||||
|
||||
doc (Doc): The document to process.
|
||||
RETURNS (Doc): The processed Doc.
|
||||
|
||||
DOCS: https://spacy.io/api/entitylinker#call
|
||||
"""
|
||||
kb_ids = self.predict([doc])
|
||||
self.set_annotations([doc], kb_ids)
|
||||
return doc
|
||||
|
||||
def pipe(self, stream: Iterable[Doc], batch_size: int = 128) -> Iterator[Doc]:
|
||||
def pipe(self, stream: Iterable[Doc], *, batch_size: int = 128) -> Iterator[Doc]:
|
||||
"""Apply the pipe to a stream of documents. This usually happens under
|
||||
the hood when the nlp object is called on a text and all components are
|
||||
applied to the Doc.
|
||||
|
||||
stream (Iterable[Doc]): A stream of documents.
|
||||
batch_size (int): The number of documents to buffer.
|
||||
YIELDS (Doc): PRocessed documents in order.
|
||||
|
||||
DOCS: https://spacy.io/api/entitylinker#pipe
|
||||
"""
|
||||
for docs in util.minibatch(stream, size=batch_size):
|
||||
kb_ids = self.predict(docs)
|
||||
self.set_annotations(docs, kb_ids)
|
||||
yield from docs
|
||||
|
||||
def predict(self, docs):
|
||||
""" Return the KB IDs for each entity in each doc, including NIL if there is no prediction """
|
||||
def predict(self, docs: Iterable[Doc]) -> List[str]:
|
||||
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||
Returns the KB IDs for each entity in each doc, including NIL if there is
|
||||
no prediction.
|
||||
|
||||
docs (Iterable[Doc]): The documents to predict.
|
||||
RETURNS (List[int]): The models prediction for each document.
|
||||
|
||||
DOCS: https://spacy.io/api/entitylinker#predict
|
||||
"""
|
||||
self.require_kb()
|
||||
entity_count = 0
|
||||
final_kb_ids = []
|
||||
|
@ -315,7 +380,14 @@ class EntityLinker(Pipe):
|
|||
raise RuntimeError(err)
|
||||
return final_kb_ids
|
||||
|
||||
def set_annotations(self, docs: Iterable[Doc], kb_ids: List[int]) -> None:
|
||||
def set_annotations(self, docs: Iterable[Doc], kb_ids: List[str]) -> None:
|
||||
"""Modify a batch of documents, using pre-computed scores.
|
||||
|
||||
docs (Iterable[Doc]): The documents to modify.
|
||||
kb_ids (List[str]): The IDs to set, produced by EntityLinker.predict.
|
||||
|
||||
DOCS: https://spacy.io/api/entitylinker#predict
|
||||
"""
|
||||
count_ents = len([ent for doc in docs for ent in doc.ents])
|
||||
if count_ents != len(kb_ids):
|
||||
raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids)))
|
||||
|
@ -328,6 +400,13 @@ class EntityLinker(Pipe):
|
|||
token.ent_kb_id_ = kb_id
|
||||
|
||||
def to_disk(self, path: Union[str, Path], exclude: Iterable[str] = tuple()) -> None:
|
||||
"""Serialize the pipe to disk.
|
||||
|
||||
path (str / Path): Path to a directory.
|
||||
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||
|
||||
DOCS: https://spacy.io/api/entitylinker#to_disk
|
||||
"""
|
||||
serialize = {}
|
||||
self.cfg["entity_width"] = self.kb.entity_vector_length
|
||||
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
|
||||
|
@ -339,6 +418,15 @@ class EntityLinker(Pipe):
|
|||
def from_disk(
|
||||
self, path: Union[str, Path], exclude: Iterable[str] = tuple()
|
||||
) -> "EntityLinker":
|
||||
"""Load the pipe from disk. Modifies the object in place and returns it.
|
||||
|
||||
path (str / Path): Path to a directory.
|
||||
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||
RETURNS (EntityLinker): The modified EntityLinker object.
|
||||
|
||||
DOCS: https://spacy.io/api/entitylinker#from_disk
|
||||
"""
|
||||
|
||||
def load_model(p):
|
||||
try:
|
||||
self.model.from_bytes(p.open("rb").read())
|
||||
|
@ -359,7 +447,7 @@ class EntityLinker(Pipe):
|
|||
util.from_disk(path, deserialize, exclude)
|
||||
return self
|
||||
|
||||
def rehearse(self, examples, sgd=None, losses=None, **config):
|
||||
def rehearse(self, examples, *, sgd=None, losses=None, **config):
|
||||
raise NotImplementedError
|
||||
|
||||
def add_label(self, label):
|
||||
|
|
|
@ -71,6 +71,10 @@ class EntityRuler:
|
|||
|
||||
nlp (Language): The shared nlp object to pass the vocab to the matchers
|
||||
and process phrase patterns.
|
||||
name (str): Instance name of the current pipeline component. Typically
|
||||
passed in automatically from the factory when the component is
|
||||
added. Used to disable the current entity ruler while creating
|
||||
phrase patterns with the nlp object.
|
||||
phrase_matcher_attr (int / str): Token attribute to match on, passed
|
||||
to the internal PhraseMatcher as `attr`
|
||||
validate (bool): Whether patterns should be validated, passed to
|
||||
|
|
|
@ -61,6 +61,17 @@ class Morphologizer(Tagger):
|
|||
labels_morph: Optional[dict] = None,
|
||||
labels_pos: Optional[dict] = None,
|
||||
):
|
||||
"""Initialize a morphologizer.
|
||||
|
||||
vocab (Vocab): The shared vocabulary.
|
||||
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
||||
name (str): The component instance name, used to add entries to the
|
||||
losses during training.
|
||||
labels_morph (dict): TODO:
|
||||
labels_pos (dict): TODO:
|
||||
|
||||
DOCS: https://spacy.io/api/morphologizer#init
|
||||
"""
|
||||
self.vocab = vocab
|
||||
self.model = model
|
||||
self.name = name
|
||||
|
@ -77,9 +88,17 @@ class Morphologizer(Tagger):
|
|||
|
||||
@property
|
||||
def labels(self):
|
||||
"""RETURNS (Tuple[str]): The labels currently added to the component."""
|
||||
return tuple(self.cfg["labels_morph"].keys())
|
||||
|
||||
def add_label(self, label):
|
||||
"""Add a new label to the pipe.
|
||||
|
||||
label (str): The label to add.
|
||||
RETURNS (int): 1
|
||||
|
||||
DOCS: https://spacy.io/api/morphologizer#add_label
|
||||
"""
|
||||
if not isinstance(label, str):
|
||||
raise ValueError(Errors.E187)
|
||||
if label in self.labels:
|
||||
|
@ -99,7 +118,20 @@ class Morphologizer(Tagger):
|
|||
self.cfg["labels_pos"][norm_label] = POS_IDS[pos]
|
||||
return 1
|
||||
|
||||
def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None):
|
||||
def begin_training(self, get_examples=lambda: [], *, pipeline=None, sgd=None):
|
||||
"""Initialize the pipe for training, using data examples if available.
|
||||
|
||||
get_examples (Callable[[], Iterable[Example]]): Optional function that
|
||||
returns gold-standard Example objects.
|
||||
pipeline (List[Tuple[str, Callable]]): Optional list of pipeline
|
||||
components that this component is part of. Corresponds to
|
||||
nlp.pipeline.
|
||||
sgd (thinc.api.Optimizer): Optional optimizer. Will be created with
|
||||
create_optimizer if it doesn't exist.
|
||||
RETURNS (thinc.api.Optimizer): The optimizer.
|
||||
|
||||
DOCS: https://spacy.io/api/morphologizer#begin_training
|
||||
"""
|
||||
for example in get_examples():
|
||||
for i, token in enumerate(example.reference):
|
||||
pos = token.pos_
|
||||
|
@ -121,6 +153,13 @@ class Morphologizer(Tagger):
|
|||
return sgd
|
||||
|
||||
def set_annotations(self, docs, batch_tag_ids):
|
||||
"""Modify a batch of documents, using pre-computed scores.
|
||||
|
||||
docs (Iterable[Doc]): The documents to modify.
|
||||
batch_tag_ids: The IDs to set, produced by Morphologizer.predict.
|
||||
|
||||
DOCS: https://spacy.io/api/morphologizer#predict
|
||||
"""
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
cdef Doc doc
|
||||
|
@ -137,6 +176,15 @@ class Morphologizer(Tagger):
|
|||
doc.is_morphed = True
|
||||
|
||||
def get_loss(self, examples, scores):
|
||||
"""Find the loss and gradient of loss for the batch of documents and
|
||||
their predicted scores.
|
||||
|
||||
examples (Iterable[Examples]): The batch of examples.
|
||||
scores: Scores representing the model's predictions.
|
||||
RETUTNRS (Tuple[float, float]): The loss and the gradient.
|
||||
|
||||
DOCS: https://spacy.io/api/morphologizer#get_loss
|
||||
"""
|
||||
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False)
|
||||
truths = []
|
||||
for eg in examples:
|
||||
|
@ -164,6 +212,15 @@ class Morphologizer(Tagger):
|
|||
return float(loss), d_scores
|
||||
|
||||
def score(self, examples, **kwargs):
|
||||
"""Score a batch of examples.
|
||||
|
||||
examples (Iterable[Example]): The examples to score.
|
||||
RETURNS (Dict[str, Any]): The scores, produced by
|
||||
Scorer.score_token_attr for the attributes "pos" and "morph" and
|
||||
Scorer.score_token_attr_per_feat for the attribute "morph".
|
||||
|
||||
DOCS: https://spacy.io/api/morphologizer#score
|
||||
"""
|
||||
results = {}
|
||||
results.update(Scorer.score_token_attr(examples, "pos", **kwargs))
|
||||
results.update(Scorer.score_token_attr(examples, "morph", **kwargs))
|
||||
|
@ -172,6 +229,13 @@ class Morphologizer(Tagger):
|
|||
return results
|
||||
|
||||
def to_bytes(self, exclude=tuple()):
|
||||
"""Serialize the pipe to a bytestring.
|
||||
|
||||
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||
RETURNS (bytes): The serialized object.
|
||||
|
||||
DOCS: https://spacy.io/api/morphologizer#to_bytes
|
||||
"""
|
||||
serialize = {}
|
||||
serialize["model"] = self.model.to_bytes
|
||||
serialize["vocab"] = self.vocab.to_bytes
|
||||
|
@ -179,6 +243,14 @@ class Morphologizer(Tagger):
|
|||
return util.to_bytes(serialize, exclude)
|
||||
|
||||
def from_bytes(self, bytes_data, exclude=tuple()):
|
||||
"""Load the pipe from a bytestring.
|
||||
|
||||
bytes_data (bytes): The serialized pipe.
|
||||
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||
RETURNS (Morphologizer): The loaded Morphologizer.
|
||||
|
||||
DOCS: https://spacy.io/api/morphologizer#from_bytes
|
||||
"""
|
||||
def load_model(b):
|
||||
try:
|
||||
self.model.from_bytes(b)
|
||||
|
@ -194,6 +266,13 @@ class Morphologizer(Tagger):
|
|||
return self
|
||||
|
||||
def to_disk(self, path, exclude=tuple()):
|
||||
"""Serialize the pipe to disk.
|
||||
|
||||
path (str / Path): Path to a directory.
|
||||
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||
|
||||
DOCS: https://spacy.io/api/morphologizer#to_disk
|
||||
"""
|
||||
serialize = {
|
||||
"vocab": lambda p: self.vocab.to_disk(p),
|
||||
"model": lambda p: p.open("wb").write(self.model.to_bytes()),
|
||||
|
@ -202,6 +281,14 @@ class Morphologizer(Tagger):
|
|||
util.to_disk(path, serialize, exclude)
|
||||
|
||||
def from_disk(self, path, exclude=tuple()):
|
||||
"""Load the pipe from disk. Modifies the object in place and returns it.
|
||||
|
||||
path (str / Path): Path to a directory.
|
||||
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||
RETURNS (Morphologizer): The modified Morphologizer object.
|
||||
|
||||
DOCS: https://spacy.io/api/morphologizer#from_disk
|
||||
"""
|
||||
def load_model(p):
|
||||
with p.open("rb") as file_:
|
||||
try:
|
||||
|
|
|
@ -91,4 +91,11 @@ cdef class EntityRecognizer(Parser):
|
|||
return tuple(sorted(labels))
|
||||
|
||||
def score(self, examples, **kwargs):
|
||||
"""Score a batch of examples.
|
||||
|
||||
examples (Iterable[Example]): The examples to score.
|
||||
RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_spans.
|
||||
|
||||
DOCS: https://spacy.io/api/entityrecognizer#score
|
||||
"""
|
||||
return Scorer.score_spans(examples, "ents", **kwargs)
|
||||
|
|
|
@ -23,7 +23,7 @@ class Pipe:
|
|||
|
||||
name = None
|
||||
|
||||
def __init__(self, vocab, model, **cfg):
|
||||
def __init__(self, vocab, model, name, **cfg):
|
||||
"""Create a new pipe instance."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -79,7 +79,7 @@ class Pipe:
|
|||
def create_optimizer(self):
|
||||
return create_default_optimizer()
|
||||
|
||||
def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None):
|
||||
def begin_training(self, get_examples=lambda: [], *, pipeline=None, sgd=None):
|
||||
"""Initialize the pipe for training, using data exampes if available.
|
||||
If no model has been initialized yet, the model is added."""
|
||||
self.model.initialize()
|
||||
|
|
|
@ -41,7 +41,7 @@ class Sentencizer(Pipe):
|
|||
'𑩃', '𑪛', '𑪜', '𑱁', '𑱂', '𖩮', '𖩯', '𖫵', '𖬷', '𖬸', '𖭄', '𛲟', '𝪈',
|
||||
'。', '。']
|
||||
|
||||
def __init__(self, name="sentencizer", *, punct_chars):
|
||||
def __init__(self, name="sentencizer", *, punct_chars=None):
|
||||
"""Initialize the sentencizer.
|
||||
|
||||
punct_chars (list): Punctuation characters to split on. Will be
|
||||
|
@ -62,8 +62,8 @@ class Sentencizer(Pipe):
|
|||
def __call__(self, doc):
|
||||
"""Apply the sentencizer to a Doc and set Token.is_sent_start.
|
||||
|
||||
example (Doc or Example): The document to process.
|
||||
RETURNS (Doc or Example): The processed Doc or Example.
|
||||
doc (Doc): The document to process.
|
||||
RETURNS (Doc): The processed Doc.
|
||||
|
||||
DOCS: https://spacy.io/api/sentencizer#call
|
||||
"""
|
||||
|
@ -83,14 +83,26 @@ class Sentencizer(Pipe):
|
|||
return doc
|
||||
|
||||
def pipe(self, stream, batch_size=128):
|
||||
"""Apply the pipe to a stream of documents. This usually happens under
|
||||
the hood when the nlp object is called on a text and all components are
|
||||
applied to the Doc.
|
||||
|
||||
stream (Iterable[Doc]): A stream of documents.
|
||||
batch_size (int): The number of documents to buffer.
|
||||
YIELDS (Doc): Processed documents in order.
|
||||
|
||||
DOCS: https://spacy.io/api/sentencizer#pipe
|
||||
"""
|
||||
for docs in util.minibatch(stream, size=batch_size):
|
||||
predictions = self.predict(docs)
|
||||
self.set_annotations(docs, predictions)
|
||||
yield from docs
|
||||
|
||||
def predict(self, docs):
|
||||
"""Apply the pipeline's model to a batch of docs, without
|
||||
modifying them.
|
||||
"""Apply the pipe to a batch of docs, without modifying them.
|
||||
|
||||
docs (Iterable[Doc]): The documents to predict.
|
||||
RETURNS: The predictions for each document.
|
||||
"""
|
||||
if not any(len(doc) for doc in docs):
|
||||
# Handle cases where there are no tokens in any docs.
|
||||
|
@ -117,6 +129,11 @@ class Sentencizer(Pipe):
|
|||
return guesses
|
||||
|
||||
def set_annotations(self, docs, batch_tag_ids):
|
||||
"""Modify a batch of documents, using pre-computed scores.
|
||||
|
||||
docs (Iterable[Doc]): The documents to modify.
|
||||
scores: The tag IDs produced by Sentencizer.predict.
|
||||
"""
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
cdef Doc doc
|
||||
|
@ -132,6 +149,13 @@ class Sentencizer(Pipe):
|
|||
doc.c[j].sent_start = -1
|
||||
|
||||
def score(self, examples, **kwargs):
|
||||
"""Score a batch of examples.
|
||||
|
||||
examples (Iterable[Example]): The examples to score.
|
||||
RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_spans.
|
||||
|
||||
DOCS: https://spacy.io/api/sentencizer#score
|
||||
"""
|
||||
return Scorer.score_spans(examples, "sents", **kwargs)
|
||||
|
||||
def to_bytes(self, exclude=tuple()):
|
||||
|
|
|
@ -45,6 +45,15 @@ class SentenceRecognizer(Tagger):
|
|||
DOCS: https://spacy.io/api/sentencerecognizer
|
||||
"""
|
||||
def __init__(self, vocab, model, name="senter"):
|
||||
"""Initialize a sentence recognizer.
|
||||
|
||||
vocab (Vocab): The shared vocabulary.
|
||||
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
||||
name (str): The component instance name, used to add entries to the
|
||||
losses during training.
|
||||
|
||||
DOCS: https://spacy.io/api/sentencerecognizer#init
|
||||
"""
|
||||
self.vocab = vocab
|
||||
self.model = model
|
||||
self.name = name
|
||||
|
@ -53,12 +62,20 @@ class SentenceRecognizer(Tagger):
|
|||
|
||||
@property
|
||||
def labels(self):
|
||||
"""RETURNS (Tuple[str]): The labels."""
|
||||
# labels are numbered by index internally, so this matches GoldParse
|
||||
# and Example where the sentence-initial tag is 1 and other positions
|
||||
# are 0
|
||||
return tuple(["I", "S"])
|
||||
|
||||
def set_annotations(self, docs, batch_tag_ids):
|
||||
"""Modify a batch of documents, using pre-computed scores.
|
||||
|
||||
docs (Iterable[Doc]): The documents to modify.
|
||||
batch_tag_ids: The IDs to set, produced by SentenceRecognizer.predict.
|
||||
|
||||
DOCS: https://spacy.io/api/sentencerecognizer#predict
|
||||
"""
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
cdef Doc doc
|
||||
|
@ -75,6 +92,15 @@ class SentenceRecognizer(Tagger):
|
|||
doc.c[j].sent_start = -1
|
||||
|
||||
def get_loss(self, examples, scores):
|
||||
"""Find the loss and gradient of loss for the batch of documents and
|
||||
their predicted scores.
|
||||
|
||||
examples (Iterable[Examples]): The batch of examples.
|
||||
scores: Scores representing the model's predictions.
|
||||
RETUTNRS (Tuple[float, float]): The loss and the gradient.
|
||||
|
||||
DOCS: https://spacy.io/api/sentencerecognizer#get_loss
|
||||
"""
|
||||
labels = self.labels
|
||||
loss_func = SequenceCategoricalCrossentropy(names=labels, normalize=False)
|
||||
truths = []
|
||||
|
@ -94,7 +120,20 @@ class SentenceRecognizer(Tagger):
|
|||
raise ValueError("nan value when computing loss")
|
||||
return float(loss), d_scores
|
||||
|
||||
def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None):
|
||||
def begin_training(self, get_examples=lambda: [], *, pipeline=None, sgd=None):
|
||||
"""Initialize the pipe for training, using data examples if available.
|
||||
|
||||
get_examples (Callable[[], Iterable[Example]]): Optional function that
|
||||
returns gold-standard Example objects.
|
||||
pipeline (List[Tuple[str, Callable]]): Optional list of pipeline
|
||||
components that this component is part of. Corresponds to
|
||||
nlp.pipeline.
|
||||
sgd (thinc.api.Optimizer): Optional optimizer. Will be created with
|
||||
create_optimizer if it doesn't exist.
|
||||
RETURNS (thinc.api.Optimizer): The optimizer.
|
||||
|
||||
DOCS: https://spacy.io/api/sentencerecognizer#begin_training
|
||||
"""
|
||||
self.set_output(len(self.labels))
|
||||
self.model.initialize()
|
||||
util.link_vectors_to_models(self.vocab)
|
||||
|
@ -106,9 +145,22 @@ class SentenceRecognizer(Tagger):
|
|||
raise NotImplementedError
|
||||
|
||||
def score(self, examples, **kwargs):
|
||||
"""Score a batch of examples.
|
||||
|
||||
examples (Iterable[Example]): The examples to score.
|
||||
RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_spans.
|
||||
DOCS: https://spacy.io/api/sentencerecognizer#score
|
||||
"""
|
||||
return Scorer.score_spans(examples, "sents", **kwargs)
|
||||
|
||||
def to_bytes(self, exclude=tuple()):
|
||||
"""Serialize the pipe to a bytestring.
|
||||
|
||||
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||
RETURNS (bytes): The serialized object.
|
||||
|
||||
DOCS: https://spacy.io/api/sentencerecognizer#to_bytes
|
||||
"""
|
||||
serialize = {}
|
||||
serialize["model"] = self.model.to_bytes
|
||||
serialize["vocab"] = self.vocab.to_bytes
|
||||
|
@ -116,6 +168,14 @@ class SentenceRecognizer(Tagger):
|
|||
return util.to_bytes(serialize, exclude)
|
||||
|
||||
def from_bytes(self, bytes_data, exclude=tuple()):
|
||||
"""Load the pipe from a bytestring.
|
||||
|
||||
bytes_data (bytes): The serialized pipe.
|
||||
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||
RETURNS (Tagger): The loaded SentenceRecognizer.
|
||||
|
||||
DOCS: https://spacy.io/api/sentencerecognizer#from_bytes
|
||||
"""
|
||||
def load_model(b):
|
||||
try:
|
||||
self.model.from_bytes(b)
|
||||
|
@ -131,6 +191,13 @@ class SentenceRecognizer(Tagger):
|
|||
return self
|
||||
|
||||
def to_disk(self, path, exclude=tuple()):
|
||||
"""Serialize the pipe to disk.
|
||||
|
||||
path (str / Path): Path to a directory.
|
||||
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||
|
||||
DOCS: https://spacy.io/api/sentencerecognizer#to_disk
|
||||
"""
|
||||
serialize = {
|
||||
"vocab": lambda p: self.vocab.to_disk(p),
|
||||
"model": lambda p: p.open("wb").write(self.model.to_bytes()),
|
||||
|
@ -139,6 +206,14 @@ class SentenceRecognizer(Tagger):
|
|||
util.to_disk(path, serialize, exclude)
|
||||
|
||||
def from_disk(self, path, exclude=tuple()):
|
||||
"""Load the pipe from disk. Modifies the object in place and returns it.
|
||||
|
||||
path (str / Path): Path to a directory.
|
||||
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||
RETURNS (Tagger): The modified SentenceRecognizer object.
|
||||
|
||||
DOCS: https://spacy.io/api/sentencerecognizer#from_disk
|
||||
"""
|
||||
def load_model(p):
|
||||
with p.open("rb") as file_:
|
||||
try:
|
||||
|
|
|
@ -51,6 +51,16 @@ class Tagger(Pipe):
|
|||
DOCS: https://spacy.io/api/tagger
|
||||
"""
|
||||
def __init__(self, vocab, model, name="tagger", *, set_morphology=False):
|
||||
"""Initialize a part-of-speech tagger.
|
||||
|
||||
vocab (Vocab): The shared vocabulary.
|
||||
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
||||
name (str): The component instance name, used to add entries to the
|
||||
losses during training.
|
||||
set_morphology (bool): Whether to set morphological features.
|
||||
|
||||
DOCS: https://spacy.io/api/tagger#init
|
||||
"""
|
||||
self.vocab = vocab
|
||||
self.model = model
|
||||
self.name = name
|
||||
|
@ -60,20 +70,52 @@ class Tagger(Pipe):
|
|||
|
||||
@property
|
||||
def labels(self):
|
||||
"""The labels currently added to the component. Note that even for a
|
||||
blank component, this will always include the built-in coarse-grained
|
||||
part-of-speech tags by default.
|
||||
|
||||
RETURNS (Tuple[str]): The labels.
|
||||
|
||||
DOCS: https://spacy.io/api/tagger#labels
|
||||
"""
|
||||
return tuple(self.vocab.morphology.tag_names)
|
||||
|
||||
def __call__(self, doc):
|
||||
"""Apply the pipe to a Doc.
|
||||
|
||||
doc (Doc): The document to process.
|
||||
RETURNS (Doc): The processed Doc.
|
||||
|
||||
DOCS: https://spacy.io/api/tagger#call
|
||||
"""
|
||||
tags = self.predict([doc])
|
||||
self.set_annotations([doc], tags)
|
||||
return doc
|
||||
|
||||
def pipe(self, stream, batch_size=128):
|
||||
def pipe(self, stream, *, batch_size=128):
|
||||
"""Apply the pipe to a stream of documents. This usually happens under
|
||||
the hood when the nlp object is called on a text and all components are
|
||||
applied to the Doc.
|
||||
|
||||
stream (Iterable[Doc]): A stream of documents.
|
||||
batch_size (int): The number of documents to buffer.
|
||||
YIELDS (Doc): Processed documents in order.
|
||||
|
||||
DOCS: https://spacy.io/api/tagger#pipe
|
||||
"""
|
||||
for docs in util.minibatch(stream, size=batch_size):
|
||||
tag_ids = self.predict(docs)
|
||||
self.set_annotations(docs, tag_ids)
|
||||
yield from docs
|
||||
|
||||
def predict(self, docs):
|
||||
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||
|
||||
docs (Iterable[Doc]): The documents to predict.
|
||||
RETURNS: The models prediction for each document.
|
||||
|
||||
DOCS: https://spacy.io/api/tagger#predict
|
||||
"""
|
||||
if not any(len(doc) for doc in docs):
|
||||
# Handle cases where there are no tokens in any docs.
|
||||
n_labels = len(self.labels)
|
||||
|
@ -96,6 +138,13 @@ class Tagger(Pipe):
|
|||
return guesses
|
||||
|
||||
def set_annotations(self, docs, batch_tag_ids):
|
||||
"""Modify a batch of documents, using pre-computed scores.
|
||||
|
||||
docs (Iterable[Doc]): The documents to modify.
|
||||
batch_tag_ids: The IDs to set, produced by Tagger.predict.
|
||||
|
||||
DOCS: https://spacy.io/api/tagger#predict
|
||||
"""
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
cdef Doc doc
|
||||
|
@ -121,10 +170,23 @@ class Tagger(Pipe):
|
|||
doc.is_tagged = True
|
||||
|
||||
def update(self, examples, *, drop=0., sgd=None, losses=None, set_annotations=False):
|
||||
"""Learn from a batch of documents and gold-standard information,
|
||||
updating the pipe's model. Delegates to predict and get_loss.
|
||||
|
||||
examples (Iterable[Example]): A batch of Example objects.
|
||||
drop (float): The dropout rate.
|
||||
set_annotations (bool): Whether or not to update the Example objects
|
||||
with the predictions.
|
||||
sgd (thinc.api.Optimizer): The optimizer.
|
||||
losses (Dict[str, float]): Optional record of the loss during training.
|
||||
Updated using the component name as the key.
|
||||
RETURNS (Dict[str, float]): The updated losses dictionary.
|
||||
|
||||
DOCS: https://spacy.io/api/tagger#update
|
||||
"""
|
||||
if losses is None:
|
||||
losses = {}
|
||||
losses.setdefault(self.name, 0.0)
|
||||
|
||||
try:
|
||||
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
|
||||
# Handle cases where there are no tokens in any docs.
|
||||
|
@ -149,9 +211,20 @@ class Tagger(Pipe):
|
|||
self.set_annotations(docs, self._scores2guesses(tag_scores))
|
||||
return losses
|
||||
|
||||
def rehearse(self, examples, drop=0., sgd=None, losses=None):
|
||||
"""Perform a 'rehearsal' update, where we try to match the output of
|
||||
an initial model.
|
||||
def rehearse(self, examples, *, drop=0., sgd=None, losses=None):
|
||||
"""Perform a "rehearsal" update from a batch of data. Rehearsal updates
|
||||
teach the current model to make predictions similar to an initial model,
|
||||
to try to address the "catastrophic forgetting" problem. This feature is
|
||||
experimental.
|
||||
|
||||
examples (Iterable[Example]): A batch of Example objects.
|
||||
drop (float): The dropout rate.
|
||||
sgd (thinc.api.Optimizer): The optimizer.
|
||||
losses (Dict[str, float]): Optional record of the loss during training.
|
||||
Updated using the component name as the key.
|
||||
RETURNS (Dict[str, float]): The updated losses dictionary.
|
||||
|
||||
DOCS: https://spacy.io/api/tagger#rehearse
|
||||
"""
|
||||
try:
|
||||
docs = [eg.predicted for eg in examples]
|
||||
|
@ -174,6 +247,15 @@ class Tagger(Pipe):
|
|||
losses[self.name] += (gradient**2).sum()
|
||||
|
||||
def get_loss(self, examples, scores):
|
||||
"""Find the loss and gradient of loss for the batch of documents and
|
||||
their predicted scores.
|
||||
|
||||
examples (Iterable[Examples]): The batch of examples.
|
||||
scores: Scores representing the model's predictions.
|
||||
RETUTNRS (Tuple[float, float]): The loss and the gradient.
|
||||
|
||||
DOCS: https://spacy.io/api/tagger#get_loss
|
||||
"""
|
||||
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False)
|
||||
truths = [eg.get_aligned("tag", as_string=True) for eg in examples]
|
||||
d_scores, loss = loss_func(scores, truths)
|
||||
|
@ -181,7 +263,20 @@ class Tagger(Pipe):
|
|||
raise ValueError("nan value when computing loss")
|
||||
return float(loss), d_scores
|
||||
|
||||
def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None):
|
||||
def begin_training(self, get_examples=lambda: [], *, pipeline=None, sgd=None):
|
||||
"""Initialize the pipe for training, using data examples if available.
|
||||
|
||||
get_examples (Callable[[], Iterable[Example]]): Optional function that
|
||||
returns gold-standard Example objects.
|
||||
pipeline (List[Tuple[str, Callable]]): Optional list of pipeline
|
||||
components that this component is part of. Corresponds to
|
||||
nlp.pipeline.
|
||||
sgd (thinc.api.Optimizer): Optional optimizer. Will be created with
|
||||
create_optimizer if it doesn't exist.
|
||||
RETURNS (thinc.api.Optimizer): The optimizer.
|
||||
|
||||
DOCS: https://spacy.io/api/tagger#begin_training
|
||||
"""
|
||||
lemma_tables = ["lemma_rules", "lemma_index", "lemma_exc", "lemma_lookup"]
|
||||
if not any(table in self.vocab.lookups for table in lemma_tables):
|
||||
warnings.warn(Warnings.W022)
|
||||
|
@ -227,6 +322,15 @@ class Tagger(Pipe):
|
|||
return sgd
|
||||
|
||||
def add_label(self, label, values=None):
|
||||
"""Add a new label to the pipe.
|
||||
|
||||
label (str): The label to add.
|
||||
values (Dict[int, str]): Optional values to map to the label, e.g. a
|
||||
tag map dictionary.
|
||||
RETURNS (int): 1
|
||||
|
||||
DOCS: https://spacy.io/api/tagger#add_label
|
||||
"""
|
||||
if not isinstance(label, str):
|
||||
raise ValueError(Errors.E187)
|
||||
if label in self.labels:
|
||||
|
@ -254,6 +358,14 @@ class Tagger(Pipe):
|
|||
yield
|
||||
|
||||
def score(self, examples, **kwargs):
|
||||
"""Score a batch of examples.
|
||||
|
||||
examples (Iterable[Example]): The examples to score.
|
||||
RETURNS (Dict[str, Any]): The scores, produced by
|
||||
Scorer.score_token_attr for the attributes "tag", "pos" and "lemma".
|
||||
|
||||
DOCS: https://spacy.io/api/tagger#score
|
||||
"""
|
||||
scores = {}
|
||||
scores.update(Scorer.score_token_attr(examples, "tag", **kwargs))
|
||||
scores.update(Scorer.score_token_attr(examples, "pos", **kwargs))
|
||||
|
@ -261,6 +373,13 @@ class Tagger(Pipe):
|
|||
return scores
|
||||
|
||||
def to_bytes(self, exclude=tuple()):
|
||||
"""Serialize the pipe to a bytestring.
|
||||
|
||||
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||
RETURNS (bytes): The serialized object.
|
||||
|
||||
DOCS: https://spacy.io/api/tagger#to_bytes
|
||||
"""
|
||||
serialize = {}
|
||||
serialize["model"] = self.model.to_bytes
|
||||
serialize["vocab"] = self.vocab.to_bytes
|
||||
|
@ -272,6 +391,14 @@ class Tagger(Pipe):
|
|||
return util.to_bytes(serialize, exclude)
|
||||
|
||||
def from_bytes(self, bytes_data, exclude=tuple()):
|
||||
"""Load the pipe from a bytestring.
|
||||
|
||||
bytes_data (bytes): The serialized pipe.
|
||||
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||
RETURNS (Tagger): The loaded Tagger.
|
||||
|
||||
DOCS: https://spacy.io/api/tagger#from_bytes
|
||||
"""
|
||||
def load_model(b):
|
||||
try:
|
||||
self.model.from_bytes(b)
|
||||
|
@ -300,6 +427,13 @@ class Tagger(Pipe):
|
|||
return self
|
||||
|
||||
def to_disk(self, path, exclude=tuple()):
|
||||
"""Serialize the pipe to disk.
|
||||
|
||||
path (str / Path): Path to a directory.
|
||||
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||
|
||||
DOCS: https://spacy.io/api/tagger#to_disk
|
||||
"""
|
||||
tag_map = dict(sorted(self.vocab.morphology.tag_map.items()))
|
||||
morph_rules = dict(self.vocab.morphology.exc)
|
||||
serialize = {
|
||||
|
@ -312,6 +446,14 @@ class Tagger(Pipe):
|
|||
util.to_disk(path, serialize, exclude)
|
||||
|
||||
def from_disk(self, path, exclude=tuple()):
|
||||
"""Load the pipe from disk. Modifies the object in place and returns it.
|
||||
|
||||
path (str / Path): Path to a directory.
|
||||
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||
RETURNS (Tagger): The modified Tagger object.
|
||||
|
||||
DOCS: https://spacy.io/api/tagger#from_disk
|
||||
"""
|
||||
def load_model(p):
|
||||
with p.open("rb") as file_:
|
||||
try:
|
||||
|
|
|
@ -77,6 +77,16 @@ class TextCategorizer(Pipe):
|
|||
*,
|
||||
labels: Iterable[str],
|
||||
) -> None:
|
||||
"""Initialize a text categorizer.
|
||||
|
||||
vocab (Vocab): The shared vocabulary.
|
||||
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
||||
name (str): The component instance name, used to add entries to the
|
||||
losses during training.
|
||||
labels (Iterable[str]): The labels to use.
|
||||
|
||||
DOCS: https://spacy.io/api/textcategorizer#init
|
||||
"""
|
||||
self.vocab = vocab
|
||||
self.model = model
|
||||
self.name = name
|
||||
|
@ -86,6 +96,10 @@ class TextCategorizer(Pipe):
|
|||
|
||||
@property
|
||||
def labels(self) -> Tuple[str]:
|
||||
"""RETURNS (Tuple[str]): The labels currently added to the component.
|
||||
|
||||
DOCS: https://spacy.io/api/textcategorizer#labels
|
||||
"""
|
||||
return tuple(self.cfg.setdefault("labels", []))
|
||||
|
||||
def require_labels(self) -> None:
|
||||
|
@ -97,13 +111,30 @@ class TextCategorizer(Pipe):
|
|||
def labels(self, value: Iterable[str]) -> None:
|
||||
self.cfg["labels"] = tuple(value)
|
||||
|
||||
def pipe(self, stream: Iterator[str], batch_size: int = 128) -> Iterator[Doc]:
|
||||
def pipe(self, stream: Iterable[Doc], *, batch_size: int = 128) -> Iterator[Doc]:
|
||||
"""Apply the pipe to a stream of documents. This usually happens under
|
||||
the hood when the nlp object is called on a text and all components are
|
||||
applied to the Doc.
|
||||
|
||||
stream (Iterable[Doc]): A stream of documents.
|
||||
batch_size (int): The number of documents to buffer.
|
||||
YIELDS (Doc): PRocessed documents in order.
|
||||
|
||||
DOCS: https://spacy.io/api/textcategorizer#pipe
|
||||
"""
|
||||
for docs in util.minibatch(stream, size=batch_size):
|
||||
scores = self.predict(docs)
|
||||
self.set_annotations(docs, scores)
|
||||
yield from docs
|
||||
|
||||
def predict(self, docs: Iterable[Doc]):
|
||||
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
||||
|
||||
docs (Iterable[Doc]): The documents to predict.
|
||||
RETURNS: The models prediction for each document.
|
||||
|
||||
DOCS: https://spacy.io/api/textcategorizer#predict
|
||||
"""
|
||||
tensors = [doc.tensor for doc in docs]
|
||||
if not any(len(doc) for doc in docs):
|
||||
# Handle cases where there are no tokens in any docs.
|
||||
|
@ -115,6 +146,13 @@ class TextCategorizer(Pipe):
|
|||
return scores
|
||||
|
||||
def set_annotations(self, docs: Iterable[Doc], scores) -> None:
|
||||
"""Modify a batch of documents, using pre-computed scores.
|
||||
|
||||
docs (Iterable[Doc]): The documents to modify.
|
||||
scores: The scores to set, produced by TextCategorizer.predict.
|
||||
|
||||
DOCS: https://spacy.io/api/textcategorizer#predict
|
||||
"""
|
||||
for i, doc in enumerate(docs):
|
||||
for j, label in enumerate(self.labels):
|
||||
doc.cats[label] = float(scores[i, j])
|
||||
|
@ -128,6 +166,20 @@ class TextCategorizer(Pipe):
|
|||
sgd: Optional[Optimizer] = None,
|
||||
losses: Optional[Dict[str, float]] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""Learn from a batch of documents and gold-standard information,
|
||||
updating the pipe's model. Delegates to predict and get_loss.
|
||||
|
||||
examples (Iterable[Example]): A batch of Example objects.
|
||||
drop (float): The dropout rate.
|
||||
set_annotations (bool): Whether or not to update the Example objects
|
||||
with the predictions.
|
||||
sgd (thinc.api.Optimizer): The optimizer.
|
||||
losses (Dict[str, float]): Optional record of the loss during training.
|
||||
Updated using the component name as the key.
|
||||
RETURNS (Dict[str, float]): The updated losses dictionary.
|
||||
|
||||
DOCS: https://spacy.io/api/textcategorizer#update
|
||||
"""
|
||||
if losses is None:
|
||||
losses = {}
|
||||
losses.setdefault(self.name, 0.0)
|
||||
|
@ -155,10 +207,25 @@ class TextCategorizer(Pipe):
|
|||
def rehearse(
|
||||
self,
|
||||
examples: Iterable[Example],
|
||||
*,
|
||||
drop: float = 0.0,
|
||||
sgd: Optional[Optimizer] = None,
|
||||
losses: Optional[Dict[str, float]] = None,
|
||||
) -> None:
|
||||
) -> Dict[str, float]:
|
||||
"""Perform a "rehearsal" update from a batch of data. Rehearsal updates
|
||||
teach the current model to make predictions similar to an initial model,
|
||||
to try to address the "catastrophic forgetting" problem. This feature is
|
||||
experimental.
|
||||
|
||||
examples (Iterable[Example]): A batch of Example objects.
|
||||
drop (float): The dropout rate.
|
||||
sgd (thinc.api.Optimizer): The optimizer.
|
||||
losses (Dict[str, float]): Optional record of the loss during training.
|
||||
Updated using the component name as the key.
|
||||
RETURNS (Dict[str, float]): The updated losses dictionary.
|
||||
|
||||
DOCS: https://spacy.io/api/textcategorizer#rehearse
|
||||
"""
|
||||
if self._rehearsal_model is None:
|
||||
return
|
||||
try:
|
||||
|
@ -182,6 +249,7 @@ class TextCategorizer(Pipe):
|
|||
if losses is not None:
|
||||
losses.setdefault(self.name, 0.0)
|
||||
losses[self.name] += (gradient ** 2).sum()
|
||||
return losses
|
||||
|
||||
def _examples_to_truth(
|
||||
self, examples: List[Example]
|
||||
|
@ -198,6 +266,15 @@ class TextCategorizer(Pipe):
|
|||
return truths, not_missing
|
||||
|
||||
def get_loss(self, examples: Iterable[Example], scores) -> Tuple[float, float]:
|
||||
"""Find the loss and gradient of loss for the batch of documents and
|
||||
their predicted scores.
|
||||
|
||||
examples (Iterable[Examples]): The batch of examples.
|
||||
scores: Scores representing the model's predictions.
|
||||
RETUTNRS (Tuple[float, float]): The loss and the gradient.
|
||||
|
||||
DOCS: https://spacy.io/api/textcategorizer#get_loss
|
||||
"""
|
||||
truths, not_missing = self._examples_to_truth(examples)
|
||||
not_missing = self.model.ops.asarray(not_missing)
|
||||
d_scores = (scores - truths) / scores.shape[0]
|
||||
|
@ -206,6 +283,13 @@ class TextCategorizer(Pipe):
|
|||
return float(mean_square_error), d_scores
|
||||
|
||||
def add_label(self, label: str) -> int:
|
||||
"""Add a new label to the pipe.
|
||||
|
||||
label (str): The label to add.
|
||||
RETURNS (int): 1.
|
||||
|
||||
DOCS: https://spacy.io/api/textcategorizer#add_label
|
||||
"""
|
||||
if not isinstance(label, str):
|
||||
raise ValueError(Errors.E187)
|
||||
if label in self.labels:
|
||||
|
@ -226,10 +310,24 @@ class TextCategorizer(Pipe):
|
|||
|
||||
def begin_training(
|
||||
self,
|
||||
get_examples: Callable = lambda: [],
|
||||
get_examples: Callable[[], Iterable[Example]] = lambda: [],
|
||||
*,
|
||||
pipeline: Optional[List[Tuple[str, Callable[[Doc], Doc]]]] = None,
|
||||
sgd: Optional[Optimizer] = None,
|
||||
) -> Optimizer:
|
||||
"""Initialize the pipe for training, using data examples if available.
|
||||
|
||||
get_examples (Callable[[], Iterable[Example]]): Optional function that
|
||||
returns gold-standard Example objects.
|
||||
pipeline (List[Tuple[str, Callable]]): Optional list of pipeline
|
||||
components that this component is part of. Corresponds to
|
||||
nlp.pipeline.
|
||||
sgd (thinc.api.Optimizer): Optional optimizer. Will be created with
|
||||
create_optimizer if it doesn't exist.
|
||||
RETURNS (thinc.api.Optimizer): The optimizer.
|
||||
|
||||
DOCS: https://spacy.io/api/textcategorizer#begin_training
|
||||
"""
|
||||
# TODO: begin_training is not guaranteed to see all data / labels ?
|
||||
examples = list(get_examples())
|
||||
for example in examples:
|
||||
|
@ -255,9 +353,18 @@ class TextCategorizer(Pipe):
|
|||
def score(
|
||||
self,
|
||||
examples: Iterable[Example],
|
||||
*,
|
||||
positive_label: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""Score a batch of examples.
|
||||
|
||||
examples (Iterable[Example]): The examples to score.
|
||||
positive_label (str): Optional positive label.
|
||||
RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_cats.
|
||||
|
||||
DOCS: https://spacy.io/api/textcategorizer#score
|
||||
"""
|
||||
return Scorer.score_cats(
|
||||
examples,
|
||||
"cats",
|
||||
|
|
|
@ -160,7 +160,7 @@ cdef class Parser:
|
|||
self.set_annotations([doc], states)
|
||||
return doc
|
||||
|
||||
def pipe(self, docs, int batch_size=256):
|
||||
def pipe(self, docs, *, int batch_size=256):
|
||||
"""Process a stream of documents.
|
||||
|
||||
stream: The sequence of documents to process.
|
||||
|
|
8
website/docs/api/dependencymatcher.md
Normal file
8
website/docs/api/dependencymatcher.md
Normal file
|
@ -0,0 +1,8 @@
|
|||
---
|
||||
title: DependencyMatcher
|
||||
teaser: Match sequences of tokens, based on the dependency parse
|
||||
tag: class
|
||||
source: spacy/matcher/dependencymatcher.pyx
|
||||
---
|
||||
|
||||
TODO: write
|
|
@ -2,18 +2,37 @@
|
|||
title: DependencyParser
|
||||
tag: class
|
||||
source: spacy/pipeline/dep_parser.pyx
|
||||
teaser: 'Pipeline component for syntactic dependency parsing'
|
||||
api_base_class: /api/pipe
|
||||
api_string_name: parser
|
||||
api_trainable: true
|
||||
---
|
||||
|
||||
This class is a subclass of `Pipe` and follows the same API. The pipeline
|
||||
component is available in the [processing pipeline](/usage/processing-pipelines)
|
||||
via the ID `"parser"`.
|
||||
## Config and implementation {#config}
|
||||
|
||||
## Implementation and defaults {#implementation}
|
||||
The default config is defined by the pipeline component factory and describes
|
||||
how the component should be configured. You can override its settings via the
|
||||
`config` argument on [`nlp.add_pipe`](/api/language#add_pipe) or in your
|
||||
[`config.cfg` for training](/usage/training#config). See the
|
||||
[model architectures](/api/architectures) documentation for details on the
|
||||
architectures and their arguments and hyperparameters.
|
||||
|
||||
See the [model architectures](/api/architectures) documentation for details on
|
||||
the architectures and their arguments and hyperparameters. To learn more about
|
||||
how to customize the config and train custom models, check out the
|
||||
[training config](/usage/training#config) docs.
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL
|
||||
> config = {
|
||||
> "moves": None,
|
||||
> # TODO: rest
|
||||
> "model": DEFAULT_PARSER_MODEL,
|
||||
> }
|
||||
> nlp.add_pipe("parser", config=config)
|
||||
> ```
|
||||
|
||||
| Setting | Type | Description | Default |
|
||||
| ------- | ------------------------------------------ | ----------------- | ----------------------------------------------------------------- |
|
||||
| `moves` | list | <!-- TODO: --> | `None` |
|
||||
| `model` | [`Model`](https://thinc.ai/docs/api-model) | The model to use. | [TransitionBasedParser](/api/architectures#TransitionBasedParser) |
|
||||
|
||||
```python
|
||||
https://github.com/explosion/spaCy/blob/develop/spacy/pipeline/dep_parser.pyx
|
||||
|
@ -30,18 +49,27 @@ https://github.com/explosion/spaCy/blob/develop/spacy/pipeline/dep_parser.pyx
|
|||
> # Construction via add_pipe with custom model
|
||||
> config = {"model": {"@architectures": "my_parser"}}
|
||||
> parser = nlp.add_pipe("parser", config=config)
|
||||
>
|
||||
> # Construction from class
|
||||
> from spacy.pipeline import DependencyParser
|
||||
> parser = DependencyParser(nlp.vocab, model)
|
||||
> ```
|
||||
|
||||
Create a new pipeline instance. In your application, you would normally use a
|
||||
shortcut for this and instantiate the component using its string name and
|
||||
[`nlp.add_pipe`](/api/language#add_pipe).
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------------------ | ------------------------------------------------------------------------------- |
|
||||
| `vocab` | `Vocab` | The shared vocabulary. |
|
||||
| `model` | `Model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. |
|
||||
| `**cfg` | - | Configuration parameters. |
|
||||
| **RETURNS** | `DependencyParser` | The newly constructed object. |
|
||||
| Name | Type | Description |
|
||||
| ----------------------------- | ------------------------------------------ | ------------------------------------------------------------------------------------------- |
|
||||
| `vocab` | `Vocab` | The shared vocabulary. |
|
||||
| `model` | [`Model`](https://thinc.ai/docs/api-model) | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. |
|
||||
| `name` | str | String name of the component instance. Used to add entries to the `losses` during training. |
|
||||
| `moves` | list | <!-- TODO: --> |
|
||||
| _keyword-only_ | | |
|
||||
| `update_with_oracle_cut_size` | int | <!-- TODO: --> |
|
||||
| `multitasks` | `Iterable` | <!-- TODO: --> |
|
||||
| `learn_tokens` | bool | <!-- TODO: --> |
|
||||
| `min_action_freq` | int | <!-- TODO: --> |
|
||||
|
||||
## DependencyParser.\_\_call\_\_ {#call tag="method"}
|
||||
|
||||
|
@ -56,8 +84,8 @@ and all pipeline components are applied to the `Doc` in order. Both
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> parser = DependencyParser(nlp.vocab)
|
||||
> doc = nlp("This is a sentence.")
|
||||
> parser = nlp.add_pipe("parser")
|
||||
> # This usually happens under the hood
|
||||
> processed = parser(doc)
|
||||
> ```
|
||||
|
@ -79,16 +107,37 @@ applied to the `Doc` in order. Both [`__call__`](/api/dependencyparser#call) and
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> parser = DependencyParser(nlp.vocab)
|
||||
> parser = nlp.add_pipe("parser")
|
||||
> for doc in parser.pipe(docs, batch_size=50):
|
||||
> pass
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------ | --------------- | ------------------------------------------------------ |
|
||||
| `stream` | `Iterable[Doc]` | A stream of documents. |
|
||||
| `batch_size` | int | The number of texts to buffer. Defaults to `128`. |
|
||||
| **YIELDS** | `Doc` | Processed documents in the order of the original text. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | ------------------------------------------------------ |
|
||||
| `stream` | `Iterable[Doc]` | A stream of documents. |
|
||||
| _keyword-only_ | | |
|
||||
| `batch_size` | int | The number of texts to buffer. Defaults to `128`. |
|
||||
| **YIELDS** | `Doc` | Processed documents in the order of the original text. |
|
||||
|
||||
## DependencyParser.begin_training {#begin_training tag="method"}
|
||||
|
||||
Initialize the pipe for training, using data examples if available. Return an
|
||||
[`Optimizer`](https://thinc.ai/docs/api-optimizers) object.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> parser = nlp.add_pipe("parser")
|
||||
> optimizer = parser.begin_training(pipeline=nlp.pipeline)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------- |
|
||||
| `get_examples` | `Callable[[], Iterable[Example]]` | Optional function that returns gold-standard annotations in the form of [`Example`](/api/example) objects. |
|
||||
| _keyword-only_ | | |
|
||||
| `pipeline` | `List[Tuple[str, Callable]]` | Optional list of pipeline components that this component is part of. |
|
||||
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | An optional optimizer. Will be created via [`create_optimizer`](/api/dependencyparser#create_optimizer) if not set. |
|
||||
| **RETURNS** | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||
|
||||
## DependencyParser.predict {#predict tag="method"}
|
||||
|
||||
|
@ -97,7 +146,7 @@ Apply the pipeline's model to a batch of docs, without modifying them.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> parser = DependencyParser(nlp.vocab)
|
||||
> parser = nlp.add_pipe("parser")
|
||||
> scores = parser.predict([doc1, doc2])
|
||||
> ```
|
||||
|
||||
|
@ -113,7 +162,7 @@ Modify a batch of documents, using pre-computed scores.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> parser = DependencyParser(nlp.vocab)
|
||||
> parser = nlp.add_pipe("parser")
|
||||
> scores = parser.predict([doc1, doc2])
|
||||
> parser.set_annotations([doc1, doc2], scores)
|
||||
> ```
|
||||
|
@ -132,7 +181,7 @@ model. Delegates to [`predict`](/api/dependencyparser#predict) and
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> parser = DependencyParser(nlp.vocab, parser_model)
|
||||
> parser = nlp.add_pipe("parser")
|
||||
> optimizer = nlp.begin_training()
|
||||
> losses = parser.update(examples, sgd=optimizer)
|
||||
> ```
|
||||
|
@ -144,7 +193,7 @@ model. Delegates to [`predict`](/api/dependencyparser#predict) and
|
|||
| `drop` | float | The dropout rate. |
|
||||
| `set_annotations` | bool | Whether or not to update the `Example` objects with the predictions, delegating to [`set_annotations`](/api/dependencyparser#set_annotations). |
|
||||
| `sgd` | `Optimizer` | The [`Optimizer`](https://thinc.ai/docs/api-optimizers) object. |
|
||||
| `losses` | `Dict[str, float]` | Optional record of the loss during training. The value keyed by the model's name is updated. |
|
||||
| `losses` | `Dict[str, float]` | Optional record of the loss during training. Updated using the component name as the key. |
|
||||
| **RETURNS** | `Dict[str, float]` | The updated `losses` dictionary. |
|
||||
|
||||
## DependencyParser.get_loss {#get_loss tag="method"}
|
||||
|
@ -155,36 +204,31 @@ predicted scores.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> parser = DependencyParser(nlp.vocab)
|
||||
> parser = nlp.add_pipe("parser")
|
||||
> scores = parser.predict([eg.predicted for eg in examples])
|
||||
> loss, d_loss = parser.get_loss(examples, scores)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------------------- | --------------------------------------------------- |
|
||||
| `examples` | `Iterable[Example]` | The batch of examples. |
|
||||
| `scores` | `syntax.StateClass` | Scores representing the model's predictions. |
|
||||
| **RETURNS** | tuple | The loss and the gradient, i.e. `(loss, gradient)`. |
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------------- | --------------------------------------------------- |
|
||||
| `examples` | `Iterable[Example]` | The batch of examples. |
|
||||
| `scores` | `syntax.StateClass` | Scores representing the model's predictions. |
|
||||
| **RETURNS** | `Tuple[float, float]` | The loss and the gradient, i.e. `(loss, gradient)`. |
|
||||
|
||||
## DependencyParser.begin_training {#begin_training tag="method"}
|
||||
## DependencyParser.score {#score tag="method" new="3"}
|
||||
|
||||
Initialize the pipe for training, using data examples if available. Return an
|
||||
[`Optimizer`](https://thinc.ai/docs/api-optimizers) object.
|
||||
Score a batch of examples.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> parser = DependencyParser(nlp.vocab)
|
||||
> nlp.pipeline.append(parser)
|
||||
> optimizer = parser.begin_training(pipeline=nlp.pipeline)
|
||||
> scores = parser.score(examples)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------- | ----------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `get_examples` | `Iterable[Example]` | Optional gold-standard annotations in the form of [`Example`](/api/example) objects. |
|
||||
| `pipeline` | `List[(str, callable)]` | Optional list of pipeline components that this component is part of. |
|
||||
| `sgd` | `Optimizer` | An optional [`Optimizer`](https://thinc.ai/docs/api-optimizers) object. Will be created via [`create_optimizer`](/api/dependencyparser#create_optimizer) if not set. |
|
||||
| **RETURNS** | `Optimizer` | An optimizer. |
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------------------- | -------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `examples` | `Iterable[Example]` | The examples to score. |
|
||||
| **RETURNS** | `Dict[str, Any]` | The scores, produced by [`Scorer.score_spans`](/api/scorer#score_spans) and [`Scorer.score_deps`](/api/scorer#score_deps). |
|
||||
|
||||
## DependencyParser.create_optimizer {#create_optimizer tag="method"}
|
||||
|
||||
|
@ -194,13 +238,13 @@ component.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> parser = DependencyParser(nlp.vocab)
|
||||
> parser = nlp.add_pipe("parser")
|
||||
> optimizer = parser.create_optimizer()
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ----------- | --------------------------------------------------------------- |
|
||||
| **RETURNS** | `Optimizer` | The [`Optimizer`](https://thinc.ai/docs/api-optimizers) object. |
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------------------------------------------- | -------------- |
|
||||
| **RETURNS** | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||
|
||||
## DependencyParser.use_params {#use_params tag="method, contextmanager"}
|
||||
|
||||
|
@ -225,7 +269,7 @@ Add a new label to the pipe.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> parser = DependencyParser(nlp.vocab)
|
||||
> parser = nlp.add_pipe("parser")
|
||||
> parser.add_label("MY_LABEL")
|
||||
> ```
|
||||
|
||||
|
@ -240,14 +284,14 @@ Serialize the pipe to disk.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> parser = DependencyParser(nlp.vocab)
|
||||
> parser = nlp.add_pipe("parser")
|
||||
> parser.to_disk("/path/to/parser")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| --------- | ------------ | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | list | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| Name | Type | Description |
|
||||
| --------- | --------------- | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
|
||||
## DependencyParser.from_disk {#from_disk tag="method"}
|
||||
|
||||
|
@ -256,14 +300,14 @@ Load the pipe from disk. Modifies the object in place and returns it.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> parser = DependencyParser(nlp.vocab)
|
||||
> parser = nlp.add_pipe("parser")
|
||||
> parser.from_disk("/path/to/parser")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------------------ | -------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | list | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `DependencyParser` | The modified `DependencyParser` object. |
|
||||
|
||||
## DependencyParser.to_bytes {#to_bytes tag="method"}
|
||||
|
@ -271,16 +315,16 @@ Load the pipe from disk. Modifies the object in place and returns it.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> parser = DependencyParser(nlp.vocab)
|
||||
> parser = nlp.add_pipe("parser")
|
||||
> parser_bytes = parser.to_bytes()
|
||||
> ```
|
||||
|
||||
Serialize the pipe to a bytestring.
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ----- | ------------------------------------------------------------------------- |
|
||||
| `exclude` | list | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | bytes | The serialized form of the `DependencyParser` object. |
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------- | ------------------------------------------------------------------------- |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | bytes | The serialized form of the `DependencyParser` object. |
|
||||
|
||||
## DependencyParser.from_bytes {#from_bytes tag="method"}
|
||||
|
||||
|
@ -290,14 +334,14 @@ Load the pipe from a bytestring. Modifies the object in place and returns it.
|
|||
>
|
||||
> ```python
|
||||
> parser_bytes = parser.to_bytes()
|
||||
> parser = DependencyParser(nlp.vocab)
|
||||
> parser = nlp.add_pipe("parser")
|
||||
> parser.from_bytes(parser_bytes)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------ | ------------------ | ------------------------------------------------------------------------- |
|
||||
| `bytes_data` | bytes | The data to load from. |
|
||||
| `exclude` | list | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `DependencyParser` | The `DependencyParser` object. |
|
||||
|
||||
## DependencyParser.labels {#labels tag="property"}
|
||||
|
|
|
@ -1,23 +1,44 @@
|
|||
---
|
||||
title: EntityLinker
|
||||
teaser:
|
||||
Functionality to disambiguate a named entity in text to a unique knowledge
|
||||
base identifier.
|
||||
tag: class
|
||||
source: spacy/pipeline/entity_linker.py
|
||||
new: 2.2
|
||||
teaser: 'Pipeline component for named entity linking and disambiguation'
|
||||
api_base_class: /api/pipe
|
||||
api_string_name: entity_linker
|
||||
api_trainable: true
|
||||
---
|
||||
|
||||
This class is a subclass of `Pipe` and follows the same API. The pipeline
|
||||
component is available in the [processing pipeline](/usage/processing-pipelines)
|
||||
via the ID `"entity_linker"`.
|
||||
## Config and implementation {#config}
|
||||
|
||||
## Implementation and defaults {#implementation}
|
||||
The default config is defined by the pipeline component factory and describes
|
||||
how the component should be configured. You can override its settings via the
|
||||
`config` argument on [`nlp.add_pipe`](/api/language#add_pipe) or in your
|
||||
[`config.cfg` for training](/usage/training#config). See the
|
||||
[model architectures](/api/architectures) documentation for details on the
|
||||
architectures and their arguments and hyperparameters.
|
||||
|
||||
See the [model architectures](/api/architectures) documentation for details on
|
||||
the architectures and their arguments and hyperparameters. To learn more about
|
||||
how to customize the config and train custom models, check out the
|
||||
[training config](/usage/training#config) docs.
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> from spacy.pipeline.entity_linker import DEFAULT_NEL_MODEL
|
||||
> config = {
|
||||
> "kb": None,
|
||||
> "labels_discard": [],
|
||||
> "incl_prior": True,
|
||||
> "incl_context": True,
|
||||
> "model": DEFAULT_NEL_MODEL,
|
||||
> }
|
||||
> nlp.add_pipe("entity_linker", config=config)
|
||||
> ```
|
||||
|
||||
| Setting | Type | Description | Default |
|
||||
| ---------------- | ------------------------------------------ | ----------------- | ----------------------------------------------- |
|
||||
| `kb` | `KnowledgeBase` | <!-- TODO: --> | `None` |
|
||||
| `labels_discard` | `Iterable[str]` | <!-- TODO: --> | `[]` |
|
||||
| `incl_prior` | bool | <!-- TODO: --> | `True` |
|
||||
| `incl_context` | bool | <!-- TODO: --> | `True` |
|
||||
| `model` | [`Model`](https://thinc.ai/docs/api-model) | The model to use. | [EntityLinker](/api/architectures#EntityLinker) |
|
||||
|
||||
```python
|
||||
https://github.com/explosion/spaCy/blob/develop/spacy/pipeline/entity_linker.py
|
||||
|
@ -34,19 +55,26 @@ https://github.com/explosion/spaCy/blob/develop/spacy/pipeline/entity_linker.py
|
|||
> # Construction via add_pipe with custom model
|
||||
> config = {"model": {"@architectures": "my_el"}}
|
||||
> entity_linker = nlp.add_pipe("entity_linker", config=config)
|
||||
>
|
||||
> # Construction from class
|
||||
> from spacy.pipeline import EntityLinker
|
||||
> entity_linker = EntityLinker(nlp.vocab, model)
|
||||
> ```
|
||||
|
||||
Create a new pipeline instance. In your application, you would normally use a
|
||||
shortcut for this and instantiate the component using its string name and
|
||||
[`nlp.add_pipe`](/api/language#add_pipe).
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------- | ------- | ------------------------------------------------------------------------------- |
|
||||
| `vocab` | `Vocab` | The shared vocabulary. |
|
||||
| `model` | `Model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. |
|
||||
| `**cfg` | - | Configuration parameters. |
|
||||
|
||||
| **RETURNS** | `EntityLinker` | The newly constructed object. |
|
||||
| Name | Type | Description |
|
||||
| ---------------- | --------------- | ------------------------------------------------------------------------------------------- |
|
||||
| `vocab` | `Vocab` | The shared vocabulary. |
|
||||
| `model` | `Model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. |
|
||||
| `name` | str | String name of the component instance. Used to add entries to the `losses` during training. |
|
||||
| _keyword-only_ | | |
|
||||
| `kb` | `KnowlegeBase` | <!-- TODO: --> |
|
||||
| `labels_discard` | `Iterable[str]` | <!-- TODO: --> |
|
||||
| `incl_prior` | bool | <!-- TODO: --> |
|
||||
| `incl_context` | bool | <!-- TODO: --> |
|
||||
|
||||
## EntityLinker.\_\_call\_\_ {#call tag="method"}
|
||||
|
||||
|
@ -60,8 +88,8 @@ delegate to the [`predict`](/api/entitylinker#predict) and
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> entity_linker = EntityLinker(nlp.vocab)
|
||||
> doc = nlp("This is a sentence.")
|
||||
> entity_linker = nlp.add_pipe("entity_linker")
|
||||
> # This usually happens under the hood
|
||||
> processed = entity_linker(doc)
|
||||
> ```
|
||||
|
@ -83,91 +111,17 @@ applied to the `Doc` in order. Both [`__call__`](/api/entitylinker#call) and
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> entity_linker = EntityLinker(nlp.vocab)
|
||||
> entity_linker = nlp.add_pipe("entity_linker")
|
||||
> for doc in entity_linker.pipe(docs, batch_size=50):
|
||||
> pass
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------ | --------------- | ------------------------------------------------------ |
|
||||
| `stream` | `Iterable[Doc]` | A stream of documents. |
|
||||
| `batch_size` | int | The number of texts to buffer. Defaults to `128`. |
|
||||
| **YIELDS** | `Doc` | Processed documents in the order of the original text. |
|
||||
|
||||
## EntityLinker.predict {#predict tag="method"}
|
||||
|
||||
Apply the pipeline's model to a batch of docs, without modifying them.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> entity_linker = EntityLinker(nlp.vocab)
|
||||
> kb_ids = entity_linker.predict([doc1, doc2])
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------- | ------------------------------------------------------------ |
|
||||
| `docs` | `Iterable[Doc]` | The documents to predict. |
|
||||
| **RETURNS** | `Iterable[str]` | The predicted KB identifiers for the entities in the `docs`. |
|
||||
|
||||
## EntityLinker.set_annotations {#set_annotations tag="method"}
|
||||
|
||||
Modify a batch of documents, using pre-computed entity IDs for a list of named
|
||||
entities.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> entity_linker = EntityLinker(nlp.vocab)
|
||||
> kb_ids = entity_linker.predict([doc1, doc2])
|
||||
> entity_linker.set_annotations([doc1, doc2], kb_ids)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------- | --------------- | ------------------------------------------------------------------------------------------------- |
|
||||
| `docs` | `Iterable[Doc]` | The documents to modify. |
|
||||
| `kb_ids` | `Iterable[str]` | The knowledge base identifiers for the entities in the docs, predicted by `EntityLinker.predict`. |
|
||||
|
||||
## EntityLinker.update {#update tag="method"}
|
||||
|
||||
Learn from a batch of [`Example`](/api/example) objects, updating both the
|
||||
pipe's entity linking model and context encoder. Delegates to
|
||||
[`predict`](/api/entitylinker#predict) and
|
||||
[`get_loss`](/api/entitylinker#get_loss).
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> entity_linker = EntityLinker(nlp.vocab, nel_model)
|
||||
> optimizer = nlp.begin_training()
|
||||
> losses = entity_linker.update(examples, sgd=optimizer)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------------- | ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `examples` | `Iterable[Example]` | A batch of [`Example`](/api/example) objects to learn from. |
|
||||
| _keyword-only_ | | |
|
||||
| `drop` | float | The dropout rate. |
|
||||
| `set_annotations` | bool | Whether or not to update the `Example` objects with the predictions, delegating to [`set_annotations`](/api/entitylinker#set_annotations). |
|
||||
| `sgd` | `Optimizer` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) object. |
|
||||
| `losses` | `Dict[str, float]` | Optional record of the loss during training. The value keyed by the model's name is updated. |
|
||||
| **RETURNS** | `Dict[str, float]` | The updated `losses` dictionary. |
|
||||
|
||||
## EntityLinker.set_kb {#set_kb tag="method"}
|
||||
|
||||
Define the knowledge base (KB) used for disambiguating named entities to KB
|
||||
identifiers.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> entity_linker = EntityLinker(nlp.vocab)
|
||||
> entity_linker.set_kb(kb)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ---- | --------------- | ------------------------------- |
|
||||
| `kb` | `KnowledgeBase` | The [`KnowledgeBase`](/api/kb). |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | ------------------------------------------------------ |
|
||||
| `stream` | `Iterable[Doc]` | A stream of documents. |
|
||||
| _keyword-only_ | | |
|
||||
| `batch_size` | int | The number of texts to buffer. Defaults to `128`. |
|
||||
| **YIELDS** | `Doc` | Processed documents in the order of the original text. |
|
||||
|
||||
## EntityLinker.begin_training {#begin_training tag="method"}
|
||||
|
||||
|
@ -184,12 +138,89 @@ method, a knowledge base should have been defined with
|
|||
> optimizer = entity_linker.begin_training(pipeline=nlp.pipeline)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------- | ----------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `get_examples` | `Iterable[Example]` | Optional gold-standard annotations in the form of [`Example`](/api/example) objects. |
|
||||
| `pipeline` | `List[(str, callable)]` | Optional list of pipeline components that this component is part of. |
|
||||
| `sgd` | `Optimizer` | An optional [`Optimizer`](https://thinc.ai/docs/api-optimizers) object. Will be created via [`create_optimizer`](/api/entitylinker#create_optimizer) if not set. |
|
||||
| **RETURNS** | `Optimizer` | An optimizer. | |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------- |
|
||||
| `get_examples` | `Callable[[], Iterable[Example]]` | Optional function that returns gold-standard annotations in the form of [`Example`](/api/example) objects. |
|
||||
| _keyword-only_ | | |
|
||||
| `pipeline` | `List[Tuple[str, Callable]]` | Optional list of pipeline components that this component is part of. |
|
||||
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | An optional optimizer. Will be created via [`create_optimizer`](/api/dependencyparser#create_optimizer) if not set. |
|
||||
| **RETURNS** | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||
|
||||
## EntityLinker.predict {#predict tag="method"}
|
||||
|
||||
Apply the pipeline's model to a batch of docs, without modifying them. Returns
|
||||
the KB IDs for each entity in each doc, including `NIL` if there is no
|
||||
prediction.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> entity_linker = nlp.add_pipe("entity_linker")
|
||||
> kb_ids = entity_linker.predict([doc1, doc2])
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------- | ------------------------------------------------------------ |
|
||||
| `docs` | `Iterable[Doc]` | The documents to predict. |
|
||||
| **RETURNS** | `List[str]` | The predicted KB identifiers for the entities in the `docs`. |
|
||||
|
||||
## EntityLinker.set_annotations {#set_annotations tag="method"}
|
||||
|
||||
Modify a batch of documents, using pre-computed entity IDs for a list of named
|
||||
entities.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> entity_linker = nlp.add_pipe("entity_linker")
|
||||
> kb_ids = entity_linker.predict([doc1, doc2])
|
||||
> entity_linker.set_annotations([doc1, doc2], kb_ids)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------- | --------------- | ------------------------------------------------------------------------------------------------- |
|
||||
| `docs` | `Iterable[Doc]` | The documents to modify. |
|
||||
| `kb_ids` | `List[str]` | The knowledge base identifiers for the entities in the docs, predicted by `EntityLinker.predict`. |
|
||||
|
||||
## EntityLinker.update {#update tag="method"}
|
||||
|
||||
Learn from a batch of [`Example`](/api/example) objects, updating both the
|
||||
pipe's entity linking model and context encoder. Delegates to
|
||||
[`predict`](/api/entitylinker#predict).
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> entity_linker = nlp.add_pipe("entity_linker")
|
||||
> optimizer = nlp.begin_training()
|
||||
> losses = entity_linker.update(examples, sgd=optimizer)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------------- | --------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `examples` | `Iterable[Example]` | A batch of [`Example`](/api/example) objects to learn from. |
|
||||
| _keyword-only_ | | |
|
||||
| `drop` | float | The dropout rate. |
|
||||
| `set_annotations` | bool | Whether or not to update the `Example` objects with the predictions, delegating to [`set_annotations`](/api/textcategorizer#set_annotations). |
|
||||
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||
| `losses` | `Dict[str, float]` | Optional record of the loss during training. Updated using the component name as the key. |
|
||||
| **RETURNS** | `Dict[str, float]` | The updated `losses` dictionary. |
|
||||
|
||||
## EntityLinker.set_kb {#set_kb tag="method"}
|
||||
|
||||
Define the knowledge base (KB) used for disambiguating named entities to KB
|
||||
identifiers.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> entity_linker = nlp.add_pipe("entity_linker")
|
||||
> entity_linker.set_kb(kb)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ---- | --------------- | ------------------------------- |
|
||||
| `kb` | `KnowledgeBase` | The [`KnowledgeBase`](/api/kb). |
|
||||
|
||||
## EntityLinker.create_optimizer {#create_optimizer tag="method"}
|
||||
|
||||
|
@ -198,13 +229,13 @@ Create an optimizer for the pipeline component.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> entity_linker = EntityLinker(nlp.vocab)
|
||||
> entity_linker = nlp.add_pipe("entity_linker")
|
||||
> optimizer = entity_linker.create_optimizer()
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ----------- | --------------------------------------------------------------- |
|
||||
| **RETURNS** | `Optimizer` | The [`Optimizer`](https://thinc.ai/docs/api-optimizers) object. |
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------------------------------------------- | -------------- |
|
||||
| **RETURNS** | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||
|
||||
## EntityLinker.use_params {#use_params tag="method, contextmanager"}
|
||||
|
||||
|
@ -213,7 +244,7 @@ Modify the pipe's EL model, to use the given parameter values.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> entity_linker = EntityLinker(nlp.vocab)
|
||||
> entity_linker = nlp.add_pipe("entity_linker")
|
||||
> with entity_linker.use_params(optimizer.averages):
|
||||
> entity_linker.to_disk("/best_model")
|
||||
> ```
|
||||
|
@ -229,14 +260,14 @@ Serialize the pipe to disk.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> entity_linker = EntityLinker(nlp.vocab)
|
||||
> entity_linker = nlp.add_pipe("entity_linker")
|
||||
> entity_linker.to_disk("/path/to/entity_linker")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| --------- | ------------ | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | list | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| Name | Type | Description |
|
||||
| --------- | --------------- | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
|
||||
## EntityLinker.from_disk {#from_disk tag="method"}
|
||||
|
||||
|
@ -245,15 +276,15 @@ Load the pipe from disk. Modifies the object in place and returns it.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> entity_linker = EntityLinker(nlp.vocab)
|
||||
> entity_linker = nlp.add_pipe("entity_linker")
|
||||
> entity_linker.from_disk("/path/to/entity_linker")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | -------------- | -------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | list | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `EntityLinker` | The modified `EntityLinker` object. |
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------- | -------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `EntityLinker` | The modified `EntityLinker` object. |
|
||||
|
||||
## Serialization fields {#serialization-fields}
|
||||
|
||||
|
|
|
@ -2,18 +2,37 @@
|
|||
title: EntityRecognizer
|
||||
tag: class
|
||||
source: spacy/pipeline/ner.pyx
|
||||
teaser: 'Pipeline component for named entity recognition'
|
||||
api_base_class: /api/pipe
|
||||
api_string_name: ner
|
||||
api_trainable: true
|
||||
---
|
||||
|
||||
This class is a subclass of `Pipe` and follows the same API. The pipeline
|
||||
component is available in the [processing pipeline](/usage/processing-pipelines)
|
||||
via the ID `"ner"`.
|
||||
## Config and implementation {#config}
|
||||
|
||||
## Implementation and defaults {#implementation}
|
||||
The default config is defined by the pipeline component factory and describes
|
||||
how the component should be configured. You can override its settings via the
|
||||
`config` argument on [`nlp.add_pipe`](/api/language#add_pipe) or in your
|
||||
[`config.cfg` for training](/usage/training#config). See the
|
||||
[model architectures](/api/architectures) documentation for details on the
|
||||
architectures and their arguments and hyperparameters.
|
||||
|
||||
See the [model architectures](/api/architectures) documentation for details on
|
||||
the architectures and their arguments and hyperparameters. To learn more about
|
||||
how to customize the config and train custom models, check out the
|
||||
[training config](/usage/training#config) docs.
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> from spacy.pipeline.ner import DEFAULT_NER_MODEL
|
||||
> config = {
|
||||
> "moves": None,
|
||||
> # TODO: rest
|
||||
> "model": DEFAULT_NER_MODEL,
|
||||
> }
|
||||
> nlp.add_pipe("ner", config=config)
|
||||
> ```
|
||||
|
||||
| Setting | Type | Description | Default |
|
||||
| ------- | ------------------------------------------ | ----------------- | ----------------------------------------------------------------- |
|
||||
| `moves` | list | <!-- TODO: --> | `None` |
|
||||
| `model` | [`Model`](https://thinc.ai/docs/api-model) | The model to use. | [TransitionBasedParser](/api/architectures#TransitionBasedParser) |
|
||||
|
||||
```python
|
||||
https://github.com/explosion/spaCy/blob/develop/spacy/pipeline/ner.pyx
|
||||
|
@ -30,18 +49,27 @@ https://github.com/explosion/spaCy/blob/develop/spacy/pipeline/ner.pyx
|
|||
> # Construction via add_pipe with custom model
|
||||
> config = {"model": {"@architectures": "my_ner"}}
|
||||
> parser = nlp.add_pipe("ner", config=config)
|
||||
>
|
||||
> # Construction from class
|
||||
> from spacy.pipeline import EntityRecognizer
|
||||
> ner = EntityRecognizer(nlp.vocab, model)
|
||||
> ```
|
||||
|
||||
Create a new pipeline instance. In your application, you would normally use a
|
||||
shortcut for this and instantiate the component using its string name and
|
||||
[`nlp.add_pipe`](/api/language#add_pipe).
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------------------ | ------------------------------------------------------------------------------- |
|
||||
| `vocab` | `Vocab` | The shared vocabulary. |
|
||||
| `model` | `Model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. |
|
||||
| `**cfg` | - | Configuration parameters. |
|
||||
| **RETURNS** | `EntityRecognizer` | The newly constructed object. |
|
||||
| Name | Type | Description |
|
||||
| ----------------------------- | ------------------------------------------ | ------------------------------------------------------------------------------------------- |
|
||||
| `vocab` | `Vocab` | The shared vocabulary. |
|
||||
| `model` | [`Model`](https://thinc.ai/docs/api-model) | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. |
|
||||
| `name` | str | String name of the component instance. Used to add entries to the `losses` during training. |
|
||||
| `moves` | list | <!-- TODO: --> |
|
||||
| _keyword-only_ | | |
|
||||
| `update_with_oracle_cut_size` | int | <!-- TODO: --> |
|
||||
| `multitasks` | `Iterable` | <!-- TODO: --> |
|
||||
| `learn_tokens` | bool | <!-- TODO: --> |
|
||||
| `min_action_freq` | int | <!-- TODO: --> |
|
||||
|
||||
## EntityRecognizer.\_\_call\_\_ {#call tag="method"}
|
||||
|
||||
|
@ -56,8 +84,8 @@ and all pipeline components are applied to the `Doc` in order. Both
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> ner = EntityRecognizer(nlp.vocab)
|
||||
> doc = nlp("This is a sentence.")
|
||||
> ner = nlp.add_pipe("ner")
|
||||
> # This usually happens under the hood
|
||||
> processed = ner(doc)
|
||||
> ```
|
||||
|
@ -79,16 +107,37 @@ applied to the `Doc` in order. Both [`__call__`](/api/entityrecognizer#call) and
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> ner = EntityRecognizer(nlp.vocab)
|
||||
> ner = nlp.add_pipe("ner")
|
||||
> for doc in ner.pipe(docs, batch_size=50):
|
||||
> pass
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------ | --------------- | ------------------------------------------------------ |
|
||||
| `stream` | `Iterable[Doc]` | A stream of documents. |
|
||||
| `batch_size` | int | The number of texts to buffer. Defaults to `128`. |
|
||||
| **YIELDS** | `Doc` | Processed documents in the order of the original text. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | ------------------------------------------------------ |
|
||||
| `docs` | `Iterable[Doc]` | A stream of documents. |
|
||||
| _keyword-only_ | | |
|
||||
| `batch_size` | int | The number of texts to buffer. Defaults to `128`. |
|
||||
| **YIELDS** | `Doc` | Processed documents in the order of the original text. |
|
||||
|
||||
## EntityRecognizer.begin_training {#begin_training tag="method"}
|
||||
|
||||
Initialize the pipe for training, using data examples if available. Return an
|
||||
[`Optimizer`](https://thinc.ai/docs/api-optimizers) object.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> ner = nlp.add_pipe("ner")
|
||||
> optimizer = ner.begin_training(pipeline=nlp.pipeline)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------- |
|
||||
| `get_examples` | `Callable[[], Iterable[Example]]` | Optional function that returns gold-standard annotations in the form of [`Example`](/api/example) objects. |
|
||||
| _keyword-only_ | | |
|
||||
| `pipeline` | `List[Tuple[str, Callable]]` | Optional list of pipeline components that this component is part of. |
|
||||
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | An optional optimizer. Will be created via [`create_optimizer`](/api/entityrecognizer#create_optimizer) if not set. |
|
||||
| **RETURNS** | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||
|
||||
## EntityRecognizer.predict {#predict tag="method"}
|
||||
|
||||
|
@ -97,7 +146,7 @@ Apply the pipeline's model to a batch of docs, without modifying them.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> ner = EntityRecognizer(nlp.vocab)
|
||||
> ner = nlp.add_pipe("ner")
|
||||
> scores = ner.predict([doc1, doc2])
|
||||
> ```
|
||||
|
||||
|
@ -113,7 +162,7 @@ Modify a batch of documents, using pre-computed scores.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> ner = EntityRecognizer(nlp.vocab)
|
||||
> ner = nlp.add_pipe("ner")
|
||||
> scores = ner.predict([doc1, doc2])
|
||||
> ner.set_annotations([doc1, doc2], scores)
|
||||
> ```
|
||||
|
@ -132,20 +181,20 @@ model. Delegates to [`predict`](/api/entityrecognizer#predict) and
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> ner = EntityRecognizer(nlp.vocab, ner_model)
|
||||
> ner = nlp.add_pipe("ner")
|
||||
> optimizer = nlp.begin_training()
|
||||
> losses = ner.update(examples, sgd=optimizer)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------------- | ------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `examples` | `Iterable[Example]` | A batch of [`Example`](/api/example) objects to learn from. |
|
||||
| _keyword-only_ | | |
|
||||
| `drop` | float | The dropout rate. |
|
||||
| `set_annotations` | bool | Whether or not to update the `Example` objects with the predictions, delegating to [`set_annotations`](/api/entityrecognizer#set_annotations). |
|
||||
| `sgd` | `Optimizer` | The [`Optimizer`](https://thinc.ai/docs/api-optimizers) object. |
|
||||
| `losses` | `Dict[str, float]` | Optional record of the loss during training. The value keyed by the model's name is updated. |
|
||||
| **RETURNS** | `Dict[str, float]` | The updated `losses` dictionary. |
|
||||
| Name | Type | Description |
|
||||
| ----------------- | --------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `examples` | `Iterable[Example]` | A batch of [`Example`](/api/example) objects to learn from. |
|
||||
| _keyword-only_ | | |
|
||||
| `drop` | float | The dropout rate. |
|
||||
| `set_annotations` | bool | Whether or not to update the `Example` objects with the predictions, delegating to [`set_annotations`](/api/entityrecognizer#set_annotations). |
|
||||
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||
| `losses` | `Dict[str, float]` | Optional record of the loss during training. Updated using the component name as the key. |
|
||||
| **RETURNS** | `Dict[str, float]` | The updated `losses` dictionary. |
|
||||
|
||||
## EntityRecognizer.get_loss {#get_loss tag="method"}
|
||||
|
||||
|
@ -155,36 +204,31 @@ predicted scores.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> ner = EntityRecognizer(nlp.vocab)
|
||||
> ner = nlp.add_pipe("ner")
|
||||
> scores = ner.predict([eg.predicted for eg in examples])
|
||||
> loss, d_loss = ner.get_loss(examples, scores)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------------------- | --------------------------------------------------- |
|
||||
| `examples` | `Iterable[Example]` | The batch of examples. |
|
||||
| `scores` | `List[StateClass]` | Scores representing the model's predictions. |
|
||||
| **RETURNS** | tuple | The loss and the gradient, i.e. `(loss, gradient)`. |
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------------- | --------------------------------------------------- |
|
||||
| `examples` | `Iterable[Example]` | The batch of examples. |
|
||||
| `scores` | `List[StateClass]` | Scores representing the model's predictions. |
|
||||
| **RETURNS** | `Tuple[float, float]` | The loss and the gradient, i.e. `(loss, gradient)`. |
|
||||
|
||||
## EntityRecognizer.begin_training {#begin_training tag="method"}
|
||||
## EntityRecognizer.score {#score tag="method" new="3"}
|
||||
|
||||
Initialize the pipe for training, using data examples if available. Return an
|
||||
[`Optimizer`](https://thinc.ai/docs/api-optimizers) object.
|
||||
Score a batch of examples.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> ner = EntityRecognizer(nlp.vocab)
|
||||
> nlp.pipeline.append(ner)
|
||||
> optimizer = ner.begin_training(pipeline=nlp.pipeline)
|
||||
> scores = ner.score(examples)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------- | ----------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `get_examples` | `Iterable[Example]` | Optional gold-standard annotations in the form of [`Example`](/api/example) objects. |
|
||||
| `pipeline` | `List[(str, callable)]` | Optional list of pipeline components that this component is part of. |
|
||||
| `sgd` | `Optimizer` | An optional [`Optimizer`](https://thinc.ai/docs/api-optimizers) object. Will be created via [`create_optimizer`](/api/entityrecognizer#create_optimizer) if not set. |
|
||||
| **RETURNS** | `Optimizer` | An optimizer. |
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------------------- | ------------------------------------------------------------------------ |
|
||||
| `examples` | `Iterable[Example]` | The examples to score. |
|
||||
| **RETURNS** | `Dict[str, Any]` | The scores, produced by [`Scorer.score_spans`](/api/scorer#score_spans). |
|
||||
|
||||
## EntityRecognizer.create_optimizer {#create_optimizer tag="method"}
|
||||
|
||||
|
@ -193,13 +237,13 @@ Create an optimizer for the pipeline component.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> ner = EntityRecognizer(nlp.vocab)
|
||||
> ner = nlp.add_pipe("ner")
|
||||
> optimizer = ner.create_optimizer()
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ----------- | --------------------------------------------------------------- |
|
||||
| **RETURNS** | `Optimizer` | The [`Optimizer`](https://thinc.ai/docs/api-optimizers) object. |
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------------------------------------------- | -------------- |
|
||||
| **RETURNS** | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||
|
||||
## EntityRecognizer.use_params {#use_params tag="method, contextmanager"}
|
||||
|
||||
|
@ -224,7 +268,7 @@ Add a new label to the pipe.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> ner = EntityRecognizer(nlp.vocab)
|
||||
> ner = nlp.add_pipe("ner")
|
||||
> ner.add_label("MY_LABEL")
|
||||
> ```
|
||||
|
||||
|
@ -239,14 +283,14 @@ Serialize the pipe to disk.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> ner = EntityRecognizer(nlp.vocab)
|
||||
> ner = nlp.add_pipe("ner")
|
||||
> ner.to_disk("/path/to/ner")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| --------- | ------------ | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | list | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| Name | Type | Description |
|
||||
| --------- | --------------- | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
|
||||
## EntityRecognizer.from_disk {#from_disk tag="method"}
|
||||
|
||||
|
@ -255,14 +299,14 @@ Load the pipe from disk. Modifies the object in place and returns it.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> ner = EntityRecognizer(nlp.vocab)
|
||||
> ner = nlp.add_pipe("ner")
|
||||
> ner.from_disk("/path/to/ner")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------------------ | -------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | list | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `EntityRecognizer` | The modified `EntityRecognizer` object. |
|
||||
|
||||
## EntityRecognizer.to_bytes {#to_bytes tag="method"}
|
||||
|
@ -270,16 +314,16 @@ Load the pipe from disk. Modifies the object in place and returns it.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> ner = EntityRecognizer(nlp.vocab)
|
||||
> ner = nlp.add_pipe("ner")
|
||||
> ner_bytes = ner.to_bytes()
|
||||
> ```
|
||||
|
||||
Serialize the pipe to a bytestring.
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ----- | ------------------------------------------------------------------------- |
|
||||
| `exclude` | list | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | bytes | The serialized form of the `EntityRecognizer` object. |
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------- | ------------------------------------------------------------------------- |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | bytes | The serialized form of the `EntityRecognizer` object. |
|
||||
|
||||
## EntityRecognizer.from_bytes {#from_bytes tag="method"}
|
||||
|
||||
|
@ -289,14 +333,14 @@ Load the pipe from a bytestring. Modifies the object in place and returns it.
|
|||
>
|
||||
> ```python
|
||||
> ner_bytes = ner.to_bytes()
|
||||
> ner = EntityRecognizer(nlp.vocab)
|
||||
> ner = nlp.add_pipe("ner")
|
||||
> ner.from_bytes(ner_bytes)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------ | ------------------ | ------------------------------------------------------------------------- |
|
||||
| `bytes_data` | bytes | The data to load from. |
|
||||
| `exclude` | list | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `EntityRecognizer` | The `EntityRecognizer` object. |
|
||||
|
||||
## EntityRecognizer.labels {#labels tag="property"}
|
||||
|
|
|
@ -3,17 +3,48 @@ title: EntityRuler
|
|||
tag: class
|
||||
source: spacy/pipeline/entityruler.py
|
||||
new: 2.1
|
||||
teaser: 'Pipeline component for rule-based named entity recognition'
|
||||
api_string_name: entity_ruler
|
||||
api_trainable: false
|
||||
---
|
||||
|
||||
The EntityRuler lets you add spans to the [`Doc.ents`](/api/doc#ents) using
|
||||
The entity ruler lets you add spans to the [`Doc.ents`](/api/doc#ents) using
|
||||
token-based rules or exact phrase matches. It can be combined with the
|
||||
statistical [`EntityRecognizer`](/api/entityrecognizer) to boost accuracy, or
|
||||
used on its own to implement a purely rule-based entity recognition system. The
|
||||
pipeline component is available in the
|
||||
[processing pipeline](/usage/processing-pipelines) via the ID `"entity_ruler"`.
|
||||
For usage examples, see the docs on
|
||||
used on its own to implement a purely rule-based entity recognition system. For
|
||||
usage examples, see the docs on
|
||||
[rule-based entity recognition](/usage/rule-based-matching#entityruler).
|
||||
|
||||
## Config and implementation {#config}
|
||||
|
||||
The default config is defined by the pipeline component factory and describes
|
||||
how the component should be configured. You can override its settings via the
|
||||
`config` argument on [`nlp.add_pipe`](/api/language#add_pipe) or in your
|
||||
[`config.cfg` for training](/usage/training#config).
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> config = {
|
||||
> "phrase_matcher_attr": None,
|
||||
> "validation": True,
|
||||
> "overwrite_ents": False,
|
||||
> "ent_id_sep": "||",
|
||||
> }
|
||||
> nlp.add_pipe("entity_ruler", config=config)
|
||||
> ```
|
||||
|
||||
| Setting | Type | Description | Default |
|
||||
| --------------------- | ---- | ------------------------------------------------------------------------------------------------------------------------------------------- | ------- |
|
||||
| `phrase_matcher_attr` | str | Optional attribute name match on for the internal [`PhraseMatcher`](/api/phrasematcher), e.g. `LOWER` to match on the lowercase token text. | `None` |
|
||||
| `validation` | bool | Whether patterns should be validated, passed to Matcher and PhraseMatcher as `validate`. | `False` |
|
||||
| `overwrite_ents` | bool | If existing entities are present, e.g. entities added by the model, overwrite them by matches if necessary. | `False` |
|
||||
| `ent_id_sep` | str | Separator used internally for entity IDs. | `"||"` |
|
||||
|
||||
```python
|
||||
https://github.com/explosion/spaCy/blob/develop/spacy/pipeline/entityruler.py
|
||||
```
|
||||
|
||||
## EntityRuler.\_\_init\_\_ {#init tag="method"}
|
||||
|
||||
Initialize the entity ruler. If patterns are supplied here, they need to be a
|
||||
|
@ -32,15 +63,16 @@ be a token pattern (list) or a phrase pattern (string). For example:
|
|||
> ruler = EntityRuler(nlp, overwrite_ents=True)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| --------------------- | ------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `nlp` | `Language` | The shared nlp object to pass the vocab to the matchers and process phrase patterns. |
|
||||
| `patterns` | iterable | Optional patterns to load in. |
|
||||
| `phrase_matcher_attr` | int / str | Optional attr to pass to the internal [`PhraseMatcher`](/api/phrasematcher). defaults to `None` |
|
||||
| `validate` | bool | Whether patterns should be validated, passed to Matcher and PhraseMatcher as `validate`. Defaults to `False`. |
|
||||
| `overwrite_ents` | bool | If existing entities are present, e.g. entities added by the model, overwrite them by matches if necessary. Defaults to `False`. |
|
||||
| `**cfg` | - | Other config parameters. If pipeline component is loaded as part of a model pipeline, this will include all keyword arguments passed to `spacy.load`. |
|
||||
| **RETURNS** | `EntityRuler` | The newly constructed object. |
|
||||
| Name | Type | Description |
|
||||
| --------------------------------- | ---------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `nlp` | `Language` | The shared nlp object to pass the vocab to the matchers and process phrase patterns. |
|
||||
| `name` <Tag variant="new">3</Tag> | str | Instance name of the current pipeline component. Typically passed in automatically from the factory when the component is added. Used to disable the current entity ruler while creating phrase patterns with the nlp object. |
|
||||
| _keyword-only_ | | |
|
||||
| `phrase_matcher_attr` | int / str | Optional attribute name match on for the internal [`PhraseMatcher`](/api/phrasematcher), e.g. `LOWER` to match on the lowercase token text. Defaults to `None`. |
|
||||
| `validate` | bool | Whether patterns should be validated, passed to Matcher and PhraseMatcher as `validate`. Defaults to `False`. |
|
||||
| `overwrite_ents` | bool | If existing entities are present, e.g. entities added by the model, overwrite them by matches if necessary. Defaults to `False`. |
|
||||
| `ent_id_sep` | str | Separator used internally for entity IDs. Defaults to `"||"`. |
|
||||
| `patterns` | iterable | Optional patterns to load in on initialization. |
|
||||
|
||||
## EntityRuler.\_\len\_\_ {#len tag="method"}
|
||||
|
||||
|
@ -49,7 +81,7 @@ The number of all patterns added to the entity ruler.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> ruler = EntityRuler(nlp)
|
||||
> ruler = nlp.add_pipe("entity_ruler")
|
||||
> assert len(ruler) == 0
|
||||
> ruler.add_patterns([{"label": "ORG", "pattern": "Apple"}])
|
||||
> assert len(ruler) == 1
|
||||
|
@ -66,7 +98,7 @@ Whether a label is present in the patterns.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> ruler = EntityRuler(nlp)
|
||||
> ruler = nlp.add_pipe("entity_ruler")
|
||||
> ruler.add_patterns([{"label": "ORG", "pattern": "Apple"}])
|
||||
> assert "ORG" in ruler
|
||||
> assert not "PERSON" in ruler
|
||||
|
@ -116,7 +148,7 @@ of dicts) or a phrase pattern (string). For more details, see the usage guide on
|
|||
> {"label": "ORG", "pattern": "Apple"},
|
||||
> {"label": "GPE", "pattern": [{"lower": "san"}, {"lower": "francisco"}]}
|
||||
> ]
|
||||
> ruler = EntityRuler(nlp)
|
||||
> ruler = nlp.add_pipe("entity_ruler")
|
||||
> ruler.add_patterns(patterns)
|
||||
> ```
|
||||
|
||||
|
@ -134,7 +166,7 @@ only the patterns are saved as JSONL. If a directory name is provided, a
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> ruler = EntityRuler(nlp)
|
||||
> ruler = nlp.add_pipe("entity_ruler")
|
||||
> ruler.to_disk("/path/to/patterns.jsonl") # saves patterns only
|
||||
> ruler.to_disk("/path/to/entity_ruler") # saves patterns and config
|
||||
> ```
|
||||
|
@ -153,7 +185,7 @@ configuration.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> ruler = EntityRuler(nlp)
|
||||
> ruler = nlp.add_pipe("entity_ruler")
|
||||
> ruler.from_disk("/path/to/patterns.jsonl") # loads patterns only
|
||||
> ruler.from_disk("/path/to/entity_ruler") # loads patterns and config
|
||||
> ```
|
||||
|
@ -170,7 +202,7 @@ Serialize the entity ruler patterns to a bytestring.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> ruler = EntityRuler(nlp)
|
||||
> ruler = nlp.add_pipe("entity_ruler")
|
||||
> ruler_bytes = ruler.to_bytes()
|
||||
> ```
|
||||
|
||||
|
@ -186,14 +218,14 @@ Load the pipe from a bytestring. Modifies the object in place and returns it.
|
|||
>
|
||||
> ```python
|
||||
> ruler_bytes = ruler.to_bytes()
|
||||
> ruler = EntityRuler(nlp)
|
||||
> ruler = nlp.add_pipe("enity_ruler")
|
||||
> ruler.from_bytes(ruler_bytes)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ---------------- | ------------- | ---------------------------------- |
|
||||
| `patterns_bytes` | bytes | The bytestring to load. |
|
||||
| **RETURNS** | `EntityRuler` | The modified `EntityRuler` object. |
|
||||
| Name | Type | Description |
|
||||
| ------------ | ------------- | ---------------------------------- |
|
||||
| `bytes_data` | bytes | The bytestring to load. |
|
||||
| **RETURNS** | `EntityRuler` | The modified `EntityRuler` object. |
|
||||
|
||||
## EntityRuler.labels {#labels tag="property"}
|
||||
|
||||
|
|
|
@ -495,6 +495,51 @@ As of spaCy v3.0, the `disable_pipes` method has been renamed to `select_pipes`:
|
|||
| `enable` | str / list | Names(s) of pipeline components that will not be disabled. |
|
||||
| **RETURNS** | `DisabledPipes` | The disabled pipes that can be restored by calling the object's `.restore()` method. |
|
||||
|
||||
## Language.get_factory_meta {#get_factory_meta tag="classmethod" new="3"}
|
||||
|
||||
Get the factory meta information for a given pipeline component name. Expects
|
||||
the name of the component **factory**. The factory meta is an instance of the
|
||||
[`FactoryMeta`](/api/language#factorymeta) dataclass and contains the
|
||||
information about the component and its default provided by the
|
||||
[`@Language.component`](/api/language#component) or
|
||||
[`@Language.factory`](/api/language#factory) decorator.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> factory_meta = Language.get_factory_meta("ner")
|
||||
> assert factory_meta.factory == "ner"
|
||||
> print(factory_meta.default_config)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ----------------------------- | ------------------ |
|
||||
| `name` | str | The factory name. |
|
||||
| **RETURNS** | [`FactoryMeta`](#factorymeta) | The factory meta. |
|
||||
|
||||
## Language.get_pipe_meta {#get_pipe_meta tag="method" new="3"}
|
||||
|
||||
Get the factory meta information for a given pipeline component name. Expects
|
||||
the name of the component **instance** in the pipeline. The factory meta is an
|
||||
instance of the [`FactoryMeta`](/api/language#factorymeta) dataclass and
|
||||
contains the information about the component and its default provided by the
|
||||
[`@Language.component`](/api/language#component) or
|
||||
[`@Language.factory`](/api/language#factory) decorator.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> nlp.add_pipe("ner", name="entity_recognizer")
|
||||
> factory_meta = nlp.get_pipe_meta("entity_recognizer")
|
||||
> assert factory_meta.factory == "ner"
|
||||
> print(factory_meta.default_config)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ----------------------------- | ---------------------------- |
|
||||
| `name` | str | The pipeline component name. |
|
||||
| **RETURNS** | [`FactoryMeta`](#factorymeta) | The factory meta. |
|
||||
|
||||
## Language.meta {#meta tag="property"}
|
||||
|
||||
Custom meta data for the Language class. If a model is loaded, contains meta
|
||||
|
@ -622,6 +667,7 @@ available to the loaded object.
|
|||
| `pipe_names` <Tag variant="new">2</Tag> | `List[str]` | List of pipeline component names, in order. |
|
||||
| `pipe_labels` <Tag variant="new">2.2</Tag> | `Dict[str, List[str]]` | List of labels set by the pipeline components, if available, keyed by component name. |
|
||||
| `pipe_factories` <Tag variant="new">2.2</Tag> | `Dict[str, str]` | Dictionary of pipeline component names, mapped to their factory names. |
|
||||
| `factories` | `Dict[str, Callable]` | All available factory functions, keyed by name. |
|
||||
| `factory_names` <Tag variant="new">3</Tag> | `List[str]` | List of all available factory names. |
|
||||
| `path` <Tag variant="new">2</Tag> | `Path` | Path to the model data directory, if a model is loaded. Otherwise `None`. |
|
||||
|
||||
|
@ -712,3 +758,19 @@ serialization by passing in the string names via the `exclude` argument.
|
|||
| `tokenizer` | Tokenization rules and exceptions. |
|
||||
| `meta` | The meta data, available as `Language.meta`. |
|
||||
| ... | String names of pipeline components, e.g. `"ner"`. |
|
||||
|
||||
## FactoryMeta {#factorymeta new="3" tag="dataclass"}
|
||||
|
||||
The `FactoryMeta` contains the information about the component and its default
|
||||
provided by the [`@Language.component`](/api/language#component) or
|
||||
[`@Language.factory`](/api/language#factory) decorator. It's created whenever a
|
||||
component is added to the pipeline and stored on the `Language` class for each
|
||||
component instance and factory instance.
|
||||
|
||||
| Name | Type | Description |
|
||||
| ---------------- | ---------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `factory` | str | The name of the registered component factory. |
|
||||
| `default_config` | `Dict[str, Any]` | The default config, describing the default values of the factory arguments. |
|
||||
| `assigns` | `Iterable[str]` | `Doc` or `Token` attributes assigned by this component, e.g. `["token.ent_id"]`. Used for pipeline analysis. <!-- TODO: link to something --> |
|
||||
| `requires` | `Iterable[str]` | `Doc` or `Token` attributes required by this component, e.g. `["token.ent_id"]`. Used for pipeline analysis. <!-- TODO: link to something --> |
|
||||
| `retokenizes` | bool | Whether the component changes tokenization. Used for pipeline analysis. <!-- TODO: link to something --> |
|
||||
|
|
|
@ -5,6 +5,8 @@ tag: class
|
|||
source: spacy/lemmatizer.py
|
||||
---
|
||||
|
||||
<!-- TODO: rewrite once it's converted to pipe -->
|
||||
|
||||
The `Lemmatizer` supports simple part-of-speech-sensitive suffix rules and
|
||||
lookup tables.
|
||||
|
||||
|
|
|
@ -142,11 +142,12 @@ patterns = [[{"TEXT": "Google"}, {"TEXT": "Now"}], [{"TEXT": "GoogleNow"}]]
|
|||
|
||||
</Infobox>
|
||||
|
||||
| Name | Type | Description |
|
||||
| ---------- | ------------------ | --------------------------------------------------------------------------------------------- |
|
||||
| `match_id` | str | An ID for the thing you're matching. |
|
||||
| `patterns` | list | Match pattern. A pattern consists of a list of dicts, where each dict describes a token. |
|
||||
| `on_match` | callable or `None` | Callback function to act on matches. Takes the arguments `matcher`, `doc`, `i` and `matches`. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | ------------------ | --------------------------------------------------------------------------------------------- |
|
||||
| `match_id` | str | An ID for the thing you're matching. |
|
||||
| `patterns` | list | Match pattern. A pattern consists of a list of dicts, where each dict describes a token. |
|
||||
| _keyword-only_ | | |
|
||||
| `on_match` | callable or `None` | Callback function to act on matches. Takes the arguments `matcher`, `doc`, `i` and `matches`. |
|
||||
|
||||
## Matcher.remove {#remove tag="method" new="2"}
|
||||
|
||||
|
|
|
@ -3,22 +3,38 @@ title: Morphologizer
|
|||
tag: class
|
||||
source: spacy/pipeline/morphologizer.pyx
|
||||
new: 3
|
||||
teaser: 'Pipeline component for predicting morphological features'
|
||||
api_base_class: /api/tagger
|
||||
api_string_name: morphologizer
|
||||
api_trainable: true
|
||||
---
|
||||
|
||||
A trainable pipeline component to predict morphological features and
|
||||
coarse-grained POS tags following the Universal Dependencies
|
||||
[UPOS](https://universaldependencies.org/u/pos/index.html) and
|
||||
[FEATS](https://universaldependencies.org/format.html#morphological-annotation)
|
||||
annotation guidelines. This class is a subclass of `Pipe` and follows the same
|
||||
API. The pipeline component is available in the
|
||||
[processing pipeline](/usage/processing-pipelines) via the ID `"morphologizer"`.
|
||||
annotation guidelines.
|
||||
|
||||
## Implementation and defaults {#implementation}
|
||||
## Config and implementation {#config}
|
||||
|
||||
See the [model architectures](/api/architectures) documentation for details on
|
||||
the architectures and their arguments and hyperparameters. To learn more about
|
||||
how to customize the config and train custom models, check out the
|
||||
[training config](/usage/training#config) docs.
|
||||
The default config is defined by the pipeline component factory and describes
|
||||
how the component should be configured. You can override its settings via the
|
||||
`config` argument on [`nlp.add_pipe`](/api/language#add_pipe) or in your
|
||||
[`config.cfg` for training](/usage/training#config). See the
|
||||
[model architectures](/api/architectures) documentation for details on the
|
||||
architectures and their arguments and hyperparameters.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> from spacy.pipeline.morphologizer import DEFAULT_MORPH_MODEL
|
||||
> config = {"model": DEFAULT_MORPH_MODEL}
|
||||
> nlp.add_pipe("morphologizer", config=config)
|
||||
> ```
|
||||
|
||||
| Setting | Type | Description | Default |
|
||||
| ------- | ------------------------------------------ | ----------------- | ----------------------------------- |
|
||||
| `model` | [`Model`](https://thinc.ai/docs/api-model) | The model to use. | [Tagger](/api/architectures#Tagger) |
|
||||
|
||||
```python
|
||||
https://github.com/explosion/spaCy/blob/develop/spacy/pipeline/morphologizer.pyx
|
||||
|
@ -31,20 +47,30 @@ Initialize the morphologizer.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> # Construction via add_pipe
|
||||
> # Construction via add_pipe with default model
|
||||
> morphologizer = nlp.add_pipe("morphologizer")
|
||||
>
|
||||
> # Construction via create_pipe with custom model
|
||||
> config = {"model": {"@architectures": "my_morphologizer"}}
|
||||
> morphologizer = nlp.add_pipe("morphologizer", config=config)
|
||||
>
|
||||
> # Construction from class
|
||||
> from spacy.pipeline import Morphologizer
|
||||
> morphologizer = Morphologizer(nlp.vocab, model)
|
||||
> ```
|
||||
|
||||
Create a new pipeline instance. In your application, you would normally use a
|
||||
shortcut for this and instantiate the component using its string name and
|
||||
[`nlp.add_pipe`](/api/language#add_pipe).
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------- | ------------------------------------------------------------------------------- |
|
||||
| `vocab` | `Vocab` | The shared vocabulary. |
|
||||
| `model` | `Model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. |
|
||||
| `**cfg` | - | Configuration parameters. |
|
||||
| **RETURNS** | `Morphologizer` | The newly constructed object. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | ------- | ------------------------------------------------------------------------------------------- |
|
||||
| `vocab` | `Vocab` | The shared vocabulary. |
|
||||
| `model` | `Model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. |
|
||||
| `name` | str | String name of the component instance. Used to add entries to the `losses` during training. |
|
||||
| _keyword-only_ | | |
|
||||
| `labels_morph` | dict | <!-- TODO: --> |
|
||||
| `labels_pos` | dict | <!-- TODO: --> |
|
||||
|
||||
## Morphologizer.\_\_call\_\_ {#call tag="method"}
|
||||
|
||||
|
@ -58,8 +84,8 @@ delegate to the [`predict`](/api/morphologizer#predict) and
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> morphologizer = Morphologizer(nlp.vocab)
|
||||
> doc = nlp("This is a sentence.")
|
||||
> morphologizer = nlp.add_pipe("morphologizer")
|
||||
> # This usually happens under the hood
|
||||
> processed = morphologizer(doc)
|
||||
> ```
|
||||
|
@ -81,16 +107,38 @@ applied to the `Doc` in order. Both [`__call__`](/api/morphologizer#call) and
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> morphologizer = Morphologizer(nlp.vocab)
|
||||
> morphologizer = nlp.add_pipe("morphologizer")
|
||||
> for doc in morphologizer.pipe(docs, batch_size=50):
|
||||
> pass
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------ | --------------- | ------------------------------------------------------ |
|
||||
| `stream` | `Iterable[Doc]` | A stream of documents. |
|
||||
| `batch_size` | int | The number of texts to buffer. Defaults to `128`. |
|
||||
| **YIELDS** | `Doc` | Processed documents in the order of the original text. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | ------------------------------------------------------ |
|
||||
| `stream` | `Iterable[Doc]` | A stream of documents. |
|
||||
| _keyword-only_ | | |
|
||||
| `batch_size` | int | The number of texts to buffer. Defaults to `128`. |
|
||||
| **YIELDS** | `Doc` | Processed documents in the order of the original text. |
|
||||
|
||||
## Morphologizer.begin_training {#begin_training tag="method"}
|
||||
|
||||
Initialize the pipe for training, using data examples if available. Return an
|
||||
[`Optimizer`](https://thinc.ai/docs/api-optimizers) object.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> morphologizer = nlp.add_pipe("morphologizer")
|
||||
> nlp.pipeline.append(morphologizer)
|
||||
> optimizer = morphologizer.begin_training(pipeline=nlp.pipeline)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `get_examples` | `Callable[[], Iterable[Example]]` | Optional function that returns gold-standard annotations in the form of [`Example`](/api/example) objects. |
|
||||
| _keyword-only_ | | |
|
||||
| `pipeline` | `List[Tuple[str, Callable]]` | Optional list of pipeline components that this component is part of. |
|
||||
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | An optional optimizer. Will be created via [`create_optimizer`](/api/sentencerecognizer#create_optimizer) if not set. |
|
||||
| **RETURNS** | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||
|
||||
## Morphologizer.predict {#predict tag="method"}
|
||||
|
||||
|
@ -99,7 +147,7 @@ Apply the pipeline's model to a batch of docs, without modifying them.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> morphologizer = Morphologizer(nlp.vocab)
|
||||
> morphologizer = nlp.add_pipe("morphologizer")
|
||||
> scores = morphologizer.predict([doc1, doc2])
|
||||
> ```
|
||||
|
||||
|
@ -115,7 +163,7 @@ Modify a batch of documents, using pre-computed scores.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> morphologizer = Morphologizer(nlp.vocab)
|
||||
> morphologizer = nlp.add_pipe("morphologizer")
|
||||
> scores = morphologizer.predict([doc1, doc2])
|
||||
> morphologizer.set_annotations([doc1, doc2], scores)
|
||||
> ```
|
||||
|
@ -134,20 +182,20 @@ pipe's model. Delegates to [`predict`](/api/morphologizer#predict) and
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> morphologizer = Morphologizer(nlp.vocab, morphologizer_model)
|
||||
> morphologizer = nlp.add_pipe("morphologizer")
|
||||
> optimizer = nlp.begin_training()
|
||||
> losses = morphologizer.update(examples, sgd=optimizer)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------------- | ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `examples` | `Iterable[Example]` | A batch of [`Example`](/api/example) objects to learn from. |
|
||||
| _keyword-only_ | | |
|
||||
| `drop` | float | The dropout rate. |
|
||||
| `set_annotations` | bool | Whether or not to update the `Example` objects with the predictions, delegating to [`set_annotations`](/api/morphologizer#set_annotations). |
|
||||
| `sgd` | `Optimizer` | The [`Optimizer`](https://thinc.ai/docs/api-optimizers) object. |
|
||||
| `losses` | `Dict[str, float]` | Optional record of the loss during training. The value keyed by the model's name is updated. |
|
||||
| **RETURNS** | `Dict[str, float]` | The updated `losses` dictionary. |
|
||||
| Name | Type | Description |
|
||||
| ----------------- | --------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `examples` | `Iterable[Example]` | A batch of [`Example`](/api/example) objects to learn from. |
|
||||
| _keyword-only_ | | |
|
||||
| `drop` | float | The dropout rate. |
|
||||
| `set_annotations` | bool | Whether or not to update the `Example` objects with the predictions, delegating to [`set_annotations`](/api/sentencerecognizer#set_annotations). |
|
||||
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||
| `losses` | `Dict[str, float]` | Optional record of the loss during training. The value keyed by the model's name is updated. |
|
||||
| **RETURNS** | `Dict[str, float]` | The updated `losses` dictionary. |
|
||||
|
||||
## Morphologizer.get_loss {#get_loss tag="method"}
|
||||
|
||||
|
@ -157,36 +205,16 @@ predicted scores.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> morphologizer = Morphologizer(nlp.vocab)
|
||||
> morphologizer = nlp.add_pipe("morphologizer")
|
||||
> scores = morphologizer.predict([eg.predicted for eg in examples])
|
||||
> loss, d_loss = morphologizer.get_loss(examples, scores)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------------------- | --------------------------------------------------- |
|
||||
| `examples` | `Iterable[Example]` | The batch of examples. |
|
||||
| `scores` | - | Scores representing the model's predictions. |
|
||||
| **RETURNS** | tuple | The loss and the gradient, i.e. `(loss, gradient)`. |
|
||||
|
||||
## Morphologizer.begin_training {#begin_training tag="method"}
|
||||
|
||||
Initialize the pipe for training, using data examples if available. Return an
|
||||
[`Optimizer`](https://thinc.ai/docs/api-optimizers) object.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> morphologizer = Morphologizer(nlp.vocab)
|
||||
> nlp.pipeline.append(morphologizer)
|
||||
> optimizer = morphologizer.begin_training(pipeline=nlp.pipeline)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------- | ----------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `get_examples` | `Iterable[Example]` | Optional gold-standard annotations in the form of [`Example`](/api/example) objects. |
|
||||
| `pipeline` | `List[(str, callable)]` | Optional list of pipeline components that this component is part of. |
|
||||
| `sgd` | `Optimizer` | An optional [`Optimizer`](https://thinc.ai/docs/api-optimizers) object. Will be created via [`create_optimizer`](/api/morphologizer#create_optimizer) if not set. |
|
||||
| **RETURNS** | `Optimizer` | An optimizer. |
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------------- | --------------------------------------------------- |
|
||||
| `examples` | `Iterable[Example]` | The batch of examples. |
|
||||
| `scores` | - | Scores representing the model's predictions. |
|
||||
| **RETURNS** | `Tuple[float, float]` | The loss and the gradient, i.e. `(loss, gradient)`. |
|
||||
|
||||
## Morphologizer.create_optimizer {#create_optimizer tag="method"}
|
||||
|
||||
|
@ -195,13 +223,13 @@ Create an optimizer for the pipeline component.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> morphologizer = Morphologizer(nlp.vocab)
|
||||
> morphologizer = nlp.add_pipe("morphologizer")
|
||||
> optimizer = morphologizer.create_optimizer()
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ----------- | --------------------------------------------------------------- |
|
||||
| **RETURNS** | `Optimizer` | The [`Optimizer`](https://thinc.ai/docs/api-optimizers) object. |
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------------------------------------------- | -------------- |
|
||||
| **RETURNS** | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||
|
||||
## Morphologizer.use_params {#use_params tag="method, contextmanager"}
|
||||
|
||||
|
@ -210,7 +238,7 @@ Modify the pipe's model, to use the given parameter values.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> morphologizer = Morphologizer(nlp.vocab)
|
||||
> morphologizer = nlp.add_pipe("morphologizer")
|
||||
> with morphologizer.use_params():
|
||||
> morphologizer.to_disk("/best_model")
|
||||
> ```
|
||||
|
@ -227,7 +255,7 @@ both `pos` and `morph`, the label should include the UPOS as the feature `POS`.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> morphologizer = Morphologizer(nlp.vocab)
|
||||
> morphologizer = nlp.add_pipe("morphologizer")
|
||||
> morphologizer.add_label("Mood=Ind|POS=VERB|Tense=Past|VerbForm=Fin")
|
||||
> ```
|
||||
|
||||
|
@ -242,14 +270,14 @@ Serialize the pipe to disk.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> morphologizer = Morphologizer(nlp.vocab)
|
||||
> morphologizer = nlp.add_pipe("morphologizer")
|
||||
> morphologizer.to_disk("/path/to/morphologizer")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| --------- | ------------ | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | list | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| Name | Type | Description |
|
||||
| --------- | --------------- | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
|
||||
## Morphologizer.from_disk {#from_disk tag="method"}
|
||||
|
||||
|
@ -258,14 +286,14 @@ Load the pipe from disk. Modifies the object in place and returns it.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> morphologizer = Morphologizer(nlp.vocab)
|
||||
> morphologizer = nlp.add_pipe("morphologizer")
|
||||
> morphologizer.from_disk("/path/to/morphologizer")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------- | -------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | list | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `Morphologizer` | The modified `Morphologizer` object. |
|
||||
|
||||
## Morphologizer.to_bytes {#to_bytes tag="method"}
|
||||
|
@ -273,16 +301,16 @@ Load the pipe from disk. Modifies the object in place and returns it.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> morphologizer = Morphologizer(nlp.vocab)
|
||||
> morphologizer = nlp.add_pipe("morphologizer")
|
||||
> morphologizer_bytes = morphologizer.to_bytes()
|
||||
> ```
|
||||
|
||||
Serialize the pipe to a bytestring.
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ----- | ------------------------------------------------------------------------- |
|
||||
| `exclude` | list | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | bytes | The serialized form of the `Morphologizer` object. |
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------- | ------------------------------------------------------------------------- |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | bytes | The serialized form of the `Morphologizer` object. |
|
||||
|
||||
## Morphologizer.from_bytes {#from_bytes tag="method"}
|
||||
|
||||
|
@ -292,14 +320,14 @@ Load the pipe from a bytestring. Modifies the object in place and returns it.
|
|||
>
|
||||
> ```python
|
||||
> morphologizer_bytes = morphologizer.to_bytes()
|
||||
> morphologizer = Morphologizer(nlp.vocab)
|
||||
> morphologizer = nlp.add_pipe("morphologizer")
|
||||
> morphologizer.from_bytes(morphologizer_bytes)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------ | --------------- | ------------------------------------------------------------------------- |
|
||||
| `bytes_data` | bytes | The data to load from. |
|
||||
| `exclude` | list | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `Morphologizer` | The `Morphologizer` object. |
|
||||
|
||||
## Morphologizer.labels {#labels tag="property"}
|
||||
|
|
|
@ -165,11 +165,12 @@ patterns = [nlp("health care reform"), nlp("healthcare reform")]
|
|||
|
||||
</Infobox>
|
||||
|
||||
| Name | Type | Description |
|
||||
| ---------- | ------------------ | --------------------------------------------------------------------------------------------- |
|
||||
| `match_id` | str | An ID for the thing you're matching. |
|
||||
| `docs` | list | `Doc` objects of the phrases to match. |
|
||||
| `on_match` | callable or `None` | Callback function to act on matches. Takes the arguments `matcher`, `doc`, `i` and `matches`. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | ------------------ | --------------------------------------------------------------------------------------------- |
|
||||
| `match_id` | str | An ID for the thing you're matching. |
|
||||
| `docs` | list | `Doc` objects of the phrases to match. |
|
||||
| _keyword-only_ | | |
|
||||
| `on_match` | callable or `None` | Callback function to act on matches. Takes the arguments `matcher`, `doc`, `i` and `matches`. |
|
||||
|
||||
## PhraseMatcher.remove {#remove tag="method" new="2.2"}
|
||||
|
||||
|
|
6
website/docs/api/pipe.md
Normal file
6
website/docs/api/pipe.md
Normal file
|
@ -0,0 +1,6 @@
|
|||
---
|
||||
title: Pipe
|
||||
tag: class
|
||||
---
|
||||
|
||||
TODO: write
|
|
@ -3,19 +3,35 @@ title: SentenceRecognizer
|
|||
tag: class
|
||||
source: spacy/pipeline/senter.pyx
|
||||
new: 3
|
||||
teaser: 'Pipeline component for sentence segmentation'
|
||||
api_base_class: /api/tagger
|
||||
api_string_name: senter
|
||||
api_trainable: true
|
||||
---
|
||||
|
||||
A trainable pipeline component for sentence segmentation. For a simpler,
|
||||
ruse-based strategy, see the [`Sentencizer`](/api/sentencizer). This class is a
|
||||
subclass of `Pipe` and follows the same API. The component is also available via
|
||||
the string name `"senter"`.
|
||||
ruse-based strategy, see the [`Sentencizer`](/api/sentencizer).
|
||||
|
||||
## Implementation and defaults {#implementation}
|
||||
## Config and implementation {#config}
|
||||
|
||||
See the [model architectures](/api/architectures) documentation for details on
|
||||
the architectures and their arguments and hyperparameters. To learn more about
|
||||
how to customize the config and train custom models, check out the
|
||||
[training config](/usage/training#config) docs.
|
||||
The default config is defined by the pipeline component factory and describes
|
||||
how the component should be configured. You can override its settings via the
|
||||
`config` argument on [`nlp.add_pipe`](/api/language#add_pipe) or in your
|
||||
[`config.cfg` for training](/usage/training#config). See the
|
||||
[model architectures](/api/architectures) documentation for details on the
|
||||
architectures and their arguments and hyperparameters.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> from spacy.pipeline.senter import DEFAULT_SENTER_MODEL
|
||||
> config = {"model": DEFAULT_SENTER_MODEL,}
|
||||
> nlp.add_pipe("senter", config=config)
|
||||
> ```
|
||||
|
||||
| Setting | Type | Description | Default |
|
||||
| ------- | ------------------------------------------ | ----------------- | ----------------------------------- |
|
||||
| `model` | [`Model`](https://thinc.ai/docs/api-model) | The model to use. | [Tagger](/api/architectures#Tagger) |
|
||||
|
||||
```python
|
||||
https://github.com/explosion/spaCy/blob/develop/spacy/pipeline/senter.pyx
|
||||
|
@ -28,8 +44,322 @@ Initialize the sentence recognizer.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> # Construction via add_pipe
|
||||
> # Construction via add_pipe with default model
|
||||
> senter = nlp.add_pipe("senter")
|
||||
>
|
||||
> # Construction via create_pipe with custom model
|
||||
> config = {"model": {"@architectures": "my_senter"}}
|
||||
> senter = nlp.add_pipe("senter", config=config)
|
||||
>
|
||||
> # Construction from class
|
||||
> from spacy.pipeline import SentenceRecognizer
|
||||
> senter = SentenceRecognizer(nlp.vocab, model)
|
||||
> ```
|
||||
|
||||
<!-- TODO: document, similar to other trainable pipeline components -->
|
||||
Create a new pipeline instance. In your application, you would normally use a
|
||||
shortcut for this and instantiate the component using its string name and
|
||||
[`nlp.add_pipe`](/api/language#add_pipe).
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------- | ------- | ------------------------------------------------------------------------------------------- |
|
||||
| `vocab` | `Vocab` | The shared vocabulary. |
|
||||
| `model` | `Model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. |
|
||||
| `name` | str | String name of the component instance. Used to add entries to the `losses` during training. |
|
||||
|
||||
## SentenceRecognizer.\_\_call\_\_ {#call tag="method"}
|
||||
|
||||
Apply the pipe to one document. The document is modified in place, and returned.
|
||||
This usually happens under the hood when the `nlp` object is called on a text
|
||||
and all pipeline components are applied to the `Doc` in order. Both
|
||||
[`__call__`](/api/sentencerecognizer#call) and
|
||||
[`pipe`](/api/sentencerecognizer#pipe) delegate to the
|
||||
[`predict`](/api/sentencerecognizer#predict) and
|
||||
[`set_annotations`](/api/sentencerecognizer#set_annotations) methods.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> doc = nlp("This is a sentence.")
|
||||
> senter = nlp.add_pipe("senter")
|
||||
> # This usually happens under the hood
|
||||
> processed = senter(doc)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ----- | ------------------------ |
|
||||
| `doc` | `Doc` | The document to process. |
|
||||
| **RETURNS** | `Doc` | The processed document. |
|
||||
|
||||
## SentenceRecognizer.pipe {#pipe tag="method"}
|
||||
|
||||
Apply the pipe to a stream of documents. This usually happens under the hood
|
||||
when the `nlp` object is called on a text and all pipeline components are
|
||||
applied to the `Doc` in order. Both [`__call__`](/api/sentencerecognizer#call)
|
||||
and [`pipe`](/api/sentencerecognizer#pipe) delegate to the
|
||||
[`predict`](/api/sentencerecognizer#predict) and
|
||||
[`set_annotations`](/api/sentencerecognizer#set_annotations) methods.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> senter = nlp.add_pipe("senter")
|
||||
> for doc in senter.pipe(docs, batch_size=50):
|
||||
> pass
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | ------------------------------------------------------ |
|
||||
| `stream` | `Iterable[Doc]` | A stream of documents. |
|
||||
| _keyword-only_ | | |
|
||||
| `batch_size` | int | The number of texts to buffer. Defaults to `128`. |
|
||||
| **YIELDS** | `Doc` | Processed documents in the order of the original text. |
|
||||
|
||||
## SentenceRecognizer.begin_training {#begin_training tag="method"}
|
||||
|
||||
Initialize the pipe for training, using data examples if available. Return an
|
||||
[`Optimizer`](https://thinc.ai/docs/api-optimizers) object.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> senter = nlp.add_pipe("senter")
|
||||
> optimizer = senter.begin_training(pipeline=nlp.pipeline)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `get_examples` | `Callable[[], Iterable[Example]]` | Optional function that returns gold-standard annotations in the form of [`Example`](/api/example) objects. |
|
||||
| _keyword-only_ | | |
|
||||
| `pipeline` | `List[Tuple[str, Callable]]` | Optional list of pipeline components that this component is part of. |
|
||||
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | An optional optimizer. Will be created via [`create_optimizer`](/api/sentencerecognizer#create_optimizer) if not set. |
|
||||
| **RETURNS** | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||
|
||||
## SentenceRecognizer.predict {#predict tag="method"}
|
||||
|
||||
Apply the pipeline's model to a batch of docs, without modifying them.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> senter = nlp.add_pipe("senter")
|
||||
> scores = senter.predict([doc1, doc2])
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------- | ----------------------------------------- |
|
||||
| `docs` | `Iterable[Doc]` | The documents to predict. |
|
||||
| **RETURNS** | - | The model's prediction for each document. |
|
||||
|
||||
## SentenceRecognizer.set_annotations {#set_annotations tag="method"}
|
||||
|
||||
Modify a batch of documents, using pre-computed scores.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> senter = nlp.add_pipe("senter")
|
||||
> scores = senter.predict([doc1, doc2])
|
||||
> senter.set_annotations([doc1, doc2], scores)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------- | --------------- | ------------------------------------------------------------ |
|
||||
| `docs` | `Iterable[Doc]` | The documents to modify. |
|
||||
| `scores` | - | The scores to set, produced by `SentenceRecognizer.predict`. |
|
||||
|
||||
## SentenceRecognizer.update {#update tag="method"}
|
||||
|
||||
Learn from a batch of documents and gold-standard information, updating the
|
||||
pipe's model. Delegates to [`predict`](/api/sentencerecognizer#predict) and
|
||||
[`get_loss`](/api/sentencerecognizer#get_loss).
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> senter = nlp.add_pipe("senter")
|
||||
> optimizer = nlp.begin_training()
|
||||
> losses = senter.update(examples, sgd=optimizer)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------------- | --------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `examples` | `Iterable[Example]` | A batch of [`Example`](/api/example) objects to learn from. |
|
||||
| _keyword-only_ | | |
|
||||
| `drop` | float | The dropout rate. |
|
||||
| `set_annotations` | bool | Whether or not to update the `Example` objects with the predictions, delegating to [`set_annotations`](/api/sentencerecognizer#set_annotations). |
|
||||
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||
| `losses` | `Dict[str, float]` | Optional record of the loss during training. The value keyed by the model's name is updated. |
|
||||
| **RETURNS** | `Dict[str, float]` | The updated `losses` dictionary. |
|
||||
|
||||
## SentenceRecognizer.rehearse {#rehearse tag="method,experimental"}
|
||||
|
||||
Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the
|
||||
current model to make predictions similar to an initial model, to try to address
|
||||
the "catastrophic forgetting" problem. This feature is experimental.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> senter = nlp.add_pipe("senter")
|
||||
> optimizer = nlp.begin_training()
|
||||
> losses = senter.rehearse(examples, sgd=optimizer)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------------------------------------------- | ----------------------------------------------------------------------------------------- |
|
||||
| `examples` | `Iterable[Example]` | A batch of [`Example`](/api/example) objects to learn from. |
|
||||
| _keyword-only_ | | |
|
||||
| `drop` | float | The dropout rate. |
|
||||
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||
| `losses` | `Dict[str, float]` | Optional record of the loss during training. Updated using the component name as the key. |
|
||||
| **RETURNS** | `Dict[str, float]` | The updated `losses` dictionary. |
|
||||
|
||||
## SentenceRecognizer.get_loss {#get_loss tag="method"}
|
||||
|
||||
Find the loss and gradient of loss for the batch of documents and their
|
||||
predicted scores.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> senter = nlp.add_pipe("senter")
|
||||
> scores = senter.predict([eg.predicted for eg in examples])
|
||||
> loss, d_loss = senter.get_loss(examples, scores)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------------- | --------------------------------------------------- |
|
||||
| `examples` | `Iterable[Example]` | The batch of examples. |
|
||||
| `scores` | - | Scores representing the model's predictions. |
|
||||
| **RETURNS** | `Tuple[float, float]` | The loss and the gradient, i.e. `(loss, gradient)`. |
|
||||
|
||||
## SentenceRecognizer.score {#score tag="method" new="3"}
|
||||
|
||||
Score a batch of examples.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> scores = senter.score(examples)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------------------- | ------------------------------------------------------------------------ |
|
||||
| `examples` | `Iterable[Example]` | The examples to score. |
|
||||
| **RETURNS** | `Dict[str, Any]` | The scores, produced by [`Scorer.score_spans`](/api/scorer#score_spans). |
|
||||
|
||||
## SentenceRecognizer.create_optimizer {#create_optimizer tag="method"}
|
||||
|
||||
Create an optimizer for the pipeline component.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> senter = nlp.add_pipe("senter")
|
||||
> optimizer = senter.create_optimizer()
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------------------------------------------- | -------------- |
|
||||
| **RETURNS** | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||
|
||||
## SentenceRecognizer.use_params {#use_params tag="method, contextmanager"}
|
||||
|
||||
Modify the pipe's model, to use the given parameter values.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> senter = nlp.add_pipe("senter")
|
||||
> with senter.use_params():
|
||||
> senter.to_disk("/best_model")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------- | ---- | ---------------------------------------------------------------------------------------------------------- |
|
||||
| `params` | - | The parameter values to use in the model. At the end of the context, the original parameters are restored. |
|
||||
|
||||
## SentenceRecognizer.to_disk {#to_disk tag="method"}
|
||||
|
||||
Serialize the pipe to disk.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> senter = nlp.add_pipe("senter")
|
||||
> senter.to_disk("/path/to/senter")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| --------- | --------------- | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
|
||||
## SentenceRecognizer.from_disk {#from_disk tag="method"}
|
||||
|
||||
Load the pipe from disk. Modifies the object in place and returns it.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> senter = nlp.add_pipe("senter")
|
||||
> senter.from_disk("/path/to/senter")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | -------------------- | -------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `SentenceRecognizer` | The modified `SentenceRecognizer` object. |
|
||||
|
||||
## SentenceRecognizer.to_bytes {#to_bytes tag="method"}
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> senter = nlp.add_pipe("senter")
|
||||
> senter_bytes = senter.to_bytes()
|
||||
> ```
|
||||
|
||||
Serialize the pipe to a bytestring.
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------- | ------------------------------------------------------------------------- |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | bytes | The serialized form of the `SentenceRecognizer` object. |
|
||||
|
||||
## SentenceRecognizer.from_bytes {#from_bytes tag="method"}
|
||||
|
||||
Load the pipe from a bytestring. Modifies the object in place and returns it.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> senter_bytes = senter.to_bytes()
|
||||
> senter = nlp.add_pipe("senter")
|
||||
> senter.from_bytes(senter_bytes)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------ | -------------------- | ------------------------------------------------------------------------- |
|
||||
| `bytes_data` | bytes | The data to load from. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `SentenceRecognizer` | The `SentenceRecognizer` object. |
|
||||
|
||||
## Serialization fields {#serialization-fields}
|
||||
|
||||
During serialization, spaCy will export several data fields used to restore
|
||||
different aspects of the object. If needed, you can exclude them from
|
||||
serialization by passing in the string names via the `exclude` argument.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> data = senter.to_disk("/path", exclude=["vocab"])
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ------- | -------------------------------------------------------------- |
|
||||
| `vocab` | The shared [`Vocab`](/api/vocab). |
|
||||
| `cfg` | The config file. You usually don't want to exclude this. |
|
||||
| `model` | The binary model data. You usually don't want to exclude this. |
|
||||
|
|
|
@ -1,15 +1,40 @@
|
|||
---
|
||||
title: Sentencizer
|
||||
tag: class
|
||||
source: spacy/pipeline/pipes.pyx
|
||||
source: spacy/pipeline/sentencizer.pyx
|
||||
teaser: 'Pipeline component for rule-based sentence boundary detection'
|
||||
api_base_class: /api/pipe
|
||||
api_string_name: sentencizer
|
||||
api_trainable: false
|
||||
---
|
||||
|
||||
A simple pipeline component, to allow custom sentence boundary detection logic
|
||||
that doesn't require the dependency parse. By default, sentence segmentation is
|
||||
performed by the [`DependencyParser`](/api/dependencyparser), so the
|
||||
`Sentencizer` lets you implement a simpler, rule-based strategy that doesn't
|
||||
require a statistical model to be loaded. The component is also available via
|
||||
the string name `"sentencizer"`.
|
||||
require a statistical model to be loaded.
|
||||
|
||||
## Config and implementation {#config}
|
||||
|
||||
The default config is defined by the pipeline component factory and describes
|
||||
how the component should be configured. You can override its settings via the
|
||||
`config` argument on [`nlp.add_pipe`](/api/language#add_pipe) or in your
|
||||
[`config.cfg` for training](/usage/training#config).
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> config = {"punct_chars": None}
|
||||
> nlp.add_pipe("entity_ruler", config=config)
|
||||
> ```
|
||||
|
||||
| Setting | Type | Description | Default |
|
||||
| ------------- | ----------- | ---------------------------------------------------------------------------------------------------------- | ------- |
|
||||
| `punct_chars` | `List[str]` | Optional custom list of punctuation characters that mark sentence ends. See below for defaults if not set. | `None` |
|
||||
|
||||
```python
|
||||
https://github.com/explosion/spaCy/blob/develop/spacy/pipeline/sentencizer.pyx
|
||||
```
|
||||
|
||||
## Sentencizer.\_\_init\_\_ {#init tag="method"}
|
||||
|
||||
|
@ -20,12 +45,16 @@ Initialize the sentencizer.
|
|||
> ```python
|
||||
> # Construction via add_pipe
|
||||
> sentencizer = nlp.add_pipe("sentencizer")
|
||||
>
|
||||
> # Construction from class
|
||||
> from spacy.pipeline import Sentencizer
|
||||
> sentencizer = Sentencizer()
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------- | ------------- | ----------------------------------------------------------------------------------------------- |
|
||||
| `punct_chars` | list | Optional custom list of punctuation characters that mark sentence ends. See below for defaults. |
|
||||
| **RETURNS** | `Sentencizer` | The newly constructed object. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | ----------- | ----------------------------------------------------------------------------------------------- |
|
||||
| _keyword-only_ | | |
|
||||
| `punct_chars` | `List[str]` | Optional custom list of punctuation characters that mark sentence ends. See below for defaults. |
|
||||
|
||||
```python
|
||||
### punct_chars defaults
|
||||
|
@ -63,6 +92,42 @@ the component has been added to the pipeline using
|
|||
| `doc` | `Doc` | The `Doc` object to process, e.g. the `Doc` in the pipeline. |
|
||||
| **RETURNS** | `Doc` | The modified `Doc` with added sentence boundaries. |
|
||||
|
||||
## Sentencizer.pipe {#pipe tag="method"}
|
||||
|
||||
Apply the pipe to a stream of documents. This usually happens under the hood
|
||||
when the `nlp` object is called on a text and all pipeline components are
|
||||
applied to the `Doc` in order.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> sentencizer = nlp.add_pipe("sentencizer")
|
||||
> for doc in sentencizer.pipe(docs, batch_size=50):
|
||||
> pass
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | ----------------------------------------------------- |
|
||||
| `stream` | `Iterable[Doc]` | A stream of documents. |
|
||||
| _keyword-only_ | | |
|
||||
| `batch_size` | int | The number of documents to buffer. Defaults to `128`. |
|
||||
| **YIELDS** | `Doc` | The processed documents in order. |
|
||||
|
||||
## Sentencizer.score {#score tag="method" new="3"}
|
||||
|
||||
Score a batch of examples.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> scores = sentencizer.score(examples)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------------------- | ------------------------------------------------------------------------ |
|
||||
| `examples` | `Iterable[Example]` | The examples to score. |
|
||||
| **RETURNS** | `Dict[str, Any]` | The scores, produced by [`Scorer.score_spans`](/api/scorer#score_spans). |
|
||||
|
||||
## Sentencizer.to_disk {#to_disk tag="method"}
|
||||
|
||||
Save the sentencizer settings (punctuation characters) a directory. Will create
|
||||
|
@ -72,13 +137,14 @@ a file `sentencizer.json`. This also happens automatically when you save an
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> sentencizer = Sentencizer(punct_chars=[".", "?", "!", "。"])
|
||||
> sentencizer.to_disk("/path/to/sentencizer.jsonl")
|
||||
> config = {"punct_chars": [".", "?", "!", "。"]}
|
||||
> sentencizer = nlp.add_pipe("sentencizer", config=config)
|
||||
> sentencizer.to_disk("/path/to/sentencizer.json")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------ | ------------ | ---------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a file, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
| Name | Type | Description |
|
||||
| ------ | ------------ | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a JSON file, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
|
||||
## Sentencizer.from_disk {#from_disk tag="method"}
|
||||
|
||||
|
@ -89,7 +155,7 @@ added to its pipeline.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> sentencizer = Sentencizer()
|
||||
> sentencizer = nlp.add_pipe("sentencizer")
|
||||
> sentencizer.from_disk("/path/to/sentencizer.json")
|
||||
> ```
|
||||
|
||||
|
@ -105,7 +171,8 @@ Serialize the sentencizer settings to a bytestring.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> sentencizer = Sentencizer(punct_chars=[".", "?", "!", "。"])
|
||||
> config = {"punct_chars": [".", "?", "!", "。"]}
|
||||
> sentencizer = nlp.add_pipe("sentencizer", config=config)
|
||||
> sentencizer_bytes = sentencizer.to_bytes()
|
||||
> ```
|
||||
|
||||
|
@ -121,7 +188,7 @@ Load the pipe from a bytestring. Modifies the object in place and returns it.
|
|||
>
|
||||
> ```python
|
||||
> sentencizer_bytes = sentencizer.to_bytes()
|
||||
> sentencizer = Sentencizer()
|
||||
> sentencizer = nlp.add_pipe("sentencizer")
|
||||
> sentencizer.from_bytes(sentencizer_bytes)
|
||||
> ```
|
||||
|
||||
|
|
|
@ -2,11 +2,40 @@
|
|||
title: Tagger
|
||||
tag: class
|
||||
source: spacy/pipeline/tagger.pyx
|
||||
teaser: 'Pipeline component for part-of-speech tagging'
|
||||
api_base_class: /api/pipe
|
||||
api_string_name: tagger
|
||||
api_trainable: true
|
||||
---
|
||||
|
||||
This class is a subclass of `Pipe` and follows the same API. The pipeline
|
||||
component is available in the [processing pipeline](/usage/processing-pipelines)
|
||||
via the ID `"tagger"`.
|
||||
## Config and implementation {#config}
|
||||
|
||||
The default config is defined by the pipeline component factory and describes
|
||||
how the component should be configured. You can override its settings via the
|
||||
`config` argument on [`nlp.add_pipe`](/api/language#add_pipe) or in your
|
||||
[`config.cfg` for training](/usage/training#config). See the
|
||||
[model architectures](/api/architectures) documentation for details on the
|
||||
architectures and their arguments and hyperparameters.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> from spacy.pipeline.tagger import DEFAULT_TAGGER_MODEL
|
||||
> config = {
|
||||
> "set_morphology": False,
|
||||
> "model": DEFAULT_TAGGER_MODEL,
|
||||
> }
|
||||
> nlp.add_pipe("tagger", config=config)
|
||||
> ```
|
||||
|
||||
| Setting | Type | Description | Default |
|
||||
| ---------------- | ------------------------------------------ | -------------------------------------- | ----------------------------------- |
|
||||
| `set_morphology` | bool | Whether to set morphological features. | `False` |
|
||||
| `model` | [`Model`](https://thinc.ai/docs/api-model) | The model to use. | [Tagger](/api/architectures#Tagger) |
|
||||
|
||||
```python
|
||||
https://github.com/explosion/spaCy/blob/develop/spacy/pipeline/tagger.pyx
|
||||
```
|
||||
|
||||
## Tagger.\_\_init\_\_ {#init tag="method"}
|
||||
|
||||
|
@ -18,19 +47,24 @@ via the ID `"tagger"`.
|
|||
>
|
||||
> # Construction via create_pipe with custom model
|
||||
> config = {"model": {"@architectures": "my_tagger"}}
|
||||
> parser = nlp.add_pipe("tagger", config)
|
||||
> parser = nlp.add_pipe("tagger", config=config)
|
||||
>
|
||||
> # Construction from class
|
||||
> from spacy.pipeline import Tagger
|
||||
> tagger = Tagger(nlp.vocab, model)
|
||||
> ```
|
||||
|
||||
Create a new pipeline instance. In your application, you would normally use a
|
||||
shortcut for this and instantiate the component using its string name and
|
||||
[`nlp.add_pipe`](/api/language#add_pipe).
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | -------- | ------------------------------------------------------------------------------- |
|
||||
| `vocab` | `Vocab` | The shared vocabulary. |
|
||||
| `model` | `Model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. |
|
||||
| `**cfg` | - | Configuration parameters. |
|
||||
| **RETURNS** | `Tagger` | The newly constructed object. |
|
||||
| Name | Type | Description |
|
||||
| ---------------- | ------- | ------------------------------------------------------------------------------------------- |
|
||||
| `vocab` | `Vocab` | The shared vocabulary. |
|
||||
| `model` | `Model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. |
|
||||
| `name` | str | String name of the component instance. Used to add entries to the `losses` during training. |
|
||||
| _keyword-only_ | | |
|
||||
| `set_morphology` | bool | Whether to set morphological features. |
|
||||
|
||||
## Tagger.\_\_call\_\_ {#call tag="method"}
|
||||
|
||||
|
@ -44,8 +78,8 @@ and all pipeline components are applied to the `Doc` in order. Both
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> tagger = Tagger(nlp.vocab)
|
||||
> doc = nlp("This is a sentence.")
|
||||
> tagger = nlp.add_pipe("tagger")
|
||||
> # This usually happens under the hood
|
||||
> processed = tagger(doc)
|
||||
> ```
|
||||
|
@ -66,16 +100,37 @@ applied to the `Doc` in order. Both [`__call__`](/api/tagger#call) and
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> tagger = Tagger(nlp.vocab)
|
||||
> tagger = nlp.add_pipe("tagger")
|
||||
> for doc in tagger.pipe(docs, batch_size=50):
|
||||
> pass
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------ | --------------- | ------------------------------------------------------ |
|
||||
| `stream` | `Iterable[Doc]` | A stream of documents. |
|
||||
| `batch_size` | int | The number of texts to buffer. Defaults to `128`. |
|
||||
| **YIELDS** | `Doc` | Processed documents in the order of the original text. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | ------------------------------------------------------ |
|
||||
| `stream` | `Iterable[Doc]` | A stream of documents. |
|
||||
| _keyword-only_ | | |
|
||||
| `batch_size` | int | The number of texts to buffer. Defaults to `128`. |
|
||||
| **YIELDS** | `Doc` | Processed documents in the order of the original text. |
|
||||
|
||||
## Tagger.begin_training {#begin_training tag="method"}
|
||||
|
||||
Initialize the pipe for training, using data examples if available. Return an
|
||||
[`Optimizer`](https://thinc.ai/docs/api-optimizers) object.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> tagger = nlp.add_pipe("tagger")
|
||||
> optimizer = tagger.begin_training(pipeline=nlp.pipeline)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------------------------------------------- | ---------------------------------------------------------------------------------------------------------- |
|
||||
| `get_examples` | `Callable[[], Iterable[Example]]` | Optional function that returns gold-standard annotations in the form of [`Example`](/api/example) objects. |
|
||||
| _keyword-only_ | | |
|
||||
| `pipeline` | `List[Tuple[str, Callable]]` | Optional list of pipeline components that this component is part of. |
|
||||
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | An optional optimizer. Will be created via [`create_optimizer`](/api/tagger#create_optimizer) if not set. |
|
||||
| **RETURNS** | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||
|
||||
## Tagger.predict {#predict tag="method"}
|
||||
|
||||
|
@ -84,7 +139,7 @@ Apply the pipeline's model to a batch of docs, without modifying them.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> tagger = Tagger(nlp.vocab)
|
||||
> tagger = nlp.add_pipe("tagger")
|
||||
> scores = tagger.predict([doc1, doc2])
|
||||
> ```
|
||||
|
||||
|
@ -100,7 +155,7 @@ Modify a batch of documents, using pre-computed scores.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> tagger = Tagger(nlp.vocab)
|
||||
> tagger = nlp.add_pipe("tagger")
|
||||
> scores = tagger.predict([doc1, doc2])
|
||||
> tagger.set_annotations([doc1, doc2], scores)
|
||||
> ```
|
||||
|
@ -119,20 +174,43 @@ pipe's model. Delegates to [`predict`](/api/tagger#predict) and
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> tagger = Tagger(nlp.vocab, tagger_model)
|
||||
> tagger = nlp.add_pipe("tagger")
|
||||
> optimizer = nlp.begin_training()
|
||||
> losses = tagger.update(examples, sgd=optimizer)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------------- | ------------------- | ------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `examples` | `Iterable[Example]` | A batch of [`Example`](/api/example) objects to learn from. |
|
||||
| _keyword-only_ | | |
|
||||
| `drop` | float | The dropout rate. |
|
||||
| `set_annotations` | bool | Whether or not to update the `Example` objects with the predictions, delegating to [`set_annotations`](/api/tagger#set_annotations). |
|
||||
| `sgd` | `Optimizer` | The [`Optimizer`](https://thinc.ai/docs/api-optimizers) object. |
|
||||
| `losses` | `Dict[str, float]` | Optional record of the loss during training. The value keyed by the model's name is updated. |
|
||||
| **RETURNS** | `Dict[str, float]` | The updated `losses` dictionary. |
|
||||
| Name | Type | Description |
|
||||
| ----------------- | --------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `examples` | `Iterable[Example]` | A batch of [`Example`](/api/example) objects to learn from. |
|
||||
| _keyword-only_ | | |
|
||||
| `drop` | float | The dropout rate. |
|
||||
| `set_annotations` | bool | Whether or not to update the `Example` objects with the predictions, delegating to [`set_annotations`](/api/tagger#set_annotations). |
|
||||
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||
| `losses` | `Dict[str, float]` | Optional record of the loss during training. The value keyed by the model's name is updated. |
|
||||
| **RETURNS** | `Dict[str, float]` | The updated `losses` dictionary. |
|
||||
|
||||
## Tagger.rehearse {#rehearse tag="method,experimental"}
|
||||
|
||||
Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the
|
||||
current model to make predictions similar to an initial model, to try to address
|
||||
the "catastrophic forgetting" problem. This feature is experimental.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> tagger = nlp.add_pipe("tagger")
|
||||
> optimizer = nlp.begin_training()
|
||||
> losses = tagger.rehearse(examples, sgd=optimizer)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------------------------------------------- | ----------------------------------------------------------------------------------------- |
|
||||
| `examples` | `Iterable[Example]` | A batch of [`Example`](/api/example) objects to learn from. |
|
||||
| _keyword-only_ | | |
|
||||
| `drop` | float | The dropout rate. |
|
||||
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||
| `losses` | `Dict[str, float]` | Optional record of the loss during training. Updated using the component name as the key. |
|
||||
| **RETURNS** | `Dict[str, float]` | The updated `losses` dictionary. |
|
||||
|
||||
## Tagger.get_loss {#get_loss tag="method"}
|
||||
|
||||
|
@ -142,36 +220,31 @@ predicted scores.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> tagger = Tagger(nlp.vocab)
|
||||
> tagger = nlp.add_pipe("tagger")
|
||||
> scores = tagger.predict([eg.predicted for eg in examples])
|
||||
> loss, d_loss = tagger.get_loss(examples, scores)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------------------- | --------------------------------------------------- |
|
||||
| `examples` | `Iterable[Example]` | The batch of examples. |
|
||||
| `scores` | - | Scores representing the model's predictions. |
|
||||
| **RETURNS** | tuple | The loss and the gradient, i.e. `(loss, gradient)`. |
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------------- | --------------------------------------------------- |
|
||||
| `examples` | `Iterable[Example]` | The batch of examples. |
|
||||
| `scores` | - | Scores representing the model's predictions. |
|
||||
| **RETURNS** | `Tuple[float, float]` | The loss and the gradient, i.e. `(loss, gradient)`. |
|
||||
|
||||
## Tagger.begin_training {#begin_training tag="method"}
|
||||
## Tagger.score {#score tag="method" new="3"}
|
||||
|
||||
Initialize the pipe for training, using data examples if available. Return an
|
||||
[`Optimizer`](https://thinc.ai/docs/api-optimizers) object.
|
||||
Score a batch of examples.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> tagger = Tagger(nlp.vocab)
|
||||
> nlp.pipeline.append(tagger)
|
||||
> optimizer = tagger.begin_training(pipeline=nlp.pipeline)
|
||||
> scores = tagger.score(examples)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------- | ----------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `get_examples` | `Iterable[Example]` | Optional gold-standard annotations in the form of [`Example`](/api/example) objects. |
|
||||
| `pipeline` | `List[(str, callable)]` | Optional list of pipeline components that this component is part of. |
|
||||
| `sgd` | `Optimizer` | An optional [`Optimizer`](https://thinc.ai/docs/api-optimizers) object. Will be created via [`create_optimizer`](/api/tagger#create_optimizer) if not set. |
|
||||
| **RETURNS** | `Optimizer` | An optimizer. |
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------------------- | ------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `examples` | `Iterable[Example]` | The examples to score. |
|
||||
| **RETURNS** | `Dict[str, Any]` | The scores, produced by [`Scorer.score_token_attr`](/api/scorer#score_token_attr) for the attributes `"pos"`, `"tag"` and `"lemma"`. |
|
||||
|
||||
## Tagger.create_optimizer {#create_optimizer tag="method"}
|
||||
|
||||
|
@ -180,13 +253,13 @@ Create an optimizer for the pipeline component.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> tagger = Tagger(nlp.vocab)
|
||||
> tagger = nlp.add_pipe("tagger")
|
||||
> optimizer = tagger.create_optimizer()
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ----------- | --------------------------------------------------------------- |
|
||||
| **RETURNS** | `Optimizer` | The [`Optimizer`](https://thinc.ai/docs/api-optimizers) object. |
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------------------------------------------- | -------------- |
|
||||
| **RETURNS** | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||
|
||||
## Tagger.use_params {#use_params tag="method, contextmanager"}
|
||||
|
||||
|
@ -195,7 +268,7 @@ Modify the pipe's model, to use the given parameter values.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> tagger = Tagger(nlp.vocab)
|
||||
> tagger = nlp.add_pipe("tagger")
|
||||
> with tagger.use_params():
|
||||
> tagger.to_disk("/best_model")
|
||||
> ```
|
||||
|
@ -212,14 +285,14 @@ Add a new label to the pipe.
|
|||
>
|
||||
> ```python
|
||||
> from spacy.symbols import POS
|
||||
> tagger = Tagger(nlp.vocab)
|
||||
> tagger.add_label("MY_LABEL", {POS: 'NOUN'})
|
||||
> tagger = nlp.add_pipe("tagger")
|
||||
> tagger.add_label("MY_LABEL", {POS: "NOUN"})
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------- | ---- | --------------------------------------------------------------- |
|
||||
| `label` | str | The label to add. |
|
||||
| `values` | dict | Optional values to map to the label, e.g. a tag map dictionary. |
|
||||
| Name | Type | Description |
|
||||
| -------- | ---------------- | --------------------------------------------------------------- |
|
||||
| `label` | str | The label to add. |
|
||||
| `values` | `Dict[int, str]` | Optional values to map to the label, e.g. a tag map dictionary. |
|
||||
|
||||
## Tagger.to_disk {#to_disk tag="method"}
|
||||
|
||||
|
@ -228,14 +301,14 @@ Serialize the pipe to disk.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> tagger = Tagger(nlp.vocab)
|
||||
> tagger = nlp.add_pipe("tagger")
|
||||
> tagger.to_disk("/path/to/tagger")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| --------- | ------------ | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | list | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| Name | Type | Description |
|
||||
| --------- | --------------- | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
|
||||
## Tagger.from_disk {#from_disk tag="method"}
|
||||
|
||||
|
@ -244,31 +317,31 @@ Load the pipe from disk. Modifies the object in place and returns it.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> tagger = Tagger(nlp.vocab)
|
||||
> tagger = nlp.add_pipe("tagger")
|
||||
> tagger.from_disk("/path/to/tagger")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------------ | -------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | list | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `Tagger` | The modified `Tagger` object. |
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------- | -------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `Tagger` | The modified `Tagger` object. |
|
||||
|
||||
## Tagger.to_bytes {#to_bytes tag="method"}
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> tagger = Tagger(nlp.vocab)
|
||||
> tagger = nlp.add_pipe("tagger")
|
||||
> tagger_bytes = tagger.to_bytes()
|
||||
> ```
|
||||
|
||||
Serialize the pipe to a bytestring.
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ----- | ------------------------------------------------------------------------- |
|
||||
| `exclude` | list | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | bytes | The serialized form of the `Tagger` object. |
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------- | ------------------------------------------------------------------------- |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | bytes | The serialized form of the `Tagger` object. |
|
||||
|
||||
## Tagger.from_bytes {#from_bytes tag="method"}
|
||||
|
||||
|
@ -278,15 +351,15 @@ Load the pipe from a bytestring. Modifies the object in place and returns it.
|
|||
>
|
||||
> ```python
|
||||
> tagger_bytes = tagger.to_bytes()
|
||||
> tagger = Tagger(nlp.vocab)
|
||||
> tagger = nlp.add_pipe("tagger")
|
||||
> tagger.from_bytes(tagger_bytes)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------ | -------- | ------------------------------------------------------------------------- |
|
||||
| `bytes_data` | bytes | The data to load from. |
|
||||
| `exclude` | list | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `Tagger` | The `Tagger` object. |
|
||||
| Name | Type | Description |
|
||||
| ------------ | --------------- | ------------------------------------------------------------------------- |
|
||||
| `bytes_data` | bytes | The data to load from. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `Tagger` | The `Tagger` object. |
|
||||
|
||||
## Tagger.labels {#labels tag="property"}
|
||||
|
||||
|
@ -301,9 +374,9 @@ tags by default, e.g. `VERB`, `NOUN` and so on.
|
|||
> assert "MY_LABEL" in tagger.labels
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ----- | ---------------------------------- |
|
||||
| **RETURNS** | tuple | The labels added to the component. |
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------------ | ---------------------------------- |
|
||||
| **RETURNS** | `Tuple[str]` | The labels added to the component. |
|
||||
|
||||
## Serialization fields {#serialization-fields}
|
||||
|
||||
|
|
|
@ -3,18 +3,36 @@ title: TextCategorizer
|
|||
tag: class
|
||||
source: spacy/pipeline/textcat.py
|
||||
new: 2
|
||||
teaser: 'Pipeline component for text classification'
|
||||
api_base_class: /api/pipe
|
||||
api_string_name: textcat
|
||||
api_trainable: true
|
||||
---
|
||||
|
||||
This class is a subclass of `Pipe` and follows the same API. The pipeline
|
||||
component is available in the [processing pipeline](/usage/processing-pipelines)
|
||||
via the ID `"textcat"`.
|
||||
## Config and implementation {#config}
|
||||
|
||||
## Implementation and defaults {#implementation}
|
||||
The default config is defined by the pipeline component factory and describes
|
||||
how the component should be configured. You can override its settings via the
|
||||
`config` argument on [`nlp.add_pipe`](/api/language#add_pipe) or in your
|
||||
[`config.cfg` for training](/usage/training#config). See the
|
||||
[model architectures](/api/architectures) documentation for details on the
|
||||
architectures and their arguments and hyperparameters.
|
||||
|
||||
See the [model architectures](/api/architectures) documentation for details on
|
||||
the architectures and their arguments and hyperparameters. To learn more about
|
||||
how to customize the config and train custom models, check out the
|
||||
[training config](/usage/training#config) docs.
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> from spacy.pipeline.textcat import DEFAULT_TEXTCAT_MODEL
|
||||
> config = {
|
||||
> "labels": [],
|
||||
> "model": DEFAULT_TEXTCAT_MODEL,
|
||||
> }
|
||||
> nlp.add_pipe("textcat", config=config)
|
||||
> ```
|
||||
|
||||
| Setting | Type | Description | Default |
|
||||
| -------- | ------------------------------------------ | ------------------ | ----------------------------------------------------- |
|
||||
| `labels` | `Iterable[str]` | The labels to use. | `[]` |
|
||||
| `model` | [`Model`](https://thinc.ai/docs/api-model) | The model to use. | [TextCatEnsemble](/api/architectures#TextCatEnsemble) |
|
||||
|
||||
```python
|
||||
https://github.com/explosion/spaCy/blob/develop/spacy/pipeline/textcat.py
|
||||
|
@ -31,18 +49,23 @@ https://github.com/explosion/spaCy/blob/develop/spacy/pipeline/textcat.py
|
|||
> # Construction via add_pipe with custom model
|
||||
> config = {"model": {"@architectures": "my_textcat"}}
|
||||
> parser = nlp.add_pipe("textcat", config=config)
|
||||
>
|
||||
> # Construction from class
|
||||
> from spacy.pipeline import TextCategorizer
|
||||
> textcat = TextCategorizer(nlp.vocab, model)
|
||||
> ```
|
||||
|
||||
Create a new pipeline instance. In your application, you would normally use a
|
||||
shortcut for this and instantiate the component using its string name and
|
||||
[`nlp.add_pipe`](/api/language#create_pipe).
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ----------------- | ------------------------------------------------------------------------------- |
|
||||
| `vocab` | `Vocab` | The shared vocabulary. |
|
||||
| `model` | `Model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. |
|
||||
| `**cfg` | - | Configuration parameters. |
|
||||
| **RETURNS** | `TextCategorizer` | The newly constructed object. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | ------------------------------------------ | ------------------------------------------------------------------------------------------- |
|
||||
| `vocab` | `Vocab` | The shared vocabulary. |
|
||||
| `model` | [`Model`](https://thinc.ai/docs/api-model) | The Thinc [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. |
|
||||
| `name` | str | String name of the component instance. Used to add entries to the `losses` during training. |
|
||||
| _keyword-only_ | | |
|
||||
| `labels` | `Iterable[str]` | The labels to use. |
|
||||
|
||||
<!-- TODO move to config page
|
||||
### Architectures {#architectures new="2.1"}
|
||||
|
@ -73,8 +96,8 @@ delegate to the [`predict`](/api/textcategorizer#predict) and
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> textcat = TextCategorizer(nlp.vocab)
|
||||
> doc = nlp("This is a sentence.")
|
||||
> textcat = nlp.add_pipe("textcat")
|
||||
> # This usually happens under the hood
|
||||
> processed = textcat(doc)
|
||||
> ```
|
||||
|
@ -96,16 +119,37 @@ applied to the `Doc` in order. Both [`__call__`](/api/textcategorizer#call) and
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> textcat = TextCategorizer(nlp.vocab)
|
||||
> textcat = nlp.add_pipe("textcat")
|
||||
> for doc in textcat.pipe(docs, batch_size=50):
|
||||
> pass
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------ | --------------- | ------------------------------------------------------ |
|
||||
| `stream` | `Iterable[Doc]` | A stream of documents. |
|
||||
| `batch_size` | int | The number of texts to buffer. Defaults to `128`. |
|
||||
| **YIELDS** | `Doc` | Processed documents in the order of the original text. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | ----------------------------------------------------- |
|
||||
| `stream` | `Iterable[Doc]` | A stream of documents. |
|
||||
| _keyword-only_ | | |
|
||||
| `batch_size` | int | The number of documents to buffer. Defaults to `128`. |
|
||||
| **YIELDS** | `Doc` | The processed documents in order. |
|
||||
|
||||
## TextCategorizer.begin_training {#begin_training tag="method"}
|
||||
|
||||
Initialize the pipe for training, using data examples if available. Return an
|
||||
[`Optimizer`](https://thinc.ai/docs/api-optimizers) object.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> textcat = nlp.add_pipe("textcat")
|
||||
> optimizer = textcat.begin_training(pipeline=nlp.pipeline)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------ |
|
||||
| `get_examples` | `Callable[[], Iterable[Example]]` | Optional function that returns gold-standard annotations in the form of [`Example`](/api/example) objects. |
|
||||
| _keyword-only_ | | |
|
||||
| `pipeline` | `List[Tuple[str, Callable]]` | Optional list of pipeline components that this component is part of. |
|
||||
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | An optional optimizer. Will be created via [`create_optimizer`](/api/textcategorizer#create_optimizer) if not set. |
|
||||
| **RETURNS** | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||
|
||||
## TextCategorizer.predict {#predict tag="method"}
|
||||
|
||||
|
@ -114,7 +158,7 @@ Apply the pipeline's model to a batch of docs, without modifying them.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> textcat = TextCategorizer(nlp.vocab)
|
||||
> textcat = nlp.add_pipe("textcat")
|
||||
> scores = textcat.predict([doc1, doc2])
|
||||
> ```
|
||||
|
||||
|
@ -130,7 +174,7 @@ Modify a batch of documents, using pre-computed scores.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> textcat = TextCategorizer(nlp.vocab)
|
||||
> textcat = nlp.add_pipe("textcat")
|
||||
> scores = textcat.predict(docs)
|
||||
> textcat.set_annotations(docs, scores)
|
||||
> ```
|
||||
|
@ -149,20 +193,43 @@ pipe's model. Delegates to [`predict`](/api/textcategorizer#predict) and
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> textcat = TextCategorizer(nlp.vocab, textcat_model)
|
||||
> textcat = nlp.add_pipe("textcat")
|
||||
> optimizer = nlp.begin_training()
|
||||
> losses = textcat.update(examples, sgd=optimizer)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------------- | ------------------- | --------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `examples` | `Iterable[Example]` | A batch of [`Example`](/api/example) objects to learn from. |
|
||||
| _keyword-only_ | | |
|
||||
| `drop` | float | The dropout rate. |
|
||||
| `set_annotations` | bool | Whether or not to update the `Example` objects with the predictions, delegating to [`set_annotations`](/api/textcategorizer#set_annotations). |
|
||||
| `sgd` | `Optimizer` | The [`Optimizer`](https://thinc.ai/docs/api-optimizers) object. |
|
||||
| `losses` | `Dict[str, float]` | Optional record of the loss during training. The value keyed by the model's name is updated. |
|
||||
| **RETURNS** | `Dict[str, float]` | The updated `losses` dictionary. |
|
||||
| Name | Type | Description |
|
||||
| ----------------- | --------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `examples` | `Iterable[Example]` | A batch of [`Example`](/api/example) objects to learn from. |
|
||||
| _keyword-only_ | | |
|
||||
| `drop` | float | The dropout rate. |
|
||||
| `set_annotations` | bool | Whether or not to update the `Example` objects with the predictions, delegating to [`set_annotations`](/api/textcategorizer#set_annotations). |
|
||||
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||
| `losses` | `Dict[str, float]` | Optional record of the loss during training. Updated using the component name as the key. |
|
||||
| **RETURNS** | `Dict[str, float]` | The updated `losses` dictionary. |
|
||||
|
||||
## TextCategorizer.rehearse {#rehearse tag="method,experimental"}
|
||||
|
||||
Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the
|
||||
current model to make predictions similar to an initial model, to try to address
|
||||
the "catastrophic forgetting" problem. This feature is experimental.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> textcat = nlp.add_pipe("textcat")
|
||||
> optimizer = nlp.begin_training()
|
||||
> losses = textcat.rehearse(examples, sgd=optimizer)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------------------------------------------- | ----------------------------------------------------------------------------------------- |
|
||||
| `examples` | `Iterable[Example]` | A batch of [`Example`](/api/example) objects to learn from. |
|
||||
| _keyword-only_ | | |
|
||||
| `drop` | float | The dropout rate. |
|
||||
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||
| `losses` | `Dict[str, float]` | Optional record of the loss during training. Updated using the component name as the key. |
|
||||
| **RETURNS** | `Dict[str, float]` | The updated `losses` dictionary. |
|
||||
|
||||
## TextCategorizer.get_loss {#get_loss tag="method"}
|
||||
|
||||
|
@ -172,36 +239,32 @@ predicted scores.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> textcat = TextCategorizer(nlp.vocab)
|
||||
> textcat = nlp.add_pipe("textcat")
|
||||
> scores = textcat.predict([eg.predicted for eg in examples])
|
||||
> loss, d_loss = textcat.get_loss(examples, scores)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------------------- | --------------------------------------------------- |
|
||||
| `examples` | `Iterable[Example]` | The batch of examples. |
|
||||
| `scores` | - | Scores representing the model's predictions. |
|
||||
| **RETURNS** | tuple | The loss and the gradient, i.e. `(loss, gradient)`. |
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------------- | --------------------------------------------------- |
|
||||
| `examples` | `Iterable[Example]` | The batch of examples. |
|
||||
| `scores` | - | Scores representing the model's predictions. |
|
||||
| **RETURNS** | `Tuple[float, float]` | The loss and the gradient, i.e. `(loss, gradient)`. |
|
||||
|
||||
## TextCategorizer.begin_training {#begin_training tag="method"}
|
||||
## TextCategorizer.score {#score tag="method" new="3"}
|
||||
|
||||
Initialize the pipe for training, using data examples if available. Return an
|
||||
[`Optimizer`](https://thinc.ai/docs/api-optimizers) object.
|
||||
Score a batch of examples.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> textcat = TextCategorizer(nlp.vocab)
|
||||
> nlp.pipeline.append(textcat)
|
||||
> optimizer = textcat.begin_training(pipeline=nlp.pipeline)
|
||||
> scores = textcat.score(examples)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------- | ----------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `get_examples` | `Iterable[Example]` | Optional gold-standard annotations in the form of [`Example`](/api/example) objects. |
|
||||
| `pipeline` | `List[(str, callable)]` | Optional list of pipeline components that this component is part of. |
|
||||
| `sgd` | `Optimizer` | An optional [`Optimizer`](https://thinc.ai/docs/api-optimizers) object. Will be created via [`create_optimizer`](/api/textcategorizer#create_optimizer) if not set. |
|
||||
| **RETURNS** | `Optimizer` | An optimizer. |
|
||||
| Name | Type | Description |
|
||||
| ---------------- | ------------------- | ---------------------------------------------------------------------- |
|
||||
| `examples` | `Iterable[Example]` | The examples to score. | _keyword-only_ | | |
|
||||
| `positive_label` | str | Optional positive label. |
|
||||
| **RETURNS** | `Dict[str, Any]` | The scores, produced by [`Scorer.score_cats`](/api/scorer#score_cats). |
|
||||
|
||||
## TextCategorizer.create_optimizer {#create_optimizer tag="method"}
|
||||
|
||||
|
@ -210,29 +273,13 @@ Create an optimizer for the pipeline component.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> textcat = TextCategorizer(nlp.vocab)
|
||||
> textcat = nlp.add_pipe("textcat")
|
||||
> optimizer = textcat.create_optimizer()
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ----------- | --------------------------------------------------------------- |
|
||||
| **RETURNS** | `Optimizer` | The [`Optimizer`](https://thinc.ai/docs/api-optimizers) object. |
|
||||
|
||||
## TextCategorizer.use_params {#use_params tag="method, contextmanager"}
|
||||
|
||||
Modify the pipe's model, to use the given parameter values.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> textcat = TextCategorizer(nlp.vocab)
|
||||
> with textcat.use_params(optimizer.averages):
|
||||
> textcat.to_disk("/best_model")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------- | ---- | ---------------------------------------------------------------------------------------------------------- |
|
||||
| `params` | dict | The parameter values to use in the model. At the end of the context, the original parameters are restored. |
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------------------------------------------- | -------------- |
|
||||
| **RETURNS** | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||
|
||||
## TextCategorizer.add_label {#add_label tag="method"}
|
||||
|
||||
|
@ -241,7 +288,7 @@ Add a new label to the pipe.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> textcat = TextCategorizer(nlp.vocab)
|
||||
> textcat = nlp.add_pipe("textcat")
|
||||
> textcat.add_label("MY_LABEL")
|
||||
> ```
|
||||
|
||||
|
@ -249,6 +296,22 @@ Add a new label to the pipe.
|
|||
| ------- | ---- | ----------------- |
|
||||
| `label` | str | The label to add. |
|
||||
|
||||
## TextCategorizer.use_params {#use_params tag="method, contextmanager"}
|
||||
|
||||
Modify the pipe's model, to use the given parameter values.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> textcat = nlp.add_pipe("textcat")
|
||||
> with textcat.use_params():
|
||||
> textcat.to_disk("/best_model")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------- | ---- | ---------------------------------------------------------------------------------------------------------- |
|
||||
| `params` | - | The parameter values to use in the model. At the end of the context, the original parameters are restored. |
|
||||
|
||||
## TextCategorizer.to_disk {#to_disk tag="method"}
|
||||
|
||||
Serialize the pipe to disk.
|
||||
|
@ -256,14 +319,14 @@ Serialize the pipe to disk.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> textcat = TextCategorizer(nlp.vocab)
|
||||
> textcat = nlp.add_pipe("textcat")
|
||||
> textcat.to_disk("/path/to/textcat")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| --------- | ------------ | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | list | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| Name | Type | Description |
|
||||
| --------- | --------------- | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
|
||||
## TextCategorizer.from_disk {#from_disk tag="method"}
|
||||
|
||||
|
@ -272,14 +335,14 @@ Load the pipe from disk. Modifies the object in place and returns it.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> textcat = TextCategorizer(nlp.vocab)
|
||||
> textcat = nlp.add_pipe("textcat")
|
||||
> textcat.from_disk("/path/to/textcat")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ----------------- | -------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | list | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `TextCategorizer` | The modified `TextCategorizer` object. |
|
||||
|
||||
## TextCategorizer.to_bytes {#to_bytes tag="method"}
|
||||
|
@ -287,16 +350,16 @@ Load the pipe from disk. Modifies the object in place and returns it.
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> textcat = TextCategorizer(nlp.vocab)
|
||||
> textcat = nlp.add_pipe("textcat")
|
||||
> textcat_bytes = textcat.to_bytes()
|
||||
> ```
|
||||
|
||||
Serialize the pipe to a bytestring.
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ----- | ------------------------------------------------------------------------- |
|
||||
| `exclude` | list | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | bytes | The serialized form of the `TextCategorizer` object. |
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------- | ------------------------------------------------------------------------- |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | bytes | The serialized form of the `TextCategorizer` object. |
|
||||
|
||||
## TextCategorizer.from_bytes {#from_bytes tag="method"}
|
||||
|
||||
|
@ -306,14 +369,14 @@ Load the pipe from a bytestring. Modifies the object in place and returns it.
|
|||
>
|
||||
> ```python
|
||||
> textcat_bytes = textcat.to_bytes()
|
||||
> textcat = TextCategorizer(nlp.vocab)
|
||||
> textcat = nlp.add_pipe("textcat")
|
||||
> textcat.from_bytes(textcat_bytes)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------ | ----------------- | ------------------------------------------------------------------------- |
|
||||
| `bytes_data` | bytes | The data to load from. |
|
||||
| `exclude` | list | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `TextCategorizer` | The `TextCategorizer` object. |
|
||||
|
||||
## TextCategorizer.labels {#labels tag="property"}
|
||||
|
|
|
@ -2,17 +2,10 @@
|
|||
title: Tok2Vec
|
||||
source: spacy/pipeline/tok2vec.py
|
||||
new: 3
|
||||
teaser: null
|
||||
api_base_class: /api/pipe
|
||||
api_string_name: tok2vec
|
||||
api_trainable: true
|
||||
---
|
||||
|
||||
<!-- TODO: document -->
|
||||
|
||||
## Implementation and defaults {#implementation}
|
||||
|
||||
See the [model architectures](/api/architectures) documentation for details on
|
||||
the architectures and their arguments and hyperparameters. To learn more about
|
||||
how to customize the config and train custom models, check out the
|
||||
[training config](/usage/training#config) docs.
|
||||
|
||||
```python
|
||||
https://github.com/explosion/spaCy/blob/develop/spacy/pipeline/tok2vec.py
|
||||
```
|
||||
TODO:
|
||||
|
|
|
@ -31,11 +31,12 @@ loaded in via [`Language.from_disk`](/api/language#from_disk).
|
|||
> nlp = spacy.load("en_core_web_sm", disable=["parser", "tagger"])
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------------ | --------------------------------------------------------------------------------- |
|
||||
| `name` | str / `Path` | Model to load, i.e. package name or path. |
|
||||
| `disable` | `List[str]` | Names of pipeline components to [disable](/usage/processing-pipelines#disabling). |
|
||||
| **RETURNS** | `Language` | A `Language` object with the loaded model. |
|
||||
| Name | Type | Description |
|
||||
| ------------------------------------------ | ----------------- | --------------------------------------------------------------------------------- |
|
||||
| `name` | str / `Path` | Model to load, i.e. package name or path. |
|
||||
| `disable` | `List[str]` | Names of pipeline components to [disable](/usage/processing-pipelines#disabling). |
|
||||
| `component_cfg` <Tag variant="new">3</Tag> | `Dict[str, dict]` | Optional config overrides for pipeline components, keyed by component names. |
|
||||
| **RETURNS** | `Language` | A `Language` object with the loaded model. |
|
||||
|
||||
Essentially, `spacy.load()` is a convenience wrapper that reads the language ID
|
||||
and pipeline components from a model's `meta.json`, initializes the `Language`
|
||||
|
@ -43,10 +44,10 @@ class, loads in the model data and returns it.
|
|||
|
||||
```python
|
||||
### Abstract example
|
||||
cls = util.get_lang_class(lang) # get language for ID, e.g. 'en'
|
||||
nlp = cls() # initialise the language
|
||||
cls = util.get_lang_class(lang) # get language for ID, e.g. "en"
|
||||
nlp = cls() # initialize the language
|
||||
for name in pipeline:
|
||||
nlp.add_pipe(name) # add component to pipeline
|
||||
nlp.add_pipe(name) # add component to pipeline
|
||||
nlp.from_disk(model_data_path) # load in model data
|
||||
```
|
||||
|
||||
|
@ -58,15 +59,14 @@ Create a blank model of a given language class. This function is the twin of
|
|||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> nlp_en = spacy.blank("en")
|
||||
> nlp_de = spacy.blank("de")
|
||||
> nlp_en = spacy.blank("en") # equivalent to English()
|
||||
> nlp_de = spacy.blank("de") # equivalent to German()
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ----------- | ------------------------------------------------------------------------------------------------ |
|
||||
| `name` | str | [ISO code](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) of the language class to load. |
|
||||
| `disable` | `List[str]` | Names of pipeline components to [disable](/usage/processing-pipelines#disabling). |
|
||||
| **RETURNS** | `Language` | An empty `Language` object of the appropriate subclass. |
|
||||
| Name | Type | Description |
|
||||
| ----------- | ---------- | ------------------------------------------------------------------------------------------------ |
|
||||
| `name` | str | [ISO code](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) of the language class to load. |
|
||||
| **RETURNS** | `Language` | An empty `Language` object of the appropriate subclass. |
|
||||
|
||||
#### spacy.info {#spacy.info tag="function"}
|
||||
|
||||
|
@ -80,13 +80,14 @@ meta data as a dictionary instead, you can use the `meta` attribute on your
|
|||
> ```python
|
||||
> spacy.info()
|
||||
> spacy.info("en_core_web_sm")
|
||||
> spacy.info(markdown=True)
|
||||
> markdown = spacy.info(markdown=True, silent=True)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ---------- | ---- | ------------------------------------------------ |
|
||||
| `model` | str | A model, i.e. a package name or path (optional). |
|
||||
| `markdown` | bool | Print information as Markdown. |
|
||||
| `silent` | bool | Don't print anything, just return. |
|
||||
|
||||
### spacy.explain {#spacy.explain tag="function"}
|
||||
|
||||
|
|
|
@ -30,13 +30,14 @@ you can add vectors to later.
|
|||
> vectors = Vectors(data=data, keys=keys)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `data` | `ndarray[ndim=1, dtype='float32']` | The vector data. |
|
||||
| `keys` | iterable | A sequence of keys aligned with the data. |
|
||||
| `shape` | tuple | Size of the table as `(n_entries, n_columns)`, the number of entries and number of columns. Not required if you're initializing the object with `data` and `keys`. |
|
||||
| `name` | str | A name to identify the vectors table. |
|
||||
| **RETURNS** | `Vectors` | The newly created object. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| _keyword-only_ | | |
|
||||
| `shape` | tuple | Size of the table as `(n_entries, n_columns)`, the number of entries and number of columns. Not required if you're initializing the object with `data` and `keys`. |
|
||||
| `data` | `ndarray[ndim=1, dtype='float32']` | The vector data. |
|
||||
| `keys` | iterable | A sequence of keys aligned with the data. |
|
||||
| `name` | str | A name to identify the vectors table. |
|
||||
| **RETURNS** | `Vectors` | The newly created object. |
|
||||
|
||||
## Vectors.\_\_getitem\_\_ {#getitem tag="method"}
|
||||
|
||||
|
@ -138,12 +139,13 @@ mapping separately. If you need to manage the strings, you should use the
|
|||
> nlp.vocab.vectors.add("dog", row=0)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ---------------------------------- | ----------------------------------------------------- |
|
||||
| `key` | str / int | The key to add. |
|
||||
| `vector` | `ndarray[ndim=1, dtype='float32']` | An optional vector to add for the key. |
|
||||
| `row` | int | An optional row number of a vector to map the key to. |
|
||||
| **RETURNS** | int | The row the vector was added to. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | ---------------------------------- | ----------------------------------------------------- |
|
||||
| `key` | str / int | The key to add. |
|
||||
| _keyword-only_ | | |
|
||||
| `vector` | `ndarray[ndim=1, dtype='float32']` | An optional vector to add for the key. |
|
||||
| `row` | int | An optional row number of a vector to map the key to. |
|
||||
| **RETURNS** | int | The row the vector was added to. |
|
||||
|
||||
## Vectors.resize {#resize tag="method"}
|
||||
|
||||
|
@ -225,13 +227,14 @@ Look up one or more keys by row, or vice versa.
|
|||
> keys = nlp.vocab.vectors.find(rows=[18, 256, 985])
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------------------------------------- | ------------------------------------------------------------------------ |
|
||||
| `key` | str / int | Find the row that the given key points to. Returns int, `-1` if missing. |
|
||||
| `keys` | iterable | Find rows that the keys point to. Returns `ndarray`. |
|
||||
| `row` | int | Find the first key that points to the row. Returns int. |
|
||||
| `rows` | iterable | Find the keys that point to the rows. Returns ndarray. |
|
||||
| **RETURNS** | The requested key, keys, row or rows. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | ------------------------------------- | ------------------------------------------------------------------------ |
|
||||
| _keyword-only_ | | |
|
||||
| `key` | str / int | Find the row that the given key points to. Returns int, `-1` if missing. |
|
||||
| `keys` | iterable | Find rows that the keys point to. Returns `ndarray`. |
|
||||
| `row` | int | Find the first key that points to the row. Returns int. |
|
||||
| `rows` | iterable | Find the keys that point to the rows. Returns ndarray. |
|
||||
| **RETURNS** | The requested key, keys, row or rows. |
|
||||
|
||||
## Vectors.shape {#shape tag="property"}
|
||||
|
||||
|
@ -318,13 +321,14 @@ performed in chunks, to avoid consuming too much memory. You can set the
|
|||
> most_similar = nlp.vocab.vectors.most_similar(queries, n=10)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------ | --------- | ------------------------------------------------------------------ |
|
||||
| `queries` | `ndarray` | An array with one or more vectors. |
|
||||
| `batch_size` | int | The batch size to use. Default to `1024`. |
|
||||
| `n` | int | The number of entries to return for each query. Defaults to `1`. |
|
||||
| `sort` | bool | Whether to sort the entries returned by score. Defaults to `True`. |
|
||||
| **RETURNS** | tuple | The most similar entries as a `(keys, best_rows, scores)` tuple. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------- | ------------------------------------------------------------------ |
|
||||
| `queries` | `ndarray` | An array with one or more vectors. |
|
||||
| _keyword-only_ | | |
|
||||
| `batch_size` | int | The batch size to use. Default to `1024`. |
|
||||
| `n` | int | The number of entries to return for each query. Defaults to `1`. |
|
||||
| `sort` | bool | Whether to sort the entries returned by score. Defaults to `True`. |
|
||||
| **RETURNS** | tuple | The most similar entries as a `(keys, best_rows, scores)` tuple. |
|
||||
|
||||
## Vectors.to_disk {#to_disk tag="method"}
|
||||
|
||||
|
|
|
@ -136,10 +136,11 @@ have to call this to change the size of the vectors. Only one of the `width` and
|
|||
> nlp.vocab.reset_vectors(width=300)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------- | ---- | -------------------------------------- |
|
||||
| `width` | int | The new width (keyword argument only). |
|
||||
| `shape` | int | The new shape (keyword argument only). |
|
||||
| Name | Type | Description |
|
||||
| -------------- | ---- | -------------------------------------- |
|
||||
| _keyword-only_ | | |
|
||||
| `width` | int | The new width (keyword argument only). |
|
||||
| `shape` | int | The new shape (keyword argument only). |
|
||||
|
||||
## Vocab.prune_vectors {#prune_vectors tag="method" new="2"}
|
||||
|
||||
|
|
|
@ -446,19 +446,6 @@ annotations:
|
|||
|
||||
</Infobox>
|
||||
|
||||
> - **Training data**: The training examples.
|
||||
> - **Text and label**: The current example.
|
||||
> - **Doc**: A `Doc` object created from the example text.
|
||||
> - **Example**: An `Example` object holding both predictions and gold-standard
|
||||
> annotations.
|
||||
> - **nlp**: The `nlp` object with the model.
|
||||
> - **Optimizer**: A function that holds state between updates.
|
||||
> - **Update**: Update the model's weights.
|
||||
|
||||
<!-- TODO: update graphic & related text -->
|
||||
|
||||
![The training loop](../images/training-loop.svg)
|
||||
|
||||
Of course, it's not enough to only show a model a single example once.
|
||||
Especially if you only have few examples, you'll want to train for a **number of
|
||||
iterations**. At each iteration, the training data is **shuffled** to ensure the
|
||||
|
@ -469,12 +456,16 @@ it harder for the model to memorize the training data. For example, a `0.25`
|
|||
dropout means that each feature or internal representation has a 1/4 likelihood
|
||||
of being dropped.
|
||||
|
||||
> - [`begin_training`](/api/language#begin_training): Start the training and
|
||||
> return an [`Optimizer`](https://thinc.ai/docs/api-optimizers) object to
|
||||
> update the model's weights.
|
||||
> - [`update`](/api/language#update): Update the model with the training
|
||||
> examplea.
|
||||
> - [`to_disk`](/api/language#to_disk): Save the updated model to a directory.
|
||||
> - [`nlp`](/api/language): The `nlp` object with the model.
|
||||
> - [`nlp.begin_training`](/api/language#begin_training): Start the training and
|
||||
> return an optimizer to update the model's weights.
|
||||
> - [`Optimizer`](https://thinc.ai/docs/api-optimizers): Function that holds
|
||||
> state between updates.
|
||||
> - [`nlp.update`](/api/language#update): Update model with examples.
|
||||
> - [`Example`](/api/example): object holding predictions and gold-standard
|
||||
> annotations.
|
||||
> - [`nlp.to_disk`](/api/language#to_disk): Save the updated model to a
|
||||
> directory.
|
||||
|
||||
```python
|
||||
### Example training loop
|
||||
|
|
|
@ -15,24 +15,183 @@ menu:
|
|||
|
||||
## Backwards Incompatibilities {#incompat}
|
||||
|
||||
### Removed deprecated methods, attributes and arguments {#incompat-removed}
|
||||
### Removed or renamed objects, methods, attributes and arguments {#incompat-removed}
|
||||
|
||||
| Removed | Replacement |
|
||||
| -------------------------------------------------------- | ----------------------------------------- |
|
||||
| `GoldParse` | [`Example`](/api/example) |
|
||||
| `spacy debug-data` | [`spacy debug data`](/api/cli#debug-data) |
|
||||
| `spacy link`, `util.set_data_path`, `util.get_data_path` | not needed, model symlinks are deprecated |
|
||||
|
||||
### Removed deprecated methods, attributes and arguments {#incompat-removed-deprecated}
|
||||
|
||||
The following deprecated methods, attributes and arguments were removed in v3.0.
|
||||
Most of them have been deprecated for quite a while now and many would
|
||||
previously raise errors. Many of them were also mostly internals. If you've been
|
||||
working with more recent versions of spaCy v2.x, it's unlikely that your code
|
||||
relied on them.
|
||||
Most of them have been **deprecated for a while** and many would previously
|
||||
raise errors. Many of them were also mostly internals. If you've been working
|
||||
with more recent versions of spaCy v2.x, it's **unlikely** that your code relied
|
||||
on them.
|
||||
|
||||
| Class | Removed |
|
||||
| --------------------- | ------------------------------------------------------- |
|
||||
| [`Doc`](/api/doc) | `Doc.tokens_from_list`, `Doc.merge` |
|
||||
| [`Span`](/api/span) | `Span.merge`, `Span.upper`, `Span.lower`, `Span.string` |
|
||||
| [`Token`](/api/token) | `Token.string` |
|
||||
|
||||
<!-- TODO: complete (see release notes Dropbox Paper doc) -->
|
||||
| Removed | Replacement |
|
||||
| ----------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `Doc.tokens_from_list` | [`Doc.__init__`](/api/doc#init) |
|
||||
| `Doc.merge`, `Span.merge` | [`Doc.retokenize`](/api/doc#retokenize) |
|
||||
| `Token.string`, `Span.string`, `Span.upper`, `Span.lower` | [`Span.text`](/api/span#attributes), [`Token.text`](/api/token#attributes) |
|
||||
| `Language.tagger`, `Language.parser`, `Language.entity` | [`Language.get_pipe`](/api/language#get_pipe) |
|
||||
| keyword-arguments like `vocab=False` on `to_disk`, `from_disk`, `to_bytes`, `from_bytes` | `exclude=["vocab"]` |
|
||||
| `n_threads` argument on [`Tokenizer`](/api/tokenizer), [`Matcher`](/api/matcher), [`PhraseMatcher`](/api/phrasematcher) | `n_process` |
|
||||
| `SentenceSegmenter` hook, `SimilarityHook` | [user hooks](/usage/processing-pipelines#custom-components-user-hooks), [`Sentencizer`](/api/sentencizer), [`SentenceRecognizer`](/api/sentenceregognizer), |
|
||||
|
||||
## Migrating from v2.x {#migrating}
|
||||
|
||||
### Custom pipeline components and factories {#migrating-pipeline-components}
|
||||
|
||||
Custom pipeline components now have to be registered explicitly using the
|
||||
[`@Language.component`](/api/language#component) or
|
||||
[`@Language.factory`](/api/language#factory) decorator. For simple functions
|
||||
that take a `Doc` and return it, all you have to do is add the
|
||||
`@Language.component` decorator to it and assign it a name:
|
||||
|
||||
```diff
|
||||
### Stateless function components
|
||||
+ from spacy.language import Language
|
||||
|
||||
+ @Language.component("my_component")
|
||||
def my_component(doc):
|
||||
return doc
|
||||
```
|
||||
|
||||
For class components that are initialized with settings and/or the shared `nlp`
|
||||
object, you can use the `@Language.factory` decorator. Also make sure that that
|
||||
the method used to initialize the factory has **two named arguments**: `nlp`
|
||||
(the current `nlp` object) and `name` (the string name of the component
|
||||
instance).
|
||||
|
||||
```diff
|
||||
### Stateful class components
|
||||
+ from spacy.language import Language
|
||||
|
||||
+ @Language.factory("my_component")
|
||||
class MyComponent:
|
||||
- def __init__(self, nlp):
|
||||
+ def __init__(self, nlp, name):
|
||||
self.nlp = nlp
|
||||
|
||||
def __call__(self, doc):
|
||||
return doc
|
||||
```
|
||||
|
||||
Instead of decorating your class, you could also add a factory function that
|
||||
takes the arguments `nlp` and `name` and returns an instance of your component:
|
||||
|
||||
```diff
|
||||
### Stateful class components with factory function
|
||||
+ from spacy.language import Language
|
||||
|
||||
+ @Language.factory("my_component")
|
||||
+ def create_my_component(nlp, name):
|
||||
+ return MyComponent(nlp)
|
||||
|
||||
class MyComponent:
|
||||
def __init__(self, nlp):
|
||||
self.nlp = nlp
|
||||
|
||||
def __call__(self, doc):
|
||||
return doc
|
||||
```
|
||||
|
||||
The `@Language.component` and `@Language.factory` decorators now take care of
|
||||
adding an entry to the component factories, so spaCy knows how to load a
|
||||
component back in from its string name. You won't have to write to
|
||||
`Language.factories` manually anymore.
|
||||
|
||||
```diff
|
||||
- Language.factories["my_component"] = lambda nlp, **cfg: MyComponent(nlp)
|
||||
```
|
||||
|
||||
#### Adding components to the pipeline {#migrating-add-pipe}
|
||||
|
||||
The [`nlp.add_pipe`](/api/language#add_pipe) method now takes the **string
|
||||
name** of the component factory instead of a callable component. This allows
|
||||
spaCy to track and serialize components that have been added and their settings.
|
||||
|
||||
```diff
|
||||
+ @Language.component("my_component")
|
||||
def my_component(doc):
|
||||
return doc
|
||||
|
||||
- nlp.add_pipe(my_component)
|
||||
+ nlp.add_pipe("my_component")
|
||||
```
|
||||
|
||||
[`nlp.add_pipe`](/api/language#add_pipe) now also returns the pipeline component
|
||||
itself, so you can access its attributes. The
|
||||
[`nlp.create_pipe`](/api/language#create_pipe) method is now mostly internals
|
||||
and you typically shouldn't have to use it in your code.
|
||||
|
||||
```diff
|
||||
- parser = nlp.create_pipe("parser")
|
||||
- nlp.add_pipe(parser)
|
||||
+ parser = nlp.add_pipe("parser")
|
||||
```
|
||||
|
||||
### Training models {#migrating-training}
|
||||
|
||||
To train your models, you should now pretty much always use the
|
||||
[`spacy train`](/api/cli#train) CLI. You shouldn't have to put together your own
|
||||
training scripts anymore, unless you _really_ want to. The training commands now
|
||||
use a [flexible config file](/usage/training#config) that describes all training
|
||||
settings and hyperparameters, as well as your pipeline, model components and
|
||||
architectures to use. The `--code` argument lets you pass in code containing
|
||||
[custom registered functions](/usage/training#custom-code) that you can
|
||||
reference in your config.
|
||||
|
||||
#### Binary .spacy training data format {#migrating-training-format}
|
||||
|
||||
spaCy now uses a new
|
||||
[binary training data format](/api/data-formats#binary-training), which is much
|
||||
smaller and consists of `Doc` objects, serialized via the
|
||||
[`DocBin`](/api/docbin). You can convert your existing JSON-formatted data using
|
||||
the [`spacy convert`](/api/cli#convert) command, which outputs `.spacy` files:
|
||||
|
||||
```bash
|
||||
$ python -m spacy convert ./training.json ./output
|
||||
```
|
||||
|
||||
#### Training config {#migrating-training-config}
|
||||
|
||||
<!-- TODO: update once we have recommended "getting started with a new config" workflow -->
|
||||
|
||||
```diff
|
||||
### {wrap="true"}
|
||||
- python -m spacy train en ./output ./train.json ./dev.json --pipeline tagger,parser --cnn-window 1 --bilstm-depth 0
|
||||
+ python -m spacy train ./train.spacy ./dev.spacy ./config.cfg --output ./output
|
||||
```
|
||||
|
||||
<Project id="some_example_project">
|
||||
|
||||
The easiest way to get started with an end-to-end training process is to clone a
|
||||
[project](/usage/projects) template. Projects let you manage multi-step
|
||||
workflows, from data preprocessing to training and packaging your model.
|
||||
|
||||
</Project>
|
||||
|
||||
#### Migrating training scripts to CLI command and config {#migrating-training-scripts}
|
||||
|
||||
<!-- TODO: write -->
|
||||
|
||||
#### Packaging models {#migrating-training-packaging}
|
||||
|
||||
The [`spacy package`](/api/cli#package) command now automatically builds the
|
||||
installable `.tar.gz` sdist of the Python package, so you don't have to run this
|
||||
step manually anymore. You can disable the behavior by setting the `--no-sdist`
|
||||
flag.
|
||||
|
||||
```diff
|
||||
python -m spacy package ./model ./packages
|
||||
- cd /output/en_model-0.0.0
|
||||
- python setup.py sdist
|
||||
```
|
||||
|
||||
## Migration notes for plugin maintainers {#plugins}
|
||||
|
||||
Thanks to everyone who's been contributing to the spaCy ecosystem by developing
|
||||
|
|
|
@ -14,6 +14,10 @@ function getNodeTitle({ childMdx }) {
|
|||
return (frontmatter.title || '').replace("'", '’')
|
||||
}
|
||||
|
||||
function findNode(pages, slug) {
|
||||
return slug ? pages.find(({ node }) => node.fields.slug === slug) : null
|
||||
}
|
||||
|
||||
exports.createPages = ({ graphql, actions }) => {
|
||||
const { createPage } = actions
|
||||
|
||||
|
@ -70,6 +74,9 @@ exports.createPages = ({ graphql, actions }) => {
|
|||
title
|
||||
teaser
|
||||
source
|
||||
api_base_class
|
||||
api_string_name
|
||||
api_trainable
|
||||
tag
|
||||
new
|
||||
next
|
||||
|
@ -115,10 +122,18 @@ exports.createPages = ({ graphql, actions }) => {
|
|||
const section = frontmatter.section || page.node.relativeDirectory
|
||||
const sectionMeta = sections[section] || {}
|
||||
const title = getNodeTitle(page.node)
|
||||
const next = frontmatter.next
|
||||
? pages.find(({ node }) => node.fields.slug === frontmatter.next)
|
||||
: null
|
||||
|
||||
const next = findNode(pages, frontmatter.next)
|
||||
const baseClass = findNode(pages, frontmatter.api_base_class)
|
||||
const apiDetails = {
|
||||
stringName: frontmatter.api_string_name,
|
||||
baseClass: baseClass
|
||||
? {
|
||||
title: getNodeTitle(baseClass.node),
|
||||
slug: frontmatter.api_base_class,
|
||||
}
|
||||
: null,
|
||||
trainable: frontmatter.api_trainable,
|
||||
}
|
||||
createPage({
|
||||
path: replacePath(page.node.fields.slug),
|
||||
component: DEFAULT_TEMPLATE,
|
||||
|
@ -131,6 +146,7 @@ exports.createPages = ({ graphql, actions }) => {
|
|||
sectionTitle: sectionMeta.title,
|
||||
menu: frontmatter.menu || [],
|
||||
teaser: frontmatter.teaser,
|
||||
apiDetails,
|
||||
source: frontmatter.source,
|
||||
sidebar: frontmatter.sidebar,
|
||||
tag: frontmatter.tag,
|
||||
|
|
|
@ -67,6 +67,7 @@
|
|||
{
|
||||
"label": "Containers",
|
||||
"items": [
|
||||
{ "text": "Language", "url": "/api/language" },
|
||||
{ "text": "Doc", "url": "/api/doc" },
|
||||
{ "text": "Token", "url": "/api/token" },
|
||||
{ "text": "Span", "url": "/api/span" },
|
||||
|
@ -78,7 +79,6 @@
|
|||
{
|
||||
"label": "Pipeline",
|
||||
"items": [
|
||||
{ "text": "Language", "url": "/api/language" },
|
||||
{ "text": "Tokenizer", "url": "/api/tokenizer" },
|
||||
{ "text": "Tok2Vec", "url": "/api/tok2vec" },
|
||||
{ "text": "Lemmatizer", "url": "/api/lemmatizer" },
|
||||
|
@ -86,14 +86,21 @@
|
|||
{ "text": "Tagger", "url": "/api/tagger" },
|
||||
{ "text": "DependencyParser", "url": "/api/dependencyparser" },
|
||||
{ "text": "EntityRecognizer", "url": "/api/entityrecognizer" },
|
||||
{ "text": "EntityRuler", "url": "/api/entityruler" },
|
||||
{ "text": "EntityLinker", "url": "/api/entitylinker" },
|
||||
{ "text": "TextCategorizer", "url": "/api/textcategorizer" },
|
||||
{ "text": "Matcher", "url": "/api/matcher" },
|
||||
{ "text": "PhraseMatcher", "url": "/api/phrasematcher" },
|
||||
{ "text": "EntityRuler", "url": "/api/entityruler" },
|
||||
{ "text": "Sentencizer", "url": "/api/sentencizer" },
|
||||
{ "text": "SentenceRecognizer", "url": "/api/sentencerecognizer" },
|
||||
{ "text": "Other Functions", "url": "/api/pipeline-functions" }
|
||||
{ "text": "Other Functions", "url": "/api/pipeline-functions" },
|
||||
{ "text": "Pipe", "url": "/api/pipe" }
|
||||
]
|
||||
},
|
||||
{
|
||||
"label": "Matchers",
|
||||
"items": [
|
||||
{ "text": "Matcher", "url": "/api/matcher" },
|
||||
{ "text": "PhraseMatcher", "url": "/api/phrasematcher" },
|
||||
{ "text": "DependencyMatcher", "url": "/api/dependencymatcher" }
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -87,6 +87,17 @@ const Link = ({
|
|||
)
|
||||
}
|
||||
|
||||
export const OptionalLink = ({ to, href, children, ...props }) => {
|
||||
const dest = to || href
|
||||
return dest ? (
|
||||
<Link to={dest} {...props}>
|
||||
{children}
|
||||
</Link>
|
||||
) : (
|
||||
children || null
|
||||
)
|
||||
}
|
||||
|
||||
Link.defaultProps = {
|
||||
hidden: false,
|
||||
hideIcon: false,
|
||||
|
|
|
@ -4,43 +4,105 @@ import classNames from 'classnames'
|
|||
|
||||
import Button from './button'
|
||||
import Tag from './tag'
|
||||
import { H1 } from './typography'
|
||||
import { OptionalLink } from './link'
|
||||
import { InlineCode } from './code'
|
||||
import { H1, Label, InlineList, Help } from './typography'
|
||||
import Icon from './icon'
|
||||
|
||||
import classes from '../styles/title.module.sass'
|
||||
|
||||
const Title = ({ id, title, tag, version, teaser, source, image, children, ...props }) => (
|
||||
<header className={classes.root}>
|
||||
{(image || source) && (
|
||||
<div className={classes.corner}>
|
||||
{source && (
|
||||
<Button to={source} icon="code">
|
||||
Source
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{image && (
|
||||
<div className={classes.image}>
|
||||
<img src={image} width={100} height={100} alt="" />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
const MetaItem = ({ label, url, children, help }) => (
|
||||
<span>
|
||||
<Label className={classes.label}>{label}:</Label>
|
||||
<OptionalLink to={url}>{children}</OptionalLink>
|
||||
{help && (
|
||||
<>
|
||||
{' '}
|
||||
<Help>{help}</Help>
|
||||
</>
|
||||
)}
|
||||
<H1 className={classes.h1} id={id} {...props}>
|
||||
{title}
|
||||
</H1>
|
||||
{tag && <Tag spaced>{tag}</Tag>}
|
||||
{version && (
|
||||
<Tag variant="new" spaced>
|
||||
{version}
|
||||
</Tag>
|
||||
)}
|
||||
|
||||
{teaser && <div className={classNames('heading-teaser', classes.teaser)}>{teaser}</div>}
|
||||
|
||||
{children}
|
||||
</header>
|
||||
</span>
|
||||
)
|
||||
|
||||
const Title = ({
|
||||
id,
|
||||
title,
|
||||
tag,
|
||||
version,
|
||||
teaser,
|
||||
source,
|
||||
image,
|
||||
apiDetails,
|
||||
children,
|
||||
...props
|
||||
}) => {
|
||||
const hasApiDetails = Object.values(apiDetails).some(v => v)
|
||||
const metaIconProps = { className: classes.metaIcon, width: 18 }
|
||||
return (
|
||||
<header className={classes.root}>
|
||||
{(image || source) && (
|
||||
<div className={classes.corner}>
|
||||
{source && (
|
||||
<Button to={source} icon="code">
|
||||
Source
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{image && (
|
||||
<div className={classes.image}>
|
||||
<img src={image} width={100} height={100} alt="" />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
<H1 className={classes.h1} id={id} {...props}>
|
||||
{title}
|
||||
</H1>
|
||||
{(tag || version) && (
|
||||
<div className={classes.tags}>
|
||||
{tag && <Tag spaced>{tag}</Tag>}
|
||||
{version && (
|
||||
<Tag variant="new" spaced>
|
||||
{version}
|
||||
</Tag>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{hasApiDetails && (
|
||||
<InlineList Component="div" className={classes.teaser}>
|
||||
{apiDetails.stringName && (
|
||||
<MetaItem
|
||||
label="String name"
|
||||
//help="String name of the component to use with nlp.add_pipe"
|
||||
>
|
||||
<InlineCode>{apiDetails.stringName}</InlineCode>
|
||||
</MetaItem>
|
||||
)}
|
||||
{apiDetails.baseClass && (
|
||||
<MetaItem label="Base class" url={apiDetails.baseClass.slug}>
|
||||
<InlineCode>{apiDetails.baseClass.title}</InlineCode>
|
||||
</MetaItem>
|
||||
)}
|
||||
{apiDetails.trainable != null && (
|
||||
<MetaItem label="Trainable">
|
||||
<span aria-label={apiDetails.trainable ? 'yes' : 'no'}>
|
||||
{apiDetails.trainable ? (
|
||||
<Icon name="yes" variant="success" {...metaIconProps} />
|
||||
) : (
|
||||
<Icon name="no" {...metaIconProps} />
|
||||
)}
|
||||
</span>
|
||||
</MetaItem>
|
||||
)}
|
||||
</InlineList>
|
||||
)}
|
||||
{teaser && <div className={classNames('heading-teaser', classes.teaser)}>{teaser}</div>}
|
||||
{children}
|
||||
</header>
|
||||
)
|
||||
}
|
||||
|
||||
Title.propTypes = {
|
||||
title: PropTypes.string,
|
||||
tag: PropTypes.string,
|
||||
|
|
|
@ -59,9 +59,9 @@ export const InlineList = ({ Component = 'p', gutterBottom = true, className, ch
|
|||
return <Component className={listClassNames}>{children}</Component>
|
||||
}
|
||||
|
||||
export const Help = ({ children }) => (
|
||||
export const Help = ({ children, size = 16 }) => (
|
||||
<span className={classes.help} data-tooltip={children}>
|
||||
<Icon name="help2" width={16} />
|
||||
<Icon name="help2" width={size} />
|
||||
</span>
|
||||
)
|
||||
|
||||
|
|
|
@ -25,6 +25,9 @@
|
|||
font-family: var(--font-secondary)
|
||||
color: var(--color-theme)
|
||||
|
||||
& > td:nth-child(2) a
|
||||
border: 0
|
||||
|
||||
.td
|
||||
padding: 1rem
|
||||
|
||||
|
@ -80,6 +83,13 @@
|
|||
white-space: nowrap
|
||||
z-index: 5
|
||||
|
||||
&:first-child // directly after thead/tr
|
||||
border-bottom: 0
|
||||
|
||||
td em
|
||||
top: -7px
|
||||
|
||||
|
||||
// Responsive table
|
||||
// Shadows adapted from "CSS only Responsive Tables" by David Bushell
|
||||
// http://codepen.io/dbushell/pen/wGaamR
|
||||
|
|
|
@ -34,3 +34,16 @@
|
|||
|
||||
.corner
|
||||
float: right
|
||||
|
||||
.label
|
||||
display: inline
|
||||
margin-right: var(--spacing-xs)
|
||||
|
||||
.meta-icon
|
||||
position: relative
|
||||
top: -2px
|
||||
|
||||
.tags
|
||||
display: inline-block
|
||||
position: relative
|
||||
top: 0.5rem
|
||||
|
|
|
@ -30,6 +30,7 @@ const Docs = ({ pageContext, children }) => (
|
|||
menu,
|
||||
theme,
|
||||
version,
|
||||
apiDetails,
|
||||
} = pageContext
|
||||
const { sidebars = [], modelsRepo, languages, nightly } = site.siteMetadata
|
||||
const isModels = section === 'models'
|
||||
|
@ -102,6 +103,7 @@ const Docs = ({ pageContext, children }) => (
|
|||
tag={tag}
|
||||
version={version}
|
||||
id="_title"
|
||||
apiDetails={apiDetails}
|
||||
/>
|
||||
{children}
|
||||
{subFooter}
|
||||
|
|
Loading…
Reference in New Issue
Block a user