diff --git a/python/caffe/net_spec.py b/python/caffe/net_spec.py index 1b4814a45c6..b1cc0130b94 100644 --- a/python/caffe/net_spec.py +++ b/python/caffe/net_spec.py @@ -90,6 +90,16 @@ def to_proto(self): return to_proto(self) +class NoTop: + """ Special class for layers without a top """ + + def __init__(self, fn): + self.fn = fn + + def to_proto(self): + return to_proto(self) + + class Function(object): """A Function specifies a layer, its parameters, and its inputs (which are Tops from other layers).""" @@ -105,7 +115,10 @@ def __init__(self, type_name, inputs, params): self.in_place = self.params.get('in_place', False) if 'in_place' in self.params: del self.params['in_place'] - self.tops = tuple(Top(self, n) for n in range(self.ntop)) + if self.ntop == 0: + self.tops = (NoTop(self),) + else: + self.tops = tuple(Top(self, n) for n in range(self.ntop)) def _get_name(self, top, names, autonames): if top not in names: @@ -129,7 +142,8 @@ def _to_proto(self, layers, names, autonames): layer.top.extend(layer.bottom) else: for top in self.tops: - layer.top.append(self._get_name(top, names, autonames)) + if not isinstance(top,NoTop): + layer.top.append(self._get_name(top, names, autonames)) layer.name = self._get_name(self.tops[0], names, autonames) for k, v in six.iteritems(self.params): @@ -180,7 +194,7 @@ class Layers(object): def __getattr__(self, name): def layer_fn(*args, **kwargs): fn = Function(name, args, kwargs) - if fn.ntop == 1: + if fn.ntop <= 1: return fn.tops[0] else: return fn.tops