mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-15 10:42:34 +03:00
formatting
This commit is contained in:
parent
d7406fffb0
commit
e48a662e46
|
@ -526,8 +526,8 @@ def debug_data(
|
||||||
msg.info(f"{len(label_list)} label(s) in train data")
|
msg.info(f"{len(label_list)} label(s) in train data")
|
||||||
p = np.array(counts)
|
p = np.array(counts)
|
||||||
p = p / p.sum()
|
p = p / p.sum()
|
||||||
entropy = np.round((-p*np.log2(p)).sum(), 2)
|
norm_entropy = (-p * np.log2(p)).sum() / np.log2(len(label_list))
|
||||||
msg.info(f"{entropy} is the train data label entropy")
|
msg.info(f"{norm_entropy} is the normalised label entropy")
|
||||||
model_labels = _get_labels_from_model(nlp, "tagger")
|
model_labels = _get_labels_from_model(nlp, "tagger")
|
||||||
labels = set(label_list)
|
labels = set(label_list)
|
||||||
missing_labels = model_labels - labels
|
missing_labels = model_labels - labels
|
||||||
|
|
|
@ -70,8 +70,12 @@ PARTIAL_DATA = [
|
||||||
def test_label_smoothing():
|
def test_label_smoothing():
|
||||||
util.fix_random_seed()
|
util.fix_random_seed()
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
tagger_no_ls = nlp.add_pipe("tagger", "no_label_smoothing", config=dict(label_smoothing=False))
|
tagger_no_ls = nlp.add_pipe(
|
||||||
tagger_ls = nlp.add_pipe("tagger", "label_smoothing", config=dict(label_smoothing=True))
|
"tagger", "no_label_smoothing", config=dict(label_smoothing=False)
|
||||||
|
)
|
||||||
|
tagger_ls = nlp.add_pipe(
|
||||||
|
"tagger", "label_smoothing", config=dict(label_smoothing=True)
|
||||||
|
)
|
||||||
train_examples = []
|
train_examples = []
|
||||||
losses = {}
|
losses = {}
|
||||||
for tag in TAGS:
|
for tag in TAGS:
|
||||||
|
@ -83,7 +87,10 @@ def test_label_smoothing():
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
losses = {}
|
losses = {}
|
||||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
assert losses == {'no_label_smoothing': 1.4892945885658264, 'label_smoothing': 1.1432453989982605}
|
assert losses == {
|
||||||
|
"no_label_smoothing": 1.4892945885658264,
|
||||||
|
"label_smoothing": 1.1432453989982605,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_no_label():
|
def test_no_label():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user