3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
15 import urllib.error
as urlliberror
16 import urllib.request
as urllib
17 HTTPError = urlliberror.HTTPError
18 URLError = urlliberror.URLError
20 import urllib2
as urllib
21 HTTPError = urllib.HTTPError
22 URLError = urllib.URLError
24 DOWNLOAD_BASE_URL =
"https://s3.amazonaws.com/caffe2/models/" 29 def signalHandler(signal, frame):
30 print(
"Killing download...")
34 signal.signal(signal.SIGINT, signalHandler)
37 def deleteDirectory(top_dir):
38 for root, dirs, files
in os.walk(top_dir, topdown=
False):
40 os.remove(os.path.join(root, name))
42 os.rmdir(os.path.join(root, name))
46 def progressBar(percentage):
47 full = int(DOWNLOAD_COLUMNS * percentage / 100)
48 bar = full *
"#" + (DOWNLOAD_COLUMNS - full) *
" " 49 sys.stdout.write(
u"\u001b[1000D[" + bar +
"] " + str(percentage) +
"%")
53 def downloadFromURLToFile(url, filename):
55 print(
"Downloading from {url}".format(url=url))
56 response = urllib.urlopen(url)
57 size = int(response.info().getheader(
'Content-Length').strip())
59 chunk = min(size, 8192)
60 print(
"Writing to {filename}".format(filename=filename))
62 with open(filename,
"wb")
as local_file:
64 data_chunk = response.read(chunk)
67 local_file.write(data_chunk)
68 downloaded_size += chunk
69 progressBar(int(100 * downloaded_size / size))
71 except HTTPError
as e:
72 raise Exception(
"Could not download model. [HTTP Error] {code}: {reason}." 73 .format(code=e.code, reason=e.reason))
75 raise Exception(
"Could not download model. [URL Error] {reason}." 76 .format(reason=e.reason))
77 except Exception
as e:
81 def getURLFromName(name, filename):
82 return "{base_url}{name}/{filename}".format(base_url=DOWNLOAD_BASE_URL,
83 name=name, filename=filename)
86 def downloadModel(model, args):
88 model_folder =
'{folder}'.format(folder=model)
89 dir_path = os.path.dirname(os.path.realpath(__file__))
91 model_folder =
'{dir_path}/{folder}'.format(dir_path=dir_path,
95 if os.path.exists(model_folder)
and not os.path.isdir(model_folder):
97 raise Exception(
"Cannot create folder for storing the model,\ 98 there exists a file of the same name.")
100 print(
"Overwriting existing file! ({filename})" 101 .format(filename=model_folder))
102 os.remove(model_folder)
103 if os.path.isdir(model_folder):
106 query =
"Model already exists, continue? [y/N] " 108 response = raw_input(query)
110 response = input(query)
111 if response.upper() ==
'N' or not response:
112 print(
"Cancelling download...")
114 print(
"Overwriting existing folder! ({filename})".format(filename=model_folder))
115 deleteDirectory(model_folder)
118 os.makedirs(model_folder)
119 for f
in [
'predict_net.pb',
'init_net.pb']:
121 downloadFromURLToFile(getURLFromName(model, f),
122 '{folder}/{f}'.format(folder=model_folder,
124 except Exception
as e:
125 print(
"Abort: {reason}".format(reason=str(e)))
126 print(
"Cleaning up...")
127 deleteDirectory(model_folder)
131 os.symlink(
"{folder}/__sym_init__.py".format(folder=dir_path),
132 "{folder}/__init__.py".format(folder=model_folder))
135 def validModelName(name):
136 invalid_names = [
'__init__']
137 if name
in invalid_names:
139 if not re.match(
"^[a-zA-Z_]+$", name):
144 if __name__ ==
"__main__":
145 parser = argparse.ArgumentParser(
146 description=
'Download or install pretrained models.')
147 parser.add_argument(
'model', nargs=
'+',
148 help=
'Model to download/install.')
149 parser.add_argument(
'-i',
'--install', action=
'store_true',
150 help=
'Install the model.')
151 parser.add_argument(
'-f',
'--force', action=
'store_true',
152 help=
'Force a download/installation.')
153 args = parser.parse_args()
154 for model
in args.model:
155 if validModelName(model):
156 downloadModel(model, args)
158 print(
"'{model}' is not a valid model name.".format(model))