3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
12 from caffe2.python
import workspace
13 from caffe2.proto
import caffe2_pb2
17 log = logging.getLogger(
"memonger")
18 log.setLevel(logging.INFO)
19 LiveRange = collections.namedtuple(
'LiveRange', [
"defined",
"used",
"size"])
24 Implements similar optimization as Torch's shareGradInput(): 25 for the gradients that are passed between layers, share blobs between 26 operators when possible. This yields significant memory savings with 29 Returns an optimized protobuf (assign to net._net) 35 return "_grad" in name
and (name.startswith(namescope)
or 36 name.startswith(
"_" + namescope))
and name
not in param_grads
40 for b
in list(op.input) + list(op.output):
45 log.warn(
"NOTE: Executing memonger to optimize gradient memory")
48 if not namescope.endswith(
"/"):
51 netproto = copy.deepcopy(net.Proto())
52 grad_ops = [op
for op
in netproto.op
if is_grad_op(op)]
53 return _compute_blob_recycling_for_dag(
54 netproto, losses, grad_ops, is_grad_blob, namescope
58 def optimize_inference_for_dag(net, input_blobs, namescope=""):
59 netproto = copy.deepcopy(net.Proto())
60 external_input = set(net.Proto().external_input)
61 external_output = set(net.Proto().external_output)
63 def is_activation_blob(b):
64 return b
not in external_input
and b
not in external_output
66 seen_as_output = set()
67 ops = list(net.Proto().op)
73 if is_activation_blob(b)
and b
not in seen_as_output:
74 assert False,
"{} not in external input".format(b)
75 seen_as_output = seen_as_output.union(set(op.output))
76 assert not op.is_gradient_op, \
77 "You can only pass inference-only nets to optimize_inference_for_dag" 79 return _compute_blob_recycling_for_dag(
80 netproto, input_blobs, ops, is_activation_blob, namescope
84 def _compute_blob_recycling_for_dag(
85 netproto, heads, ops, is_shareable, namescope
88 Computes a blob recycling by traversing the computation DAG. The resulting 89 model can be executed safely on a DAGNet. 91 start_time = time.time()
94 blobs_to_ops = collections.defaultdict(
lambda: [])
95 blob_input_count = collections.defaultdict(
lambda: 0)
96 op_inputs = collections.defaultdict(
lambda: 0)
97 op_visit_count = collections.defaultdict(
lambda: 0)
99 for i, op
in enumerate(ops):
101 if is_shareable(inp)
or inp
in heads:
103 if inp
not in op.output:
104 blobs_to_ops[inp].append(i)
114 def descend(op_idx, free_blobs):
116 new_free_blobs = set()
117 for inp
in cur_op.input:
118 if is_shareable(inp):
119 blob_input_count[inp] += 1
120 if blob_input_count[inp] == len(blobs_to_ops[inp]):
121 actual_blob = inp
if inp
not in mapping
else mapping[inp]
122 new_free_blobs.add(actual_blob)
124 for outp
in cur_op.output:
125 if is_shareable(outp):
126 if outp
not in output_blobs:
128 for freeb
in free_blobs:
129 mapping[outp] = freeb
130 free_blobs.remove(freeb)
133 output_blobs.add(outp)
135 free_blobs.update(new_free_blobs)
138 for outp
in cur_op.output:
139 for inp_op_idx
in blobs_to_ops[outp]:
140 op_visit_count[inp_op_idx] += 1
143 if op_visit_count[inp_op_idx] == op_inputs[inp_op_idx]:
144 free_blobs_fwd = free_blobs
if first_branch
else set()
146 descend(inp_op_idx, free_blobs_fwd)
149 for head_blob
in heads:
150 for op_idx
in blobs_to_ops[head_blob]:
151 descend(op_idx, set())
154 shared_blobs = set(mapping.values())
156 for j, b
in enumerate(shared_blobs):
157 renamed[b] = namescope +
"__m{}_".format(j)
160 for k, v
in mapping.items():
161 mapping[k] = renamed[v]
164 mapping.update(renamed)
165 log.info(
"Remapping {} blobs, using {} shared".format(
166 len(mapping), len(renamed),
168 log.debug(
"Assignments: {}".format(mapping))
170 apply_assignments(netproto, mapping)
171 log.info(
"Memonger optimization took {} secs".format(
172 time.time() - start_time),
177 def _find_source_nodes(g):
178 ''' Return nodes without predecessors ''' 181 cur_pred = g.predecessors(cn)
187 def _find_target_nodes(g):
188 ''' Return nodes without successors ''' 191 cur_succ = g.successors(cn)
197 def _add_single_target_ifneeded(g):
198 targets = _find_target_nodes(g)
199 assert len(targets) >= 1
200 if len(targets) == 1:
202 ret = copy.deepcopy(g)
204 def _next_available_idx(g):
212 target_node_idx = _next_available_idx(g)
213 ret.add_node(target_node_idx)
215 ret.add_edge(cn, target_node_idx)
220 def _get_path(pred_list, dist_list):
221 ''' Get the path from nx.bellman_ford()'s output ''' 224 assert all(dist_list[x] <= 0
for x
in dist_list)
226 target = min(dist_list, key=
lambda x: dist_list[x])
230 while cur
is not None:
233 return list(reversed(ret))
236 def _get_longest_paths(g, source_nodes):
237 ''' Get the longest path for nodes in 'source_nodes' 238 Find with bellman_ford() by setting weight = -1 241 ng = copy.deepcopy(g)
242 for u, v
in ng.edges():
243 ng[u][v][
"weight"] = -1
246 for cn
in source_nodes:
247 pred, dist = nx.bellman_ford(ng, cn, weight=
"weight")
248 path = _get_path(pred, dist)
250 assert len(path) - 1 == -dist[path[-1]]
256 def _build_tree(paths):
257 ''' Build a tree for given paths based on common elements. 258 Last elements of all paths are the same, which is the root of the tree. 260 assert all(cp[-1] == paths[0][-1]
for cp
in paths)
262 node_set = {y
for x
in paths
for y
in x}
263 g.add_nodes_from(node_set)
265 for ce
in zip(cp[0:-1], cp[1:]):
266 g.add_edge(ce[1], ce[0])
269 _compute_tree_height(g, root)
274 def _compute_tree_height(g, root):
275 ''' Compute the heights of the tree for all nodes 276 Height of leaves are 0 278 def _get_height(root):
279 children = g.successors(root)
282 child_heights = [_get_height(x)
for x
in children]
283 height = max(child_heights) + 1
284 g.node[root][
"height"] = height
290 def _sort_tree_leaves(g, root):
291 ''' For each node, sort its child nodes based on the height of the nodes. 292 Return the leaf nodes of the tree after sorting. 294 def _get_height(root):
295 return g.node[root][
"height"]
297 def _get_sorted_leaves(root):
298 children = g.successors(root)
301 child_heights = [_get_height(x)
for x
in children]
302 order = sorted(range(len(children)), key=
lambda x: child_heights[x])
306 ret += _get_sorted_leaves(cr)
310 return _get_sorted_leaves(root)
314 ''' The graph 'g' may contain several source nodes (nodes without incoming 315 edge), which could have be in any order and still being a valid 316 topoligical sorting result. We would like to arrange these source nodes 317 so that the average live spans of the computed blobs are shorter. 318 The idea is to sort the source nodes based on the length of their path to 319 the target node so that the one with longer path is used first. 321 - Add a single target node if there are multiple target nodes in 'g'. 322 - Find the longest path between each source and the target node. 323 - Convert the longest paths to a tree with the target node being the root 324 and source nodes being the leaves. 325 - Sort the nodes of the tree based on the height of the tree. 327 gt = _add_single_target_ifneeded(g)
328 source_nodes = _find_source_nodes(gt)
329 lpaths = _get_longest_paths(gt, source_nodes)
330 tree, root = _build_tree(lpaths.values())
331 sorted_sources = _sort_tree_leaves(tree, root)
332 assert(sorted(sorted_sources) == sorted(source_nodes))
334 ret = nx.topological_sort(g, sorted_sources)
335 assert(len(ret) == len(g.node))
339 def topological_sort_traversal(g):
340 return nx.topological_sort(g)
343 def compute_ranges(linearized_ops, blob_sizes=None):
345 log.warning(
'Provide blob sizes to get more accurate assignments.')
347 blobs = collections.defaultdict(
348 lambda: LiveRange(defined=
None, used=
None, size=
None))
349 for i, op
in enumerate(linearized_ops):
350 for blob
in op.input:
351 used = blobs[blob].used
356 blobs[blob] = blobs[blob]._replace(used=used)
357 blob_size = blob_sizes[blob]
if blob_sizes
else None 358 assert not blob_sizes
or blob_size
is not None 359 blobs[blob] = blobs[blob]._replace(size=blob_size)
360 for blob
in op.output:
361 defined = blobs[blob].defined
365 defined = min(defined, i)
366 blobs[blob] = blobs[blob]._replace(defined=defined)
367 blob_size = blob_sizes[blob]
if blob_sizes
else None 368 assert not blob_sizes
or blob_size
is not None 369 blobs[blob] = blobs[blob]._replace(size=blob_size)
374 def is_compatible(candidate_range, assignment, static_blobs):
375 (name, range_) = assignment[-1]
376 if name
in static_blobs:
378 if candidate_range.defined
is None or range_.defined
is None \
379 or range_.used
is None:
381 return candidate_range.defined > range_.used
384 def compute_blob_assignments(assignments):
385 blob_assignments = {}
386 for assignment
in assignments:
387 if len(assignment) == 1:
389 last_blob, _ = assignment[-1]
390 for (blob, _)
in assignment:
391 blob_assignments[blob] = last_blob
392 return blob_assignments
395 def _get_max_size(assignment):
398 ret = max([x[1].size
for x
in assignment])
399 ret = 0
if ret
is None else ret
403 def get_memory_usage(assignments):
405 for cur
in assignments:
406 ret += _get_max_size(cur)
410 def compute_assignments_greedy(ranges_sorted, init_assignments=None):
411 assignments = init_assignments
or []
412 visited = {y[0]
for x
in assignments
for y
in x}
414 for (name, range_)
in ranges_sorted:
419 min_dist = float(
"inf")
420 candidate_size = range_.size
or 0
421 for idx, assignment
in enumerate(assignments):
422 if is_compatible(range_, assignment, []):
424 dist = abs(_get_max_size(assignment) - candidate_size)
427 best_assignment = idx
429 assignment = assignments[best_assignment]
430 assignment.append((name, range_))
432 assignments.append([(name, range_)])
436 def _get_count(assignments):
437 ''' Return number of blobs in assignments ''' 439 return sum([len(x)
for x
in assignments])
444 ''' Compute assignment for blobs in 'ranges_sorted' on top of 'init_assignment' 445 using dynamic programming + recursion. 447 ranges_sorted: blobs sorted by 'used' 448 init_assignment: assignment to start with, blobs in 'ranges_sorted' should 449 not be used in 'init_assignment' 451 Using f(b, k, init) to represent the best assignment for blobs b[0:k] 452 given initial assignment 'init', we have 453 f(b, k, init) = f(b, j, init) + 454 find_best(b[j:k], f(b, j, init)) 455 where j is the index of the last best assignment that is independent of 456 blob b[k - 1] (b[k - 1] is compatible with all assignments in 457 f(b, j, init)), and find_best(b1, init1) gives the best assignment 458 for blobs in 'b1' based on the initial assignment 'init1', and blobs 459 b1[0:-1] should be incompatible with with b1[-1]. f(b, len(b), []) gives 460 the best assignment for blobs 'b'. 462 For find_best(b, init), since b[0:-1] are not compatible with b[-1], we 463 could reduce it to a smaller problem to find best assignment for b[0:-1] 465 find_best(b, init) = min { 466 f(b[0:-1], len(b) - 1, init - x) + [x, b[-1]] for x in init, or 467 f(b[0:-1], len(b) - 1, init) + [b[-1]] 469 where min{} gives the assignment with minimum memory usage. 472 def _get_compatible_prev(candidate_range, best_assignments, cur_idx):
473 ''' Find closest position k of best_assignments that is independent of 474 candidate_range that candiate_range is compatible with all assignments 475 in best_assignments[k]. 476 Return -1 if not found. 478 def is_compatible_all(candidate_range, assignments):
479 ''' return true if compatiable for all assignments in assignments ''' 480 return all([is_compatible(candidate_range[1], x, [])
for x
in assignments])
484 cba = best_assignments[ii]
485 if is_compatible_all(candidate_range, cba):
490 def _find_best(ranges, init_assignment, prev_best_assignment, counter):
491 ''' Find the best assignment for blobs 'ranges' given an initialized 492 assignment 'init_assignment'. 494 Blobs in ranges[0:-1] should be incompatible with blob range[-1]. 495 'prev_best_assignment': best assignment for blobs in ranges[:-1] 497 By assigning ranges[-1] to each assignment k in 'init_assignment' or 498 in a new assignment, the problem becomes a smaller problem to find 499 the best assignment for ranges[0:-1] given the initial assignment 500 init_assigment[0:k, (k+1):-1]. 503 find_range = ranges[-1]
506 assert all(
not is_compatible(x[1], [find_range], [])
for x
in ranges[0:-1])
508 sz = len(init_assignment)
512 if not is_compatible(find_range[1], init_assignment[ii], []):
514 cur_best = copy.deepcopy(init_assignment)
515 cur_best[ii].append(find_range)
517 cur_best_tmp = [x
for i, x
in enumerate(cur_best)
if i != ii]
520 ranges[:-1], cur_best_tmp, counter)
521 cur_best = cur_best_tmp + [cur_best[ii]]
522 best_candidates.append(cur_best)
524 best_candidates.append(prev_best_assignment + [[find_range]])
526 ret = min(best_candidates, key=
lambda x: get_memory_usage(x))
533 if counter
and counter[0] % 5000 == 0:
534 rs = [ranges_sorted[0][1].defined, ranges_sorted[-1][1].used]
535 log.info(
'Finding assignments {} ({} -> {})...'.format(
536 counter[0], rs[0], rs[1]))
538 init_assignment = init_assignment
or []
540 best_assignments = []
542 for ii, cur_range
in enumerate(ranges_sorted):
544 prev_idx = _get_compatible_prev(cur_range, best_assignments, ii)
545 prev_best = copy.deepcopy(init_assignment)
if prev_idx < 0
else \
546 copy.deepcopy(best_assignments[prev_idx])
548 ranges_part = ranges_sorted[(prev_idx + 1):(ii + 1)]
549 cur_best = _find_best(
550 ranges_part, prev_best,
551 best_assignments[-1]
if best_assignments
else init_assignment,
553 assert _get_count(cur_best) == _get_count(prev_best) + len(ranges_part)
554 best_assignments.append(copy.deepcopy(cur_best))
556 assert len(best_assignments) == len(ranges_sorted)
558 best = best_assignments[-1]
564 ''' Set LiveRange.defined = -1 if it is None 565 Set LiveRange.used = max_live if it is None 566 Set LiveRanee.size = 1 if it is None 569 def _get_max_live(ranges):
570 max_live = max(x[1].used
for x
in ranges
if x[1].used) + 1
573 def _update_range(x, max_live, size):
575 if x[1].defined
is None:
576 cx = (cx[0], cx[1]._replace(defined=-1))
577 if x[1].used
is None:
578 cx = (cx[0], cx[1]._replace(used=max_live))
579 if x[1].size
is None:
580 cx = (cx[0], cx[1]._replace(size=size))
584 max_live = _get_max_live(ranges)
585 ranges = [_update_range(x, max_live, 1)
for x
in ranges]
592 algo: Method used to find assignments (AssignmentAlgorithm.GREEDY or 593 AssignmentAlgorithm.DYNAMIC_PROGRAMMING). 594 AssignmentAlgorithm.DYNAMIC_PROGRAMMING gives optimal soultion at the 595 cost of more computation. 596 AssignmentAlgorithm.GREEDY may be better in the case 'blob_sizes' is 605 list(ranges.items()),
606 key=
lambda p: (p[1].used
is None, p[1].used),
612 ranges_sharable = [x
for x
in ranges
if x[0]
not in static_blobs]
614 ranges_static = [x
for x
in ranges
if x[0]
in static_blobs]
616 log.info(
"Total sharable blobs {}".format(len(ranges_sharable)))
619 if algo == AssignmentAlgorithm.DYNAMIC_PROGRAMMING:
621 elif algo == AssignmentAlgorithm.GREEDY:
622 best_assignment = compute_assignments_greedy(ranges_sharable, [])
624 assert "Invalid algo name {}".format(algo)
625 best_assignment += [[x]
for x
in ranges_static]
629 return best_assignment
632 def verify_assignments(assignments):
633 for cur
in assignments:
634 for x, y
in zip(cur[0:-1], cur[1:]):
635 assert x[1].used < y[1].defined
638 def compute_interference_graph(ops):
640 for i, op
in enumerate(ops):
642 for i, parent_op
in enumerate(ops):
643 for j, child_op
in enumerate(ops):
646 if any(output
in child_op.input
for output
in parent_op.output):
647 deps = set(child_op.input).intersection(parent_op.output)
648 g.add_edge(i, j, deps=deps)
649 assert nx.is_directed_acyclic_graph(g), child_op
653 Optimization = collections.namedtuple(
654 'Optimization', [
'net',
'assignments',
'blob_assignments'])
657 def apply_assignments(net, blob_assignments):
658 def canonical_name(blob):
659 if blob
not in blob_assignments:
661 return blob_assignments[blob]
665 if op.type.startswith(
'RecurrentNetwork'):
666 apply_recurrent_blob_assignments(op, blob_assignments, canonical_name)
668 for i, input_
in enumerate(op.input):
669 op.input[i] = canonical_name(input_)
670 for i, output
in enumerate(op.output):
671 op.output[i] = canonical_name(output)
674 def apply_recurrent_blob_assignments(op, blob_assignments, canonical_name):
675 log.debug(
"Applying assignments to recurrent op: {}".format(op.type))
676 import google.protobuf.text_format
as protobuftx
677 step_args = [a
for a
in op.arg
if a.name.endswith(
"step_net")]
678 for step_arg
in step_args:
679 step_proto = caffe2_pb2.NetDef()
680 protobuftx.Merge(step_arg.s, step_proto)
681 apply_assignments(step_proto, blob_assignments)
682 for i, einp
in enumerate(step_proto.external_input):
683 if einp
in blob_assignments:
684 step_proto.external_input[i] = canonical_name(einp)
685 step_arg.s = str(step_proto)
687 for blob, renamed
in blob_assignments.items():
688 if blob
in list(op.input) + list(op.output):
689 a = caffe2_pb2.Argument()
690 a.name = blob +
".rename" 697 DYNAMIC_PROGRAMMING = 1
701 ordering_function=topological_sort_traversal,
703 algo=AssignmentAlgorithm.GREEDY):
705 ordering_function: topological_sort_traversal or 706 topological_sort_traversal_longest_path. 707 topological_sort_traversal_longest_path gives better 708 results but needs a bit more computation. 709 algo: Method used to find assignments (AssignmentAlgorithm.GREEDY or 710 AssignmentAlgorithm.DYNAMIC_PROGRAMMING). 711 AssignmentAlgorithm.DYNAMIC_PROGRAMMING gives optimal soultion at the 712 cost of more computation. 713 AssignmentAlgorithm.GREEDY may be better in the case 'blob_sizes' is 718 1) Use a BFS traversal of the execution graph to generate an 719 ordering of the node executions. 720 2) Generate use-def ranges for each `blob` in the BFS traversal 722 3) Assign blobs to `canonical blobs` 723 4) Rename blobs to canonical blobs 725 net = copy.deepcopy(net)
726 g = compute_interference_graph(net.op)
727 ordering = ordering_function(g)
728 linearized_ops = [net.op[i]
for i
in ordering]
735 net.op.extend(linearized_ops)
737 ranges = compute_ranges(linearized_ops, blob_sizes)
739 blob_assignments = compute_blob_assignments(assignments)
740 apply_assignments(net, blob_assignments)
743 blob_assignments=blob_assignments,
744 assignments=assignments)
747 Statistics = collections.namedtuple(
748 'Statistics', [
'baseline_nbytes',
'optimized_nbytes'])
751 def compute_statistics(assignments):
752 def blob_nbytes(blob):
755 blob: blob_nbytes(blob)
for assignment
in assignments
756 for (blob, _)
in assignment}
757 baseline_nbytes = sum(v
for _, v
in blob_bytes.items())
758 optimized_nbytes = sum(
759 max(blob_bytes[blob]
for (blob, _)
in assignment)
760 for assignment
in assignments)
762 baseline_nbytes=baseline_nbytes,
763 optimized_nbytes=optimized_nbytes)
767 ''' College blob sizes from worksapce ''' 768 def blob_nbytes(blob):
773 for blob
in op.input:
774 blobs[blob] = blob_nbytes(blob)
775 for blob
in op.output:
776 blobs[blob] = blob_nbytes(blob)
def collect_blob_sizes(net)
def share_grad_blobs(net, losses, param_grads, namescope)
def compute_assignments(ranges, static_blobs, algo)
def compute_assignments_dp(ranges_sorted, init_assignment, counter=None)
def optimize_interference(net, static_blobs, ordering_function=topological_sort_traversal, blob_sizes=None, algo=AssignmentAlgorithm.GREEDY)
def get_updated_ranges(ranges, max_live=None)
def topological_sort_traversal_longest_path(g)