Use DependencyParser and EntityRecognizer in the Language class.

This commit is contained in:
Matthew Honnibal 2016-10-16 17:58:12 +02:00
parent 4574fe87c6
commit ca51f3b77e

View File

@ -24,8 +24,6 @@ from .tagger import Tagger
from .matcher import Matcher from .matcher import Matcher
from . import attrs from . import attrs
from . import orth from . import orth
from .syntax.ner import BiluoPushDown
from .syntax.arc_eager import ArcEager
from . import util from . import util
from .lemmatizer import Lemmatizer from .lemmatizer import Lemmatizer
from .train import Trainer from .train import Trainer
@ -33,6 +31,7 @@ from .train import Trainer
from .attrs import TAG, DEP, ENT_IOB, ENT_TYPE, HEAD, PROB, LANG, IS_STOP from .attrs import TAG, DEP, ENT_IOB, ENT_TYPE, HEAD, PROB, LANG, IS_STOP
from .syntax.parser import get_templates from .syntax.parser import get_templates
from .syntax.nonproj import PseudoProjectivity from .syntax.nonproj import PseudoProjectivity
from .pipeline import DependencyParser, EntityRecognizer
class BaseDefaults(object): class BaseDefaults(object):
@ -100,23 +99,19 @@ class BaseDefaults(object):
def Parser(self, vocab, **cfg): def Parser(self, vocab, **cfg):
if self.path and (self.path / 'deps').exists(): if self.path and (self.path / 'deps').exists():
return Parser.load(self.path / 'deps', vocab, ArcEager) return DependencyParser.load(self.path / 'deps', vocab)
else: else:
if 'features' not in cfg: if 'features' not in cfg:
cfg['features'] = self.parser_features cfg['features'] = self.parser_features
if 'actions' not in cfg: return DependencyParser.blank(vocab, **cfg)
cfg['actions'] = self.parser_labels
return Parser.blank(vocab, ArcEager, **cfg)
def Entity(self, vocab, **cfg): def Entity(self, vocab, **cfg):
if self.path and (self.path / 'ner').exists(): if self.path and (self.path / 'ner').exists():
return Parser.load(self.path / 'ner', vocab, BiluoPushDown) return EntityRecognizer.load(self.path / 'ner', vocab)
else: else:
if 'features' not in cfg: if 'features' not in cfg:
cfg['features'] = self.entity_features cfg['features'] = self.entity_features
if 'actions' not in cfg: return EntityRecognizer.blank(vocab, **cfg)
cfg['actions'] = self.entity_labels
return Parser.blank(vocab, BiluoPushDown, **cfg)
def Matcher(self, vocab, **cfg): def Matcher(self, vocab, **cfg):
if self.path: if self.path:
@ -147,19 +142,6 @@ class BaseDefaults(object):
tokenizer_exceptions = {} tokenizer_exceptions = {}
parser_labels = {0: {'': True}, 1: {'': True}, 2: {'ROOT': True, 'nmod': True},
3: {'ROOT': True, 'nmod': True}, 4: {'ROOT': True}}
entity_labels = {
0: {'': True},
1: {'PER': True, 'LOC': True, 'ORG': True, 'MISC': True},
2: {'PER': True, 'LOC': True, 'ORG': True, 'MISC': True},
3: {'PER': True, 'LOC': True, 'ORG': True, 'MISC': True},
4: {'PER': True, 'LOC': True, 'ORG': True, 'MISC': True},
5: {'': True}
}
parser_features = get_templates('parser') parser_features = get_templates('parser')
entity_features = get_templates('ner') entity_features = get_templates('ner')