mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 01:34:30 +03:00
Fix tok2vec arch after refactor
This commit is contained in:
parent
f8d740bfb1
commit
165e378082
|
@ -17,6 +17,6 @@ def FeedForward(config):
|
||||||
def LayerNormalizedMaxout(config):
|
def LayerNormalizedMaxout(config):
|
||||||
width = config["width"]
|
width = config["width"]
|
||||||
pieces = config["pieces"]
|
pieces = config["pieces"]
|
||||||
layer = chain(Maxout(width, pieces=pieces), LayerNorm(nO=width))
|
layer = LayerNorm(Maxout(width, pieces=pieces))
|
||||||
layer.nO = width
|
layer.nO = width
|
||||||
return layer
|
return layer
|
||||||
|
|
|
@ -14,10 +14,12 @@ from .common import *
|
||||||
|
|
||||||
@register_architecture("spacy.Tok2Vec.v1")
|
@register_architecture("spacy.Tok2Vec.v1")
|
||||||
def Tok2Vec(config):
|
def Tok2Vec(config):
|
||||||
|
print(config)
|
||||||
doc2feats = make_layer(config["@doc2feats"])
|
doc2feats = make_layer(config["@doc2feats"])
|
||||||
embed = make_layer(config["@embed"])
|
embed = make_layer(config["@embed"])
|
||||||
encode = make_layer(config["@encode"])
|
encode = make_layer(config["@encode"])
|
||||||
tok2vec = chain(doc2feats, with_flatten(chain(embed, encode)))
|
depth = config["@encode"]["config"]["depth"]
|
||||||
|
tok2vec = chain(doc2feats, with_flatten(chain(embed, encode), pad=depth))
|
||||||
tok2vec.cfg = config
|
tok2vec.cfg = config
|
||||||
tok2vec.nO = encode.nO
|
tok2vec.nO = encode.nO
|
||||||
tok2vec.embed = embed
|
tok2vec.embed = embed
|
||||||
|
@ -81,8 +83,7 @@ def MaxoutWindowEncoder(config):
|
||||||
|
|
||||||
cnn = chain(
|
cnn = chain(
|
||||||
ExtractWindow(nW=nW),
|
ExtractWindow(nW=nW),
|
||||||
Maxout(nO, nO * ((nW * 2) + 1), pieces=nP),
|
LayerNorm(Maxout(nO, nO * ((nW * 2) + 1), pieces=nP)),
|
||||||
LayerNorm(nO=nO),
|
|
||||||
)
|
)
|
||||||
model = clone(Residual(cnn), depth)
|
model = clone(Residual(cnn), depth)
|
||||||
model.nO = nO
|
model.nO = nO
|
||||||
|
|
Loading…
Reference in New Issue
Block a user