Caffe2 - Python API
A deep learning, cross platform ML framework
download.py
1 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 import argparse
8 import os
9 import sys
10 import signal
11 import re
12 
13 # Import urllib
14 try:
15  import urllib.error as urlliberror
16  import urllib.request as urllib
17  HTTPError = urlliberror.HTTPError
18  URLError = urlliberror.URLError
19 except ImportError:
20  import urllib2 as urllib
21  HTTPError = urllib.HTTPError
22  URLError = urllib.URLError
23 
24 DOWNLOAD_BASE_URL = "https://s3.amazonaws.com/caffe2/models/"
25 DOWNLOAD_COLUMNS = 70
26 
27 
28 # Don't let urllib hang up on big downloads
29 def signalHandler(signal, frame):
30  print("Killing download...")
31  exit(0)
32 
33 
34 signal.signal(signal.SIGINT, signalHandler)
35 
36 
37 def deleteDirectory(top_dir):
38  for root, dirs, files in os.walk(top_dir, topdown=False):
39  for name in files:
40  os.remove(os.path.join(root, name))
41  for name in dirs:
42  os.rmdir(os.path.join(root, name))
43  os.rmdir(top_dir)
44 
45 
46 def progressBar(percentage):
47  full = int(DOWNLOAD_COLUMNS * percentage / 100)
48  bar = full * "#" + (DOWNLOAD_COLUMNS - full) * " "
49  sys.stdout.write(u"\u001b[1000D[" + bar + "] " + str(percentage) + "%")
50  sys.stdout.flush()
51 
52 
53 def downloadFromURLToFile(url, filename):
54  try:
55  print("Downloading from {url}".format(url=url))
56  response = urllib.urlopen(url)
57  size = int(response.info().getheader('Content-Length').strip())
58  downloaded_size = 0
59  chunk = min(size, 8192)
60  print("Writing to {filename}".format(filename=filename))
61  progressBar(0)
62  with open(filename, "wb") as local_file:
63  while True:
64  data_chunk = response.read(chunk)
65  if not data_chunk:
66  break
67  local_file.write(data_chunk)
68  downloaded_size += chunk
69  progressBar(int(100 * downloaded_size / size))
70  print("") # New line to fix for progress bar
71  except HTTPError as e:
72  raise Exception("Could not download model. [HTTP Error] {code}: {reason}."
73  .format(code=e.code, reason=e.reason))
74  except URLError as e:
75  raise Exception("Could not download model. [URL Error] {reason}."
76  .format(reason=e.reason))
77  except Exception as e:
78  raise e
79 
80 
81 def getURLFromName(name, filename):
82  return "{base_url}{name}/{filename}".format(base_url=DOWNLOAD_BASE_URL,
83  name=name, filename=filename)
84 
85 
86 def downloadModel(model, args):
87  # Figure out where to store the model
88  model_folder = '{folder}'.format(folder=model)
89  dir_path = os.path.dirname(os.path.realpath(__file__))
90  if args.install:
91  model_folder = '{dir_path}/{folder}'.format(dir_path=dir_path,
92  folder=model)
93 
94  # Check if that folder is already there
95  if os.path.exists(model_folder) and not os.path.isdir(model_folder):
96  if not args.force:
97  raise Exception("Cannot create folder for storing the model,\
98  there exists a file of the same name.")
99  else:
100  print("Overwriting existing file! ({filename})"
101  .format(filename=model_folder))
102  os.remove(model_folder)
103  if os.path.isdir(model_folder):
104  if not args.force:
105  response = ""
106  query = "Model already exists, continue? [y/N] "
107  try:
108  response = raw_input(query)
109  except NameError:
110  response = input(query)
111  if response.upper() == 'N' or not response:
112  print("Cancelling download...")
113  exit(0)
114  print("Overwriting existing folder! ({filename})".format(filename=model_folder))
115  deleteDirectory(model_folder)
116 
117  # Now we can safely create the folder and download the model
118  os.makedirs(model_folder)
119  for f in ['predict_net.pb', 'init_net.pb']:
120  try:
121  downloadFromURLToFile(getURLFromName(model, f),
122  '{folder}/{f}'.format(folder=model_folder,
123  f=f))
124  except Exception as e:
125  print("Abort: {reason}".format(reason=str(e)))
126  print("Cleaning up...")
127  deleteDirectory(model_folder)
128  exit(0)
129 
130  if args.install:
131  os.symlink("{folder}/__sym_init__.py".format(folder=dir_path),
132  "{folder}/__init__.py".format(folder=model_folder))
133 
134 
135 def validModelName(name):
136  invalid_names = ['__init__']
137  if name in invalid_names:
138  return False
139  if not re.match("^[a-zA-Z_]+$", name):
140  return False
141  return True
142 
143 
144 if __name__ == "__main__":
145  parser = argparse.ArgumentParser(
146  description='Download or install pretrained models.')
147  parser.add_argument('model', nargs='+',
148  help='Model to download/install.')
149  parser.add_argument('-i', '--install', action='store_true',
150  help='Install the model.')
151  parser.add_argument('-f', '--force', action='store_true',
152  help='Force a download/installation.')
153  args = parser.parse_args()
154  for model in args.model:
155  if validModelName(model):
156  downloadModel(model, args)
157  else:
158  print("'{model}' is not a valid model name.".format(model))