mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 09:26:27 +03:00
Merge branch 'master' into spacy.io
This commit is contained in:
commit
21ade53ef7
|
@ -43,7 +43,11 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None
|
||||||
# nlp.create_pipe works for built-ins that are registered with spaCy
|
# nlp.create_pipe works for built-ins that are registered with spaCy
|
||||||
if "textcat" not in nlp.pipe_names:
|
if "textcat" not in nlp.pipe_names:
|
||||||
textcat = nlp.create_pipe(
|
textcat = nlp.create_pipe(
|
||||||
"textcat", config={"architecture": "simple_cnn", "exclusive_classes": True}
|
"textcat",
|
||||||
|
config={
|
||||||
|
"exclusive_classes": True,
|
||||||
|
"architecture": "simple_cnn",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
nlp.add_pipe(textcat, last=True)
|
nlp.add_pipe(textcat, last=True)
|
||||||
# otherwise, get it, so we can add labels to it
|
# otherwise, get it, so we can add labels to it
|
||||||
|
@ -56,7 +60,9 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None
|
||||||
|
|
||||||
# load the IMDB dataset
|
# load the IMDB dataset
|
||||||
print("Loading IMDB data...")
|
print("Loading IMDB data...")
|
||||||
(train_texts, train_cats), (dev_texts, dev_cats) = load_data(limit=n_texts)
|
(train_texts, train_cats), (dev_texts, dev_cats) = load_data()
|
||||||
|
train_texts = train_texts[:n_texts]
|
||||||
|
train_cats = train_cats[:n_texts]
|
||||||
print(
|
print(
|
||||||
"Using {} examples ({} training, {} evaluation)".format(
|
"Using {} examples ({} training, {} evaluation)".format(
|
||||||
n_texts, len(train_texts), len(dev_texts)
|
n_texts, len(train_texts), len(dev_texts)
|
||||||
|
|
|
@ -43,8 +43,9 @@ redirects = [
|
||||||
{from = "/usage/lightning-tour", to = "/usage/spacy-101#lightning-tour"},
|
{from = "/usage/lightning-tour", to = "/usage/spacy-101#lightning-tour"},
|
||||||
{from = "/usage/linguistic-features#rule-based-matching", to = "/usage/rule-based-matching"},
|
{from = "/usage/linguistic-features#rule-based-matching", to = "/usage/rule-based-matching"},
|
||||||
{from = "/models/comparison", to = "/models"},
|
{from = "/models/comparison", to = "/models"},
|
||||||
{from = "/api/#section-cython", to = "/api/cython"},
|
{from = "/api/#section-cython", to = "/api/cython", force = true},
|
||||||
{from = "/api/#cython", to = "/api/cython"},
|
{from = "/api/#cython", to = "/api/cython", force = true},
|
||||||
|
{from = "/api/sentencesegmenter", to="/api/sentencizer"},
|
||||||
{from = "/universe", to = "/universe/project/:id", query = {id = ":id"}, force = true},
|
{from = "/universe", to = "/universe/project/:id", query = {id = ":id"}, force = true},
|
||||||
{from = "/universe", to = "/universe/category/:category", query = {category = ":category"}, force = true},
|
{from = "/universe", to = "/universe/category/:category", query = {category = ":category"}, force = true},
|
||||||
]
|
]
|
||||||
|
|
94
spacy/_ml.py
94
spacy/_ml.py
|
@ -81,18 +81,6 @@ def _zero_init(model):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@layerize
|
|
||||||
def _preprocess_doc(docs, drop=0.0):
|
|
||||||
keys = [doc.to_array(LOWER) for doc in docs]
|
|
||||||
# The dtype here matches what thinc is expecting -- which differs per
|
|
||||||
# platform (by int definition). This should be fixed once the problem
|
|
||||||
# is fixed on Thinc's side.
|
|
||||||
lengths = numpy.array([arr.shape[0] for arr in keys], dtype=numpy.int_)
|
|
||||||
keys = numpy.concatenate(keys)
|
|
||||||
vals = numpy.zeros(keys.shape, dtype='f')
|
|
||||||
return (keys, vals, lengths), None
|
|
||||||
|
|
||||||
|
|
||||||
def with_cpu(ops, model):
|
def with_cpu(ops, model):
|
||||||
"""Wrap a model that should run on CPU, transferring inputs and outputs
|
"""Wrap a model that should run on CPU, transferring inputs and outputs
|
||||||
as necessary."""
|
as necessary."""
|
||||||
|
@ -133,20 +121,31 @@ def _to_device(ops, X):
|
||||||
return ops.asarray(X)
|
return ops.asarray(X)
|
||||||
|
|
||||||
|
|
||||||
@layerize
|
class extract_ngrams(Model):
|
||||||
def _preprocess_doc_bigrams(docs, drop=0.0):
|
def __init__(self, ngram_size, attr=LOWER):
|
||||||
unigrams = [doc.to_array(LOWER) for doc in docs]
|
Model.__init__(self)
|
||||||
ops = Model.ops
|
self.ngram_size = ngram_size
|
||||||
bigrams = [ops.ngrams(2, doc_unis) for doc_unis in unigrams]
|
self.attr = attr
|
||||||
keys = [ops.xp.concatenate(feats) for feats in zip(unigrams, bigrams)]
|
|
||||||
keys, vals = zip(*[ops.xp.unique(k, return_counts=True) for k in keys])
|
def begin_update(self, docs, drop=0.0):
|
||||||
# The dtype here matches what thinc is expecting -- which differs per
|
batch_keys = []
|
||||||
# platform (by int definition). This should be fixed once the problem
|
batch_vals = []
|
||||||
# is fixed on Thinc's side.
|
for doc in docs:
|
||||||
lengths = ops.asarray([arr.shape[0] for arr in keys], dtype=numpy.int_)
|
unigrams = doc.to_array([self.attr])
|
||||||
keys = ops.xp.concatenate(keys)
|
ngrams = [unigrams]
|
||||||
vals = ops.asarray(ops.xp.concatenate(vals), dtype="f")
|
for n in range(2, self.ngram_size + 1):
|
||||||
return (keys, vals, lengths), None
|
ngrams.append(self.ops.ngrams(n, unigrams))
|
||||||
|
keys = self.ops.xp.concatenate(ngrams)
|
||||||
|
keys, vals = self.ops.xp.unique(keys, return_counts=True)
|
||||||
|
batch_keys.append(keys)
|
||||||
|
batch_vals.append(vals)
|
||||||
|
# The dtype here matches what thinc is expecting -- which differs per
|
||||||
|
# platform (by int definition). This should be fixed once the problem
|
||||||
|
# is fixed on Thinc's side.
|
||||||
|
lengths = self.ops.asarray([arr.shape[0] for arr in batch_keys], dtype=numpy.int_)
|
||||||
|
batch_keys = self.ops.xp.concatenate(batch_keys)
|
||||||
|
batch_vals = self.ops.asarray(self.ops.xp.concatenate(batch_vals), dtype="f")
|
||||||
|
return (batch_keys, batch_vals, lengths), None
|
||||||
|
|
||||||
|
|
||||||
@describe.on_data(
|
@describe.on_data(
|
||||||
|
@ -486,16 +485,6 @@ def zero_init(model):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@layerize
|
|
||||||
def preprocess_doc(docs, drop=0.0):
|
|
||||||
keys = [doc.to_array([LOWER]) for doc in docs]
|
|
||||||
ops = Model.ops
|
|
||||||
lengths = ops.asarray([arr.shape[0] for arr in keys])
|
|
||||||
keys = ops.xp.concatenate(keys)
|
|
||||||
vals = ops.allocate(keys.shape[0]) + 1
|
|
||||||
return (keys, vals, lengths), None
|
|
||||||
|
|
||||||
|
|
||||||
def getitem(i):
|
def getitem(i):
|
||||||
def getitem_fwd(X, drop=0.0):
|
def getitem_fwd(X, drop=0.0):
|
||||||
return X[i], None
|
return X[i], None
|
||||||
|
@ -602,10 +591,8 @@ def build_text_classifier(nr_class, width=64, **cfg):
|
||||||
>> zero_init(Affine(nr_class, width, drop_factor=0.0))
|
>> zero_init(Affine(nr_class, width, drop_factor=0.0))
|
||||||
)
|
)
|
||||||
|
|
||||||
linear_model = (
|
linear_model = build_bow_text_classifier(
|
||||||
_preprocess_doc
|
nr_class, ngram_size=cfg.get("ngram_size", 1), exclusive_classes=False)
|
||||||
>> with_cpu(Model.ops, LinearModel(nr_class))
|
|
||||||
)
|
|
||||||
if cfg.get('exclusive_classes'):
|
if cfg.get('exclusive_classes'):
|
||||||
output_layer = Softmax(nr_class, nr_class * 2)
|
output_layer = Softmax(nr_class, nr_class * 2)
|
||||||
else:
|
else:
|
||||||
|
@ -623,6 +610,33 @@ def build_text_classifier(nr_class, width=64, **cfg):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def build_bow_text_classifier(nr_class, ngram_size=1, exclusive_classes=False,
|
||||||
|
no_output_layer=False, **cfg):
|
||||||
|
with Model.define_operators({">>": chain}):
|
||||||
|
model = (
|
||||||
|
extract_ngrams(ngram_size, attr=ORTH)
|
||||||
|
>> with_cpu(Model.ops,
|
||||||
|
LinearModel(nr_class)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not no_output_layer:
|
||||||
|
model = model >> (cpu_softmax if exclusive_classes else logistic)
|
||||||
|
model.nO = nr_class
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@layerize
|
||||||
|
def cpu_softmax(X, drop=0.):
|
||||||
|
ops = NumpyOps()
|
||||||
|
|
||||||
|
Y = ops.softmax(X)
|
||||||
|
|
||||||
|
def cpu_softmax_backward(dY, sgd=None):
|
||||||
|
return dY
|
||||||
|
|
||||||
|
return ops.softmax(X), cpu_softmax_backward
|
||||||
|
|
||||||
|
|
||||||
def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=False, **cfg):
|
def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=False, **cfg):
|
||||||
"""
|
"""
|
||||||
Build a simple CNN text classifier, given a token-to-vector model as inputs.
|
Build a simple CNN text classifier, given a token-to-vector model as inputs.
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# fmt: off
|
# fmt: off
|
||||||
|
|
||||||
__title__ = "spacy"
|
__title__ = "spacy"
|
||||||
__version__ = "2.1.2"
|
__version__ = "2.1.3"
|
||||||
__summary__ = "Industrial-strength Natural Language Processing (NLP) with Python and Cython"
|
__summary__ = "Industrial-strength Natural Language Processing (NLP) with Python and Cython"
|
||||||
__uri__ = "https://spacy.io"
|
__uri__ = "https://spacy.io"
|
||||||
__author__ = "Explosion AI"
|
__author__ = "Explosion AI"
|
||||||
|
|
|
@ -15,7 +15,7 @@ from .tokenizer import Tokenizer
|
||||||
from .vocab import Vocab
|
from .vocab import Vocab
|
||||||
from .lemmatizer import Lemmatizer
|
from .lemmatizer import Lemmatizer
|
||||||
from .pipeline import DependencyParser, Tensorizer, Tagger, EntityRecognizer
|
from .pipeline import DependencyParser, Tensorizer, Tagger, EntityRecognizer
|
||||||
from .pipeline import SimilarityHook, TextCategorizer, SentenceSegmenter
|
from .pipeline import SimilarityHook, TextCategorizer, Sentencizer
|
||||||
from .pipeline import merge_noun_chunks, merge_entities, merge_subtokens
|
from .pipeline import merge_noun_chunks, merge_entities, merge_subtokens
|
||||||
from .pipeline import EntityRuler
|
from .pipeline import EntityRuler
|
||||||
from .compat import izip, basestring_
|
from .compat import izip, basestring_
|
||||||
|
@ -119,7 +119,7 @@ class Language(object):
|
||||||
"ner": lambda nlp, **cfg: EntityRecognizer(nlp.vocab, **cfg),
|
"ner": lambda nlp, **cfg: EntityRecognizer(nlp.vocab, **cfg),
|
||||||
"similarity": lambda nlp, **cfg: SimilarityHook(nlp.vocab, **cfg),
|
"similarity": lambda nlp, **cfg: SimilarityHook(nlp.vocab, **cfg),
|
||||||
"textcat": lambda nlp, **cfg: TextCategorizer(nlp.vocab, **cfg),
|
"textcat": lambda nlp, **cfg: TextCategorizer(nlp.vocab, **cfg),
|
||||||
"sentencizer": lambda nlp, **cfg: SentenceSegmenter(nlp.vocab, **cfg),
|
"sentencizer": lambda nlp, **cfg: Sentencizer(**cfg),
|
||||||
"merge_noun_chunks": lambda nlp, **cfg: merge_noun_chunks,
|
"merge_noun_chunks": lambda nlp, **cfg: merge_noun_chunks,
|
||||||
"merge_entities": lambda nlp, **cfg: merge_entities,
|
"merge_entities": lambda nlp, **cfg: merge_entities,
|
||||||
"merge_subtokens": lambda nlp, **cfg: merge_subtokens,
|
"merge_subtokens": lambda nlp, **cfg: merge_subtokens,
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
from .pipes import Tagger, DependencyParser, EntityRecognizer
|
from .pipes import Tagger, DependencyParser, EntityRecognizer
|
||||||
from .pipes import TextCategorizer, Tensorizer, Pipe
|
from .pipes import TextCategorizer, Tensorizer, Pipe, Sentencizer
|
||||||
from .entityruler import EntityRuler
|
from .entityruler import EntityRuler
|
||||||
from .hooks import SentenceSegmenter, SimilarityHook
|
from .hooks import SentenceSegmenter, SimilarityHook
|
||||||
from .functions import merge_entities, merge_noun_chunks, merge_subtokens
|
from .functions import merge_entities, merge_noun_chunks, merge_subtokens
|
||||||
|
@ -15,6 +15,7 @@ __all__ = [
|
||||||
"Tensorizer",
|
"Tensorizer",
|
||||||
"Pipe",
|
"Pipe",
|
||||||
"EntityRuler",
|
"EntityRuler",
|
||||||
|
"Sentencizer",
|
||||||
"SentenceSegmenter",
|
"SentenceSegmenter",
|
||||||
"SimilarityHook",
|
"SimilarityHook",
|
||||||
"merge_entities",
|
"merge_entities",
|
||||||
|
|
|
@ -191,7 +191,7 @@ class EntityRuler(object):
|
||||||
**kwargs: Other config paramters, mostly for consistency.
|
**kwargs: Other config paramters, mostly for consistency.
|
||||||
RETURNS (EntityRuler): The loaded entity ruler.
|
RETURNS (EntityRuler): The loaded entity ruler.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/entityruler
|
DOCS: https://spacy.io/api/entityruler#to_disk
|
||||||
"""
|
"""
|
||||||
path = ensure_path(path)
|
path = ensure_path(path)
|
||||||
path = path.with_suffix(".jsonl")
|
path = path.with_suffix(".jsonl")
|
||||||
|
|
|
@ -15,8 +15,6 @@ class SentenceSegmenter(object):
|
||||||
initialization, or assign a new strategy to the .strategy attribute.
|
initialization, or assign a new strategy to the .strategy attribute.
|
||||||
Sentence detection strategies should be generators that take `Doc` objects
|
Sentence detection strategies should be generators that take `Doc` objects
|
||||||
and yield `Span` objects for each sentence.
|
and yield `Span` objects for each sentence.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/sentencesegmenter
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = "sentencizer"
|
name = "sentencizer"
|
||||||
|
@ -35,12 +33,12 @@ class SentenceSegmenter(object):
|
||||||
def split_on_punct(doc):
|
def split_on_punct(doc):
|
||||||
start = 0
|
start = 0
|
||||||
seen_period = False
|
seen_period = False
|
||||||
for i, word in enumerate(doc):
|
for i, token in enumerate(doc):
|
||||||
if seen_period and not word.is_punct:
|
if seen_period and not token.is_punct:
|
||||||
yield doc[start : word.i]
|
yield doc[start : token.i]
|
||||||
start = word.i
|
start = token.i
|
||||||
seen_period = False
|
seen_period = False
|
||||||
elif word.text in [".", "!", "?"]:
|
elif token.text in [".", "!", "?"]:
|
||||||
seen_period = True
|
seen_period = True
|
||||||
if start < len(doc):
|
if start < len(doc):
|
||||||
yield doc[start : len(doc)]
|
yield doc[start : len(doc)]
|
||||||
|
|
|
@ -25,6 +25,7 @@ from ..attrs import POS, ID
|
||||||
from ..parts_of_speech import X
|
from ..parts_of_speech import X
|
||||||
from .._ml import Tok2Vec, build_tagger_model
|
from .._ml import Tok2Vec, build_tagger_model
|
||||||
from .._ml import build_text_classifier, build_simple_cnn_text_classifier
|
from .._ml import build_text_classifier, build_simple_cnn_text_classifier
|
||||||
|
from .._ml import build_bow_text_classifier
|
||||||
from .._ml import link_vectors_to_models, zero_init, flatten
|
from .._ml import link_vectors_to_models, zero_init, flatten
|
||||||
from .._ml import masked_language_model, create_default_optimizer
|
from .._ml import masked_language_model, create_default_optimizer
|
||||||
from ..errors import Errors, TempErrors
|
from ..errors import Errors, TempErrors
|
||||||
|
@ -876,6 +877,8 @@ class TextCategorizer(Pipe):
|
||||||
if cfg.get("architecture") == "simple_cnn":
|
if cfg.get("architecture") == "simple_cnn":
|
||||||
tok2vec = Tok2Vec(token_vector_width, embed_size, **cfg)
|
tok2vec = Tok2Vec(token_vector_width, embed_size, **cfg)
|
||||||
return build_simple_cnn_text_classifier(tok2vec, nr_class, **cfg)
|
return build_simple_cnn_text_classifier(tok2vec, nr_class, **cfg)
|
||||||
|
elif cfg.get("architecture") == "bow":
|
||||||
|
return build_bow_text_classifier(nr_class, **cfg)
|
||||||
else:
|
else:
|
||||||
return build_text_classifier(nr_class, **cfg)
|
return build_text_classifier(nr_class, **cfg)
|
||||||
|
|
||||||
|
@ -1058,4 +1061,90 @@ cdef class EntityRecognizer(Parser):
|
||||||
if move[0] in ("B", "I", "L", "U")))
|
if move[0] in ("B", "I", "L", "U")))
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Tagger", "DependencyParser", "EntityRecognizer", "Tensorizer", "TextCategorizer"]
|
class Sentencizer(object):
|
||||||
|
"""Segment the Doc into sentences using a rule-based strategy.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/sentencizer
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "sentencizer"
|
||||||
|
default_punct_chars = [".", "!", "?"]
|
||||||
|
|
||||||
|
def __init__(self, punct_chars=None, **kwargs):
|
||||||
|
"""Initialize the sentencizer.
|
||||||
|
|
||||||
|
punct_chars (list): Punctuation characters to split on. Will be
|
||||||
|
serialized with the nlp object.
|
||||||
|
RETURNS (Sentencizer): The sentencizer component.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/sentencizer#init
|
||||||
|
"""
|
||||||
|
self.punct_chars = punct_chars or self.default_punct_chars
|
||||||
|
|
||||||
|
def __call__(self, doc):
|
||||||
|
"""Apply the sentencizer to a Doc and set Token.is_sent_start.
|
||||||
|
|
||||||
|
doc (Doc): The document to process.
|
||||||
|
RETURNS (Doc): The processed Doc.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/sentencizer#call
|
||||||
|
"""
|
||||||
|
start = 0
|
||||||
|
seen_period = False
|
||||||
|
for i, token in enumerate(doc):
|
||||||
|
is_in_punct_chars = token.text in self.punct_chars
|
||||||
|
token.is_sent_start = i == 0
|
||||||
|
if seen_period and not token.is_punct and not is_in_punct_chars:
|
||||||
|
doc[start].is_sent_start = True
|
||||||
|
start = token.i
|
||||||
|
seen_period = False
|
||||||
|
elif is_in_punct_chars:
|
||||||
|
seen_period = True
|
||||||
|
if start < len(doc):
|
||||||
|
doc[start].is_sent_start = True
|
||||||
|
return doc
|
||||||
|
|
||||||
|
def to_bytes(self, **kwargs):
|
||||||
|
"""Serialize the sentencizer to a bytestring.
|
||||||
|
|
||||||
|
RETURNS (bytes): The serialized object.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/sentencizer#to_bytes
|
||||||
|
"""
|
||||||
|
return srsly.msgpack_dumps({"punct_chars": self.punct_chars})
|
||||||
|
|
||||||
|
def from_bytes(self, bytes_data, **kwargs):
|
||||||
|
"""Load the sentencizer from a bytestring.
|
||||||
|
|
||||||
|
bytes_data (bytes): The data to load.
|
||||||
|
returns (Sentencizer): The loaded object.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/sentencizer#from_bytes
|
||||||
|
"""
|
||||||
|
cfg = srsly.msgpack_loads(bytes_data)
|
||||||
|
self.punct_chars = cfg.get("punct_chars", self.default_punct_chars)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def to_disk(self, path, exclude=tuple(), **kwargs):
|
||||||
|
"""Serialize the sentencizer to disk.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/sentencizer#to_disk
|
||||||
|
"""
|
||||||
|
path = util.ensure_path(path)
|
||||||
|
path = path.with_suffix(".json")
|
||||||
|
srsly.write_json(path, {"punct_chars": self.punct_chars})
|
||||||
|
|
||||||
|
|
||||||
|
def from_disk(self, path, exclude=tuple(), **kwargs):
|
||||||
|
"""Load the sentencizer from disk.
|
||||||
|
|
||||||
|
DOCS: https://spacy.io/api/sentencizer#from_disk
|
||||||
|
"""
|
||||||
|
path = util.ensure_path(path)
|
||||||
|
path = path.with_suffix(".json")
|
||||||
|
cfg = srsly.read_json(path)
|
||||||
|
self.punct_chars = cfg.get("punct_chars", self.default_punct_chars)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["Tagger", "DependencyParser", "EntityRecognizer", "Tensorizer", "TextCategorizer", "Sentencizer"]
|
||||||
|
|
|
@ -574,11 +574,12 @@ cdef class Parser:
|
||||||
cfg.setdefault('min_action_freq', 30)
|
cfg.setdefault('min_action_freq', 30)
|
||||||
actions = self.moves.get_actions(gold_parses=get_gold_tuples(),
|
actions = self.moves.get_actions(gold_parses=get_gold_tuples(),
|
||||||
min_freq=cfg.get('min_action_freq', 30))
|
min_freq=cfg.get('min_action_freq', 30))
|
||||||
previous_labels = dict(self.moves.labels)
|
for action, labels in self.moves.labels.items():
|
||||||
|
actions.setdefault(action, {})
|
||||||
|
for label, freq in labels.items():
|
||||||
|
if label not in actions[action]:
|
||||||
|
actions[action][label] = freq
|
||||||
self.moves.initialize_actions(actions)
|
self.moves.initialize_actions(actions)
|
||||||
for action, label_freqs in previous_labels.items():
|
|
||||||
for label in label_freqs:
|
|
||||||
self.moves.add_action(action, label)
|
|
||||||
cfg.setdefault('token_vector_width', 96)
|
cfg.setdefault('token_vector_width', 96)
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
self.model, cfg = self.Model(self.moves.n_moves, **cfg)
|
self.model, cfg = self.Model(self.moves.n_moves, **cfg)
|
||||||
|
|
|
@ -8,7 +8,8 @@ from spacy.attrs import NORM
|
||||||
from spacy.gold import GoldParse
|
from spacy.gold import GoldParse
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
from spacy.pipeline import DependencyParser
|
from spacy.pipeline import DependencyParser, EntityRecognizer
|
||||||
|
from spacy.util import fix_random_seed
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -19,18 +20,6 @@ def vocab():
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def parser(vocab):
|
def parser(vocab):
|
||||||
parser = DependencyParser(vocab)
|
parser = DependencyParser(vocab)
|
||||||
parser.cfg["token_vector_width"] = 8
|
|
||||||
parser.cfg["hidden_width"] = 30
|
|
||||||
parser.cfg["hist_size"] = 0
|
|
||||||
parser.add_label("left")
|
|
||||||
parser.begin_training([], **parser.cfg)
|
|
||||||
sgd = Adam(NumpyOps(), 0.001)
|
|
||||||
|
|
||||||
for i in range(10):
|
|
||||||
losses = {}
|
|
||||||
doc = Doc(vocab, words=["a", "b", "c", "d"])
|
|
||||||
gold = GoldParse(doc, heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"])
|
|
||||||
parser.update([doc], [gold], sgd=sgd, losses=losses)
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,25 +27,23 @@ def test_init_parser(parser):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# TODO: This is flakey, because it depends on what the parser first learns.
|
def _train_parser(parser):
|
||||||
# TODO: This now seems to be implicated in segfaults. Not sure what's up!
|
fix_random_seed(1)
|
||||||
@pytest.mark.skip
|
parser.add_label("left")
|
||||||
|
parser.begin_training([], **parser.cfg)
|
||||||
|
sgd = Adam(NumpyOps(), 0.001)
|
||||||
|
|
||||||
|
for i in range(5):
|
||||||
|
losses = {}
|
||||||
|
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
|
||||||
|
gold = GoldParse(doc, heads=[1, 1, 3, 3], deps=["left", "ROOT", "left", "ROOT"])
|
||||||
|
parser.update([doc], [gold], sgd=sgd, losses=losses)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def test_add_label(parser):
|
def test_add_label(parser):
|
||||||
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
|
parser = _train_parser(parser)
|
||||||
doc = parser(doc)
|
|
||||||
assert doc[0].head.i == 1
|
|
||||||
assert doc[0].dep_ == "left"
|
|
||||||
assert doc[1].head.i == 1
|
|
||||||
assert doc[2].head.i == 3
|
|
||||||
assert doc[2].head.i == 3
|
|
||||||
parser.add_label("right")
|
parser.add_label("right")
|
||||||
doc = Doc(parser.vocab, words=["a", "b", "c", "d"])
|
|
||||||
doc = parser(doc)
|
|
||||||
assert doc[0].head.i == 1
|
|
||||||
assert doc[0].dep_ == "left"
|
|
||||||
assert doc[1].head.i == 1
|
|
||||||
assert doc[2].head.i == 3
|
|
||||||
assert doc[2].head.i == 3
|
|
||||||
sgd = Adam(NumpyOps(), 0.001)
|
sgd = Adam(NumpyOps(), 0.001)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
losses = {}
|
losses = {}
|
||||||
|
@ -69,3 +56,15 @@ def test_add_label(parser):
|
||||||
doc = parser(doc)
|
doc = parser(doc)
|
||||||
assert doc[0].dep_ == "right"
|
assert doc[0].dep_ == "right"
|
||||||
assert doc[2].dep_ == "left"
|
assert doc[2].dep_ == "left"
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_label_deserializes_correctly():
|
||||||
|
ner1 = EntityRecognizer(Vocab())
|
||||||
|
ner1.add_label("C")
|
||||||
|
ner1.add_label("B")
|
||||||
|
ner1.add_label("A")
|
||||||
|
ner1.begin_training([])
|
||||||
|
ner2 = EntityRecognizer(Vocab()).from_bytes(ner1.to_bytes())
|
||||||
|
assert ner1.moves.n_moves == ner2.moves.n_moves
|
||||||
|
for i in range(ner1.moves.n_moves):
|
||||||
|
assert ner1.moves.get_class_name(i) == ner2.moves.get_class_name(i)
|
||||||
|
|
87
spacy/tests/pipeline/test_sentencizer.py
Normal file
87
spacy/tests/pipeline/test_sentencizer.py
Normal file
|
@ -0,0 +1,87 @@
|
||||||
|
# coding: utf8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from spacy.pipeline import Sentencizer
|
||||||
|
from spacy.tokens import Doc
|
||||||
|
|
||||||
|
|
||||||
|
def test_sentencizer(en_vocab):
|
||||||
|
doc = Doc(en_vocab, words=["Hello", "!", "This", "is", "a", "test", "."])
|
||||||
|
sentencizer = Sentencizer()
|
||||||
|
doc = sentencizer(doc)
|
||||||
|
assert doc.is_sentenced
|
||||||
|
sent_starts = [t.is_sent_start for t in doc]
|
||||||
|
assert sent_starts == [True, False, True, False, False, False, False]
|
||||||
|
assert len(list(doc.sents)) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"words,sent_starts,n_sents",
|
||||||
|
[
|
||||||
|
# The expected result here is that the duplicate punctuation gets merged
|
||||||
|
# onto the same sentence and no one-token sentence is created for them.
|
||||||
|
(
|
||||||
|
["Hello", "!", ".", "Test", ".", ".", "ok"],
|
||||||
|
[True, False, False, True, False, False, True],
|
||||||
|
3,
|
||||||
|
),
|
||||||
|
# We also want to make sure ¡ and ¿ aren't treated as sentence end
|
||||||
|
# markers, even though they're punctuation
|
||||||
|
(
|
||||||
|
["¡", "Buen", "día", "!", "Hola", ",", "¿", "qué", "tal", "?"],
|
||||||
|
[True, False, False, False, True, False, False, False, False, False],
|
||||||
|
2,
|
||||||
|
),
|
||||||
|
# The Token.is_punct check ensures that quotes are handled as well
|
||||||
|
(
|
||||||
|
['"', "Nice", "!", '"', "I", "am", "happy", "."],
|
||||||
|
[True, False, False, False, True, False, False, False],
|
||||||
|
2,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_sentencizer_complex(en_vocab, words, sent_starts, n_sents):
|
||||||
|
doc = Doc(en_vocab, words=words)
|
||||||
|
sentencizer = Sentencizer()
|
||||||
|
doc = sentencizer(doc)
|
||||||
|
assert doc.is_sentenced
|
||||||
|
assert [t.is_sent_start for t in doc] == sent_starts
|
||||||
|
assert len(list(doc.sents)) == n_sents
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"punct_chars,words,sent_starts,n_sents",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
["~", "?"],
|
||||||
|
["Hello", "world", "~", "A", ".", "B", "."],
|
||||||
|
[True, False, False, True, False, False, False],
|
||||||
|
2,
|
||||||
|
),
|
||||||
|
# Even thought it's not common, the punct_chars should be able to
|
||||||
|
# handle any tokens
|
||||||
|
(
|
||||||
|
[".", "ö"],
|
||||||
|
["Hello", ".", "Test", "ö", "Ok", "."],
|
||||||
|
[True, False, True, False, True, False],
|
||||||
|
3,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_sentencizer_custom_punct(en_vocab, punct_chars, words, sent_starts, n_sents):
|
||||||
|
doc = Doc(en_vocab, words=words)
|
||||||
|
sentencizer = Sentencizer(punct_chars=punct_chars)
|
||||||
|
doc = sentencizer(doc)
|
||||||
|
assert doc.is_sentenced
|
||||||
|
assert [t.is_sent_start for t in doc] == sent_starts
|
||||||
|
assert len(list(doc.sents)) == n_sents
|
||||||
|
|
||||||
|
|
||||||
|
def test_sentencizer_serialize_bytes(en_vocab):
|
||||||
|
punct_chars = [".", "~", "+"]
|
||||||
|
sentencizer = Sentencizer(punct_chars=punct_chars)
|
||||||
|
assert sentencizer.punct_chars == punct_chars
|
||||||
|
bytes_data = sentencizer.to_bytes()
|
||||||
|
new_sentencizer = Sentencizer().from_bytes(bytes_data)
|
||||||
|
assert new_sentencizer.punct_chars == punct_chars
|
22
spacy/tests/regression/test_issue3468.py
Normal file
22
spacy/tests/regression/test_issue3468.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
# coding: utf8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from spacy.lang.en import English
|
||||||
|
from spacy.tokens import Doc
|
||||||
|
|
||||||
|
|
||||||
|
def test_issue3468():
|
||||||
|
"""Test that sentence boundaries are set correctly so Doc.is_sentenced can
|
||||||
|
be restored after serialization."""
|
||||||
|
nlp = English()
|
||||||
|
nlp.add_pipe(nlp.create_pipe("sentencizer"))
|
||||||
|
doc = nlp("Hello world")
|
||||||
|
assert doc[0].is_sent_start
|
||||||
|
assert doc.is_sentenced
|
||||||
|
assert len(list(doc.sents)) == 1
|
||||||
|
doc_bytes = doc.to_bytes()
|
||||||
|
new_doc = Doc(nlp.vocab).from_bytes(doc_bytes)
|
||||||
|
assert new_doc[0].is_sent_start
|
||||||
|
assert new_doc.is_sentenced
|
||||||
|
assert len(list(new_doc.sents)) == 1
|
|
@ -230,7 +230,7 @@ cdef class Doc:
|
||||||
defined as having at least one of the following:
|
defined as having at least one of the following:
|
||||||
|
|
||||||
a) An entry "sents" in doc.user_hooks";
|
a) An entry "sents" in doc.user_hooks";
|
||||||
b) sent.is_parsed is set to True;
|
b) Doc.is_parsed is set to True;
|
||||||
c) At least one token other than the first where sent_start is not None.
|
c) At least one token other than the first where sent_start is not None.
|
||||||
"""
|
"""
|
||||||
if "sents" in self.user_hooks:
|
if "sents" in self.user_hooks:
|
||||||
|
|
|
@ -441,6 +441,7 @@ cdef class Token:
|
||||||
|
|
||||||
property sent_start:
|
property sent_start:
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
|
"""Deprecated: use Token.is_sent_start instead."""
|
||||||
# Raising a deprecation warning here causes errors for autocomplete
|
# Raising a deprecation warning here causes errors for autocomplete
|
||||||
# Handle broken backwards compatibility case: doc[0].sent_start
|
# Handle broken backwards compatibility case: doc[0].sent_start
|
||||||
# was False.
|
# was False.
|
||||||
|
|
|
@ -1,78 +0,0 @@
|
||||||
---
|
|
||||||
title: SentenceSegmenter
|
|
||||||
tag: class
|
|
||||||
source: spacy/pipeline/hooks.py
|
|
||||||
---
|
|
||||||
|
|
||||||
A simple spaCy hook, to allow custom sentence boundary detection logic that
|
|
||||||
doesn't require the dependency parse. By default, sentence segmentation is
|
|
||||||
performed by the [`DependencyParser`](/api/dependencyparser), so the
|
|
||||||
`SentenceSegmenter` lets you implement a simpler, rule-based strategy that
|
|
||||||
doesn't require a statistical model to be loaded. The component is also
|
|
||||||
available via the string name `"sentencizer"`. After initialization, it is
|
|
||||||
typically added to the processing pipeline using
|
|
||||||
[`nlp.add_pipe`](/api/language#add_pipe).
|
|
||||||
|
|
||||||
## SentenceSegmenter.\_\_init\_\_ {#init tag="method"}
|
|
||||||
|
|
||||||
Initialize the sentence segmenter. To change the sentence boundary detection
|
|
||||||
strategy, pass a generator function `strategy` on initialization, or assign a
|
|
||||||
new strategy to the `.strategy` attribute. Sentence detection strategies should
|
|
||||||
be generators that take `Doc` objects and yield `Span` objects for each
|
|
||||||
sentence.
|
|
||||||
|
|
||||||
> #### Example
|
|
||||||
>
|
|
||||||
> ```python
|
|
||||||
> # Construction via create_pipe
|
|
||||||
> sentencizer = nlp.create_pipe("sentencizer")
|
|
||||||
>
|
|
||||||
> # Construction from class
|
|
||||||
> from spacy.pipeline import SentenceSegmenter
|
|
||||||
> sentencizer = SentenceSegmenter(nlp.vocab)
|
|
||||||
> ```
|
|
||||||
|
|
||||||
| Name | Type | Description |
|
|
||||||
| ----------- | ------------------- | ----------------------------------------------------------- |
|
|
||||||
| `vocab` | `Vocab` | The shared vocabulary. |
|
|
||||||
| `strategy` | unicode / callable | The segmentation strategy to use. Defaults to `"on_punct"`. |
|
|
||||||
| **RETURNS** | `SentenceSegmenter` | The newly constructed object. |
|
|
||||||
|
|
||||||
## SentenceSegmenter.\_\_call\_\_ {#call tag="method"}
|
|
||||||
|
|
||||||
Apply the sentence segmenter on a `Doc`. Typically, this happens automatically
|
|
||||||
after the component has been added to the pipeline using
|
|
||||||
[`nlp.add_pipe`](/api/language#add_pipe).
|
|
||||||
|
|
||||||
> #### Example
|
|
||||||
>
|
|
||||||
> ```python
|
|
||||||
> from spacy.lang.en import English
|
|
||||||
>
|
|
||||||
> nlp = English()
|
|
||||||
> sentencizer = nlp.create_pipe("sentencizer")
|
|
||||||
> nlp.add_pipe(sentencizer)
|
|
||||||
> doc = nlp(u"This is a sentence. This is another sentence.")
|
|
||||||
> assert list(doc.sents) == 2
|
|
||||||
> ```
|
|
||||||
|
|
||||||
| Name | Type | Description |
|
|
||||||
| ----------- | ----- | ------------------------------------------------------------ |
|
|
||||||
| `doc` | `Doc` | The `Doc` object to process, e.g. the `Doc` in the pipeline. |
|
|
||||||
| **RETURNS** | `Doc` | The modified `Doc` with added sentence boundaries. |
|
|
||||||
|
|
||||||
## SentenceSegmenter.split_on_punct {#split_on_punct tag="staticmethod"}
|
|
||||||
|
|
||||||
Split the `Doc` on punctuation characters `.`, `!` and `?`. This is the default
|
|
||||||
strategy used by the `SentenceSegmenter.`
|
|
||||||
|
|
||||||
| Name | Type | Description |
|
|
||||||
| ---------- | ------ | ------------------------------ |
|
|
||||||
| `doc` | `Doc` | The `Doc` object to process. |
|
|
||||||
| **YIELDS** | `Span` | The sentences in the document. |
|
|
||||||
|
|
||||||
## Attributes {#attributes}
|
|
||||||
|
|
||||||
| Name | Type | Description |
|
|
||||||
| ---------- | -------- | ------------------------------------------------------------------- |
|
|
||||||
| `strategy` | callable | The segmentation strategy. Can be overwritten after initialization. |
|
|
136
website/docs/api/sentencizer.md
Normal file
136
website/docs/api/sentencizer.md
Normal file
|
@ -0,0 +1,136 @@
|
||||||
|
---
|
||||||
|
title: Sentencizer
|
||||||
|
tag: class
|
||||||
|
source: spacy/pipeline/pipes.pyx
|
||||||
|
---
|
||||||
|
|
||||||
|
A simple pipeline component, to allow custom sentence boundary detection logic
|
||||||
|
that doesn't require the dependency parse. By default, sentence segmentation is
|
||||||
|
performed by the [`DependencyParser`](/api/dependencyparser), so the
|
||||||
|
`Sentencizer` lets you implement a simpler, rule-based strategy that doesn't
|
||||||
|
require a statistical model to be loaded. The component is also available via
|
||||||
|
the string name `"sentencizer"`. After initialization, it is typically added to
|
||||||
|
the processing pipeline using [`nlp.add_pipe`](/api/language#add_pipe).
|
||||||
|
|
||||||
|
<Infobox title="Important note" variant="warning">
|
||||||
|
|
||||||
|
Compared to the previous `SentenceSegmenter` class, the `Sentencizer` component
|
||||||
|
doesn't add a hook to `doc.user_hooks["sents"]`. Instead, it iterates over the
|
||||||
|
tokens in the `Doc` and sets the `Token.is_sent_start` property. The
|
||||||
|
`SentenceSegmenter` is still available if you import it directly:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from spacy.pipeline import SentenceSegmenter
|
||||||
|
```
|
||||||
|
|
||||||
|
</Infobox>
|
||||||
|
|
||||||
|
## Sentencizer.\_\_init\_\_ {#init tag="method"}
|
||||||
|
|
||||||
|
Initialize the sentencizer.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> # Construction via create_pipe
|
||||||
|
> sentencizer = nlp.create_pipe("sentencizer")
|
||||||
|
>
|
||||||
|
> # Construction from class
|
||||||
|
> from spacy.pipeline import Sentencizer
|
||||||
|
> sentencizer = Sentencizer()
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| ------------- | ------------- | ------------------------------------------------------------------------------------------------------ |
|
||||||
|
| `punct_chars` | list | Optional custom list of punctuation characters that mark sentence ends. Defaults to `[".", "!", "?"].` |
|
||||||
|
| **RETURNS** | `Sentencizer` | The newly constructed object. |
|
||||||
|
|
||||||
|
## Sentencizer.\_\_call\_\_ {#call tag="method"}
|
||||||
|
|
||||||
|
Apply the sentencizer on a `Doc`. Typically, this happens automatically after
|
||||||
|
the component has been added to the pipeline using
|
||||||
|
[`nlp.add_pipe`](/api/language#add_pipe).
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> from spacy.lang.en import English
|
||||||
|
>
|
||||||
|
> nlp = English()
|
||||||
|
> sentencizer = nlp.create_pipe("sentencizer")
|
||||||
|
> nlp.add_pipe(sentencizer)
|
||||||
|
> doc = nlp(u"This is a sentence. This is another sentence.")
|
||||||
|
> assert list(doc.sents) == 2
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| ----------- | ----- | ------------------------------------------------------------ |
|
||||||
|
| `doc` | `Doc` | The `Doc` object to process, e.g. the `Doc` in the pipeline. |
|
||||||
|
| **RETURNS** | `Doc` | The modified `Doc` with added sentence boundaries. |
|
||||||
|
|
||||||
|
## Sentencizer.to_disk {#to_disk tag="method"}
|
||||||
|
|
||||||
|
Save the sentencizer settings (punctuation characters) a directory. Will create
|
||||||
|
a file `sentencizer.json`. This also happens automatically when you save an
|
||||||
|
`nlp` object with a sentencizer added to its pipeline.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> sentencizer = Sentencizer(punct_chars=[".", "?", "!", "。"])
|
||||||
|
> sentencizer.to_disk("/path/to/sentencizer.jsonl")
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| ------ | ---------------- | ---------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `path` | unicode / `Path` | A path to a file, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||||
|
|
||||||
|
## Sentencizer.from_disk {#from_disk tag="method"}
|
||||||
|
|
||||||
|
Load the sentencizer settings from a file. Expects a JSON file. This also
|
||||||
|
happens automatically when you load an `nlp` object or model with a sentencizer
|
||||||
|
added to its pipeline.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> sentencizer = Sentencizer()
|
||||||
|
> sentencizer.from_disk("/path/to/sentencizer.json")
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| ----------- | ---------------- | -------------------------------------------------------------------------- |
|
||||||
|
| `path` | unicode / `Path` | A path to a JSON file. Paths may be either strings or `Path`-like objects. |
|
||||||
|
| **RETURNS** | `Sentencizer` | The modified `Sentencizer` object. |
|
||||||
|
|
||||||
|
## Sentencizer.to_bytes {#to_bytes tag="method"}
|
||||||
|
|
||||||
|
Serialize the sentencizer settings to a bytestring.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> sentencizer = Sentencizer(punct_chars=[".", "?", "!", "。"])
|
||||||
|
> sentencizer_bytes = sentencizer.to_bytes()
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| ----------- | ----- | -------------------- |
|
||||||
|
| **RETURNS** | bytes | The serialized data. |
|
||||||
|
|
||||||
|
## Sentencizer.from_bytes {#from_bytes tag="method"}
|
||||||
|
|
||||||
|
Load the pipe from a bytestring. Modifies the object in place and returns it.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> sentencizer_bytes = sentencizer.to_bytes()
|
||||||
|
> sentencizer = Sentencizer()
|
||||||
|
> sentencizer.from_bytes(sentencizer_bytes)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Type | Description |
|
||||||
|
| ------------ | ------------- | ---------------------------------- |
|
||||||
|
| `bytes_data` | bytes | The bytestring to load. |
|
||||||
|
| **RETURNS** | `Sentencizer` | The modified `Sentencizer` object. |
|
|
@ -56,10 +56,11 @@ of problems. To handle a wider variety of problems, the `TextCategorizer` object
|
||||||
allows configuration of its model architecture, using the `architecture` keyword
|
allows configuration of its model architecture, using the `architecture` keyword
|
||||||
argument.
|
argument.
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| -------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `"ensemble"` | **Default:** Stacked ensemble of a unigram bag-of-words model and a neural network model. The neural network uses a CNN with mean pooling and attention. |
|
| `"ensemble"` | **Default:** Stacked ensemble of a bag-of-words model and a neural network model. The neural network uses a CNN with mean pooling and attention. The "ngram_size" and "attr" arguments can be used to configure the feature extraction for the bag-of-words model. |
|
||||||
| `"simple_cnn"` | A neural network model where token vectors are calculated using a CNN. The vectors are mean pooled and used as features in a feed-forward network. |
|
| `"simple_cnn"` | A neural network model where token vectors are calculated using a CNN. The vectors are mean pooled and used as features in a feed-forward network. This architecture is usually less accurate than the ensemble, but runs faster. |
|
||||||
|
| `"bow"` | An ngram "bag-of-words" model. This architecture should run much faster than the others, but may not be as accurate, especially if texts are short. The features extracted can be controlled using the keyword arguments `ngram_size` and `attr`. For instance, `ngram_size=3` and `attr="lower"` would give lower-cased unigram, trigram and bigram features. 2, 3 or 4 are usually good choices of ngram size. |
|
||||||
|
|
||||||
## TextCategorizer.\_\_call\_\_ {#call tag="method"}
|
## TextCategorizer.\_\_call\_\_ {#call tag="method"}
|
||||||
|
|
||||||
|
|
|
@ -25,21 +25,21 @@ an **annotated document**. It also orchestrates training and serialization.
|
||||||
|
|
||||||
### Processing pipeline {#architecture-pipeline}
|
### Processing pipeline {#architecture-pipeline}
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| --------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------- |
|
| ------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| [`Language`](/api/language) | A text-processing pipeline. Usually you'll load this once per process as `nlp` and pass the instance around your application. |
|
| [`Language`](/api/language) | A text-processing pipeline. Usually you'll load this once per process as `nlp` and pass the instance around your application. |
|
||||||
| [`Tokenizer`](/api/tokenizer) | Segment text, and create `Doc` objects with the discovered segment boundaries. |
|
| [`Tokenizer`](/api/tokenizer) | Segment text, and create `Doc` objects with the discovered segment boundaries. |
|
||||||
| [`Lemmatizer`](/api/lemmatizer) | Determine the base forms of words. |
|
| [`Lemmatizer`](/api/lemmatizer) | Determine the base forms of words. |
|
||||||
| `Morphology` | Assign linguistic features like lemmas, noun case, verb tense etc. based on the word and its part-of-speech tag. |
|
| `Morphology` | Assign linguistic features like lemmas, noun case, verb tense etc. based on the word and its part-of-speech tag. |
|
||||||
| [`Tagger`](/api/tagger) | Annotate part-of-speech tags on `Doc` objects. |
|
| [`Tagger`](/api/tagger) | Annotate part-of-speech tags on `Doc` objects. |
|
||||||
| [`DependencyParser`](/api/dependencyparser) | Annotate syntactic dependencies on `Doc` objects. |
|
| [`DependencyParser`](/api/dependencyparser) | Annotate syntactic dependencies on `Doc` objects. |
|
||||||
| [`EntityRecognizer`](/api/entityrecognizer) | Annotate named entities, e.g. persons or products, on `Doc` objects. |
|
| [`EntityRecognizer`](/api/entityrecognizer) | Annotate named entities, e.g. persons or products, on `Doc` objects. |
|
||||||
| [`TextCategorizer`](/api/textcategorizer) | Assign categories or labels to `Doc` objects. |
|
| [`TextCategorizer`](/api/textcategorizer) | Assign categories or labels to `Doc` objects. |
|
||||||
| [`Matcher`](/api/matcher) | Match sequences of tokens, based on pattern rules, similar to regular expressions. |
|
| [`Matcher`](/api/matcher) | Match sequences of tokens, based on pattern rules, similar to regular expressions. |
|
||||||
| [`PhraseMatcher`](/api/phrasematcher) | Match sequences of tokens based on phrases. |
|
| [`PhraseMatcher`](/api/phrasematcher) | Match sequences of tokens based on phrases. |
|
||||||
| [`EntityRuler`](/api/entityruler) | Add entity spans to the `Doc` using token-based rules or exact phrase matches. |
|
| [`EntityRuler`](/api/entityruler) | Add entity spans to the `Doc` using token-based rules or exact phrase matches. |
|
||||||
| [`SentenceSegmenter`](/api/sentencesegmenter) | Implement custom sentence boundary detection logic that doesn't require the dependency parse. |
|
| [`Sentencizer`](/api/sentencizer) | Implement custom sentence boundary detection logic that doesn't require the dependency parse. |
|
||||||
| [Other functions](/api/pipeline-functions) | Automatically apply something to the `Doc`, e.g. to merge spans of tokens. |
|
| [Other functions](/api/pipeline-functions) | Automatically apply something to the `Doc`, e.g. to merge spans of tokens. |
|
||||||
|
|
||||||
### Other classes {#architecture-other}
|
### Other classes {#architecture-other}
|
||||||
|
|
||||||
|
|
|
@ -1149,9 +1149,14 @@ but it also means you'll need a **statistical model** and accurate predictions.
|
||||||
If your texts are closer to general-purpose news or web text, this should work
|
If your texts are closer to general-purpose news or web text, this should work
|
||||||
well out-of-the-box. For social media or conversational text that doesn't follow
|
well out-of-the-box. For social media or conversational text that doesn't follow
|
||||||
the same rules, your application may benefit from a custom rule-based
|
the same rules, your application may benefit from a custom rule-based
|
||||||
implementation. You can either plug a rule-based component into your
|
implementation. You can either use the built-in
|
||||||
[processing pipeline](/usage/processing-pipelines) or use the
|
[`Sentencizer`](/api/sentencizer) or plug an entirely custom rule-based function
|
||||||
`SentenceSegmenter` component with a custom strategy.
|
into your [processing pipeline](/usage/processing-pipelines).
|
||||||
|
|
||||||
|
spaCy's dependency parser respects already set boundaries, so you can preprocess
|
||||||
|
your `Doc` using custom rules _before_ it's parsed. Depending on your text, this
|
||||||
|
may also improve accuracy, since the parser is constrained to predict parses
|
||||||
|
consistent with the sentence boundaries.
|
||||||
|
|
||||||
### Default: Using the dependency parse {#sbd-parser model="parser"}
|
### Default: Using the dependency parse {#sbd-parser model="parser"}
|
||||||
|
|
||||||
|
@ -1168,13 +1173,35 @@ for sent in doc.sents:
|
||||||
print(sent.text)
|
print(sent.text)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Setting boundaries manually {#sbd-manual}
|
### Rule-based pipeline component {#sbd-component}
|
||||||
|
|
||||||
spaCy's dependency parser respects already set boundaries, so you can preprocess
|
The [`Sentencizer`](/api/sentencizer) component is a
|
||||||
your `Doc` using custom rules _before_ it's parsed. This can be done by adding a
|
[pipeline component](/usage/processing-pipelines) that splits sentences on
|
||||||
[custom pipeline component](/usage/processing-pipelines). Depending on your
|
punctuation like `.`, `!` or `?`. You can plug it into your pipeline if you only
|
||||||
text, this may also improve accuracy, since the parser is constrained to predict
|
need sentence boundaries without the dependency parse.
|
||||||
parses consistent with the sentence boundaries.
|
|
||||||
|
```python
|
||||||
|
### {executable="true"}
|
||||||
|
import spacy
|
||||||
|
from spacy.lang.en import English
|
||||||
|
|
||||||
|
nlp = English() # just the language with no model
|
||||||
|
sentencizer = nlp.create_pipe("sentencizer")
|
||||||
|
nlp.add_pipe(sentencizer)
|
||||||
|
doc = nlp(u"This is a sentence. This is another sentence.")
|
||||||
|
for sent in doc.sents:
|
||||||
|
print(sent.text)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Custom rule-based strategy {id="sbd-custom"}
|
||||||
|
|
||||||
|
If you want to implement your own strategy that differs from the default
|
||||||
|
rule-based approach of splitting on sentences, you can also create a
|
||||||
|
[custom pipeline component](/usage/processing-pipelines#custom-components) that
|
||||||
|
takes a `Doc` object and sets the `Token.is_sent_start` attribute on each
|
||||||
|
individual token. If set to `False`, the token is explicitly marked as _not_ the
|
||||||
|
start of a sentence. If set to `None` (default), it's treated as a missing value
|
||||||
|
and can still be overwritten by the parser.
|
||||||
|
|
||||||
<Infobox title="Important note" variant="warning">
|
<Infobox title="Important note" variant="warning">
|
||||||
|
|
||||||
|
@ -1187,9 +1214,11 @@ adding it to the pipeline using [`nlp.add_pipe`](/api/language#add_pipe).
|
||||||
|
|
||||||
Here's an example of a component that implements a pre-processing rule for
|
Here's an example of a component that implements a pre-processing rule for
|
||||||
splitting on `'...'` tokens. The component is added before the parser, which is
|
splitting on `'...'` tokens. The component is added before the parser, which is
|
||||||
then used to further segment the text. This approach can be useful if you want
|
then used to further segment the text. That's possible, because `is_sent_start`
|
||||||
to implement **additional** rules specific to your data, while still being able
|
is only set to `True` for some of the tokens – all others still specify `None`
|
||||||
to take advantage of dependency-based sentence segmentation.
|
for unset sentence boundaries. This approach can be useful if you want to
|
||||||
|
implement **additional** rules specific to your data, while still being able to
|
||||||
|
take advantage of dependency-based sentence segmentation.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
### {executable="true"}
|
### {executable="true"}
|
||||||
|
@ -1212,62 +1241,6 @@ doc = nlp(text)
|
||||||
print("After:", [sent.text for sent in doc.sents])
|
print("After:", [sent.text for sent in doc.sents])
|
||||||
```
|
```
|
||||||
|
|
||||||
### Rule-based pipeline component {#sbd-component}
|
|
||||||
|
|
||||||
The `sentencizer` component is a
|
|
||||||
[pipeline component](/usage/processing-pipelines) that splits sentences on
|
|
||||||
punctuation like `.`, `!` or `?`. You can plug it into your pipeline if you only
|
|
||||||
need sentence boundaries without the dependency parse. Note that `Doc.sents`
|
|
||||||
will **raise an error** if no sentence boundaries are set.
|
|
||||||
|
|
||||||
```python
|
|
||||||
### {executable="true"}
|
|
||||||
import spacy
|
|
||||||
from spacy.lang.en import English
|
|
||||||
|
|
||||||
nlp = English() # just the language with no model
|
|
||||||
sentencizer = nlp.create_pipe("sentencizer")
|
|
||||||
nlp.add_pipe(sentencizer)
|
|
||||||
doc = nlp(u"This is a sentence. This is another sentence.")
|
|
||||||
for sent in doc.sents:
|
|
||||||
print(sent.text)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Custom rule-based strategy {#sbd-custom}
|
|
||||||
|
|
||||||
If you want to implement your own strategy that differs from the default
|
|
||||||
rule-based approach of splitting on sentences, you can also instantiate the
|
|
||||||
`SentenceSegmenter` directly and pass in your own strategy. The strategy should
|
|
||||||
be a function that takes a `Doc` object and yields a `Span` for each sentence.
|
|
||||||
Here's an example of a custom segmentation strategy for splitting on newlines
|
|
||||||
only:
|
|
||||||
|
|
||||||
```python
|
|
||||||
### {executable="true"}
|
|
||||||
from spacy.lang.en import English
|
|
||||||
from spacy.pipeline import SentenceSegmenter
|
|
||||||
|
|
||||||
def split_on_newlines(doc):
|
|
||||||
start = 0
|
|
||||||
seen_newline = False
|
|
||||||
for word in doc:
|
|
||||||
if seen_newline and not word.is_space:
|
|
||||||
yield doc[start:word.i]
|
|
||||||
start = word.i
|
|
||||||
seen_newline = False
|
|
||||||
elif word.text == '\\n':
|
|
||||||
seen_newline = True
|
|
||||||
if start < len(doc):
|
|
||||||
yield doc[start:len(doc)]
|
|
||||||
|
|
||||||
nlp = English() # Just the language with no model
|
|
||||||
sentencizer = SentenceSegmenter(nlp.vocab, strategy=split_on_newlines)
|
|
||||||
nlp.add_pipe(sentencizer)
|
|
||||||
doc = nlp(u"This is a sentence\\n\\nThis is another sentence\\nAnd more")
|
|
||||||
for sent in doc.sents:
|
|
||||||
print([token.text for token in sent])
|
|
||||||
```
|
|
||||||
|
|
||||||
## Rule-based matching {#rule-based-matching hidden="true"}
|
## Rule-based matching {#rule-based-matching hidden="true"}
|
||||||
|
|
||||||
<div id="rule-based-matching">
|
<div id="rule-based-matching">
|
||||||
|
|
|
@ -138,7 +138,7 @@ require them in the pipeline settings in your model's `meta.json`.
|
||||||
| `ner` | [`EntityRecognizer`](/api/entityrecognizer) | Assign named entities. |
|
| `ner` | [`EntityRecognizer`](/api/entityrecognizer) | Assign named entities. |
|
||||||
| `textcat` | [`TextCategorizer`](/api/textcategorizer) | Assign text categories. |
|
| `textcat` | [`TextCategorizer`](/api/textcategorizer) | Assign text categories. |
|
||||||
| `entity_ruler` | [`EntityRuler`](/api/entityruler) | Assign named entities based on pattern rules. |
|
| `entity_ruler` | [`EntityRuler`](/api/entityruler) | Assign named entities based on pattern rules. |
|
||||||
| `sentencizer` | [`SentenceSegmenter`](/api/sentencesegmenter) | Add rule-based sentence segmentation without the dependency parse. |
|
| `sentencizer` | [`Sentencizer`](/api/sentencizer) | Add rule-based sentence segmentation without the dependency parse. |
|
||||||
| `merge_noun_chunks` | [`merge_noun_chunks`](/api/pipeline-functions#merge_noun_chunks) | Merge all noun chunks into a single token. Should be added after the tagger and parser. |
|
| `merge_noun_chunks` | [`merge_noun_chunks`](/api/pipeline-functions#merge_noun_chunks) | Merge all noun chunks into a single token. Should be added after the tagger and parser. |
|
||||||
| `merge_entities` | [`merge_entities`](/api/pipeline-functions#merge_entities) | Merge all entities into a single token. Should be added after the entity recognizer. |
|
| `merge_entities` | [`merge_entities`](/api/pipeline-functions#merge_entities) | Merge all entities into a single token. Should be added after the entity recognizer. |
|
||||||
| `merge_subtokens` | [`merge_subtokens`](/api/pipeline-functions#merge_subtokens) | Merge subtokens predicted by the parser into single tokens. Should be added after the parser. |
|
| `merge_subtokens` | [`merge_subtokens`](/api/pipeline-functions#merge_subtokens) | Merge subtokens predicted by the parser into single tokens. Should be added after the parser. |
|
||||||
|
|
|
@ -197,7 +197,7 @@ the existing pages and added some new content:
|
||||||
- **Universe:** [Videos](/universe/category/videos) and
|
- **Universe:** [Videos](/universe/category/videos) and
|
||||||
[Podcasts](/universe/category/podcasts)
|
[Podcasts](/universe/category/podcasts)
|
||||||
- **API:** [`EntityRuler`](/api/entityruler)
|
- **API:** [`EntityRuler`](/api/entityruler)
|
||||||
- **API:** [`SentenceSegmenter`](/api/sentencesegmenter)
|
- **API:** [`Sentencizer`](/api/sentencizer)
|
||||||
- **API:** [Pipeline functions](/api/pipeline-functions)
|
- **API:** [Pipeline functions](/api/pipeline-functions)
|
||||||
|
|
||||||
## Backwards incompatibilities {#incompat}
|
## Backwards incompatibilities {#incompat}
|
||||||
|
|
|
@ -79,7 +79,7 @@
|
||||||
{ "text": "Matcher", "url": "/api/matcher" },
|
{ "text": "Matcher", "url": "/api/matcher" },
|
||||||
{ "text": "PhraseMatcher", "url": "/api/phrasematcher" },
|
{ "text": "PhraseMatcher", "url": "/api/phrasematcher" },
|
||||||
{ "text": "EntityRuler", "url": "/api/entityruler" },
|
{ "text": "EntityRuler", "url": "/api/entityruler" },
|
||||||
{ "text": "SentenceSegmenter", "url": "/api/sentencesegmenter" },
|
{ "text": "Sentencizer", "url": "/api/sentencizer" },
|
||||||
{ "text": "Other Functions", "url": "/api/pipeline-functions" }
|
{ "text": "Other Functions", "url": "/api/pipeline-functions" }
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 30 KiB After Width: | Height: | Size: 24 KiB |
Loading…
Reference in New Issue
Block a user