Add cfg attr to pipeline components

This commit is contained in:
Matthew Honnibal 2017-07-23 00:52:47 +02:00
parent d8aa721664
commit 4fe77bced2

View File

@ -160,6 +160,7 @@ class TokenVectorEncoder(BaseThincComponent):
self.vocab = vocab self.vocab = vocab
self.doc2feats = doc2feats() self.doc2feats = doc2feats()
self.model = model self.model = model
self.cfg = dict(cfg)
def __call__(self, doc): def __call__(self, doc):
"""Add context-sensitive vectors to a `Doc`, e.g. from a CNN or LSTM """Add context-sensitive vectors to a `Doc`, e.g. from a CNN or LSTM
@ -239,9 +240,10 @@ class TokenVectorEncoder(BaseThincComponent):
class NeuralTagger(BaseThincComponent): class NeuralTagger(BaseThincComponent):
name = 'tagger' name = 'tagger'
def __init__(self, vocab, model=True): def __init__(self, vocab, model=True, **cfg):
self.vocab = vocab self.vocab = vocab
self.model = model self.model = model
self.cfg = dict(cfg)
def __call__(self, doc): def __call__(self, doc):
tags = self.predict([doc.tensor]) tags = self.predict([doc.tensor])
@ -416,10 +418,18 @@ class NeuralTagger(BaseThincComponent):
class NeuralLabeller(NeuralTagger): class NeuralLabeller(NeuralTagger):
name = 'nn_labeller' name = 'nn_labeller'
def __init__(self, vocab, model=True): def __init__(self, vocab, model=True, **cfg):
self.vocab = vocab self.vocab = vocab
self.model = model self.model = model
self.labels = {} self.cfg = dict(cfg)
@property
def labels(self):
return self.cfg.get('labels', {})
@labels.setter
def labels(self, value):
self.cfg['labels'] = value
def set_annotations(self, docs, dep_ids): def set_annotations(self, docs, dep_ids):
pass pass
@ -478,9 +488,10 @@ class SimilarityHook(BaseThincComponent):
Where W is a vector of dimension weights, initialized to 1. Where W is a vector of dimension weights, initialized to 1.
""" """
name = 'similarity' name = 'similarity'
def __init__(self, vocab, model=True): def __init__(self, vocab, model=True, **cfg):
self.vocab = vocab self.vocab = vocab
self.model = model self.model = model
self.cfg = dict(cfg)
@classmethod @classmethod
def Model(cls, length): def Model(cls, length):
@ -527,7 +538,7 @@ class TextCategorizer(BaseThincComponent):
def __init__(self, vocab, model=True, **cfg): def __init__(self, vocab, model=True, **cfg):
self.vocab = vocab self.vocab = vocab
self.model = model self.model = model
self.cfg = cfg self.cfg = dict(cfg)
@property @property
def labels(self): def labels(self):