mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-24 23:13:04 +03:00
Add flag to toggle GPU to DyNet code
This commit is contained in:
parent
3a31c3a961
commit
8f053fd943
|
@ -7,12 +7,15 @@ import plac
|
||||||
import random
|
import random
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from itertools import count
|
from itertools import count
|
||||||
|
|
||||||
#import _gdynet as dynet
|
if os.environ.get('DYNET_GPU') == '1':
|
||||||
#from _gdynet import cg
|
import _gdynet as dynet
|
||||||
|
from _gdynet import cg
|
||||||
|
else:
|
||||||
import dynet
|
import dynet
|
||||||
from dynet import cg
|
from dynet import cg
|
||||||
|
|
||||||
|
@ -180,7 +183,7 @@ def main(train_loc, dev_loc, model_dir):
|
||||||
|
|
||||||
tagged = loss = 0
|
tagged = loss = 0
|
||||||
|
|
||||||
for ITER in xrange(50):
|
for ITER in xrange(1):
|
||||||
random.shuffle(train)
|
random.shuffle(train)
|
||||||
for i, s in enumerate(train,1):
|
for i, s in enumerate(train,1):
|
||||||
if i % 5000 == 0:
|
if i % 5000 == 0:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user