diff --git a/spacy/pipeline.pyx b/spacy/pipeline.pyx index b669e95ec..d3018ffd7 100644 --- a/spacy/pipeline.pyx +++ b/spacy/pipeline.pyx @@ -38,21 +38,47 @@ from .parts_of_speech import X class TokenVectorEncoder(object): - '''Assign position-sensitive vectors to tokens, using a CNN or RNN.''' + """Assign position-sensitive vectors to tokens, using a CNN or RNN.""" name = 'tok2vec' @classmethod def Model(cls, width=128, embed_size=5000, **cfg): + """Create a new statistical model for the class. + + width (int): Output size of the model. + embed_size (int): Number of vectors in the embedding table. + **cfg: Config parameters. + RETURNS (Model): A `thinc.neural.Model` or similar instance. + """ width = util.env_opt('token_vector_width', width) embed_size = util.env_opt('embed_size', embed_size) return Tok2Vec(width, embed_size, preprocess=None) def __init__(self, vocab, model=True, **cfg): + """Construct a new statistical model. Weights are not allocated on + initialisation. + + vocab (Vocab): A `Vocab` instance. The model must share the same `Vocab` + instance with the `Doc` objects it will process. + model (Model): A `Model` instance or `True` allocate one later. + **cfg: Config parameters. + + EXAMPLE: + >>> from spacy.pipeline import TokenVectorEncoder + >>> tok2vec = TokenVectorEncoder(nlp.vocab) + >>> tok2vec.model = tok2vec.Model(128, 5000) + """ self.vocab = vocab self.doc2feats = doc2feats() self.model = model def __call__(self, docs, state=None): + """Add context-sensitive vectors to a `Doc`, e.g. from a CNN or LSTM + model. Vectors are set to the `Doc.tensor` attribute. + + docs (Doc or iterable): One or more documents to add vectors to. + RETURNS (dict or None): Intermediate computations. + """ if isinstance(docs, Doc): docs = [docs] tokvecs = self.predict(docs) @@ -62,6 +88,13 @@ class TokenVectorEncoder(object): return state def pipe(self, stream, batch_size=128, n_threads=-1): + """Process `Doc` objects as a stream. + + stream (iterator): A sequence of `Doc` objects to process. + batch_size (int): Number of `Doc` objects to group. + n_threads (int): Number of threads. + YIELDS (tuple): Tuples of `(Doc, state)`. + """ for batch in cytoolz.partition_all(batch_size, stream): docs, states = zip(*batch) tokvecs = self.predict(docs) @@ -71,18 +104,35 @@ class TokenVectorEncoder(object): yield from zip(docs, states) def predict(self, docs): + """Return a single tensor for a batch of documents. + + docs (iterable): A sequence of `Doc` objects. + RETURNS (object): Vector representations for each token in the documents. + """ feats = self.doc2feats(docs) tokvecs = self.model(feats) return tokvecs def set_annotations(self, docs, tokvecs): + """Set the tensor attribute for a batch of documents. + + docs (iterable): A sequence of `Doc` objects. + tokvecs (object): Vector representation for each token in the documents. + """ start = 0 for doc in docs: doc.tensor = tokvecs[start : start + len(doc)] start += len(doc) - def update(self, docs, golds, state=None, - drop=0., sgd=None): + def update(self, docs, golds, state=None, drop=0., sgd=None): + """Update the model. + + docs (iterable): A batch of `Doc` objects. + golds (iterable): A batch of `GoldParse` objects. + drop (float): The droput rate. + sgd (function): An optimizer. + RETURNS (dict): Results from the update. + """ if isinstance(docs, Doc): docs = [docs] golds = [golds] @@ -95,14 +145,26 @@ class TokenVectorEncoder(object): return state def get_loss(self, docs, golds, scores): + # TODO: implement raise NotImplementedError def begin_training(self, gold_tuples, pipeline=None): + """Allocate models, pre-process training data and acquire a trainer and + optimizer. + + gold_tuples (iterable): Gold-standard training data. + pipeline (list): The pipeline the model is part of. + """ self.doc2feats = doc2feats() if self.model is True: self.model = self.Model() def use_params(self, params): + """Replace weights of models in the pipeline with those provided in the + params dictionary. + + params (dict): A dictionary of parameters keyed by model ID. + """ with self.model.use_params(params): yield