mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-22 14:03:05 +03:00
Add cfg attr to pipeline components
This commit is contained in:
parent
d8aa721664
commit
4fe77bced2
|
@ -160,6 +160,7 @@ class TokenVectorEncoder(BaseThincComponent):
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.doc2feats = doc2feats()
|
self.doc2feats = doc2feats()
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.cfg = dict(cfg)
|
||||||
|
|
||||||
def __call__(self, doc):
|
def __call__(self, doc):
|
||||||
"""Add context-sensitive vectors to a `Doc`, e.g. from a CNN or LSTM
|
"""Add context-sensitive vectors to a `Doc`, e.g. from a CNN or LSTM
|
||||||
|
@ -239,9 +240,10 @@ class TokenVectorEncoder(BaseThincComponent):
|
||||||
|
|
||||||
class NeuralTagger(BaseThincComponent):
|
class NeuralTagger(BaseThincComponent):
|
||||||
name = 'tagger'
|
name = 'tagger'
|
||||||
def __init__(self, vocab, model=True):
|
def __init__(self, vocab, model=True, **cfg):
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.cfg = dict(cfg)
|
||||||
|
|
||||||
def __call__(self, doc):
|
def __call__(self, doc):
|
||||||
tags = self.predict([doc.tensor])
|
tags = self.predict([doc.tensor])
|
||||||
|
@ -416,10 +418,18 @@ class NeuralTagger(BaseThincComponent):
|
||||||
|
|
||||||
class NeuralLabeller(NeuralTagger):
|
class NeuralLabeller(NeuralTagger):
|
||||||
name = 'nn_labeller'
|
name = 'nn_labeller'
|
||||||
def __init__(self, vocab, model=True):
|
def __init__(self, vocab, model=True, **cfg):
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.model = model
|
self.model = model
|
||||||
self.labels = {}
|
self.cfg = dict(cfg)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def labels(self):
|
||||||
|
return self.cfg.get('labels', {})
|
||||||
|
|
||||||
|
@labels.setter
|
||||||
|
def labels(self, value):
|
||||||
|
self.cfg['labels'] = value
|
||||||
|
|
||||||
def set_annotations(self, docs, dep_ids):
|
def set_annotations(self, docs, dep_ids):
|
||||||
pass
|
pass
|
||||||
|
@ -478,9 +488,10 @@ class SimilarityHook(BaseThincComponent):
|
||||||
Where W is a vector of dimension weights, initialized to 1.
|
Where W is a vector of dimension weights, initialized to 1.
|
||||||
"""
|
"""
|
||||||
name = 'similarity'
|
name = 'similarity'
|
||||||
def __init__(self, vocab, model=True):
|
def __init__(self, vocab, model=True, **cfg):
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.cfg = dict(cfg)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def Model(cls, length):
|
def Model(cls, length):
|
||||||
|
@ -527,7 +538,7 @@ class TextCategorizer(BaseThincComponent):
|
||||||
def __init__(self, vocab, model=True, **cfg):
|
def __init__(self, vocab, model=True, **cfg):
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.model = model
|
self.model = model
|
||||||
self.cfg = cfg
|
self.cfg = dict(cfg)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def labels(self):
|
def labels(self):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user