Update docstrings and remove deprecated load classmethod

This commit is contained in:
ines 2017-05-21 13:27:52 +02:00
parent c9f04f3cd0
commit 885e82c9b0

View File

@ -1,7 +1,6 @@
# coding: utf8 # coding: utf8
from __future__ import unicode_literals from __future__ import unicode_literals
import ujson
from collections import defaultdict from collections import defaultdict
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
@ -15,7 +14,6 @@ from .tokens.doc cimport Doc
from .attrs cimport TAG from .attrs cimport TAG
from .gold cimport GoldParse from .gold cimport GoldParse
from .attrs cimport * from .attrs cimport *
from . import util
cpdef enum: cpdef enum:
@ -108,55 +106,15 @@ cdef inline void _fill_from_token(atom_t* context, const TokenC* t) nogil:
cdef class Tagger: cdef class Tagger:
""" """Annotate part-of-speech tags on Doc objects."""
Annotate part-of-speech tags on Doc objects.
"""
@classmethod
def load(cls, path, vocab, require=False):
"""
Load the statistical model from the supplied path.
Arguments:
path (Path):
The path to load from.
vocab (Vocab):
The vocabulary. Must be shared by the documents to be processed.
require (bool):
Whether to raise an error if the files are not found.
Returns (Tagger):
The newly created object.
"""
# TODO: Change this to expect config.json when we don't have to
# support old data.
path = util.ensure_path(path)
if (path / 'templates.json').exists():
with (path / 'templates.json').open('r', encoding='utf8') as file_:
templates = ujson.load(file_)
elif require:
raise IOError(
"Required file %s/templates.json not found when loading Tagger" % str(path))
else:
templates = cls.feature_templates
self = cls(vocab, model=None, feature_templates=templates)
if (path / 'model').exists():
self.model.load(str(path / 'model'))
elif require:
raise IOError(
"Required file %s/model not found when loading Tagger" % str(path))
return self
def __init__(self, Vocab vocab, TaggerModel model=None, **cfg): def __init__(self, Vocab vocab, TaggerModel model=None, **cfg):
""" """Create a Tagger.
Create a Tagger.
Arguments: vocab (Vocab): The vocabulary object. Must be shared with documents to
vocab (Vocab): be processed.
The vocabulary object. Must be shared with documents to be processed. model (thinc.linear.AveragedPerceptron): The statistical model.
model (thinc.linear.AveragedPerceptron): RETURNS (Tagger): The newly constructed object.
The statistical model.
Returns (Tagger):
The newly constructed object.
""" """
if model is None: if model is None:
model = TaggerModel(cfg.get('features', self.feature_templates), model = TaggerModel(cfg.get('features', self.feature_templates),
@ -186,13 +144,9 @@ cdef class Tagger:
tokens._py_tokens = [None] * tokens.length tokens._py_tokens = [None] * tokens.length
def __call__(self, Doc tokens): def __call__(self, Doc tokens):
""" """Apply the tagger, setting the POS tags onto the Doc object.
Apply the tagger, setting the POS tags onto the Doc object.
Arguments: doc (Doc): The tokens to be tagged.
doc (Doc): The tokens to be tagged.
Returns:
None
""" """
if tokens.length == 0: if tokens.length == 0:
return 0 return 0
@ -215,34 +169,25 @@ cdef class Tagger:
tokens._py_tokens = [None] * tokens.length tokens._py_tokens = [None] * tokens.length
def pipe(self, stream, batch_size=1000, n_threads=2): def pipe(self, stream, batch_size=1000, n_threads=2):
""" """Tag a stream of documents.
Tag a stream of documents.
Arguments: Arguments:
stream: The sequence of documents to tag. stream: The sequence of documents to tag.
batch_size (int): batch_size (int): The number of documents to accumulate into a working set.
The number of documents to accumulate into a working set. n_threads (int): The number of threads with which to work on the buffer
n_threads (int): in parallel, if the Matcher implementation supports multi-threading.
The number of threads with which to work on the buffer in parallel, YIELDS (Doc): Documents, in order.
if the Matcher implementation supports multi-threading.
Yields:
Doc Documents, in order.
""" """
for doc in stream: for doc in stream:
self(doc) self(doc)
yield doc yield doc
def update(self, Doc tokens, GoldParse gold, itn=0): def update(self, Doc tokens, GoldParse gold, itn=0):
""" """Update the statistical model, with tags supplied for the given document.
Update the statistical model, with tags supplied for the given document.
Arguments: doc (Doc): The document to update on.
doc (Doc): gold (GoldParse): Manager for the gold-standard tags.
The document to update on. RETURNS (int): Number of tags predicted correctly.
gold (GoldParse):
Manager for the gold-standard tags.
Returns (int):
Number of tags correct.
""" """
gold_tag_strs = gold.tags gold_tag_strs = gold.tags
assert len(tokens) == len(gold_tag_strs) assert len(tokens) == len(gold_tag_strs)