Fix parser resizing when there is no upper layer (#6460)

* allow resizing of the parser model even when upper=False

* update from spacy.TransitionBasedParser.v1 to v2

* bugfix
This commit is contained in:
Sofie Van Landeghem 2020-12-18 11:56:57 +01:00 committed by GitHub
parent 0a923a7915
commit 282a3b49ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 195 additions and 64 deletions

View File

@ -75,7 +75,7 @@ grad_factor = 1.0
factory = "parser" factory = "parser"
[components.parser.model] [components.parser.model]
@architectures = "spacy.TransitionBasedParser.v1" @architectures = "spacy.TransitionBasedParser.v2"
state_type = "parser" state_type = "parser"
extra_state_tokens = false extra_state_tokens = false
hidden_width = 128 hidden_width = 128
@ -96,7 +96,7 @@ grad_factor = 1.0
factory = "ner" factory = "ner"
[components.ner.model] [components.ner.model]
@architectures = "spacy.TransitionBasedParser.v1" @architectures = "spacy.TransitionBasedParser.v2"
state_type = "ner" state_type = "ner"
extra_state_tokens = false extra_state_tokens = false
hidden_width = 64 hidden_width = 64
@ -226,7 +226,7 @@ width = ${components.tok2vec.model.encode.width}
factory = "parser" factory = "parser"
[components.parser.model] [components.parser.model]
@architectures = "spacy.TransitionBasedParser.v1" @architectures = "spacy.TransitionBasedParser.v2"
state_type = "parser" state_type = "parser"
extra_state_tokens = false extra_state_tokens = false
hidden_width = 128 hidden_width = 128
@ -244,7 +244,7 @@ width = ${components.tok2vec.model.encode.width}
factory = "ner" factory = "ner"
[components.ner.model] [components.ner.model]
@architectures = "spacy.TransitionBasedParser.v1" @architectures = "spacy.TransitionBasedParser.v2"
state_type = "ner" state_type = "ner"
extra_state_tokens = false extra_state_tokens = false
hidden_width = 64 hidden_width = 64

View File

@ -11,7 +11,7 @@ from ...tokens import Doc
@registry.architectures.register("spacy.TransitionBasedParser.v1") @registry.architectures.register("spacy.TransitionBasedParser.v1")
def build_tb_parser_model( def transition_parser_v1(
tok2vec: Model[List[Doc], List[Floats2d]], tok2vec: Model[List[Doc], List[Floats2d]],
state_type: Literal["parser", "ner"], state_type: Literal["parser", "ner"],
extra_state_tokens: bool, extra_state_tokens: bool,
@ -19,6 +19,46 @@ def build_tb_parser_model(
maxout_pieces: int, maxout_pieces: int,
use_upper: bool = True, use_upper: bool = True,
nO: Optional[int] = None, nO: Optional[int] = None,
) -> Model:
return build_tb_parser_model(
tok2vec,
state_type,
extra_state_tokens,
hidden_width,
maxout_pieces,
use_upper,
nO,
)
@registry.architectures.register("spacy.TransitionBasedParser.v2")
def transition_parser_v2(
tok2vec: Model[List[Doc], List[Floats2d]],
state_type: Literal["parser", "ner"],
extra_state_tokens: bool,
hidden_width: int,
maxout_pieces: int,
use_upper: bool,
nO: Optional[int] = None,
) -> Model:
return build_tb_parser_model(
tok2vec,
state_type,
extra_state_tokens,
hidden_width,
maxout_pieces,
use_upper,
nO,
)
def build_tb_parser_model(
tok2vec: Model[List[Doc], List[Floats2d]],
state_type: Literal["parser", "ner"],
extra_state_tokens: bool,
hidden_width: int,
maxout_pieces: int,
use_upper: bool,
nO: Optional[int] = None,
) -> Model: ) -> Model:
""" """
Build a transition-based parser model. Can apply to NER or dependency-parsing. Build a transition-based parser model. Can apply to NER or dependency-parsing.
@ -72,16 +112,100 @@ def build_tb_parser_model(
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
tok2vec = chain(tok2vec, list2array(), Linear(hidden_width, t2v_width)) tok2vec = chain(tok2vec, list2array(), Linear(hidden_width, t2v_width))
tok2vec.set_dim("nO", hidden_width) tok2vec.set_dim("nO", hidden_width)
lower = PrecomputableAffine( lower = _define_lower(
nO=hidden_width if use_upper else nO, nO=hidden_width if use_upper else nO,
nF=nr_feature_tokens, nF=nr_feature_tokens,
nI=tok2vec.get_dim("nO"), nI=tok2vec.get_dim("nO"),
nP=maxout_pieces, nP=maxout_pieces,
) )
upper = None
if use_upper: if use_upper:
with use_ops("numpy"): with use_ops("numpy"):
# Initialize weights at zero, as it's a classification layer. # Initialize weights at zero, as it's a classification layer.
upper = Linear(nO=nO, init_W=zero_init) upper = _define_upper(nO=nO, nI=None)
else: return TransitionModel(tok2vec, lower, upper, resize_output)
upper = None
return TransitionModel(tok2vec, lower, upper)
def _define_upper(nO, nI):
return Linear(nO=nO, nI=nI, init_W=zero_init)
def _define_lower(nO, nF, nI, nP):
return PrecomputableAffine(nO=nO, nF=nF, nI=nI, nP=nP)
def resize_output(model, new_nO):
if model.attrs["has_upper"]:
return _resize_upper(model, new_nO)
return _resize_lower(model, new_nO)
def _resize_upper(model, new_nO):
upper = model.get_ref("upper")
if upper.has_dim("nO") is None:
upper.set_dim("nO", new_nO)
return model
elif new_nO == upper.get_dim("nO"):
return model
smaller = upper
nI = smaller.maybe_get_dim("nI")
with use_ops("numpy"):
larger = _define_upper(nO=new_nO, nI=nI)
# it could be that the model is not initialized yet, then skip this bit
if smaller.has_param("W"):
larger_W = larger.ops.alloc2f(new_nO, nI)
larger_b = larger.ops.alloc1f(new_nO)
smaller_W = smaller.get_param("W")
smaller_b = smaller.get_param("b")
# Weights are stored in (nr_out, nr_in) format, so we're basically
# just adding rows here.
if smaller.has_dim("nO"):
old_nO = smaller.get_dim("nO")
larger_W[: old_nO] = smaller_W
larger_b[: old_nO] = smaller_b
for i in range(old_nO, new_nO):
model.attrs["unseen_classes"].add(i)
larger.set_param("W", larger_W)
larger.set_param("b", larger_b)
model._layers[-1] = larger
model.set_ref("upper", larger)
return model
def _resize_lower(model, new_nO):
lower = model.get_ref("lower")
if lower.has_dim("nO") is None:
lower.set_dim("nO", new_nO)
return model
smaller = lower
nI = smaller.maybe_get_dim("nI")
nF = smaller.maybe_get_dim("nF")
nP = smaller.maybe_get_dim("nP")
with use_ops("numpy"):
larger = _define_lower(nO=new_nO, nI=nI, nF=nF, nP=nP)
# it could be that the model is not initialized yet, then skip this bit
if smaller.has_param("W"):
larger_W = larger.ops.alloc4f(nF, new_nO, nP, nI)
larger_b = larger.ops.alloc2f(new_nO, nP)
larger_pad = larger.ops.alloc4f(1, nF, new_nO, nP)
smaller_W = smaller.get_param("W")
smaller_b = smaller.get_param("b")
smaller_pad = smaller.get_param("pad")
# Copy the old weights and padding into the new layer
if smaller.has_dim("nO"):
old_nO = smaller.get_dim("nO")
larger_W[:, 0:old_nO, :, :] = smaller_W
larger_pad[:, :, 0:old_nO, :] = smaller_pad
larger_b[0:old_nO, :] = smaller_b
for i in range(old_nO, new_nO):
model.attrs["unseen_classes"].add(i)
larger.set_param("W", larger_W)
larger.set_param("b", larger_b)
larger.set_param("pad", larger_pad)
model._layers[1] = larger
model.set_ref("lower", larger)
return model

View File

@ -2,7 +2,7 @@ from thinc.api import Model, noop, use_ops, Linear
from .parser_model import ParserStepModel from .parser_model import ParserStepModel
def TransitionModel(tok2vec, lower, upper, dropout=0.2, unseen_classes=set()): def TransitionModel(tok2vec, lower, upper, resize_output, dropout=0.2, unseen_classes=set()):
"""Set up a stepwise transition-based model""" """Set up a stepwise transition-based model"""
if upper is None: if upper is None:
has_upper = False has_upper = False
@ -45,42 +45,3 @@ def init(model, X=None, Y=None):
statevecs = model.ops.alloc2f(2, lower.get_dim("nO")) statevecs = model.ops.alloc2f(2, lower.get_dim("nO"))
model.get_ref("upper").initialize(X=statevecs) model.get_ref("upper").initialize(X=statevecs)
def resize_output(model, new_nO):
lower = model.get_ref("lower")
upper = model.get_ref("upper")
if not model.attrs["has_upper"]:
if lower.has_dim("nO") is None:
lower.set_dim("nO", new_nO)
return
elif upper.has_dim("nO") is None:
upper.set_dim("nO", new_nO)
return
elif new_nO == upper.get_dim("nO"):
return
smaller = upper
nI = None
if smaller.has_dim("nI"):
nI = smaller.get_dim("nI")
with use_ops("numpy"):
larger = Linear(nO=new_nO, nI=nI)
larger.init = smaller.init
# it could be that the model is not initialized yet, then skip this bit
if nI:
larger_W = larger.ops.alloc2f(new_nO, nI)
larger_b = larger.ops.alloc1f(new_nO)
smaller_W = smaller.get_param("W")
smaller_b = smaller.get_param("b")
# Weights are stored in (nr_out, nr_in) format, so we're basically
# just adding rows here.
if smaller.has_dim("nO"):
larger_W[: smaller.get_dim("nO")] = smaller_W
larger_b[: smaller.get_dim("nO")] = smaller_b
for i in range(smaller.get_dim("nO"), new_nO):
model.attrs["unseen_classes"].add(i)
larger.set_param("W", larger_W)
larger.set_param("b", larger_b)
model._layers[-1] = larger
model.set_ref("upper", larger)
return model

View File

@ -14,11 +14,12 @@ from ..training import validate_examples
default_model_config = """ default_model_config = """
[model] [model]
@architectures = "spacy.TransitionBasedParser.v1" @architectures = "spacy.TransitionBasedParser.v2"
state_type = "parser" state_type = "parser"
extra_state_tokens = false extra_state_tokens = false
hidden_width = 64 hidden_width = 64
maxout_pieces = 2 maxout_pieces = 2
use_upper = true
[model.tok2vec] [model.tok2vec]
@architectures = "spacy.HashEmbedCNN.v1" @architectures = "spacy.HashEmbedCNN.v1"

View File

@ -12,11 +12,12 @@ from ..training import validate_examples
default_model_config = """ default_model_config = """
[model] [model]
@architectures = "spacy.TransitionBasedParser.v1" @architectures = "spacy.TransitionBasedParser.v2"
state_type = "ner" state_type = "ner"
extra_state_tokens = false extra_state_tokens = false
hidden_width = 64 hidden_width = 64
maxout_pieces = 2 maxout_pieces = 2
use_upper = true
[model.tok2vec] [model.tok2vec]
@architectures = "spacy.HashEmbedCNN.v1" @architectures = "spacy.HashEmbedCNN.v1"

View File

@ -301,10 +301,13 @@ def test_block_ner():
assert [token.ent_type_ for token in doc] == expected_types assert [token.ent_type_ for token in doc] == expected_types
def test_overfitting_IO(): @pytest.mark.parametrize(
"use_upper", [True, False]
)
def test_overfitting_IO(use_upper):
# Simple test to try and quickly overfit the NER component - ensuring the ML models work correctly # Simple test to try and quickly overfit the NER component - ensuring the ML models work correctly
nlp = English() nlp = English()
ner = nlp.add_pipe("ner") ner = nlp.add_pipe("ner", config={"model": {"use_upper": use_upper}})
train_examples = [] train_examples = []
for text, annotations in TRAIN_DATA: for text, annotations in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations)) train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
@ -334,6 +337,15 @@ def test_overfitting_IO():
assert len(ents2) == 1 assert len(ents2) == 1
assert ents2[0].text == "London" assert ents2[0].text == "London"
assert ents2[0].label_ == "LOC" assert ents2[0].label_ == "LOC"
# Ensure that the predictions are still the same, even after adding a new label
ner2 = nlp2.get_pipe("ner")
assert ner2.model.attrs["has_upper"] == use_upper
ner2.add_label("RANDOM_NEW_LABEL")
doc3 = nlp2(test_text)
ents3 = doc3.ents
assert len(ents3) == 1
assert ents3[0].text == "London"
assert ents3[0].label_ == "LOC"
# Make sure that running pipe twice, or comparing to call, always amounts to the same predictions # Make sure that running pipe twice, or comparing to call, always amounts to the same predictions
texts = [ texts = [

View File

@ -117,13 +117,35 @@ width = ${components.tok2vec.model.width}
""" """
parser_config_string = """ parser_config_string_upper = """
[model] [model]
@architectures = "spacy.TransitionBasedParser.v1" @architectures = "spacy.TransitionBasedParser.v2"
state_type = "parser" state_type = "parser"
extra_state_tokens = false extra_state_tokens = false
hidden_width = 66 hidden_width = 66
maxout_pieces = 2 maxout_pieces = 2
use_upper = true
[model.tok2vec]
@architectures = "spacy.HashEmbedCNN.v1"
pretrained_vectors = null
width = 333
depth = 4
embed_size = 5555
window_size = 1
maxout_pieces = 7
subword_features = false
"""
parser_config_string_no_upper = """
[model]
@architectures = "spacy.TransitionBasedParser.v2"
state_type = "parser"
extra_state_tokens = false
hidden_width = 66
maxout_pieces = 2
use_upper = false
[model.tok2vec] [model.tok2vec]
@architectures = "spacy.HashEmbedCNN.v1" @architectures = "spacy.HashEmbedCNN.v1"
@ -154,6 +176,7 @@ def my_parser():
extra_state_tokens=True, extra_state_tokens=True,
hidden_width=65, hidden_width=65,
maxout_pieces=5, maxout_pieces=5,
use_upper=True,
) )
return parser return parser
@ -241,12 +264,15 @@ def test_serialize_custom_nlp():
nlp2 = spacy.load(d) nlp2 = spacy.load(d)
model = nlp2.get_pipe("parser").model model = nlp2.get_pipe("parser").model
model.get_ref("tok2vec") model.get_ref("tok2vec")
upper = model.get_ref("upper")
# check that we have the correct settings, not the default ones # check that we have the correct settings, not the default ones
assert upper.get_dim("nI") == 65 assert model.get_ref("upper").get_dim("nI") == 65
assert model.get_ref("lower").get_dim("nI") == 65
def test_serialize_parser(): @pytest.mark.parametrize(
"parser_config_string", [parser_config_string_upper, parser_config_string_no_upper]
)
def test_serialize_parser(parser_config_string):
""" Create a non-default parser config to check nlp serializes it correctly """ """ Create a non-default parser config to check nlp serializes it correctly """
nlp = English() nlp = English()
model_config = Config().from_str(parser_config_string) model_config = Config().from_str(parser_config_string)
@ -259,9 +285,11 @@ def test_serialize_parser():
nlp2 = spacy.load(d) nlp2 = spacy.load(d)
model = nlp2.get_pipe("parser").model model = nlp2.get_pipe("parser").model
model.get_ref("tok2vec") model.get_ref("tok2vec")
upper = model.get_ref("upper")
# check that we have the correct settings, not the default ones # check that we have the correct settings, not the default ones
assert upper.get_dim("nI") == 66 if model.attrs["has_upper"]:
assert model.get_ref("upper").get_dim("nI") == 66
assert model.get_ref("lower").get_dim("nI") == 66
def test_config_nlp_roundtrip(): def test_config_nlp_roundtrip():
@ -408,7 +436,10 @@ def test_config_auto_fill_extra_fields():
load_model_from_config(nlp.config) load_model_from_config(nlp.config)
def test_config_validate_literal(): @pytest.mark.parametrize(
"parser_config_string", [parser_config_string_upper, parser_config_string_no_upper]
)
def test_config_validate_literal(parser_config_string):
nlp = English() nlp = English()
config = Config().from_str(parser_config_string) config = Config().from_str(parser_config_string)
config["model"]["state_type"] = "nonsense" config["model"]["state_type"] = "nonsense"

View File

@ -428,17 +428,18 @@ one component.
## Parser & NER architectures {#parser} ## Parser & NER architectures {#parser}
### spacy.TransitionBasedParser.v1 {#TransitionBasedParser source="spacy/ml/models/parser.py"} ### spacy.TransitionBasedParser.v2 {#TransitionBasedParser source="spacy/ml/models/parser.py"}
> #### Example Config > #### Example Config
> >
> ```ini > ```ini
> [model] > [model]
> @architectures = "spacy.TransitionBasedParser.v1" > @architectures = "spacy.TransitionBasedParser.v2"
> state_type = "ner" > state_type = "ner"
> extra_state_tokens = false > extra_state_tokens = false
> hidden_width = 64 > hidden_width = 64
> maxout_pieces = 2 > maxout_pieces = 2
> use_upper = true
> >
> [model.tok2vec] > [model.tok2vec]
> @architectures = "spacy.HashEmbedCNN.v1" > @architectures = "spacy.HashEmbedCNN.v1"