Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 7148c1f

Browse files
author
Luke Metz
committed
First commit of learning_unsupervised_learning
1 parent d640ab9 commit 7148c1f

21 files changed

+2623
-0
lines changed

CODEOWNERS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
/research/inception/ @shlens @vincentvanhoucke
1818
/research/learned_optimizer/ @olganw @nirum
1919
/research/learning_to_remember_rare_events/ @lukaszkaiser @ofirnachum
20+
/research/learning_unsupervised_learning/ @lukemetz @nirum
2021
/research/lexnet_nc/ @vered1986 @waterson
2122
/research/lfads/ @jazcollins @susillo
2223
/research/lm_1b/ @oriolvinyals @panyx0718

research/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ installation](https://www.tensorflow.org/install).
3636
- [inception](inception): deep convolutional networks for computer vision.
3737
- [learning_to_remember_rare_events](learning_to_remember_rare_events): a
3838
large-scale life-long memory module for use in deep learning.
39+
- [learning_unsupervised_learning](learning_unsupervised_learning): a
40+
meta-learned unsupervised learning update rule.
3941
- [lexnet_nc](lexnet_nc): a distributed model for noun compound relationship
4042
classification.
4143
- [lfads](lfads): sequential variational autoencoder for analyzing
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.pyc
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Learning Unsupervised Learning Rules
2+
This repository contains code and weights for the learned update rule
3+
presented in "Learning Unsupervised Learning Rules." At this time, this
4+
code can not meta-train the update rule.
5+
6+
7+
### Structure
8+
`run_eval.py` contains the main training loop. This constructs an op
9+
that runs one iteration of the learned update rule and assigns the
10+
results to variables. Additionally, it loads the weights from our
11+
pre-trained model.
12+
13+
The base model and the update rule architecture definition can be found in
14+
`architectures/more_local_weight_update.py`. For a complete description
15+
of the model, see our [paper](https://arxiv.org/abs/1804.00222).
16+
17+
### Dependencies
18+
[absl]([https://github.com/abseil/abseil-py), [tensorflow](https://tensorflow.org), [sonnet](https://github.com/deepmind/sonnet)
19+
20+
### Usage
21+
22+
First, download the [pre-trained optimizer model weights](https://storage.googleapis.com/learning_unsupervised_learning/200_tf_graph.zip) and extract it.
23+
24+
```bash
25+
# move to the folder above this folder
26+
cd path_to/research/learning_unsupervised_learning/../
27+
28+
# launch the eval script
29+
python -m learning_unsupervised_learning.run_eval \
30+
--train_log_dir="/tmp/learning_unsupervised_learning" \
31+
--checkpoint_dir="/path/to/downloaded/model/tf_graph_data.ckpt"
32+
```
33+
34+
### Contact
35+
Luke Metz, Niru Maheswaranathan, Github: @lukemetz, @nirum. Email: {lmetz, nirum}@google.com
36+
37+

research/learning_unsupervised_learning/__init__.py

Whitespace-only changes.
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright 2018 Google, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
17+
import more_local_weight_update
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright 2018 Google, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import sonnet as snt
22+
import tensorflow as tf
23+
import numpy as np
24+
import collections
25+
from learning_unsupervised_learning import utils
26+
27+
from tensorflow.python.util import nest
28+
29+
from learning_unsupervised_learning import variable_replace
30+
31+
32+
class LinearBatchNorm(snt.AbstractModule):
33+
"""Module that does a Linear layer then a BatchNorm followed by an activation fn"""
34+
def __init__(self, size, activation_fn=tf.nn.relu, name="LinearBatchNorm"):
35+
self.size = size
36+
self.activation_fn = activation_fn
37+
super(LinearBatchNorm, self).__init__(name=name)
38+
39+
def _build(self, x):
40+
x = tf.to_float(x)
41+
initializers={"w": tf.truncated_normal_initializer(stddev=0.01)}
42+
lin = snt.Linear(self.size, use_bias=False, initializers=initializers)
43+
z = lin(x)
44+
45+
scale = tf.constant(1., dtype=tf.float32)
46+
offset = tf.get_variable(
47+
"b",
48+
shape=[1, z.shape.as_list()[1]],
49+
initializer=tf.truncated_normal_initializer(stddev=0.1),
50+
dtype=tf.float32
51+
)
52+
53+
mean, var = tf.nn.moments(z, [0], keep_dims=True)
54+
z = ((z - mean) * tf.rsqrt(var + 1e-6)) * scale + offset
55+
56+
x_p = self.activation_fn(z)
57+
58+
return z, x_p
59+
60+
# This needs to work by string name sadly due to how the variable replace
61+
# works and would also work even if the custom getter approuch was used.
62+
# This is verbose, but it should atleast be clear as to what is going on.
63+
# TODO(lmetz) a better way to do this (the next 3 functions:
64+
# _raw_name, w(), b() )
65+
def _raw_name(self, var_name):
66+
"""Return just the name of the variable, not the scopes."""
67+
return var_name.split("/")[-1].split(":")[0]
68+
69+
70+
@property
71+
def w(self):
72+
var_list = snt.get_variables_in_module(self)
73+
w = [x for x in var_list if self._raw_name(x.name) == "w"]
74+
assert len(w) == 1
75+
return w[0]
76+
77+
@property
78+
def b(self):
79+
var_list = snt.get_variables_in_module(self)
80+
b = [x for x in var_list if self._raw_name(x.name) == "b"]
81+
assert len(b) == 1
82+
return b[0]
83+
84+
85+
86+
class Linear(snt.AbstractModule):
87+
def __init__(self, size, use_bias=True, init_const_mag=True):
88+
self.size = size
89+
self.use_bias = use_bias
90+
self.init_const_mag = init_const_mag
91+
super(Linear, self).__init__(name="commonLinear")
92+
93+
def _build(self, x):
94+
if self.init_const_mag:
95+
initializers={"w": tf.truncated_normal_initializer(stddev=0.01)}
96+
else:
97+
initializers={}
98+
lin = snt.Linear(self.size, use_bias=self.use_bias, initializers=initializers)
99+
z = lin(x)
100+
return z
101+
102+
# This needs to work by string name sadly due to how the variable replace
103+
# works and would also work even if the custom getter approuch was used.
104+
# This is verbose, but it should atleast be clear as to what is going on.
105+
# TODO(lmetz) a better way to do this (the next 3 functions:
106+
# _raw_name, w(), b() )
107+
def _raw_name(self, var_name):
108+
"""Return just the name of the variable, not the scopes."""
109+
return var_name.split("/")[-1].split(":")[0]
110+
111+
@property
112+
def w(self):
113+
var_list = snt.get_variables_in_module(self)
114+
if self.use_bias:
115+
assert len(var_list) == 2, "Found not 2 but %d" % len(var_list)
116+
else:
117+
assert len(var_list) == 1, "Found not 1 but %d" % len(var_list)
118+
w = [x for x in var_list if self._raw_name(x.name) == "w"]
119+
assert len(w) == 1
120+
return w[0]
121+
122+
@property
123+
def b(self):
124+
var_list = snt.get_variables_in_module(self)
125+
assert len(var_list) == 2, "Found not 2 but %d" % len(var_list)
126+
b = [x for x in var_list if self._raw_name(x.name) == "b"]
127+
assert len(b) == 1
128+
return b[0]
129+
130+
131+
def transformer_at_state(base_model, new_variables):
132+
"""Get the base_model that has been transformed to use the variables
133+
in final_state.
134+
Args:
135+
base_model: snt.Module
136+
Goes from batch to features
137+
new_variables: list
138+
New list of variables to use
139+
Returns:
140+
func: callable of same api as base_model.
141+
"""
142+
assert not variable_replace.in_variable_replace_scope()
143+
144+
def _feature_transformer(input_data):
145+
"""Feature transformer at the end of training."""
146+
initial_variables = base_model.get_variables()
147+
replacement = collections.OrderedDict(
148+
utils.eqzip(initial_variables, new_variables))
149+
with variable_replace.variable_replace(replacement):
150+
features = base_model(input_data)
151+
return features
152+
153+
return _feature_transformer

0 commit comments

Comments
 (0)